diff --git a/AGENTS.MD b/AGENTS.MD index 56c05f8..a97d92a 100644 --- a/AGENTS.MD +++ b/AGENTS.MD @@ -337,11 +337,48 @@ When building your own BYOVA integration: **Important**: This gateway is not production-ready. For production use: - **Implement proper security**: Add authentication, authorization, and encryption +- **Enable JWT validation**: Configure JWT token validation for secure gRPC communication (see below) - **Add production monitoring**: Implement comprehensive logging, metrics, and alerting - **Handle scaling**: Design for horizontal scaling and load balancing - **Add error handling**: Implement robust error handling and recovery mechanisms - **Security review**: Conduct thorough security reviews before deployment +#### JWT Authentication for Production + +The gateway includes JWT (JSON Web Token) validation for securing gRPC requests from Webex Contact Center. This should be enabled for production deployments. + +**Key Features:** +- Validates JWT signatures using RSA public keys from Webex identity broker +- Verifies all required claims (issuer, audience, subject, JWT ID, expiration) +- Validates datasource-specific claims (URL and schema UUID) +- Caches public keys for 60 minutes to optimize performance +- Supports optional enforcement for gradual rollout + +**Configuration in `config/config.yaml`:** + +```yaml +jwt_validation: + enabled: true # Enable JWT validation + enforce_validation: true # Reject invalid tokens (set to false for logging only) + datasource_url: "https://your-gateway.example.com:443" # Must match BYODS registration + datasource_schema_uuid: "5397013b-7920-4ffc-807c-e8a3e0a18f43" # BYOVA schema UUID + cache_duration_minutes: 60 # Public key cache duration +``` + +**Implementation Details:** +- **Module**: `src/auth/jwt_validator.py` - Core validation logic +- **Interceptor**: `src/auth/jwt_interceptor.py` - gRPC request interceptor +- **Integration**: Automatically loaded in `main.py` when enabled +- **Reference**: Based on [Webex sample Java implementation](https://github.com/CiscoDevNet/webex-contact-center-provider-sample-code/blob/main/media-service-api/dialog-connector-simulator/src/main/java/com/cisco/wccai/grpc/server/interceptors/JWTAuthorizationHandler.java) + +**Deployment Recommendations:** +1. Start with `enforce_validation: false` to monitor validation without blocking requests +2. Verify logs show successful validation for all requests +3. Enable `enforce_validation: true` for full security +4. Monitor for authentication errors and adjust configuration as needed + +See [README.md](README.md) for complete JWT authentication documentation and troubleshooting. + ## Working with the Codebase To begin development (for learning or building upon this example): diff --git a/README.md b/README.md index 8172e8d..390eaf3 100644 --- a/README.md +++ b/README.md @@ -56,10 +56,12 @@ This comprehensive guide walks you through: which python # Should show path to venv/bin/python ``` -3. **Install Dependencies** +3. **Install Dependencies** (Required) ```bash pip install -r requirements.txt ``` + + **Important**: All dependencies including JWT authentication libraries are required. The gateway will not start if dependencies are missing. 4. **Generate gRPC Stubs** ```bash @@ -119,60 +121,144 @@ The gateway will start with the local audio connector by default, which uses the ### Configuration -The gateway is configured via `config/config.yaml`. Key configuration options: +The gateway is configured via `config/config.yaml`. Key configuration sections: ```yaml # Gateway settings gateway: - host: "0.0.0.0" - port: 50051 + host: "0.0.0.0" + port: 50051 + +# Connectors configuration +connectors: + # Local Audio Connector - plays audio files from the audio/ directory + local_audio_connector: + type: "local_audio_connector" + class: "LocalAudioConnector" + module: "connectors.local_audio_connector" + config: + audio_files: + welcome: "welcome.wav" + transfer: "transferring.wav" + goodbye: "goodbye.wav" + error: "error.wav" + default: "default_response.wav" + agents: + - "Local Playback" + + # AWS Lex Connector - integrates with Amazon Lex bots + aws_lex_connector: + type: "aws_lex_connector" + class: "AWSLexConnector" + module: "connectors.aws_lex_connector" + config: + region_name: "us-east-1" + # bot_alias_id: "YOUR_BOT_ALIAS_ID" # Required for specific bot + # aws_access_key_id: "YOUR_ACCESS_KEY" # Optional, uses env vars if not set + # aws_secret_access_key: "YOUR_SECRET_KEY" # Optional, uses env vars if not set + initial_trigger_text: "hello" + barge_in_enabled: false + audio_logging: + enabled: true + output_dir: "logs/audio_recordings" + filename_format: "{conversation_id}_{timestamp}_{source}.wav" + log_all_audio: true + max_file_size: 10485760 + sample_rate: 8000 + bit_depth: 8 + channels: 1 + encoding: "ulaw" + agents: [] # Monitoring interface monitoring: - enabled: true - host: "0.0.0.0" - port: 8080 - -# Connectors -connectors: - local_audio_connector: - type: "local_audio_connector" - class: "LocalAudioConnector" - module: "connectors.local_audio_connector" - config: - audio_files: - welcome: "welcome.wav" - transfer: "transferring.wav" - goodbye: "goodbye.wav" - error: "error.wav" - default: "default_response.wav" - agents: - - "Local Playback" - - # Example: AWS Lex Connector Configuration - aws_lex_connector: - type: "aws_lex_connector" - class: "AWSLexConnector" - module: "connectors.aws_lex_connector" - config: - region_name: "us-east-1" # Set your AWS region (required) - initial_trigger_text: "hello" # Text sent when starting conversation (default: "hello") - barge_in_enabled: false # Allow users to interrupt bot responses (default: false) - audio_logging: - enabled: true - output_dir: "logs/audio_recordings" - filename_format: "{conversation_id}_{timestamp}_{source}.wav" - log_all_audio: true - max_file_size: 10485760 - sample_rate: 8000 - bit_depth: 8 - channels: 1 - encoding: "ulaw" + enabled: true + host: "0.0.0.0" + port: 8080 + metrics_enabled: true + health_check_interval: 30 + +# Web dashboard authentication +authentication: + enabled: true + environment: "dev" # Options: "dev" or "production" + session: + timeout_hours: 24 + secret_key_env: "FLASK_SECRET_KEY" + webex_oauth: + scopes: "openid email profile" + state: "byova_gateway_auth" + +# JWT validation for gRPC requests (REQUIRED when enabled) +jwt_validation: + # Enable/disable JWT validation (default: true for security) + enabled: true + + # Enforce validation - if false, invalid tokens are logged but allowed + enforce_validation: true + + # REQUIRED: Datasource URL - must match URL registered with Webex Contact Center + # Example: "https://your-gateway-domain.com:443" + datasource_url: "" # Must be configured if enabled=true + + # Datasource schema UUID (default is standard BYOVA schema) + # This is the schema ID from https://github.com/webex/dataSourceSchemas + # Path: Services/VoiceVirtualAgent/5397013b-7920-4ffc-807c-e8a3e0a18f43/schema.json + # This value should not change unless there is a major modification to the BYOVA schema + datasource_schema_uuid: "5397013b-7920-4ffc-807c-e8a3e0a18f43" + + # Public key cache duration in minutes + cache_duration_minutes: 60 + +# Logging configuration +logging: + gateway: + level: "INFO" # DEBUG, INFO, WARNING, ERROR + format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + file: "logs/gateway.log" + max_size: "10MB" + backup_count: 5 + web: + level: "WARNING" + format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + file: "logs/web.log" + max_size: "5MB" + backup_count: 3 + +# Session management +sessions: + timeout: 600 # Session timeout in seconds + max_sessions: 1000 + cleanup_interval: 60 + enable_auto_cleanup: true + max_session_duration: 3600 + +# Audio processing +audio: + supported_formats: + - "wav" + - "mp3" + - "flac" + - "ogg" ``` -### AWS Credentials Configuration +#### Important Configuration Notes + +**JWT Validation** (Required): +- JWT validation is **enabled by default** for security +- You **must** configure `datasource_url` before starting the gateway +- The `datasource_url` must exactly match the URL you register with Webex Contact Center via the BYoDS API +- If JWT validation is enabled without `datasource_url`, the gateway will **fail to start** +- For development without JWT validation, explicitly set `jwt_validation.enabled: false` + +**Connector Configuration**: +- Multiple connectors can be configured simultaneously +- Each connector must have a unique identifier (e.g., `local_audio_connector`, `aws_lex_connector`) +- Connectors are loaded dynamically based on the `module` and `class` specified -**IMPORTANT:** AWS credentials are NOT configured in config files for security reasons. +**AWS Credentials**: +- Prefer environment variables (`AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`) over hardcoded credentials +- Explicit credentials in config files should only be used for development/testing The connector uses the standard AWS credential chain (in order of precedence): 1. **Environment variables**: `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` @@ -196,6 +282,150 @@ aws configure # No credentials needed in config files! ``` +You can add these lines to your shell profile (e.g., `.bashrc`, `.zshrc`) or set them in your deployment environment. The gateway will automatically use these credentials if they are set. + +### JWT Authentication for gRPC Requests + +The gateway supports JWT (JSON Web Token) validation for all gRPC requests to ensure secure communication with Webex Contact Center. This feature validates tokens from Webex identity broker endpoints and verifies all required claims. + +#### Overview + +JWT authentication provides: +- **Signature Verification**: Validates tokens using RSA public keys from Webex JWKS endpoints +- **Claims Validation**: Verifies issuer, audience, subject, JWT ID, and expiration +- **Datasource Validation**: Ensures tokens are issued for the correct datasource URL and schema +- **Caching**: Public keys are cached for 60 minutes to reduce endpoint load +- **Optional Enforcement**: Can be configured to log violations without rejecting requests + +#### Configuration + +JWT validation is configured in the `jwt_validation` section of `config/config.yaml`. See the [Configuration](#configuration) section above for the complete configuration structure. + +**Key Points**: +- JWT validation is **enabled by default** (`enabled: true`) +- You **must** configure `datasource_url` or the gateway will fail to start +- Set `enabled: false` to disable JWT validation for development/testing + +#### How to Obtain Your Datasource URL + +The `datasource_url` must **EXACTLY match** (character-for-character) the URL you provide when registering your datasource via the [BYoDS (Bring Your Own Data Source)](https://developer.webex.com/webex-contact-center/docs/api/v1/data-sources) API. + +**Critical**: The JWT token from Webex Contact Center contains a `com.cisco.datasource.url` claim that must match this value exactly. Use the **exact same format** that you used in the BYoDS API registration. + +**Examples** (use whatever format YOU registered): + +```yaml +# If you registered with explicit :443 port +datasource_url: "https://your-gateway.example.com:443" + +# If you registered without port (common for standard HTTPS) +datasource_url: "https://your-gateway.example.com" + +# Ngrok URLs (check your BYoDS registration for exact format) +datasource_url: "https://abc123def456.ngrok-free.app" +``` + +**How to verify**: +1. Check your BYoDS datasource registration (via API or Control Hub) +2. Copy the EXACT URL you registered (character-for-character) +3. Paste it into `jwt_validation.datasource_url` in your config + +**Common mistakes**: +- ❌ Registered: `https://example.com` → Config: `https://example.com:443` (MISMATCH!) +- ❌ Registered: `https://example.com:443` → Config: `https://example.com` (MISMATCH!) +- ✅ Registered: `https://example.com` → Config: `https://example.com` (MATCH!) +- ✅ Registered: `https://example.com:443` → Config: `https://example.com:443` (MATCH!) + +#### Understanding the Datasource Schema UUID + +The `datasource_schema_uuid` identifies the specific schema definition used for communication between Webex Contact Center and your gateway. This UUID comes from the [Webex dataSourceSchemas repository](https://github.com/webex/dataSourceSchemas). + +**For BYOVA (Voice Virtual Agent)**: +- **Schema UUID**: `5397013b-7920-4ffc-807c-e8a3e0a18f43` +- **Schema Location**: `Services/VoiceVirtualAgent/5397013b-7920-4ffc-807c-e8a3e0a18f43/schema.json` +- **Proto Definitions**: Defined in the same directory structure +- **Stability**: This UUID should **not change** unless there is a major modification to the BYOVA schema definition by Webex + +**What is it?** +The schema UUID defines the structure of request and response payloads, protocol (gRPC), and supported app types. It ensures that both Webex Contact Center and your gateway are using the same communication protocol and message formats. + +**Do I need to change it?** +In most cases, **no**. The default value is the standard BYOVA schema UUID and will work for all standard BYOVA implementations. You would only change this if: +- Webex releases a new major version of the BYOVA schema +- You're using a different Webex Contact Center service schema (not BYOVA) + +**Reference**: [Webex dataSourceSchemas Documentation](https://github.com/webex/dataSourceSchemas) + +#### Supported Webex Regions + +The gateway validates tokens from these Webex identity broker issuers: +- `https://idbrokerbts.webex.com/idb` (BTS US) +- `https://idbrokerbts-eu.webex.com/idb` (BTS EU) +- `https://idbroker.webex.com/idb` (Production US) +- `https://idbroker-eu.webex.com/idb` (Production EU) +- `https://idbroker-b-us.webex.com/idb` (B-US) +- `https://idbroker-ca.webex.com/idb` (Canada) + +#### Token Format + +Tokens are expected in the gRPC metadata `authorization` header: +``` +authorization: Bearer +``` + +#### Deployment Recommendations + +**Development**: +```yaml +jwt_validation: + enabled: false # Or enabled: true with enforce_validation: false for testing +``` + +**Production**: +```yaml +jwt_validation: + enabled: true + enforce_validation: true + datasource_url: "https://your-production-url.com:443" +``` + +#### Troubleshooting + +**Error: "Missing JWT token in authorization metadata"** +- Ensure Webex Contact Center is configured to send JWT tokens with gRPC requests +- Verify your datasource is properly registered with Webex Contact Center + +**Error: "JWT token signature not valid"** +- Check that public keys can be fetched from Webex identity broker +- Verify your network allows outbound HTTPS connections to Webex endpoints + +**Error: "Invalid issuer"** +- The JWT token's issuer claim must be from a valid Webex identity broker +- **Security**: Issuer is validated BEFORE fetching keys to prevent SSRF attacks +- Supported issuers: + - `https://idbrokerbts.webex.com/idb` (BTS US) + - `https://idbrokerbts-eu.webex.com/idb` (BTS EU) + - `https://idbroker.webex.com/idb` (Production US) + - `https://idbroker-eu.webex.com/idb` (Production EU) + - `https://idbroker-b-us.webex.com/idb` (B-US) + - `https://idbroker-ca.webex.com/idb` (Canada) +- Verify your datasource is properly configured in Webex Contact Center +- If you see this error with a malformed issuer URL, it may indicate a security attack attempt + +**Error: "Datasource URL mismatch"** or "Datasource claims validation failed" +- Your `datasource_url` in config must EXACTLY match (character-for-character) the URL you registered via BYoDS API +- The JWT token contains a `com.cisco.datasource.url` claim that must match your config value exactly +- Check if you registered with or without the port (`:443`) and match it exactly +- Common issue: Config has `:443` but BYoDS registration doesn't (or vice versa) +- **To debug**: Set log level to DEBUG and check the log message showing expected vs actual URL +- **Solution**: Copy the exact URL from your BYoDS datasource registration and update your config + +**Error: "JWT token is expired"** +- This indicates Webex Contact Center sent an expired token +- Check system clock synchronization between your gateway and Webex services + +For gradual rollout, start with `enforce_validation: false` to log validation results without rejecting requests, then enable enforcement after verification. + ### Running the Server #### Method 1: Manual Start (Recommended for Development) diff --git a/config/config.yaml b/config/config.yaml index 25ce1d0..df04d05 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -75,6 +75,30 @@ authentication: scopes: "openid email profile" state: "byova_gateway_auth" +# JWT validation configuration for gRPC requests +jwt_validation: + # Enable/disable JWT validation (default: true for production security) + enabled: true + + # Enforce validation - if false, invalid tokens are logged but allowed (default: true) + enforce_validation: true + + # Datasource URL - must match the URL configured in your Webex Contact Center datasource + # This should be the publicly accessible URL of your gateway (including port if non-standard) + # Example: "https://your-gateway-domain.com:443" + datasource_url: "https://b9eb5df4443d.ngrok-free.app" + + # Datasource schema UUID - typically this is the BYOVA schema UUID + # This is the schema ID from https://github.com/webex/dataSourceSchemas + # Path: Services/VoiceVirtualAgent/5397013b-7920-4ffc-807c-e8a3e0a18f43/schema.json + # Default value is the standard BYOVA schema UUID and should not change unless + # there is a major modification to the BYOVA schema definition + datasource_schema_uuid: "5397013b-7920-4ffc-807c-e8a3e0a18f43" + + # Cache duration for public keys in minutes (default: 60) + # Public keys are fetched from Webex identity broker and cached to improve performance + cache_duration_minutes: 60 + # Logging configuration logging: gateway: diff --git a/main.py b/main.py index df46cfe..d41cef5 100644 --- a/main.py +++ b/main.py @@ -11,6 +11,7 @@ import threading from concurrent import futures from pathlib import Path +from typing import Optional import grpc import yaml @@ -19,12 +20,19 @@ sys.path.insert(0, str(Path(__file__).parent / "src")) sys.path.insert(0, str(Path(__file__).parent / "src" / "core")) +from grpc_health.v1 import health_pb2_grpc + +from auth.jwt_interceptor import JWTAuthInterceptor + +# Import JWT authentication components (required) +from auth.jwt_validator import JWTValidator +from core.health_service import HealthCheckService from core.virtual_agent_router import VirtualAgentRouter from core.wxcc_gateway_server import WxCCGatewayServer -from core.health_service import HealthCheckService from monitoring.app import run_web_app -from src.generated.voicevirtualagent_pb2_grpc import add_VoiceVirtualAgentServicer_to_server -from grpc_health.v1 import health_pb2_grpc +from src.generated.voicevirtualagent_pb2_grpc import ( + add_VoiceVirtualAgentServicer_to_server, +) def setup_logging(config: dict) -> None: @@ -155,23 +163,108 @@ def create_router_config(config: dict) -> dict: """ # The connectors config is already in the correct dictionary format connectors_config = config.get("connectors", {}) - + # Ensure each connector has the required fields for connector_id, connector_config in connectors_config.items(): if not isinstance(connector_config, dict): - raise ValueError(f"Connector {connector_id} configuration must be a dictionary") - + raise ValueError( + f"Connector {connector_id} configuration must be a dictionary" + ) + # Ensure required fields exist if "class" not in connector_config: raise ValueError(f"Connector {connector_id} missing required 'class' field") if "module" not in connector_config: - raise ValueError(f"Connector {connector_id} missing required 'module' field") + raise ValueError( + f"Connector {connector_id} missing required 'module' field" + ) if "config" not in connector_config: connectors_config[connector_id]["config"] = {} return {"connectors": connectors_config} +def create_jwt_interceptor( + config: dict, logger: logging.Logger +) -> Optional[JWTAuthInterceptor]: + """ + Create JWT authentication interceptor if configured. + + Args: + config: Configuration dictionary containing JWT settings + logger: Logger instance + + Returns: + JWTAuthInterceptor instance or None if not configured + + Raises: + ValueError: If JWT validation is enabled but datasource_url is not configured + """ + jwt_config = config.get("jwt_validation", {}) + + if not jwt_config.get("enabled", False): + logger.info("JWT validation is disabled in configuration") + return None + + # Validate required configuration + datasource_url = jwt_config.get("datasource_url", "") + if not datasource_url: + error_msg = ( + "JWT validation is enabled but datasource_url is not configured. " + "Please set jwt_validation.datasource_url in config.yaml or disable JWT validation." + ) + logger.error(error_msg) + raise ValueError(error_msg) + + # Get configuration values + datasource_schema_uuid = jwt_config.get( + "datasource_schema_uuid", "5397013b-7920-4ffc-807c-e8a3e0a18f43" + ) + cache_duration_minutes = jwt_config.get("cache_duration_minutes", 60) + enforce_validation = jwt_config.get("enforce_validation", True) + + try: + # Create JWT validator + validator = JWTValidator( + datasource_url=datasource_url, + datasource_schema_uuid=datasource_schema_uuid, + cache_duration_minutes=cache_duration_minutes, + ) + + # Create interceptor + interceptor = JWTAuthInterceptor( + jwt_validator=validator, + enabled=True, + enforce=enforce_validation, + ) + + logger.info("JWT authentication interceptor created successfully") + logger.info(f"Datasource URL: {datasource_url}") + logger.info( + f"Enforcement: {'ENABLED' if enforce_validation else 'DISABLED (logging only)'}" + ) + + return interceptor + + except Exception as e: + logger.error(f"Failed to create JWT interceptor: {e}") + + # If JWT validation is enabled, this is a fatal error + jwt_config = config.get("jwt_validation", {}) + if jwt_config.get("enabled", True): + error_msg = ( + "Failed to create JWT interceptor but JWT validation is enabled. " + "This is a fatal configuration error. Please check your configuration and dependencies." + ) + logger.error(error_msg) + raise RuntimeError(error_msg) from e + else: + logger.warning( + "JWT validation is disabled, continuing without JWT interceptor" + ) + return None + + def main(): """ Main entry point for the BYOVA Gateway. @@ -208,7 +301,7 @@ def main(): # Create WxCCGatewayServer server = WxCCGatewayServer(router) logger.info("WxCCGatewayServer created") - + # Create health service with router for real health monitoring health_service = HealthCheckService(router) logger.info("HealthCheckService created with real health monitoring") @@ -218,21 +311,38 @@ def main(): host = gateway_config.get("host", "0.0.0.0") port = gateway_config.get("port", 50051) - # Create gRPC server - grpc_server = grpc.server( - futures.ThreadPoolExecutor(max_workers=10), - options=[ - ("grpc.max_send_message_length", 50 * 1024 * 1024), # 50MB - ("grpc.max_receive_message_length", 50 * 1024 * 1024), # 50MB - ("grpc.max_concurrent_streams", 100), - ], - ) + # Create JWT interceptor if configured + jwt_interceptor = create_jwt_interceptor(config, logger) + interceptors = [] + if jwt_interceptor: + interceptors.append(jwt_interceptor) + + # Create gRPC server with interceptors + if interceptors: + grpc_server = grpc.server( + futures.ThreadPoolExecutor(max_workers=10), + interceptors=interceptors, + options=[ + ("grpc.max_send_message_length", 50 * 1024 * 1024), # 50MB + ("grpc.max_receive_message_length", 50 * 1024 * 1024), # 50MB + ("grpc.max_concurrent_streams", 100), + ], + ) + logger.info(f"gRPC server created with {len(interceptors)} interceptor(s)") + else: + grpc_server = grpc.server( + futures.ThreadPoolExecutor(max_workers=10), + options=[ + ("grpc.max_send_message_length", 50 * 1024 * 1024), # 50MB + ("grpc.max_receive_message_length", 50 * 1024 * 1024), # 50MB + ("grpc.max_concurrent_streams", 100), + ], + ) + logger.info("gRPC server created without interceptors") # Add servicer to the server - add_VoiceVirtualAgentServicer_to_server( - server, grpc_server - ) - + add_VoiceVirtualAgentServicer_to_server(server, grpc_server) + # Add health service to the server health_pb2_grpc.add_HealthServicer_to_server(health_service, grpc_server) logger.info("Health service registered with gRPC server") @@ -298,6 +408,20 @@ def main(): else: print(" • Disabled") + print() + print("🔐 JWT Authentication:") + jwt_config = config.get("jwt_validation", {}) + if jwt_config.get("enabled", False) and jwt_interceptor: + print(" • Status: ENABLED") + print( + f" • Enforcement: {'ENABLED' if jwt_config.get('enforce_validation', True) else 'DISABLED (logging only)'}" + ) + print( + f" • Datasource URL: {jwt_config.get('datasource_url', 'Not configured')}" + ) + else: + print(" • Status: DISABLED") + print() print("✅ Gateway is running! Press Ctrl+C to stop.") print("=" * 60) diff --git a/requirements.txt b/requirements.txt index cda4d00..45851f6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,7 +14,8 @@ click>=8.2.1 blinker>=1.9.0 # Authentication and HTTP -PyJWT>=2.8.0 +PyJWT[crypto]>=2.8.0 # JWT with cryptography support for RSA signature verification +cryptography>=43.0.0 # Cryptographic operations for JWT validation requests>=2.31.0 # Audio processing diff --git a/src/auth/__init__.py b/src/auth/__init__.py new file mode 100644 index 0000000..a73cf68 --- /dev/null +++ b/src/auth/__init__.py @@ -0,0 +1,6 @@ +""" +Authentication module for Webex Contact Center BYOVA Gateway. + +This module provides JWT validation and authentication functionality +for securing gRPC requests. +""" diff --git a/src/auth/jwt_interceptor.py b/src/auth/jwt_interceptor.py new file mode 100644 index 0000000..885995b --- /dev/null +++ b/src/auth/jwt_interceptor.py @@ -0,0 +1,175 @@ +""" +gRPC Interceptor for JWT validation. + +This module implements a gRPC server interceptor that validates JWT tokens +from request metadata before allowing requests to proceed. +""" + +import logging +from typing import Callable + +import grpc + +from .jwt_validator import AccessTokenException, JWTValidator + + +class JWTAuthInterceptor(grpc.ServerInterceptor): + """ + gRPC server interceptor that validates JWT tokens from metadata. + + This interceptor: + - Extracts JWT token from 'authorization' metadata header + - Validates the token using JWTValidator + - Returns UNAUTHENTICATED status on validation failure (if enforcement enabled) + - Supports optional enforcement for gradual rollout + """ + + def __init__( + self, + jwt_validator: JWTValidator, + enabled: bool = True, + enforce: bool = True, + ): + """ + Initialize JWT authentication interceptor. + + Args: + jwt_validator: JWTValidator instance for token validation + enabled: Whether JWT validation is enabled + enforce: Whether to reject requests with invalid tokens + """ + self.jwt_validator = jwt_validator + self.enabled = enabled + self.enforce = enforce + self.logger = logging.getLogger(__name__) + + if not self.enabled: + self.logger.warning( + "JWT validation is DISABLED - all requests will be allowed" + ) + elif not self.enforce: + self.logger.warning( + "JWT validation enforcement is DISABLED - invalid tokens will be logged but allowed" + ) + else: + self.logger.info("JWT validation is ENABLED and ENFORCED") + + def intercept_service( + self, continuation: Callable, handler_call_details: grpc.HandlerCallDetails + ) -> grpc.RpcMethodHandler: + """ + Intercept gRPC service calls to validate JWT tokens. + + Args: + continuation: Function to invoke the next interceptor or handler + handler_call_details: Details about the RPC call + + Returns: + RPC method handler + """ + # If validation is disabled, proceed without checking + if not self.enabled: + return continuation(handler_call_details) + + # Extract token from metadata + metadata = dict(handler_call_details.invocation_metadata) + token = None + + # Look for authorization header (case-insensitive) + for key in metadata: + if key.lower() == "authorization": + auth_value = metadata[key] + # Handle "Bearer " format + if auth_value.startswith("Bearer "): + token = auth_value[7:] # Remove "Bearer " prefix + else: + token = auth_value + break + + # Validate token + if not token: + error_msg = "Missing JWT token in authorization metadata" + self.logger.warning(f"{error_msg} for method {handler_call_details.method}") + + if self.enforce: + # Return error handler that aborts with UNAUTHENTICATED + return self._abort_unauthenticated(error_msg) + else: + self.logger.info( + "Allowing request without token (enforcement disabled)" + ) + return continuation(handler_call_details) + + # Validate the token + try: + self.jwt_validator.validate_token(token) + self.logger.debug( + f"JWT validated successfully for method {handler_call_details.method}" + ) + + except AccessTokenException as e: + error_msg = f"JWT validation failed: {str(e)}" + self.logger.error(f"{error_msg} for method {handler_call_details.method}") + + if self.enforce: + # Return error handler that aborts with UNAUTHENTICATED + return self._abort_unauthenticated(error_msg) + else: + self.logger.warning( + "Allowing request with invalid token (enforcement disabled)" + ) + + except Exception as e: + error_msg = f"Unexpected error during JWT validation: {str(e)}" + self.logger.error(f"{error_msg} for method {handler_call_details.method}") + + if self.enforce: + # Return error handler that aborts with INTERNAL + return self._abort_internal(error_msg) + else: + self.logger.warning( + "Allowing request due to validation error (enforcement disabled)" + ) + + # Proceed with the request + return continuation(handler_call_details) + + def _abort_unauthenticated(self, error_message: str) -> grpc.RpcMethodHandler: + """ + Create a handler that aborts with UNAUTHENTICATED status. + + Args: + error_message: Error message to return + + Returns: + RPC method handler that aborts the call + """ + + def abort(request, context): + context.abort(grpc.StatusCode.UNAUTHENTICATED, error_message) + + return grpc.unary_unary_rpc_method_handler( + abort, + request_deserializer=lambda x: x, + response_serializer=lambda x: x, + ) + + def _abort_internal(self, error_message: str) -> grpc.RpcMethodHandler: + """ + Create a handler that aborts with INTERNAL status. + + Args: + error_message: Error message to return + + Returns: + RPC method handler that aborts the call + """ + + def abort(request, context): + context.abort(grpc.StatusCode.INTERNAL, error_message) + + return grpc.unary_unary_rpc_method_handler( + abort, + request_deserializer=lambda x: x, + response_serializer=lambda x: x, + ) diff --git a/src/auth/jwt_validator.py b/src/auth/jwt_validator.py new file mode 100644 index 0000000..983f0e9 --- /dev/null +++ b/src/auth/jwt_validator.py @@ -0,0 +1,361 @@ +""" +JWT Validator for Webex Contact Center BYOVA Gateway. + +This module implements JWS/JWT token validation for gRPC requests, +verifying tokens against Webex identity broker public keys. +""" + +import logging +import threading +import time +from typing import Any, Dict + +import jwt +import requests +from jwt import PyJWK +from jwt.exceptions import ( + DecodeError, + ExpiredSignatureError, + InvalidSignatureError, +) + + +class AccessTokenException(Exception): + """Exception raised when token validation fails.""" + + pass + + +class JWTValidator: + """ + Validates JWT tokens from Webex Contact Center. + + This validator: + - Fetches public keys from Webex JWKS endpoints + - Validates JWT signatures using RSA public keys + - Verifies token expiration and claims + - Caches public keys for improved performance + - Handles rate limiting by falling back to cache + """ + + # Valid Webex identity broker issuers for different regions + VALID_ISSUERS = [ + "https://idbrokerbts.webex.com/idb", # BTS US + "https://idbrokerbts-eu.webex.com/idb", # BTS EU + "https://idbroker.webex.com/idb", # Production US + "https://idbroker-eu.webex.com/idb", # Production EU + "https://idbroker-b-us.webex.com/idb", # B-US + "https://idbroker-ca.webex.com/idb", # Canada + ] + + # Datasource claim keys + DATASOURCE_URL_KEY = "com.cisco.datasource.url" + DATASOURCE_SCHEMA_KEY = "com.cisco.datasource.schema.uuid" + + # Default BYOVA schema UUID from https://github.com/webex/dataSourceSchemas + # Path: Services/VoiceVirtualAgent/5397013b-7920-4ffc-807c-e8a3e0a18f43/schema.json + # This should not change unless there is a major modification to the BYOVA schema + DEFAULT_SCHEMA_UUID = "5397013b-7920-4ffc-807c-e8a3e0a18f43" + + def __init__( + self, + datasource_url: str, + datasource_schema_uuid: str = None, + cache_duration_minutes: int = 60, + ): + """ + Initialize JWT validator. + + Args: + datasource_url: Expected datasource URL for claim validation + datasource_schema_uuid: Expected schema UUID (default is standard BYOVA schema + from https://github.com/webex/dataSourceSchemas). If None, uses DEFAULT_SCHEMA_UUID. + cache_duration_minutes: How long to cache public keys (default 60 minutes) + """ + self.datasource_url = datasource_url + self.datasource_schema_uuid = datasource_schema_uuid or self.DEFAULT_SCHEMA_UUID + self.cache_duration_seconds = cache_duration_minutes * 60 + + # Cache for public keys by issuer + self._public_keys_cache: Dict[str, Dict[str, Any]] = {} + self._cache_lock = threading.RLock() + + self.logger = logging.getLogger(__name__) + self.logger.info( + f"JWTValidator initialized with datasource URL: {datasource_url}" + ) + + def validate_token(self, token: str) -> bool: + """ + Validate a JWT token. + + Args: + token: The JWT token string to validate + + Returns: + True if token is valid + + Raises: + AccessTokenException: If token validation fails + """ + try: + # Decode token header to get issuer without verification + unverified_token = jwt.decode( + token, + options={ + "verify_signature": False, + "verify_exp": False, + "verify_aud": False, + }, + ) + + issuer = unverified_token.get("iss") + if not issuer: + raise AccessTokenException("Token missing 'iss' claim") + + # SECURITY: Validate issuer BEFORE fetching keys to prevent SSRF attacks + if issuer not in self.VALID_ISSUERS: + self.logger.error( + f"Invalid issuer: {issuer}. Must be one of: {self.VALID_ISSUERS}" + ) + raise AccessTokenException( + f"Invalid issuer: {issuer}. Must be one of the allowed Webex identity brokers." + ) + + # Debug logging + self.logger.info(f"Validating token from issuer: {issuer}") + self.logger.debug(f"Token claims (unverified): {unverified_token.keys()}") + + # Fetch public keys for this issuer (issuer is now validated) + public_keys = self._fetch_public_keys(issuer) + num_keys = len(public_keys.get("keys", [])) + self.logger.info(f"Fetched {num_keys} public key(s) from JWKS endpoint") + + # Try to validate signature with each public key + is_valid_signature = False + decoded_token = None + last_error = None + + for idx, key_data in enumerate(public_keys.get("keys", [])): + try: + # Check if key_data is already a key object (from tests) or a JWK dict + if isinstance(key_data, dict): + # It's a JWK dictionary from JWKS endpoint + kid = key_data.get("kid", "unknown") + kty = key_data.get("kty", "unknown") + alg = key_data.get("alg", "RS256") + + self.logger.debug( + f"Trying key {idx + 1}/{num_keys}: kid={kid}, kty={kty}, alg={alg}" + ) + + # Convert JWK to RSA public key + jwk = PyJWK(key_data) + public_key = jwk.key + else: + # It's already a key object (likely from tests) + self.logger.debug( + f"Trying key {idx + 1}/{num_keys} (direct key object)" + ) + public_key = key_data + kid = "direct-key" + + # Validate JWT signature and decode + decoded_token = jwt.decode( + token, + key=public_key, + algorithms=["RS256"], + options={ + "verify_signature": True, + "verify_exp": True, + "verify_aud": False, # We'll verify manually + }, + ) + is_valid_signature = True + self.logger.info( + f"JWT signature validated successfully with key kid={kid}" + ) + break + except (InvalidSignatureError, DecodeError) as e: + # Try next key - this is expected if key doesn't match + key_id = kid if "kid" in locals() else f"key-{idx}" + self.logger.debug(f"Key {key_id} signature validation failed: {e}") + last_error = e + continue + except ExpiredSignatureError as e: + self.logger.error("JWT token is expired") + raise AccessTokenException("JWT token is expired") from e + except Exception as e: + # Unexpected error with this key + key_id = kid if "kid" in locals() else f"key-{idx}" + self.logger.warning( + f"Unexpected error validating with key {key_id}: {e}" + ) + last_error = e + continue + + if not is_valid_signature or not decoded_token: + error_msg = f"JWT token signature not valid. Tried {num_keys} key(s)." + if last_error: + error_msg += f" Last error: {last_error}" + self.logger.error(error_msg) + raise AccessTokenException("JWT token signature not valid") + + # Verify all claims + if not self._verify_claims(decoded_token): + self.logger.error("Claims validation failed") + raise AccessTokenException("Claims validation failed") + + if not self._verify_datasource_claims(decoded_token): + self.logger.error("Datasource claims validation failed") + raise AccessTokenException("Datasource claims validation failed") + + self.logger.info("JWT token validated successfully") + return True + + except AccessTokenException: + raise + except Exception as e: + self.logger.error(f"Token validation failed: {e}") + raise AccessTokenException(f"Token validation failed: {str(e)}") from e + + def _fetch_public_keys(self, issuer: str) -> Dict[str, Any]: + """ + Fetch public keys from JWKS endpoint with caching. + + Args: + issuer: The issuer URL from the token + + Returns: + Dictionary containing public keys + + Raises: + AccessTokenException: If keys cannot be fetched + """ + with self._cache_lock: + current_time = time.time() + + # Check cache first + if issuer in self._public_keys_cache: + cached_data = self._public_keys_cache[issuer] + if current_time < cached_data.get("expiration_at", 0): + self.logger.debug("Returning cached public keys") + return cached_data.get("keys_data", {}) + + # Fetch fresh keys + try: + # Construct JWKS URL - issuer is guaranteed to be valid at this point + jwks_url = f"{issuer}/oauth2/v2/keys/verificationjwk" + self.logger.debug(f"Fetching public keys from: {jwks_url}") + + response = requests.get(jwks_url, timeout=10) + + if response.status_code == 200: + self.logger.info("Public keys fetched successfully") + keys_data = response.json() + + # Cache the keys + self._public_keys_cache[issuer] = { + "keys_data": keys_data, + "expiration_at": current_time + self.cache_duration_seconds, + } + + return keys_data + + elif response.status_code == 429: + # Rate limited - try to use cached keys even if expired + self.logger.warning( + "Rate limit exceeded, attempting to use cached keys" + ) + if issuer in self._public_keys_cache: + self.logger.info("Using cached public keys despite rate limit") + return self._public_keys_cache[issuer].get("keys_data", {}) + else: + raise AccessTokenException( + "Rate limit exceeded and no cached public keys available" + ) + + else: + error_message = ( + f"Failed to fetch public keys: HTTP {response.status_code}" + ) + if response.text: + error_message += f" - {response.text}" + raise AccessTokenException(error_message) + + except requests.RequestException as e: + self.logger.error(f"Error fetching public keys: {e}") + # Try to use cached keys on network error + if issuer in self._public_keys_cache: + self.logger.warning("Using cached keys due to network error") + return self._public_keys_cache[issuer].get("keys_data", {}) + raise AccessTokenException( + f"Error while fetching public keys: {str(e)}" + ) from e + + def _verify_claims(self, decoded_token: Dict[str, Any]) -> bool: + """ + Verify standard JWT claims. + + Args: + decoded_token: The decoded JWT token + + Returns: + True if all required claims are valid + """ + try: + # Verify issuer + issuer = decoded_token.get("iss") + if not issuer or issuer not in self.VALID_ISSUERS: + self.logger.error(f"Invalid or missing issuer: {issuer}") + return False + + # Verify required claims are present + required_claims = ["aud", "sub", "jti"] + for claim in required_claims: + if claim not in decoded_token or not decoded_token[claim]: + self.logger.error(f"Missing or empty required claim: {claim}") + return False + + self.logger.debug("Standard claims validated successfully") + return True + + except Exception as e: + self.logger.error(f"Error verifying claims: {e}") + return False + + def _verify_datasource_claims(self, decoded_token: Dict[str, Any]) -> bool: + """ + Verify datasource-specific claims. + + Args: + decoded_token: The decoded JWT token + + Returns: + True if datasource claims are valid + """ + try: + # Verify datasource URL + token_datasource_url = decoded_token.get(self.DATASOURCE_URL_KEY) + if token_datasource_url != self.datasource_url: + self.logger.error( + f"Datasource URL mismatch. Expected: {self.datasource_url}, " + f"Got: {token_datasource_url}" + ) + return False + + # Verify datasource schema UUID + token_schema_uuid = decoded_token.get(self.DATASOURCE_SCHEMA_KEY) + if token_schema_uuid != self.datasource_schema_uuid: + self.logger.error( + f"Datasource schema UUID mismatch. Expected: {self.datasource_schema_uuid}, " + f"Got: {token_schema_uuid}" + ) + return False + + self.logger.debug("Datasource claims validated successfully") + return True + + except Exception as e: + self.logger.error(f"Error verifying datasource claims: {e}") + return False diff --git a/tests/test_jwt_validation.py b/tests/test_jwt_validation.py new file mode 100644 index 0000000..3920e58 --- /dev/null +++ b/tests/test_jwt_validation.py @@ -0,0 +1,514 @@ +""" +Unit tests for JWT validation functionality. + +This module tests the JWT validator and interceptor to ensure proper +token validation, signature verification, and claims checking. +""" + +import time +from datetime import datetime, timedelta, timezone +from unittest.mock import Mock, patch + +import grpc +import jwt +import pytest +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa + +from src.auth.jwt_interceptor import JWTAuthInterceptor +from src.auth.jwt_validator import AccessTokenException, JWTValidator + + +class TestJWTValidator: + """Test cases for JWTValidator class.""" + + @pytest.fixture + def rsa_keys(self): + """Generate RSA key pair for testing.""" + private_key = rsa.generate_private_key( + public_exponent=65537, key_size=2048, backend=default_backend() + ) + public_key = private_key.public_key() + + # Get PEM format + private_pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + public_pem = public_key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + + return { + "private_key": private_key, + "public_key": public_key, + "private_pem": private_pem, + "public_pem": public_pem, + } + + @pytest.fixture + def validator(self): + """Create a JWTValidator instance for testing.""" + return JWTValidator( + datasource_url="https://test-gateway.example.com:443", + datasource_schema_uuid="5397013b-7920-4ffc-807c-e8a3e0a18f43", + cache_duration_minutes=60, + ) + + def test_default_schema_uuid(self): + """Test that default schema UUID is used when not provided.""" + validator = JWTValidator(datasource_url="https://test-gateway.example.com:443") + assert ( + validator.datasource_schema_uuid == "5397013b-7920-4ffc-807c-e8a3e0a18f43" + ) + assert validator.datasource_schema_uuid == JWTValidator.DEFAULT_SCHEMA_UUID + + @pytest.fixture + def valid_claims(self): + """Generate valid JWT claims.""" + return { + "iss": "https://idbroker.webex.com/idb", + "aud": "test-audience", + "sub": "test-subject", + "jti": "test-jwt-id", + "exp": datetime.now(timezone.utc) + timedelta(hours=1), + "com.cisco.datasource.url": "https://test-gateway.example.com:443", + "com.cisco.datasource.schema.uuid": "5397013b-7920-4ffc-807c-e8a3e0a18f43", + } + + def create_test_token(self, claims, private_key): + """Create a test JWT token.""" + return jwt.encode(claims, private_key, algorithm="RS256") + + def test_valid_token_validation(self, validator, rsa_keys, valid_claims): + """Test successful validation of a valid token.""" + token = self.create_test_token(valid_claims, rsa_keys["private_key"]) + + # Mock the public key fetching + mock_jwks_response = { + "keys": [ + { + "kty": "RSA", + "use": "sig", + "kid": "test-key-id", + "n": jwt.utils.base64url_encode( + rsa_keys["public_key"].public_numbers().n.to_bytes(256, "big") + ).decode("utf-8"), + "e": jwt.utils.base64url_encode( + rsa_keys["public_key"].public_numbers().e.to_bytes(3, "big") + ).decode("utf-8"), + } + ] + } + + with patch.object( + validator, "_fetch_public_keys", return_value=mock_jwks_response + ): + # Mock jwt.decode to use our public key + with patch("jwt.decode") as mock_decode: + mock_decode.return_value = valid_claims + assert validator.validate_token(token) is True + + def test_expired_token_rejection(self, validator, rsa_keys, valid_claims): + """Test rejection of an expired token.""" + expired_claims = valid_claims.copy() + expired_claims["exp"] = datetime.now(timezone.utc) - timedelta(hours=1) + + token = self.create_test_token(expired_claims, rsa_keys["private_key"]) + + # Mock the public key fetching + mock_jwks_response = {"keys": [rsa_keys["public_key"]]} + + with patch.object( + validator, "_fetch_public_keys", return_value=mock_jwks_response + ): + # The first jwt.decode call (unverified) should succeed, second should fail + with patch("src.auth.jwt_validator.jwt.decode") as mock_decode: + # First call returns the claims, second call raises ExpiredSignatureError + mock_decode.side_effect = [ + expired_claims, # Unverified decode succeeds + jwt.ExpiredSignatureError("Token expired"), # Verified decode fails + ] + with pytest.raises(AccessTokenException, match="JWT token is expired"): + validator.validate_token(token) + + def test_invalid_signature_rejection(self, validator, rsa_keys, valid_claims): + """Test rejection of a token with invalid signature.""" + # Create token with one key + token = self.create_test_token(valid_claims, rsa_keys["private_key"]) + + # Try to validate with a different key + other_private_key = rsa.generate_private_key( + public_exponent=65537, key_size=2048, backend=default_backend() + ) + + mock_jwks_response = {"keys": [other_private_key.public_key()]} + + with patch.object( + validator, "_fetch_public_keys", return_value=mock_jwks_response + ): + # The first jwt.decode call (unverified) should succeed, second should fail + with patch("src.auth.jwt_validator.jwt.decode") as mock_decode: + # First call returns the claims, second call raises InvalidSignatureError + mock_decode.side_effect = [ + valid_claims, # Unverified decode succeeds + jwt.InvalidSignatureError( + "Invalid signature" + ), # Verified decode fails + ] + with pytest.raises( + AccessTokenException, match="JWT token signature not valid" + ): + validator.validate_token(token) + + def test_missing_issuer_rejection(self, validator, rsa_keys, valid_claims): + """Test rejection of token missing issuer claim.""" + invalid_claims = valid_claims.copy() + del invalid_claims["iss"] + + token = self.create_test_token(invalid_claims, rsa_keys["private_key"]) + + with pytest.raises(AccessTokenException, match="Token missing 'iss' claim"): + validator.validate_token(token) + + def test_invalid_issuer_rejection(self, validator, rsa_keys, valid_claims): + """Test rejection of token with invalid issuer (SSRF prevention).""" + invalid_claims = valid_claims.copy() + invalid_claims["iss"] = "https://evil-attacker.com/malicious" + + token = self.create_test_token(invalid_claims, rsa_keys["private_key"]) + + # Mock jwt.decode for the unverified decode call + with patch("jwt.decode", return_value=invalid_claims): + # The validator should reject BEFORE attempting to fetch keys + with pytest.raises( + AccessTokenException, match="Invalid issuer.*Must be one of the allowed" + ): + validator.validate_token(token) + + def test_invalid_issuer_prevents_ssrf(self, validator, rsa_keys, valid_claims): + """Test that invalid issuer is rejected BEFORE fetching keys (SSRF prevention).""" + invalid_claims = valid_claims.copy() + invalid_claims["iss"] = "https://attacker-controlled-server.com/idb" + + token = self.create_test_token(invalid_claims, rsa_keys["private_key"]) + + # Mock _fetch_public_keys to verify it is NEVER called with invalid issuer + with patch.object(validator, "_fetch_public_keys") as mock_fetch: + with patch("jwt.decode", return_value=invalid_claims): + with pytest.raises(AccessTokenException, match="Invalid issuer"): + validator.validate_token(token) + + # Verify that _fetch_public_keys was NEVER called (SSRF prevented) + mock_fetch.assert_not_called() + + def test_missing_required_claims_rejection(self, validator, rsa_keys, valid_claims): + """Test rejection of token missing required claims.""" + for claim in ["aud", "sub", "jti"]: + invalid_claims = valid_claims.copy() + del invalid_claims[claim] + + token = self.create_test_token(invalid_claims, rsa_keys["private_key"]) + + mock_jwks_response = {"keys": [rsa_keys["public_key"]]} + + with patch.object( + validator, "_fetch_public_keys", return_value=mock_jwks_response + ): + with patch("jwt.decode", return_value=invalid_claims): + with pytest.raises( + AccessTokenException, match="Claims validation failed" + ): + validator.validate_token(token) + + def test_datasource_url_mismatch_rejection(self, validator, rsa_keys, valid_claims): + """Test rejection of token with mismatched datasource URL.""" + invalid_claims = valid_claims.copy() + invalid_claims["com.cisco.datasource.url"] = "https://wrong-url.com:443" + + token = self.create_test_token(invalid_claims, rsa_keys["private_key"]) + + mock_jwks_response = {"keys": [rsa_keys["public_key"]]} + + with patch.object( + validator, "_fetch_public_keys", return_value=mock_jwks_response + ): + with patch("jwt.decode", return_value=invalid_claims): + with pytest.raises( + AccessTokenException, match="Datasource claims validation failed" + ): + validator.validate_token(token) + + def test_datasource_schema_uuid_mismatch_rejection( + self, validator, rsa_keys, valid_claims + ): + """Test rejection of token with mismatched schema UUID.""" + invalid_claims = valid_claims.copy() + invalid_claims["com.cisco.datasource.schema.uuid"] = "wrong-uuid" + + token = self.create_test_token(invalid_claims, rsa_keys["private_key"]) + + mock_jwks_response = {"keys": [rsa_keys["public_key"]]} + + with patch.object( + validator, "_fetch_public_keys", return_value=mock_jwks_response + ): + with patch("jwt.decode", return_value=invalid_claims): + with pytest.raises( + AccessTokenException, match="Datasource claims validation failed" + ): + validator.validate_token(token) + + def test_public_key_caching(self, validator): + """Test that public keys are cached properly.""" + issuer = "https://idbroker.webex.com/idb" + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"keys": [{"test": "key"}]} + + with patch("requests.get", return_value=mock_response) as mock_get: + # First call should fetch keys + result1 = validator._fetch_public_keys(issuer) + assert mock_get.call_count == 1 + + # Second call should use cache + result2 = validator._fetch_public_keys(issuer) + assert mock_get.call_count == 1 # Still 1, no new call + + assert result1 == result2 + + def test_rate_limit_handling(self, validator): + """Test handling of rate limit (429) response.""" + issuer = "https://idbroker.webex.com/idb" + + # Pre-populate cache + validator._public_keys_cache[issuer] = { + "keys_data": {"keys": [{"test": "key"}]}, + "expiration_at": time.time() - 1, # Expired + } + + # Mock rate limit response + mock_response = Mock() + mock_response.status_code = 429 + + with patch("requests.get", return_value=mock_response): + # Should return cached keys despite being expired + result = validator._fetch_public_keys(issuer) + assert result == {"keys": [{"test": "key"}]} + + def test_network_error_fallback_to_cache(self, validator): + """Test fallback to cache on network error.""" + issuer = "https://idbroker.webex.com/idb" + + # Pre-populate cache + validator._public_keys_cache[issuer] = { + "keys_data": {"keys": [{"test": "key"}]}, + "expiration_at": time.time() + 3600, + } + + # Mock network error + with patch("requests.get", side_effect=Exception("Network error")): + result = validator._fetch_public_keys(issuer) + assert result == {"keys": [{"test": "key"}]} + + +class TestJWTAuthInterceptor: + """Test cases for JWTAuthInterceptor class.""" + + @pytest.fixture + def mock_validator(self): + """Create a mock JWTValidator.""" + validator = Mock(spec=JWTValidator) + validator.validate_token = Mock(return_value=True) + return validator + + @pytest.fixture + def interceptor(self, mock_validator): + """Create a JWTAuthInterceptor for testing.""" + return JWTAuthInterceptor( + jwt_validator=mock_validator, + enabled=True, + enforce=True, + ) + + @pytest.fixture + def mock_handler_call_details(self): + """Create mock handler call details.""" + details = Mock(spec=grpc.HandlerCallDetails) + details.method = "/test.Service/TestMethod" + return details + + def test_valid_token_allowed(self, interceptor, mock_handler_call_details): + """Test that requests with valid tokens are allowed.""" + # Set up metadata with valid token + mock_handler_call_details.invocation_metadata = [ + ("authorization", "Bearer valid-token-here") + ] + + # Mock continuation + mock_continuation = Mock(return_value="handler") + + # Intercept + result = interceptor.intercept_service( + mock_continuation, mock_handler_call_details + ) + + # Should call continuation + mock_continuation.assert_called_once() + assert result == "handler" + + def test_missing_token_rejected_when_enforced( + self, interceptor, mock_handler_call_details + ): + """Test that requests without tokens are rejected when enforcement is enabled.""" + # Set up metadata without token + mock_handler_call_details.invocation_metadata = [] + + # Mock continuation + mock_continuation = Mock(return_value="handler") + + # Intercept + result = interceptor.intercept_service( + mock_continuation, mock_handler_call_details + ) + + # Should NOT call continuation + mock_continuation.assert_not_called() + + # Result should be an abort handler (not the continuation result) + assert result != "handler" + + def test_missing_token_allowed_when_not_enforced(self, mock_validator): + """Test that requests without tokens are allowed when enforcement is disabled.""" + interceptor = JWTAuthInterceptor( + jwt_validator=mock_validator, + enabled=True, + enforce=False, + ) + + mock_handler_call_details = Mock(spec=grpc.HandlerCallDetails) + mock_handler_call_details.method = "/test.Service/TestMethod" + mock_handler_call_details.invocation_metadata = [] + + mock_continuation = Mock(return_value="handler") + + result = interceptor.intercept_service( + mock_continuation, mock_handler_call_details + ) + + # Should call continuation even without token + mock_continuation.assert_called_once() + assert result == "handler" + + def test_invalid_token_rejected_when_enforced( + self, mock_validator, mock_handler_call_details + ): + """Test that requests with invalid tokens are rejected when enforcement is enabled.""" + # Set up validator to reject token + mock_validator.validate_token.side_effect = AccessTokenException( + "Invalid token" + ) + + interceptor = JWTAuthInterceptor( + jwt_validator=mock_validator, + enabled=True, + enforce=True, + ) + + mock_handler_call_details.invocation_metadata = [ + ("authorization", "Bearer invalid-token") + ] + + mock_continuation = Mock(return_value="handler") + + result = interceptor.intercept_service( + mock_continuation, mock_handler_call_details + ) + + # Should NOT call continuation + mock_continuation.assert_not_called() + assert result != "handler" + + def test_invalid_token_allowed_when_not_enforced(self, mock_validator): + """Test that requests with invalid tokens are allowed when enforcement is disabled.""" + mock_validator.validate_token.side_effect = AccessTokenException( + "Invalid token" + ) + + interceptor = JWTAuthInterceptor( + jwt_validator=mock_validator, + enabled=True, + enforce=False, + ) + + mock_handler_call_details = Mock(spec=grpc.HandlerCallDetails) + mock_handler_call_details.method = "/test.Service/TestMethod" + mock_handler_call_details.invocation_metadata = [ + ("authorization", "Bearer invalid-token") + ] + + mock_continuation = Mock(return_value="handler") + + result = interceptor.intercept_service( + mock_continuation, mock_handler_call_details + ) + + # Should call continuation even with invalid token + mock_continuation.assert_called_once() + assert result == "handler" + + def test_disabled_interceptor_allows_all(self, mock_validator): + """Test that disabled interceptor allows all requests.""" + interceptor = JWTAuthInterceptor( + jwt_validator=mock_validator, + enabled=False, + enforce=True, + ) + + mock_handler_call_details = Mock(spec=grpc.HandlerCallDetails) + mock_handler_call_details.method = "/test.Service/TestMethod" + mock_handler_call_details.invocation_metadata = [] + + mock_continuation = Mock(return_value="handler") + + result = interceptor.intercept_service( + mock_continuation, mock_handler_call_details + ) + + # Should call continuation without checking token + mock_continuation.assert_called_once() + assert result == "handler" + mock_validator.validate_token.assert_not_called() + + def test_bearer_token_format_handling(self, interceptor, mock_handler_call_details): + """Test proper handling of 'Bearer ' format.""" + mock_handler_call_details.invocation_metadata = [ + ("authorization", "Bearer test-token-123") + ] + + mock_continuation = Mock(return_value="handler") + + interceptor.intercept_service(mock_continuation, mock_handler_call_details) + + # Should extract token without "Bearer " prefix + interceptor.jwt_validator.validate_token.assert_called_once_with( + "test-token-123" + ) + + def test_direct_token_format_handling(self, interceptor, mock_handler_call_details): + """Test handling of token without 'Bearer ' prefix.""" + mock_handler_call_details.invocation_metadata = [ + ("authorization", "test-token-123") + ] + + mock_continuation = Mock(return_value="handler") + + interceptor.intercept_service(mock_continuation, mock_handler_call_details) + + # Should use token as-is + interceptor.jwt_validator.validate_token.assert_called_once_with( + "test-token-123" + ) diff --git a/tests/test_main_jwt_interceptor.py b/tests/test_main_jwt_interceptor.py new file mode 100644 index 0000000..c20fdc0 --- /dev/null +++ b/tests/test_main_jwt_interceptor.py @@ -0,0 +1,200 @@ +""" +Unit tests for JWT interceptor creation in main.py + +This module tests the strict error handling for JWT authentication +configuration and initialization. +""" + +import logging + +# Import the function we're testing +import sys +from pathlib import Path +from unittest.mock import Mock, patch + +import pytest + +# Add src to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent / "src")) + +# Import after path is set +from main import create_jwt_interceptor + + +class TestCreateJWTInterceptor: + """Test cases for create_jwt_interceptor function in main.py""" + + @pytest.fixture + def logger(self): + """Create a test logger.""" + return logging.getLogger("test") + + @pytest.fixture + def valid_config(self): + """Create a valid JWT configuration.""" + return { + "jwt_validation": { + "enabled": True, + "enforce_validation": True, + "datasource_url": "https://test-gateway.example.com:443", + "datasource_schema_uuid": "5397013b-7920-4ffc-807c-e8a3e0a18f43", + "cache_duration_minutes": 60, + } + } + + def test_jwt_enabled_missing_datasource_url_raises_valueerror(self, logger): + """Test that missing datasource_url raises ValueError when JWT is enabled.""" + config = { + "jwt_validation": { + "enabled": True, + "datasource_url": "", # Empty/missing + } + } + + with pytest.raises(ValueError, match="datasource_url is not configured"): + create_jwt_interceptor(config, logger) + + def test_jwt_disabled_missing_datasource_url_returns_none(self, logger): + """Test that missing datasource_url returns None when JWT is disabled.""" + config = { + "jwt_validation": { + "enabled": False, + "datasource_url": "", # Empty + } + } + + result = create_jwt_interceptor(config, logger) + assert result is None + + def test_jwt_enabled_interceptor_creation_fails_raises_runtime_error( + self, logger, valid_config + ): + """Test that interceptor creation failure raises RuntimeError when JWT is enabled.""" + # Mock JWTValidator to raise an exception + with patch("main.JWTValidator") as mock_validator: + mock_validator.side_effect = Exception("Test error during initialization") + + with pytest.raises( + RuntimeError, + match="Failed to create JWT interceptor but JWT validation is enabled", + ): + create_jwt_interceptor(valid_config, logger) + + def test_jwt_disabled_interceptor_creation_fails_returns_none(self, logger): + """Test that interceptor creation failure returns None when JWT is disabled.""" + config = { + "jwt_validation": { + "enabled": False, + "datasource_url": "https://test.example.com:443", + } + } + + # Mock JWTValidator to raise an exception + with patch("main.JWTValidator") as mock_validator: + mock_validator.side_effect = Exception("Test error during initialization") + + result = create_jwt_interceptor(config, logger) + assert result is None + + def test_jwt_enabled_valid_config_creates_interceptor(self, logger, valid_config): + """Test that valid configuration creates interceptor successfully.""" + with patch("main.JWTValidator") as mock_validator, patch( + "main.JWTAuthInterceptor" + ) as mock_interceptor: + # Set up mocks + mock_validator_instance = Mock() + mock_validator.return_value = mock_validator_instance + + mock_interceptor_instance = Mock() + mock_interceptor.return_value = mock_interceptor_instance + + result = create_jwt_interceptor(valid_config, logger) + + # Verify interceptor was created + assert result is mock_interceptor_instance + + # Verify validator was created with correct parameters + mock_validator.assert_called_once_with( + datasource_url="https://test-gateway.example.com:443", + datasource_schema_uuid="5397013b-7920-4ffc-807c-e8a3e0a18f43", + cache_duration_minutes=60, + ) + + # Verify interceptor was created with validator + mock_interceptor.assert_called_once_with( + jwt_validator=mock_validator_instance, enabled=True, enforce=True + ) + + def test_jwt_enabled_default_schema_uuid(self, logger): + """Test that default schema UUID is used when not specified.""" + config = { + "jwt_validation": { + "enabled": True, + "datasource_url": "https://test-gateway.example.com:443", + # datasource_schema_uuid not specified + } + } + + with patch("main.JWTValidator") as mock_validator, patch( + "main.JWTAuthInterceptor" + ): + create_jwt_interceptor(config, logger) + + # Verify default schema UUID was used + call_kwargs = mock_validator.call_args[1] + assert ( + call_kwargs["datasource_schema_uuid"] + == "5397013b-7920-4ffc-807c-e8a3e0a18f43" + ) + + def test_jwt_enabled_custom_cache_duration(self, logger): + """Test that custom cache duration is used.""" + config = { + "jwt_validation": { + "enabled": True, + "datasource_url": "https://test-gateway.example.com:443", + "cache_duration_minutes": 120, # Custom value + } + } + + with patch("main.JWTValidator") as mock_validator, patch( + "main.JWTAuthInterceptor" + ): + create_jwt_interceptor(config, logger) + + # Verify custom cache duration was used + call_kwargs = mock_validator.call_args[1] + assert call_kwargs["cache_duration_minutes"] == 120 + + def test_jwt_enabled_enforce_false(self, logger, valid_config): + """Test that enforce_validation: false is respected.""" + valid_config["jwt_validation"]["enforce_validation"] = False + + with patch("main.JWTValidator"), patch( + "main.JWTAuthInterceptor" + ) as mock_interceptor: + create_jwt_interceptor(valid_config, logger) + + # Verify interceptor was created with enforce=False + call_kwargs = mock_interceptor.call_args[1] + assert call_kwargs["enforce"] is False + + def test_jwt_no_config_section_returns_none(self, logger): + """Test that missing jwt_validation config section returns None.""" + config = {} # No jwt_validation section + + result = create_jwt_interceptor(config, logger) + assert result is None + + def test_jwt_enabled_not_specified_returns_none(self, logger): + """Test that JWT validation returns None when enabled is not specified (backward compatibility).""" + config = { + "jwt_validation": { + # enabled not specified - should return None for backward compatibility + "datasource_url": "https://test-gateway.example.com:443", + } + } + + result = create_jwt_interceptor(config, logger) + # When enabled is not specified, it defaults to False (get returns False if not present) + assert result is None