From da787d7d4e79f00835795f3bf53da3d13881bc93 Mon Sep 17 00:00:00 2001 From: Ford Date: Wed, 12 Nov 2025 10:53:31 -0800 Subject: [PATCH 1/9] feat(adminClient): Add auth support for admin requests - Add auth module with models and services - Use new auth for AdminClient requests --- src/amp/admin/client.py | 24 ++++- src/amp/auth/__init__.py | 10 ++ src/amp/auth/models.py | 41 +++++++ src/amp/auth/service.py | 227 +++++++++++++++++++++++++++++++++++++++ src/amp/client.py | 9 +- 5 files changed, 305 insertions(+), 6 deletions(-) create mode 100644 src/amp/auth/__init__.py create mode 100644 src/amp/auth/models.py create mode 100644 src/amp/auth/service.py diff --git a/src/amp/admin/client.py b/src/amp/admin/client.py index 27a6d61..75292e9 100644 --- a/src/amp/admin/client.py +++ b/src/amp/admin/client.py @@ -20,21 +20,39 @@ class AdminClient: Args: base_url: Base URL for Admin API (e.g., 'http://localhost:8080') auth_token: Optional Bearer token for authentication + auth: If True, load auth token from ~/.amp-cli-config (shared with TS CLI) Example: - >>> client = AdminClient('http://localhost:8080') - >>> datasets = client.datasets.list_all() + >>> # Use amp auth system + >>> client = AdminClient('http://localhost:8080', auth=True) + >>> + >>> # Or use manual token + >>> client = AdminClient('http://localhost:8080', auth_token='your-token') """ - def __init__(self, base_url: str, auth_token: Optional[str] = None): + def __init__(self, base_url: str, auth_token: Optional[str] = None, auth: bool = False): """Initialize Admin API client. Args: base_url: Base URL for Admin API (e.g., 'http://localhost:8080') auth_token: Optional Bearer token for authentication + auth: If True, load auth token from ~/.amp-cli-config + + Raises: + ValueError: If both auth=True and auth_token are provided """ + if auth and auth_token: + raise ValueError('Cannot specify both auth=True and auth_token. Choose one authentication method.') + self.base_url = base_url.rstrip('/') + # Load token from amp auth system if requested + if auth: + from amp.auth import AuthService + + auth_service = AuthService() + auth_token = auth_service.get_token() + # Build headers headers = {} if auth_token: diff --git a/src/amp/auth/__init__.py b/src/amp/auth/__init__.py new file mode 100644 index 0000000..1e617c5 --- /dev/null +++ b/src/amp/auth/__init__.py @@ -0,0 +1,10 @@ +"""Authentication module for amp Python client. + +Provides Privy authentication support compatible with the TypeScript CLI. +Reads and manages auth tokens from ~/.amp-cli-config. +""" + +from .models import AuthStorage, RefreshTokenResponse +from .service import AuthService + +__all__ = ['AuthService', 'AuthStorage', 'RefreshTokenResponse'] diff --git a/src/amp/auth/models.py b/src/amp/auth/models.py new file mode 100644 index 0000000..925e2e8 --- /dev/null +++ b/src/amp/auth/models.py @@ -0,0 +1,41 @@ +"""Data models for Privy authentication. + +Models match the TypeScript CLI schema for compatibility. +""" + +from typing import List, Optional + +from pydantic import BaseModel, Field + + +class AuthStorage(BaseModel): + """Auth storage schema for ~/.amp-cli-config file. + + Matches the TypeScript AuthStorageSchema so they can share auth. + """ + + accessToken: str = Field(..., description='Access token for authenticated requests') + refreshToken: str = Field(..., description='Refresh token for renewing access') + userId: str = Field(..., description="User's Privy DID (format: did:privy:c...)") + accounts: Optional[List[str]] = Field(None, description='List of connected accounts/wallets') + expiry: Optional[int] = Field(None, description='Token expiry timestamp in milliseconds') + + +class RefreshTokenResponseUser(BaseModel): + """User information in the refresh token response.""" + + id: str = Field(..., description='User ID (Privy DID)') + accounts: List[str] = Field(..., description='List of connected accounts/wallets') + + +class RefreshTokenResponse(BaseModel): + """Response from the token refresh endpoint. + + Matches the TypeScript RefreshTokenResponse schema. + """ + + token: str = Field(..., description='The refreshed access token') + refresh_token: Optional[str] = Field(None, description='New refresh token (if rotated)') + session_update_action: str = Field(..., description='Session update action') + expires_in: int = Field(..., description='Seconds until token expires') + user: RefreshTokenResponseUser = Field(..., description='User information') diff --git a/src/amp/auth/service.py b/src/amp/auth/service.py new file mode 100644 index 0000000..1b56bbb --- /dev/null +++ b/src/amp/auth/service.py @@ -0,0 +1,227 @@ +"""Auth service for managing Privy authentication. + +Handles loading, refreshing, and persisting auth tokens from ~/.amp-cli-config. +Compatible with the TypeScript CLI authentication system. +""" + +import json +import time +from pathlib import Path +from typing import Optional + +import httpx + +from .models import AuthStorage, RefreshTokenResponse + +# Auth platform URL (matches TypeScript implementation) +AUTH_PLATFORM_URL = 'https://auth.amp.edgeandnode.com/' + +# Storage location (matches TypeScript implementation) +# TypeScript CLI uses: ~/.amp-cli-config/amp_cli_auth (directory with file inside) +AUTH_CONFIG_DIR = Path.home() / '.amp-cli-config' +AUTH_CONFIG_FILE = AUTH_CONFIG_DIR / 'amp_cli_auth' + + +class AuthService: + """Service for managing Privy authentication tokens. + + Loads tokens from ~/.amp-cli-config (shared with TypeScript CLI), + automatically refreshes expired tokens, and persists updates. + + Example: + >>> auth = AuthService() + >>> if auth.is_authenticated(): + ... token = auth.get_token() # Auto-refreshes if needed + """ + + def __init__(self, config_path: Optional[Path] = None): + """Initialize auth service. + + Args: + config_path: Optional custom path to config file (defaults to ~/.amp-cli-config/amp_cli_auth) + """ + self.config_path = config_path or AUTH_CONFIG_FILE + self._http = httpx.Client(timeout=30.0) + + def is_authenticated(self) -> bool: + """Check if user is authenticated. + + Returns: + True if valid auth exists in ~/.amp-cli-config + """ + try: + auth = self.load_auth() + return auth is not None + except Exception: + return False + + def load_auth(self) -> Optional[AuthStorage]: + """Load auth from ~/.amp-cli-config/amp_cli_auth file. + + Returns: + AuthStorage if found, None if not authenticated + + Raises: + FileNotFoundError: If config file doesn't exist + json.JSONDecodeError: If config file is invalid JSON + ValueError: If auth data is invalid + """ + if not self.config_path.exists(): + return None + + with open(self.config_path) as f: + auth_data = json.load(f) + + return AuthStorage.model_validate(auth_data) + + def save_auth(self, auth: AuthStorage) -> None: + """Save auth to ~/.amp-cli-config/amp_cli_auth file. + + Args: + auth: Auth data to persist + + Raises: + IOError: If unable to write to config file + """ + # Ensure directory exists + self.config_path.parent.mkdir(parents=True, exist_ok=True) + + # Write auth data directly to file (no wrapper object) + with open(self.config_path, 'w') as f: + json.dump(auth.model_dump(mode='json', exclude_none=False), f, indent=2) + + def get_token(self) -> str: + """Get valid access token, refreshing if needed. + + Returns: + Valid access token string + + Raises: + FileNotFoundError: If not authenticated (no ~/.amp-cli-config) + ValueError: If auth data is invalid or refresh fails + """ + auth = self.load_auth() + if not auth: + raise FileNotFoundError( + 'Not authenticated. Please run authentication first.\n' + "Use the TypeScript CLI 'amp auth login' or authenticate via the Python client." + ) + + # Check if we need to refresh the token + needs_refresh = self._needs_refresh(auth) + + if needs_refresh: + auth = self.refresh_token(auth) + + return auth.accessToken + + def _needs_refresh(self, auth: AuthStorage) -> bool: + """Check if token needs to be refreshed. + + Token needs refresh if: + - Missing expiry field (old format, need to refresh to populate) + - Missing accounts field (old format, need to refresh to populate) + - Token is expired + - Token is expiring within 5 minutes + + Args: + auth: Auth storage to check + + Returns: + True if token needs refresh + """ + # Missing expiry or accounts - refresh to populate + if auth.expiry is None or auth.accounts is None: + return True + + # Get current time in milliseconds (matching TypeScript) + now_ms = int(time.time() * 1000) + + # Token is expired + if auth.expiry < now_ms: + return True + + # Token is expiring within 5 minutes + five_minutes_ms = 5 * 60 * 1000 + if auth.expiry - now_ms <= five_minutes_ms: + return True + + return False + + def refresh_token(self, auth: AuthStorage) -> AuthStorage: + """Refresh an expired or expiring access token. + + Args: + auth: Current auth storage with refresh token + + Returns: + Updated auth storage with new tokens + + Raises: + httpx.HTTPStatusError: If refresh request fails + ValueError: If response validation fails or user ID mismatch + """ + # Build refresh request (matches TypeScript implementation) + url = f'{AUTH_PLATFORM_URL}api/v1/auth/refresh' + headers = { + 'Authorization': f'Bearer {auth.accessToken}', + 'Content-Type': 'application/json', + 'Accept': 'application/json', + } + body = {'refresh_token': auth.refreshToken, 'user_id': auth.userId} + + # Make request + response = self._http.post(url, headers=headers, json=body) + + # Handle errors + if response.status_code == 401 or response.status_code == 403: + raise ValueError('Token refresh failed: Authentication expired. Please log in again.') + + if response.status_code == 429: + retry_after = response.headers.get('retry-after', '60') + raise ValueError(f'Token refresh rate limited. Retry after {retry_after} seconds.') + + if response.status_code != 200: + error_msg = 'Failed to refresh token' + try: + error_data = response.json() + if 'error_description' in error_data: + error_msg = error_data['error_description'] + except Exception: + pass + raise ValueError(f'{error_msg} (status: {response.status_code})') + + # Parse response + refresh_response = RefreshTokenResponse.model_validate(response.json()) + + # Validate user ID matches (security check) + if refresh_response.user.id != auth.userId: + raise ValueError( + f'User ID mismatch after refresh. Expected {auth.userId}, got {refresh_response.user.id}' + ) + + # Calculate new expiry + now_ms = int(time.time() * 1000) + expiry_ms = now_ms + (refresh_response.expires_in * 1000) + + # Create updated auth storage + refreshed_auth = AuthStorage( + accessToken=refresh_response.token, + refreshToken=refresh_response.refresh_token or auth.refreshToken, + userId=refresh_response.user.id, + accounts=refresh_response.user.accounts, + expiry=expiry_ms, + ) + + # Persist updated tokens + self.save_auth(refreshed_auth) + + return refreshed_auth + + def __enter__(self): + """Context manager entry.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit.""" + self._http.close() diff --git a/src/amp/client.py b/src/amp/client.py index 2af91ee..d5aced2 100644 --- a/src/amp/client.py +++ b/src/amp/client.py @@ -238,15 +238,17 @@ class Client: query_url: Query endpoint URL via Flight SQL (e.g., 'grpc://localhost:1602') admin_url: Optional Admin API URL (e.g., 'http://localhost:8080') auth_token: Optional Bearer token for Admin API authentication + auth: If True, load auth token from ~/.amp-cli-config (shared with TS CLI) Example: >>> # Query-only client (backward compatible) >>> client = Client(url='grpc://localhost:1602') >>> - >>> # Client with admin capabilities + >>> # Client with admin capabilities and amp auth >>> client = Client( ... query_url='grpc://localhost:1602', - ... admin_url='http://localhost:8080' + ... admin_url='http://localhost:8080', + ... auth=True ... ) """ @@ -256,6 +258,7 @@ def __init__( query_url: Optional[str] = None, admin_url: Optional[str] = None, auth_token: Optional[str] = None, + auth: bool = False, ): # Backward compatibility: url parameter → query_url if url and not query_url: @@ -276,7 +279,7 @@ def __init__( if admin_url: from amp.admin.client import AdminClient - self._admin_client = AdminClient(admin_url, auth_token) + self._admin_client = AdminClient(admin_url, auth_token=auth_token, auth=auth) else: self._admin_client = None From cd64269d24ccb447cb4d8173654cfcd1f57e8012 Mon Sep 17 00:00:00 2001 From: Ford Date: Wed, 12 Nov 2025 14:39:14 -0800 Subject: [PATCH 2/9] feat(client): Add auth support to query client (FlightSQL gRPC) --- src/amp/client.py | 50 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/src/amp/client.py b/src/amp/client.py index d5aced2..0f090b2 100644 --- a/src/amp/client.py +++ b/src/amp/client.py @@ -4,6 +4,7 @@ import pyarrow as pa from google.protobuf.any_pb2 import Any from pyarrow import flight +from pyarrow.flight import ClientMiddleware, ClientMiddlewareFactory from . import FlightSql_pb2 from .config.connection_manager import ConnectionManager @@ -19,6 +20,38 @@ ) +class AuthMiddleware(ClientMiddleware): + """Flight middleware to add Bearer token authentication header.""" + + def __init__(self, token: str): + """Initialize auth middleware. + + Args: + token: Bearer token to add to requests + """ + self.token = token + + def sending_headers(self): + """Add Authorization header to outgoing requests.""" + return {'authorization': f'Bearer {self.token}'} + + +class AuthMiddlewareFactory(ClientMiddlewareFactory): + """Factory for creating auth middleware instances.""" + + def __init__(self, token: str): + """Initialize auth middleware factory. + + Args: + token: Bearer token to use for authentication + """ + self.token = token + + def start_call(self, info): + """Create auth middleware for each call.""" + return AuthMiddleware(self.token) + + class QueryBuilder: """Chainable query builder for data loading operations. @@ -264,9 +297,24 @@ def __init__( if url and not query_url: query_url = url + # Get auth token if using amp auth system + flight_auth_token = None + if auth and not auth_token: + from amp.auth import AuthService + + auth_service = AuthService() + flight_auth_token = auth_service.get_token() + elif auth_token: + flight_auth_token = auth_token + # Initialize Flight SQL client if query_url: - self.conn = flight.connect(query_url) + # Add auth middleware if token is provided + if flight_auth_token: + middleware = [AuthMiddlewareFactory(flight_auth_token)] + self.conn = flight.connect(query_url, middleware=middleware) + else: + self.conn = flight.connect(query_url) else: raise ValueError('Either url or query_url must be provided for Flight SQL connection') From 9e58bb67f929a89bc9ce863d6c43137e3680178a Mon Sep 17 00:00:00 2001 From: Ford Date: Wed, 12 Nov 2025 18:36:05 -0800 Subject: [PATCH 3/9] feat(auth): Add interactive device auth flow via browser --- src/amp/auth/__init__.py | 3 +- src/amp/auth/device_flow.py | 289 ++++++++++++++++++++++++++++++++++++ src/amp/auth/service.py | 33 +++- 3 files changed, 321 insertions(+), 4 deletions(-) create mode 100644 src/amp/auth/device_flow.py diff --git a/src/amp/auth/__init__.py b/src/amp/auth/__init__.py index 1e617c5..8e456c7 100644 --- a/src/amp/auth/__init__.py +++ b/src/amp/auth/__init__.py @@ -4,7 +4,8 @@ Reads and manages auth tokens from ~/.amp-cli-config. """ +from .device_flow import interactive_device_login from .models import AuthStorage, RefreshTokenResponse from .service import AuthService -__all__ = ['AuthService', 'AuthStorage', 'RefreshTokenResponse'] +__all__ = ['AuthService', 'AuthStorage', 'RefreshTokenResponse', 'interactive_device_login'] diff --git a/src/amp/auth/device_flow.py b/src/amp/auth/device_flow.py new file mode 100644 index 0000000..33bb659 --- /dev/null +++ b/src/amp/auth/device_flow.py @@ -0,0 +1,289 @@ +"""OAuth2 Device Authorization Flow for Privy authentication. + +Implements the device authorization grant flow with PKCE for secure authentication. +Matches the TypeScript CLI implementation. +""" + +import base64 +import hashlib +import secrets +import time +import webbrowser +from typing import Optional, Tuple + +import httpx +from pydantic import BaseModel, Field + +from .models import AuthStorage +from .service import AUTH_PLATFORM_URL + + +class DeviceAuthorizationResponse(BaseModel): + """Response from device authorization endpoint.""" + + device_code: str = Field(..., description='Device verification code for polling') + user_code: str = Field(..., description='Code for user to enter in browser') + verification_uri: str = Field(..., description='URL where user enters the code') + expires_in: int = Field(..., description='Seconds until device code expires') + interval: int = Field(..., description='Minimum polling interval in seconds') + + +class DeviceTokenResponse(BaseModel): + """Response from device token endpoint (success case).""" + + access_token: str = Field(..., description='Access token for authenticated requests') + refresh_token: str = Field(..., description='Refresh token for renewing access') + user_id: str = Field(..., description='Authenticated user ID') + user_accounts: list[str] = Field(..., description='List of user accounts/wallets') + expires_in: int = Field(..., description='Seconds until token expires') + + +class DeviceTokenPendingResponse(BaseModel): + """Response when authorization is still pending.""" + + error: str = Field('authorization_pending', description='Error code') + + +class DeviceTokenExpiredResponse(BaseModel): + """Response when device code has expired.""" + + error: str = Field('expired_token', description='Error code') + + +def generate_pkce_pair() -> Tuple[str, str]: + """Generate PKCE code_verifier and code_challenge. + + Returns: + Tuple of (code_verifier, code_challenge) + """ + # Generate cryptographically random code_verifier + # Must be 43-128 characters using unreserved chars [A-Za-z0-9-._~] + code_verifier_bytes = secrets.token_bytes(32) + code_verifier = base64.urlsafe_b64encode(code_verifier_bytes).decode('utf-8').rstrip('=') + + # Generate code_challenge = BASE64URL(SHA256(code_verifier)) + challenge_bytes = hashlib.sha256(code_verifier.encode('utf-8')).digest() + code_challenge = base64.urlsafe_b64encode(challenge_bytes).decode('utf-8').rstrip('=') + + return code_verifier, code_challenge + + +def request_device_authorization(http_client: httpx.Client) -> Tuple[DeviceAuthorizationResponse, str]: + """Request device authorization from auth platform. + + Args: + http_client: HTTP client to use for request + + Returns: + Tuple of (DeviceAuthorizationResponse, code_verifier) + + Raises: + httpx.HTTPStatusError: If request fails + ValueError: If response is invalid + """ + # Generate PKCE parameters + code_verifier, code_challenge = generate_pkce_pair() + + # Request device authorization + url = f'{AUTH_PLATFORM_URL}api/v1/device/authorize' + response = http_client.post( + url, json={'code_challenge': code_challenge, 'code_challenge_method': 'S256'}, timeout=30.0 + ) + + if response.status_code != 200: + raise ValueError(f'Device authorization failed: {response.status_code} - {response.text}') + + device_auth = DeviceAuthorizationResponse.model_validate(response.json()) + return device_auth, code_verifier + + +def poll_for_token(http_client: httpx.Client, device_code: str, code_verifier: str) -> Optional[DeviceTokenResponse]: + """Poll device token endpoint once. + + Args: + http_client: HTTP client to use for request + device_code: Device code from authorization response + code_verifier: PKCE code verifier + + Returns: + DeviceTokenResponse if auth complete, None if still pending + + Raises: + ValueError: If device code expired or other error + """ + url = f'{AUTH_PLATFORM_URL}api/v1/device/token' + response = http_client.get(url, params={'device_code': device_code, 'code_verifier': code_verifier}, timeout=10.0) + + data = response.json() + + # Check for error responses + if 'error' in data: + error = data['error'] + if error == 'authorization_pending': + return None # Still pending + elif error == 'expired_token': + raise ValueError('Device code expired. Please try again.') + else: + raise ValueError(f'Token polling error: {error}') + + # Success - parse token response + return DeviceTokenResponse.model_validate(data) + + +def poll_until_authenticated( + http_client: httpx.Client, + device_code: str, + code_verifier: str, + interval: int, + expires_in: int, + on_poll: Optional[callable] = None, + verbose: bool = False, +) -> DeviceTokenResponse: + """Poll for token until authenticated or timeout. + + Args: + http_client: HTTP client to use for requests + device_code: Device code from authorization + code_verifier: PKCE code verifier + interval: Minimum polling interval in seconds + expires_in: Seconds until device code expires + on_poll: Optional callback called on each poll attempt + + Returns: + DeviceTokenResponse when authentication completes + + Raises: + ValueError: If authentication times out or fails + """ + start_time = time.time() + poll_count = 0 + max_polls = int(expires_in / interval) + 5 # Add some buffer + + while poll_count < max_polls: + elapsed = time.time() - start_time + if elapsed > expires_in: + raise ValueError('Authentication timed out. Please try again.') + + if on_poll: + on_poll(poll_count, elapsed) + + # Poll for token + try: + token_response = poll_for_token(http_client, device_code, code_verifier) + if token_response: + return token_response + except ValueError as e: + if 'expired' in str(e).lower(): + raise + # Other errors, log and continue polling + if verbose: + print(f'\n⚠ Polling error (will retry): {e}') + pass + except Exception as e: + # Log unexpected errors + if verbose: + print(f'\n⚠ Unexpected error (will retry): {type(e).__name__}: {e}') + pass + + # Wait before next poll + time.sleep(interval) + poll_count += 1 + + raise ValueError('Authentication timed out. Please try again.') + + +def open_browser(url: str) -> bool: + """Open URL in browser. + + Args: + url: URL to open + + Returns: + True if browser opened successfully + """ + try: + webbrowser.open(url) + return True + except Exception: + return False + + +def interactive_device_login(verbose: bool = True, auto_open_browser: bool = True) -> AuthStorage: + """Perform interactive device authorization flow. + + Args: + verbose: Print progress messages + auto_open_browser: Automatically open browser for user + + Returns: + AuthStorage with tokens + + Raises: + ValueError: If authentication fails + """ + http_client = httpx.Client() + + try: + # Step 1: Request device authorization + if verbose: + print('🔐 Starting authentication...\n') + + device_auth, code_verifier = request_device_authorization(http_client) + + # Step 2: Display user code and open browser + if verbose: + print(f'📱 Verification Code: {device_auth.user_code}') + print(f'🌐 Verification URL: {device_auth.verification_uri}\n') + + if auto_open_browser: + if verbose: + print('Opening browser...') + if open_browser(device_auth.verification_uri): + if verbose: + print('✓ Browser opened') + else: + if verbose: + print('✗ Could not open browser automatically') + print(f' Please open: {device_auth.verification_uri}') + + if verbose: + print(f'\n⏳ Waiting for authentication (expires in {device_auth.expires_in}s)...') + print(' Complete the authentication in your browser.\n') + + # Step 3: Poll for token + spinner_frames = ['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏'] + + def poll_callback(count: int, elapsed: float): + if verbose: + spinner = spinner_frames[count % len(spinner_frames)] + print(f'\r{spinner} Polling... ({int(elapsed)}s elapsed)', end='', flush=True) + + token_response = poll_until_authenticated( + http_client, + device_auth.device_code, + code_verifier, + device_auth.interval, + device_auth.expires_in, + poll_callback, + verbose=verbose, + ) + + if verbose: + print('\r✓ Authentication successful! \n') + + # Step 4: Create auth storage + now_ms = int(time.time() * 1000) + expiry_ms = now_ms + (token_response.expires_in * 1000) + + auth_storage = AuthStorage( + accessToken=token_response.access_token, + refreshToken=token_response.refresh_token, + userId=token_response.user_id, + accounts=token_response.user_accounts, + expiry=expiry_ms, + ) + + return auth_storage + + finally: + http_client.close() diff --git a/src/amp/auth/service.py b/src/amp/auth/service.py index 1b56bbb..be2e631 100644 --- a/src/amp/auth/service.py +++ b/src/amp/auth/service.py @@ -196,9 +196,7 @@ def refresh_token(self, auth: AuthStorage) -> AuthStorage: # Validate user ID matches (security check) if refresh_response.user.id != auth.userId: - raise ValueError( - f'User ID mismatch after refresh. Expected {auth.userId}, got {refresh_response.user.id}' - ) + raise ValueError(f'User ID mismatch after refresh. Expected {auth.userId}, got {refresh_response.user.id}') # Calculate new expiry now_ms = int(time.time() * 1000) @@ -225,3 +223,32 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): """Context manager exit.""" self._http.close() + + def login(self, verbose: bool = True, auto_open_browser: bool = True) -> None: + """Perform interactive browser-based login. + + Opens browser for OAuth2 device authorization flow with PKCE. + Saves authentication tokens to ~/.amp-cli-config/amp_cli_auth. + + Args: + verbose: Print progress messages + auto_open_browser: Automatically open browser + + Raises: + ValueError: If authentication fails + + Example: + >>> auth = AuthService() + >>> auth.login() # Opens browser for authentication + >>> # Auth tokens saved to ~/.amp-cli-config/amp_cli_auth + """ + from .device_flow import interactive_device_login + + # Perform device authorization flow + auth_storage = interactive_device_login(verbose=verbose, auto_open_browser=auto_open_browser) + + # Save to config file + self.save_auth(auth_storage) + + if verbose: + print(f'✓ Authentication saved to {self.config_path}') From 6fce5c0ff61d7166fd92051d835b9198e977b8f4 Mon Sep 17 00:00:00 2001 From: Ford Date: Fri, 14 Nov 2025 15:46:58 -0800 Subject: [PATCH 4/9] feat(auth): Support multiple ways of passing in auth - `AMP_AUTH_TOKEN` env var, `auth_token` param, or locally stored auth file (from interactive browser login) --- src/amp/admin/client.py | 34 +++++++-- src/amp/client.py | 30 ++++++-- tests/unit/test_client.py | 150 +++++++++++++++++++++++++++++++++++++- 3 files changed, 198 insertions(+), 16 deletions(-) diff --git a/src/amp/admin/client.py b/src/amp/admin/client.py index 75292e9..7cee60b 100644 --- a/src/amp/admin/client.py +++ b/src/amp/admin/client.py @@ -4,6 +4,7 @@ with the Amp Admin API over HTTP. """ +import os from typing import Optional import httpx @@ -19,15 +20,24 @@ class AdminClient: Args: base_url: Base URL for Admin API (e.g., 'http://localhost:8080') - auth_token: Optional Bearer token for authentication + auth_token: Optional Bearer token for authentication (highest priority) auth: If True, load auth token from ~/.amp-cli-config (shared with TS CLI) + Authentication Priority (highest to lowest): + 1. Explicit auth_token parameter + 2. AMP_AUTH_TOKEN environment variable + 3. auth=True - reads from ~/.amp-cli-config/amp_cli_auth + Example: - >>> # Use amp auth system + >>> # Use amp auth from file >>> client = AdminClient('http://localhost:8080', auth=True) >>> - >>> # Or use manual token + >>> # Use manual token >>> client = AdminClient('http://localhost:8080', auth_token='your-token') + >>> + >>> # Use environment variable + >>> # export AMP_AUTH_TOKEN="eyJhbGci..." + >>> client = AdminClient('http://localhost:8080') """ def __init__(self, base_url: str, auth_token: Optional[str] = None, auth: bool = False): @@ -46,17 +56,25 @@ def __init__(self, base_url: str, auth_token: Optional[str] = None, auth: bool = self.base_url = base_url.rstrip('/') - # Load token from amp auth system if requested - if auth: + # Resolve auth token with priority: explicit param > env var > auth file + resolved_token = None + if auth_token: + # Priority 1: Explicit auth_token parameter + resolved_token = auth_token + elif os.getenv('AMP_AUTH_TOKEN'): + # Priority 2: AMP_AUTH_TOKEN environment variable + resolved_token = os.getenv('AMP_AUTH_TOKEN') + elif auth: + # Priority 3: Load from ~/.amp-cli-config/amp_cli_auth from amp.auth import AuthService auth_service = AuthService() - auth_token = auth_service.get_token() + resolved_token = auth_service.get_token() # Build headers headers = {} - if auth_token: - headers['Authorization'] = f'Bearer {auth_token}' + if resolved_token: + headers['Authorization'] = f'Bearer {resolved_token}' # Create HTTP client self._http = httpx.Client( diff --git a/src/amp/client.py b/src/amp/client.py index 0f090b2..904d7e1 100644 --- a/src/amp/client.py +++ b/src/amp/client.py @@ -1,4 +1,5 @@ import logging +import os from typing import Dict, Iterator, List, Optional, Union import pyarrow as pa @@ -270,19 +271,28 @@ class Client: url: Flight SQL URL (for backward compatibility, treated as query_url) query_url: Query endpoint URL via Flight SQL (e.g., 'grpc://localhost:1602') admin_url: Optional Admin API URL (e.g., 'http://localhost:8080') - auth_token: Optional Bearer token for Admin API authentication + auth_token: Optional Bearer token for authentication (highest priority) auth: If True, load auth token from ~/.amp-cli-config (shared with TS CLI) + Authentication Priority (highest to lowest): + 1. Explicit auth_token parameter + 2. AMP_AUTH_TOKEN environment variable + 3. auth=True - reads from ~/.amp-cli-config/amp_cli_auth + Example: >>> # Query-only client (backward compatible) >>> client = Client(url='grpc://localhost:1602') >>> - >>> # Client with admin capabilities and amp auth + >>> # Client with amp auth from file >>> client = Client( ... query_url='grpc://localhost:1602', ... admin_url='http://localhost:8080', ... auth=True ... ) + >>> + >>> # Client with auth from environment variable + >>> # export AMP_AUTH_TOKEN="eyJhbGci..." + >>> client = Client(query_url='grpc://localhost:1602') """ def __init__( @@ -297,15 +307,20 @@ def __init__( if url and not query_url: query_url = url - # Get auth token if using amp auth system + # Resolve auth token with priority: explicit param > env var > auth file flight_auth_token = None - if auth and not auth_token: + if auth_token: + # Priority 1: Explicit auth_token parameter + flight_auth_token = auth_token + elif os.getenv('AMP_AUTH_TOKEN'): + # Priority 2: AMP_AUTH_TOKEN environment variable + flight_auth_token = os.getenv('AMP_AUTH_TOKEN') + elif auth: + # Priority 3: Load from ~/.amp-cli-config/amp_cli_auth from amp.auth import AuthService auth_service = AuthService() flight_auth_token = auth_service.get_token() - elif auth_token: - flight_auth_token = auth_token # Initialize Flight SQL client if query_url: @@ -327,7 +342,8 @@ def __init__( if admin_url: from amp.admin.client import AdminClient - self._admin_client = AdminClient(admin_url, auth_token=auth_token, auth=auth) + # Pass resolved token to AdminClient (maintains same priority logic) + self._admin_client = AdminClient(admin_url, auth_token=flight_auth_token, auth=False) else: self._admin_client = None diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 2a0c04c..9733e3d 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -7,7 +7,7 @@ import json from pathlib import Path -from unittest.mock import Mock +from unittest.mock import Mock, patch import pytest @@ -87,6 +87,154 @@ def test_client_requires_url_or_query_url(self): Client() +@pytest.mark.unit +class TestClientAuthPriority: + """Test Client authentication priority (explicit token > env var > auth file)""" + + @patch('amp.client.os.getenv') + @patch('amp.client.flight.connect') + def test_explicit_token_highest_priority(self, mock_connect, mock_getenv): + """Test that explicit auth_token parameter has highest priority""" + mock_getenv.return_value = 'env-var-token' + + client = Client(query_url='grpc://localhost:1602', auth_token='explicit-token') + + # Verify that explicit token was used (not env var) + mock_connect.assert_called_once() + call_args = mock_connect.call_args + middleware = call_args[1].get('middleware', []) + assert len(middleware) == 1 + assert middleware[0].token == 'explicit-token' + + @patch('amp.client.os.getenv') + @patch('amp.client.flight.connect') + def test_env_var_second_priority(self, mock_connect, mock_getenv): + """Test that AMP_AUTH_TOKEN env var has second priority""" + + # Return 'env-var-token' for AMP_AUTH_TOKEN, None for others + def getenv_side_effect(key, default=None): + if key == 'AMP_AUTH_TOKEN': + return 'env-var-token' + return default + + mock_getenv.side_effect = getenv_side_effect + + client = Client(query_url='grpc://localhost:1602') + + # Verify env var was checked + calls = [str(call) for call in mock_getenv.call_args_list] + assert any('AMP_AUTH_TOKEN' in call for call in calls) + mock_connect.assert_called_once() + call_args = mock_connect.call_args + middleware = call_args[1].get('middleware', []) + assert len(middleware) == 1 + assert middleware[0].token == 'env-var-token' + + @patch('amp.auth.AuthService') + @patch('amp.client.os.getenv') + @patch('amp.client.flight.connect') + def test_auth_file_lowest_priority(self, mock_connect, mock_getenv, mock_auth_service): + """Test that auth=True has lowest priority""" + + # Return None for all getenv calls + def getenv_side_effect(key, default=None): + return default + + mock_getenv.side_effect = getenv_side_effect + + mock_service_instance = Mock() + mock_service_instance.get_token.return_value = 'file-token' + mock_auth_service.return_value = mock_service_instance + + client = Client(query_url='grpc://localhost:1602', auth=True) + + # Verify auth file was used + mock_auth_service.assert_called_once() + mock_service_instance.get_token.assert_called_once() + mock_connect.assert_called_once() + call_args = mock_connect.call_args + middleware = call_args[1].get('middleware', []) + assert len(middleware) == 1 + assert middleware[0].token == 'file-token' + + @patch('amp.client.os.getenv') + @patch('amp.client.flight.connect') + def test_no_auth_when_nothing_provided(self, mock_connect, mock_getenv): + """Test that no auth middleware is added when no auth is provided""" + + # Return None/default for all getenv calls + def getenv_side_effect(key, default=None): + return default + + mock_getenv.side_effect = getenv_side_effect + + client = Client(query_url='grpc://localhost:1602') + + # Verify no middleware was added + mock_connect.assert_called_once() + call_args = mock_connect.call_args + middleware = call_args[1].get('middleware') + assert middleware is None or len(middleware) == 0 + + +@pytest.mark.unit +class TestAdminClientAuthPriority: + """Test AdminClient authentication priority""" + + @patch('amp.admin.client.os.getenv') + def test_admin_explicit_token_highest_priority(self, mock_getenv): + """Test that explicit auth_token parameter has highest priority for AdminClient""" + from amp.admin.client import AdminClient + + mock_getenv.return_value = 'env-var-token' + + client = AdminClient('http://localhost:8080', auth_token='explicit-token') + + # Verify explicit token was used + assert client._http.headers.get('Authorization') == 'Bearer explicit-token' + + @patch('amp.admin.client.os.getenv') + def test_admin_env_var_second_priority(self, mock_getenv): + """Test that AMP_AUTH_TOKEN env var has second priority for AdminClient""" + from amp.admin.client import AdminClient + + mock_getenv.return_value = 'env-var-token' + + client = AdminClient('http://localhost:8080') + + # Verify env var was used + mock_getenv.assert_called_with('AMP_AUTH_TOKEN') + assert client._http.headers.get('Authorization') == 'Bearer env-var-token' + + @patch('amp.auth.AuthService') + @patch('amp.admin.client.os.getenv') + def test_admin_auth_file_lowest_priority(self, mock_getenv, mock_auth_service): + """Test that auth=True has lowest priority for AdminClient""" + from amp.admin.client import AdminClient + + mock_getenv.return_value = None + mock_service_instance = Mock() + mock_service_instance.get_token.return_value = 'file-token' + mock_auth_service.return_value = mock_service_instance + + client = AdminClient('http://localhost:8080', auth=True) + + # Verify auth file was used + assert client._http.headers.get('Authorization') == 'Bearer file-token' + + @patch('amp.admin.client.os.getenv') + def test_admin_no_auth_when_nothing_provided(self, mock_getenv): + """Test that no auth header is added when no auth is provided""" + from amp.admin.client import AdminClient + + mock_getenv.return_value = None + + client = AdminClient('http://localhost:8080') + + # Verify no auth header + assert 'Authorization' not in client._http.headers + + @pytest.mark.unit class TestQueryBuilderManifest: """Test QueryBuilder manifest generation""" From e247d1491b78c7263eee8feb06eeb001f73f552e Mon Sep 17 00:00:00 2001 From: Ford Date: Sun, 16 Nov 2025 18:44:33 -0300 Subject: [PATCH 5/9] fix(auth/service): Update Auth platform url (use thegraph domain) --- src/amp/auth/service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/amp/auth/service.py b/src/amp/auth/service.py index be2e631..1d303f5 100644 --- a/src/amp/auth/service.py +++ b/src/amp/auth/service.py @@ -14,7 +14,7 @@ from .models import AuthStorage, RefreshTokenResponse # Auth platform URL (matches TypeScript implementation) -AUTH_PLATFORM_URL = 'https://auth.amp.edgeandnode.com/' +AUTH_PLATFORM_URL = 'https://auth.amp.thegraph.com/' # Storage location (matches TypeScript implementation) # TypeScript CLI uses: ~/.amp-cli-config/amp_cli_auth (directory with file inside) From 9d6a36dd26d81fde1874b53e2c7916c3f9cfe6bd Mon Sep 17 00:00:00 2001 From: Ford Date: Mon, 17 Nov 2025 00:01:26 -0300 Subject: [PATCH 6/9] tests: Add unit tests for auth service --- tests/unit/test_auth_service.py | 345 ++++++++++++++++++++++++++++++++ 1 file changed, 345 insertions(+) create mode 100644 tests/unit/test_auth_service.py diff --git a/tests/unit/test_auth_service.py b/tests/unit/test_auth_service.py new file mode 100644 index 0000000..450a6e3 --- /dev/null +++ b/tests/unit/test_auth_service.py @@ -0,0 +1,345 @@ +"""Unit tests for Privy authentication service.""" + +import json +import time +from unittest.mock import Mock, patch + +import pytest + +from src.amp.auth.models import AuthStorage +from src.amp.auth.service import AuthService + + +@pytest.mark.unit +class TestAuthService: + """Test AuthService functionality.""" + + def test_is_authenticated_when_no_file(self, tmp_path): + """Test is_authenticated returns False when config file doesn't exist.""" + config_path = tmp_path / 'nonexistent.json' + auth = AuthService(config_path=config_path) + + assert auth.is_authenticated() is False + + def test_is_authenticated_when_no_auth_data(self, tmp_path): + """Test is_authenticated returns False when config exists but has no auth data.""" + config_path = tmp_path / 'config.json' + config_path.write_text('{}') + + auth = AuthService(config_path=config_path) + + assert auth.is_authenticated() is False + + def test_is_authenticated_when_auth_exists(self, tmp_path): + """Test is_authenticated returns True when valid auth exists.""" + config_path = tmp_path / 'amp_cli_auth' + auth_data = { + 'accessToken': 'test-access-token', + 'refreshToken': 'test-refresh-token', + 'userId': 'did:privy:test123', + 'accounts': ['0x123'], + 'expiry': int(time.time() * 1000) + 3600000, # 1 hour from now + } + config_path.write_text(json.dumps(auth_data)) + + auth = AuthService(config_path=config_path) + + assert auth.is_authenticated() is True + + def test_load_auth_success(self, tmp_path): + """Test loading auth from config file.""" + config_path = tmp_path / 'amp_cli_auth' + auth_data = { + 'accessToken': 'test-access-token', + 'refreshToken': 'test-refresh-token', + 'userId': 'did:privy:test123', + 'accounts': ['0x123'], + 'expiry': 1234567890000, + } + config_path.write_text(json.dumps(auth_data)) + + auth = AuthService(config_path=config_path) + result = auth.load_auth() + + assert result is not None + assert result.accessToken == 'test-access-token' + assert result.refreshToken == 'test-refresh-token' + assert result.userId == 'did:privy:test123' + assert result.accounts == ['0x123'] + assert result.expiry == 1234567890000 + + def test_load_auth_returns_none_when_missing(self, tmp_path): + """Test load_auth returns None when config doesn't exist.""" + config_path = tmp_path / 'nonexistent.json' + auth = AuthService(config_path=config_path) + + result = auth.load_auth() + + assert result is None + + def test_save_auth(self, tmp_path): + """Test saving auth to config file.""" + config_path = tmp_path / 'amp_cli_auth' + auth = AuthService(config_path=config_path) + + auth_storage = AuthStorage( + accessToken='new-access-token', + refreshToken='new-refresh-token', + userId='did:privy:test456', + accounts=['0x456'], + expiry=9876543210000, + ) + + auth.save_auth(auth_storage) + + # Verify file was created + assert config_path.exists() + + # Verify content + with open(config_path) as f: + saved_data = json.load(f) + + assert saved_data['accessToken'] == 'new-access-token' + assert saved_data['userId'] == 'did:privy:test456' + + def test_save_auth_overwrites_existing(self, tmp_path): + """Test saving auth overwrites existing auth data.""" + config_path = tmp_path / 'amp_cli_auth' + initial_data = { + 'accessToken': 'old-token', + 'refreshToken': 'old-refresh', + 'userId': 'did:privy:old', + 'accounts': [], + 'expiry': 12345, + } + config_path.write_text(json.dumps(initial_data)) + + auth = AuthService(config_path=config_path) + auth_storage = AuthStorage( + accessToken='new-token', + refreshToken='new-refresh', + userId='did:privy:new', + accounts=['0x123'], + expiry=67890, + ) + + auth.save_auth(auth_storage) + + # Verify new data saved + with open(config_path) as f: + saved_data = json.load(f) + + assert saved_data['accessToken'] == 'new-token' + assert saved_data['userId'] == 'did:privy:new' + assert saved_data['accounts'] == ['0x123'] + + def test_get_token_raises_when_not_authenticated(self, tmp_path): + """Test get_token raises error when not authenticated.""" + config_path = tmp_path / 'nonexistent.json' + auth = AuthService(config_path=config_path) + + with pytest.raises(FileNotFoundError, match='Not authenticated'): + auth.get_token() + + def test_get_token_returns_valid_token(self, tmp_path): + """Test get_token returns token when valid and not expired.""" + config_path = tmp_path / 'amp_cli_auth' + future_expiry = int(time.time() * 1000) + 3600000 # 1 hour from now + auth_data = { + 'accessToken': 'valid-token', + 'refreshToken': 'refresh-token', + 'userId': 'did:privy:test', + 'accounts': ['0x123'], + 'expiry': future_expiry, + } + config_path.write_text(json.dumps(auth_data)) + + auth = AuthService(config_path=config_path) + token = auth.get_token() + + assert token == 'valid-token' + + +@pytest.mark.unit +class TestAuthServiceRefresh: + """Test token refresh functionality.""" + + def test_needs_refresh_when_missing_expiry(self): + """Test needs refresh when expiry field is missing.""" + auth = AuthService() + auth_storage = AuthStorage( + accessToken='token', + refreshToken='refresh', + userId='did:privy:test', + accounts=['0x123'], + expiry=None, # Missing expiry + ) + + assert auth._needs_refresh(auth_storage) is True + + def test_needs_refresh_when_missing_accounts(self): + """Test needs refresh when accounts field is missing.""" + auth = AuthService() + auth_storage = AuthStorage( + accessToken='token', + refreshToken='refresh', + userId='did:privy:test', + accounts=None, # Missing accounts + expiry=int(time.time() * 1000) + 3600000, + ) + + assert auth._needs_refresh(auth_storage) is True + + def test_needs_refresh_when_expired(self): + """Test needs refresh when token is expired.""" + auth = AuthService() + past_expiry = int(time.time() * 1000) - 1000 # 1 second ago + auth_storage = AuthStorage( + accessToken='token', + refreshToken='refresh', + userId='did:privy:test', + accounts=['0x123'], + expiry=past_expiry, + ) + + assert auth._needs_refresh(auth_storage) is True + + def test_needs_refresh_when_expiring_soon(self): + """Test needs refresh when token expires within 5 minutes.""" + auth = AuthService() + soon_expiry = int(time.time() * 1000) + (4 * 60 * 1000) # 4 minutes from now + auth_storage = AuthStorage( + accessToken='token', + refreshToken='refresh', + userId='did:privy:test', + accounts=['0x123'], + expiry=soon_expiry, + ) + + assert auth._needs_refresh(auth_storage) is True + + def test_does_not_need_refresh_when_valid(self): + """Test does not need refresh when token is valid and not expiring soon.""" + auth = AuthService() + future_expiry = int(time.time() * 1000) + (10 * 60 * 1000) # 10 minutes from now + auth_storage = AuthStorage( + accessToken='token', + refreshToken='refresh', + userId='did:privy:test', + accounts=['0x123'], + expiry=future_expiry, + ) + + assert auth._needs_refresh(auth_storage) is False + + @patch('httpx.Client.post') + def test_refresh_token_success(self, mock_post, tmp_path): + """Test successful token refresh.""" + config_path = tmp_path / 'config.json' + + # Mock HTTP response + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + 'token': 'new-access-token', + 'refresh_token': 'new-refresh-token', + 'session_update_action': 'update', + 'expires_in': 3600, + 'user': {'id': 'did:privy:test', 'accounts': ['0x123', '0x456']}, + } + mock_post.return_value = mock_response + + auth = AuthService(config_path=config_path) + old_auth = AuthStorage( + accessToken='old-token', + refreshToken='old-refresh', + userId='did:privy:test', + accounts=['0x123'], + expiry=12345, + ) + + new_auth = auth.refresh_token(old_auth) + + # Verify new tokens + assert new_auth.accessToken == 'new-access-token' + assert new_auth.refreshToken == 'new-refresh-token' + assert new_auth.userId == 'did:privy:test' + assert new_auth.accounts == ['0x123', '0x456'] + assert new_auth.expiry > int(time.time() * 1000) + + # Verify saved to file + assert config_path.exists() + + @patch('httpx.Client.post') + def test_refresh_token_user_id_mismatch(self, mock_post, tmp_path): + """Test refresh fails when user ID doesn't match.""" + config_path = tmp_path / 'config.json' + + # Mock HTTP response with different user ID + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + 'token': 'new-token', + 'refresh_token': 'new-refresh', + 'session_update_action': 'update', + 'expires_in': 3600, + 'user': {'id': 'did:privy:different-user', 'accounts': []}, + } + mock_post.return_value = mock_response + + auth = AuthService(config_path=config_path) + old_auth = AuthStorage( + accessToken='old-token', + refreshToken='old-refresh', + userId='did:privy:original-user', + accounts=[], + expiry=12345, + ) + + with pytest.raises(ValueError, match='User ID mismatch'): + auth.refresh_token(old_auth) + + @patch('httpx.Client.post') + def test_refresh_token_401_error(self, mock_post, tmp_path): + """Test refresh handles 401 authentication error.""" + config_path = tmp_path / 'config.json' + + # Mock 401 response + mock_response = Mock() + mock_response.status_code = 401 + mock_post.return_value = mock_response + + auth = AuthService(config_path=config_path) + old_auth = AuthStorage( + accessToken='old-token', + refreshToken='old-refresh', + userId='did:privy:test', + accounts=[], + expiry=12345, + ) + + with pytest.raises(ValueError, match='Authentication expired'): + auth.refresh_token(old_auth) + + @patch('httpx.Client.post') + def test_refresh_token_429_rate_limit(self, mock_post, tmp_path): + """Test refresh handles 429 rate limit error.""" + config_path = tmp_path / 'config.json' + + # Mock 429 response + mock_response = Mock() + mock_response.status_code = 429 + mock_response.headers = {'retry-after': '120'} + mock_post.return_value = mock_response + + auth = AuthService(config_path=config_path) + old_auth = AuthStorage( + accessToken='old-token', + refreshToken='old-refresh', + userId='did:privy:test', + accounts=[], + expiry=12345, + ) + + with pytest.raises(ValueError, match='rate limited'): + auth.refresh_token(old_auth) From f90ec8e3beb1a2fc013cbb1d274b05bb4330f4e6 Mon Sep 17 00:00:00 2001 From: Ford Date: Mon, 17 Nov 2025 01:28:56 -0300 Subject: [PATCH 7/9] linting and formatting --- tests/unit/test_client.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 9733e3d..e3c91f2 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -97,7 +97,7 @@ def test_explicit_token_highest_priority(self, mock_connect, mock_getenv): """Test that explicit auth_token parameter has highest priority""" mock_getenv.return_value = 'env-var-token' - client = Client(query_url='grpc://localhost:1602', auth_token='explicit-token') + Client(query_url='grpc://localhost:1602', auth_token='explicit-token') # Verify that explicit token was used (not env var) mock_connect.assert_called_once() @@ -119,7 +119,7 @@ def getenv_side_effect(key, default=None): mock_getenv.side_effect = getenv_side_effect - client = Client(query_url='grpc://localhost:1602') + Client(query_url='grpc://localhost:1602') # Verify env var was checked calls = [str(call) for call in mock_getenv.call_args_list] @@ -146,7 +146,7 @@ def getenv_side_effect(key, default=None): mock_service_instance.get_token.return_value = 'file-token' mock_auth_service.return_value = mock_service_instance - client = Client(query_url='grpc://localhost:1602', auth=True) + Client(query_url='grpc://localhost:1602', auth=True) # Verify auth file was used mock_auth_service.assert_called_once() @@ -168,7 +168,7 @@ def getenv_side_effect(key, default=None): mock_getenv.side_effect = getenv_side_effect - client = Client(query_url='grpc://localhost:1602') + Client(query_url='grpc://localhost:1602') # Verify no middleware was added mock_connect.assert_called_once() From cc9de479ee42cc8c8dc7f2d272099190c130c0cc Mon Sep 17 00:00:00 2001 From: Ford Date: Tue, 18 Nov 2025 15:28:21 -0300 Subject: [PATCH 8/9] client, admin-client: Keep auth token fresh --- src/amp/admin/client.py | 31 +++++------ src/amp/client.py | 56 ++++++++++++-------- tests/integration/admin/test_admin_client.py | 3 +- tests/integration/test_snowflake_loader.py | 2 +- tests/unit/test_client.py | 16 +++--- 5 files changed, 61 insertions(+), 47 deletions(-) diff --git a/src/amp/admin/client.py b/src/amp/admin/client.py index 7cee60b..7f7eb8f 100644 --- a/src/amp/admin/client.py +++ b/src/amp/admin/client.py @@ -56,30 +56,25 @@ def __init__(self, base_url: str, auth_token: Optional[str] = None, auth: bool = self.base_url = base_url.rstrip('/') - # Resolve auth token with priority: explicit param > env var > auth file - resolved_token = None + # Resolve auth token provider with priority: explicit param > env var > auth file + self._get_token = None if auth_token: - # Priority 1: Explicit auth_token parameter - resolved_token = auth_token + # Priority 1: Explicit auth_token parameter (static token) + self._get_token = lambda: auth_token elif os.getenv('AMP_AUTH_TOKEN'): - # Priority 2: AMP_AUTH_TOKEN environment variable - resolved_token = os.getenv('AMP_AUTH_TOKEN') + # Priority 2: AMP_AUTH_TOKEN environment variable (static token) + env_token = os.getenv('AMP_AUTH_TOKEN') + self._get_token = lambda: env_token elif auth: - # Priority 3: Load from ~/.amp-cli-config/amp_cli_auth + # Priority 3: Load from ~/.amp-cli-config/amp_cli_auth (auto-refreshing) from amp.auth import AuthService auth_service = AuthService() - resolved_token = auth_service.get_token() + self._get_token = auth_service.get_token # Callable that auto-refreshes - # Build headers - headers = {} - if resolved_token: - headers['Authorization'] = f'Bearer {resolved_token}' - - # Create HTTP client + # Create HTTP client (no auth header yet - will be added per-request) self._http = httpx.Client( base_url=self.base_url, - headers=headers, timeout=30.0, follow_redirects=True, ) @@ -102,6 +97,12 @@ def _request( Raises: AdminAPIError: If the API returns an error response """ + # Add auth header dynamically (auto-refreshes if needed) + headers = kwargs.get('headers', {}) + if self._get_token: + headers['Authorization'] = f'Bearer {self._get_token()}' + kwargs['headers'] = headers + response = self._http.request(method, path, json=json, params=params, **kwargs) # Handle error responses diff --git a/src/amp/client.py b/src/amp/client.py index 904d7e1..49f170a 100644 --- a/src/amp/client.py +++ b/src/amp/client.py @@ -24,33 +24,33 @@ class AuthMiddleware(ClientMiddleware): """Flight middleware to add Bearer token authentication header.""" - def __init__(self, token: str): + def __init__(self, get_token): """Initialize auth middleware. Args: - token: Bearer token to add to requests + get_token: Callable that returns the current access token """ - self.token = token + self.get_token = get_token def sending_headers(self): """Add Authorization header to outgoing requests.""" - return {'authorization': f'Bearer {self.token}'} + return {'authorization': f'Bearer {self.get_token()}'} class AuthMiddlewareFactory(ClientMiddlewareFactory): """Factory for creating auth middleware instances.""" - def __init__(self, token: str): + def __init__(self, get_token): """Initialize auth middleware factory. Args: - token: Bearer token to use for authentication + get_token: Callable that returns the current access token """ - self.token = token + self.get_token = get_token def start_call(self, info): """Create auth middleware for each call.""" - return AuthMiddleware(self.token) + return AuthMiddleware(self.get_token) class QueryBuilder: @@ -307,26 +307,30 @@ def __init__( if url and not query_url: query_url = url - # Resolve auth token with priority: explicit param > env var > auth file - flight_auth_token = None + # Resolve auth token provider with priority: explicit param > env var > auth file + get_token = None if auth_token: - # Priority 1: Explicit auth_token parameter - flight_auth_token = auth_token + # Priority 1: Explicit auth_token parameter (static token) + def get_token(): + return auth_token elif os.getenv('AMP_AUTH_TOKEN'): - # Priority 2: AMP_AUTH_TOKEN environment variable - flight_auth_token = os.getenv('AMP_AUTH_TOKEN') + # Priority 2: AMP_AUTH_TOKEN environment variable (static token) + env_token = os.getenv('AMP_AUTH_TOKEN') + + def get_token(): + return env_token elif auth: - # Priority 3: Load from ~/.amp-cli-config/amp_cli_auth + # Priority 3: Load from ~/.amp-cli-config/amp_cli_auth (auto-refreshing) from amp.auth import AuthService auth_service = AuthService() - flight_auth_token = auth_service.get_token() + get_token = auth_service.get_token # Callable that auto-refreshes # Initialize Flight SQL client if query_url: - # Add auth middleware if token is provided - if flight_auth_token: - middleware = [AuthMiddlewareFactory(flight_auth_token)] + # Add auth middleware if token provider exists + if get_token: + middleware = [AuthMiddlewareFactory(get_token)] self.conn = flight.connect(query_url, middleware=middleware) else: self.conn = flight.connect(query_url) @@ -342,8 +346,18 @@ def __init__( if admin_url: from amp.admin.client import AdminClient - # Pass resolved token to AdminClient (maintains same priority logic) - self._admin_client = AdminClient(admin_url, auth_token=flight_auth_token, auth=False) + # Pass auth=True if we have a get_token callable from auth file + # Otherwise pass the static token if available + if auth: + # Use auth file (auto-refreshing) + self._admin_client = AdminClient(admin_url, auth=True) + elif auth_token or os.getenv('AMP_AUTH_TOKEN'): + # Use static token + static_token = auth_token or os.getenv('AMP_AUTH_TOKEN') + self._admin_client = AdminClient(admin_url, auth_token=static_token) + else: + # No auth + self._admin_client = AdminClient(admin_url) else: self._admin_client = None diff --git a/tests/integration/admin/test_admin_client.py b/tests/integration/admin/test_admin_client.py index f8e7c31..82e7648 100644 --- a/tests/integration/admin/test_admin_client.py +++ b/tests/integration/admin/test_admin_client.py @@ -25,8 +25,7 @@ def test_admin_client_with_auth_token(self): """Test AdminClient with authentication token.""" client = AdminClient('http://localhost:8080', auth_token='test-token') - assert 'Authorization' in client._http.headers - assert client._http.headers['Authorization'] == 'Bearer test-token' + assert client._get_token() == 'test-token' @respx.mock def test_request_success(self): diff --git a/tests/integration/test_snowflake_loader.py b/tests/integration/test_snowflake_loader.py index d13f8eb..78c2c17 100644 --- a/tests/integration/test_snowflake_loader.py +++ b/tests/integration/test_snowflake_loader.py @@ -67,7 +67,7 @@ def wait_for_snowpipe_data(loader, table_name, expected_count, max_wait=30, poll # Skip all Snowflake tests -# pytestmark = pytest.mark.skip(reason='Requires active Snowflake account - see module docstring for details') +pytestmark = pytest.mark.skip(reason='Requires active Snowflake account - see module docstring for details') @pytest.fixture diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index e3c91f2..00b0488 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -104,7 +104,7 @@ def test_explicit_token_highest_priority(self, mock_connect, mock_getenv): call_args = mock_connect.call_args middleware = call_args[1].get('middleware', []) assert len(middleware) == 1 - assert middleware[0].token == 'explicit-token' + assert middleware[0].get_token() == 'explicit-token' @patch('amp.client.os.getenv') @patch('amp.client.flight.connect') @@ -128,7 +128,7 @@ def getenv_side_effect(key, default=None): call_args = mock_connect.call_args middleware = call_args[1].get('middleware', []) assert len(middleware) == 1 - assert middleware[0].token == 'env-var-token' + assert middleware[0].get_token() == 'env-var-token' @patch('amp.auth.AuthService') @patch('amp.client.os.getenv') @@ -150,12 +150,12 @@ def getenv_side_effect(key, default=None): # Verify auth file was used mock_auth_service.assert_called_once() - mock_service_instance.get_token.assert_called_once() mock_connect.assert_called_once() call_args = mock_connect.call_args middleware = call_args[1].get('middleware', []) assert len(middleware) == 1 - assert middleware[0].token == 'file-token' + # The middleware should use the auth service's get_token method directly + assert middleware[0].get_token == mock_service_instance.get_token @patch('amp.client.os.getenv') @patch('amp.client.flight.connect') @@ -190,8 +190,8 @@ def test_admin_explicit_token_highest_priority(self, mock_getenv): client = AdminClient('http://localhost:8080', auth_token='explicit-token') - # Verify explicit token was used - assert client._http.headers.get('Authorization') == 'Bearer explicit-token' + # Verify explicit token was used (check get_token callable) + assert client._get_token() == 'explicit-token' @patch('amp.admin.client.os.getenv') def test_admin_env_var_second_priority(self, mock_getenv): @@ -204,7 +204,7 @@ def test_admin_env_var_second_priority(self, mock_getenv): # Verify env var was used mock_getenv.assert_called_with('AMP_AUTH_TOKEN') - assert client._http.headers.get('Authorization') == 'Bearer env-var-token' + assert client._get_token() == 'env-var-token' @patch('amp.auth.AuthService') @patch('amp.admin.client.os.getenv') @@ -220,7 +220,7 @@ def test_admin_auth_file_lowest_priority(self, mock_getenv, mock_auth_service): client = AdminClient('http://localhost:8080', auth=True) # Verify auth file was used - assert client._http.headers.get('Authorization') == 'Bearer file-token' + assert client._get_token == mock_service_instance.get_token @patch('amp.admin.client.os.getenv') def test_admin_no_auth_when_nothing_provided(self, mock_getenv): From 8c0c6a8d5a7267b8c57c5cdbb26a708e98e47fe3 Mon Sep 17 00:00:00 2001 From: Ford Date: Thu, 20 Nov 2025 17:52:19 -0300 Subject: [PATCH 9/9] auth: Update location --- src/amp/admin/client.py | 6 +++--- src/amp/auth/__init__.py | 2 +- src/amp/auth/models.py | 2 +- src/amp/auth/service.py | 22 +++++++++++----------- src/amp/client.py | 6 +++--- 5 files changed, 19 insertions(+), 19 deletions(-) diff --git a/src/amp/admin/client.py b/src/amp/admin/client.py index 7f7eb8f..e6469c0 100644 --- a/src/amp/admin/client.py +++ b/src/amp/admin/client.py @@ -21,12 +21,12 @@ class AdminClient: Args: base_url: Base URL for Admin API (e.g., 'http://localhost:8080') auth_token: Optional Bearer token for authentication (highest priority) - auth: If True, load auth token from ~/.amp-cli-config (shared with TS CLI) + auth: If True, load auth token from ~/.amp/cache (shared with TS CLI) Authentication Priority (highest to lowest): 1. Explicit auth_token parameter 2. AMP_AUTH_TOKEN environment variable - 3. auth=True - reads from ~/.amp-cli-config/amp_cli_auth + 3. auth=True - reads from ~/.amp/cache/amp_cli_auth Example: >>> # Use amp auth from file @@ -46,7 +46,7 @@ def __init__(self, base_url: str, auth_token: Optional[str] = None, auth: bool = Args: base_url: Base URL for Admin API (e.g., 'http://localhost:8080') auth_token: Optional Bearer token for authentication - auth: If True, load auth token from ~/.amp-cli-config + auth: If True, load auth token from ~/.amp/cache Raises: ValueError: If both auth=True and auth_token are provided diff --git a/src/amp/auth/__init__.py b/src/amp/auth/__init__.py index 8e456c7..d648a6f 100644 --- a/src/amp/auth/__init__.py +++ b/src/amp/auth/__init__.py @@ -1,7 +1,7 @@ """Authentication module for amp Python client. Provides Privy authentication support compatible with the TypeScript CLI. -Reads and manages auth tokens from ~/.amp-cli-config. +Reads and manages auth tokens from ~/.amp/cache. """ from .device_flow import interactive_device_login diff --git a/src/amp/auth/models.py b/src/amp/auth/models.py index 925e2e8..d60f7c5 100644 --- a/src/amp/auth/models.py +++ b/src/amp/auth/models.py @@ -9,7 +9,7 @@ class AuthStorage(BaseModel): - """Auth storage schema for ~/.amp-cli-config file. + """Auth storage schema for ~/.amp/cache/amp_cli_auth file. Matches the TypeScript AuthStorageSchema so they can share auth. """ diff --git a/src/amp/auth/service.py b/src/amp/auth/service.py index 1d303f5..b0c83b0 100644 --- a/src/amp/auth/service.py +++ b/src/amp/auth/service.py @@ -1,6 +1,6 @@ """Auth service for managing Privy authentication. -Handles loading, refreshing, and persisting auth tokens from ~/.amp-cli-config. +Handles loading, refreshing, and persisting auth tokens from ~/.amp/cache. Compatible with the TypeScript CLI authentication system. """ @@ -17,15 +17,15 @@ AUTH_PLATFORM_URL = 'https://auth.amp.thegraph.com/' # Storage location (matches TypeScript implementation) -# TypeScript CLI uses: ~/.amp-cli-config/amp_cli_auth (directory with file inside) -AUTH_CONFIG_DIR = Path.home() / '.amp-cli-config' +# TypeScript CLI uses: ~/.amp/cache/amp_cli_auth (directory with file inside) +AUTH_CONFIG_DIR = Path.home() / '.amp' / 'cache' AUTH_CONFIG_FILE = AUTH_CONFIG_DIR / 'amp_cli_auth' class AuthService: """Service for managing Privy authentication tokens. - Loads tokens from ~/.amp-cli-config (shared with TypeScript CLI), + Loads tokens from ~/.amp/cache (shared with TypeScript CLI), automatically refreshes expired tokens, and persists updates. Example: @@ -38,7 +38,7 @@ def __init__(self, config_path: Optional[Path] = None): """Initialize auth service. Args: - config_path: Optional custom path to config file (defaults to ~/.amp-cli-config/amp_cli_auth) + config_path: Optional custom path to config file (defaults to ~/.amp/cache/amp_cli_auth) """ self.config_path = config_path or AUTH_CONFIG_FILE self._http = httpx.Client(timeout=30.0) @@ -47,7 +47,7 @@ def is_authenticated(self) -> bool: """Check if user is authenticated. Returns: - True if valid auth exists in ~/.amp-cli-config + True if valid auth exists in ~/.amp/cache """ try: auth = self.load_auth() @@ -56,7 +56,7 @@ def is_authenticated(self) -> bool: return False def load_auth(self) -> Optional[AuthStorage]: - """Load auth from ~/.amp-cli-config/amp_cli_auth file. + """Load auth from ~/.amp/cache/amp_cli_auth file. Returns: AuthStorage if found, None if not authenticated @@ -75,7 +75,7 @@ def load_auth(self) -> Optional[AuthStorage]: return AuthStorage.model_validate(auth_data) def save_auth(self, auth: AuthStorage) -> None: - """Save auth to ~/.amp-cli-config/amp_cli_auth file. + """Save auth to ~/.amp/cache/amp_cli_auth file. Args: auth: Auth data to persist @@ -97,7 +97,7 @@ def get_token(self) -> str: Valid access token string Raises: - FileNotFoundError: If not authenticated (no ~/.amp-cli-config) + FileNotFoundError: If not authenticated (no ~/.amp/cache) ValueError: If auth data is invalid or refresh fails """ auth = self.load_auth() @@ -228,7 +228,7 @@ def login(self, verbose: bool = True, auto_open_browser: bool = True) -> None: """Perform interactive browser-based login. Opens browser for OAuth2 device authorization flow with PKCE. - Saves authentication tokens to ~/.amp-cli-config/amp_cli_auth. + Saves authentication tokens to ~/.amp/cache/amp_cli_auth. Args: verbose: Print progress messages @@ -240,7 +240,7 @@ def login(self, verbose: bool = True, auto_open_browser: bool = True) -> None: Example: >>> auth = AuthService() >>> auth.login() # Opens browser for authentication - >>> # Auth tokens saved to ~/.amp-cli-config/amp_cli_auth + >>> # Auth tokens saved to ~/.amp/cache/amp_cli_auth """ from .device_flow import interactive_device_login diff --git a/src/amp/client.py b/src/amp/client.py index 49f170a..760bfc7 100644 --- a/src/amp/client.py +++ b/src/amp/client.py @@ -272,12 +272,12 @@ class Client: query_url: Query endpoint URL via Flight SQL (e.g., 'grpc://localhost:1602') admin_url: Optional Admin API URL (e.g., 'http://localhost:8080') auth_token: Optional Bearer token for authentication (highest priority) - auth: If True, load auth token from ~/.amp-cli-config (shared with TS CLI) + auth: If True, load auth token from ~/.amp/cache (shared with TS CLI) Authentication Priority (highest to lowest): 1. Explicit auth_token parameter 2. AMP_AUTH_TOKEN environment variable - 3. auth=True - reads from ~/.amp-cli-config/amp_cli_auth + 3. auth=True - reads from ~/.amp/cache/amp_cli_auth Example: >>> # Query-only client (backward compatible) @@ -320,7 +320,7 @@ def get_token(): def get_token(): return env_token elif auth: - # Priority 3: Load from ~/.amp-cli-config/amp_cli_auth (auto-refreshing) + # Priority 3: Load from ~/.amp/cache/amp_cli_auth (auto-refreshing) from amp.auth import AuthService auth_service = AuthService()