From 02d16d286adbfc5998c8d621339e5ac737ea6763 Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Fri, 5 Sep 2025 05:07:12 +0000 Subject: [PATCH 1/7] token federation for python driver --- src/databricks/sql/auth/auth.py | 30 +- src/databricks/sql/auth/common.py | 6 + src/databricks/sql/auth/token_federation.py | 226 ++++++++++++++ tests/unit/test_token_federation.py | 325 ++++++++++++++++++++ 4 files changed, 581 insertions(+), 6 deletions(-) create mode 100644 src/databricks/sql/auth/token_federation.py create mode 100644 tests/unit/test_token_federation.py diff --git a/src/databricks/sql/auth/auth.py b/src/databricks/sql/auth/auth.py index a8accac06..122928b74 100755 --- a/src/databricks/sql/auth/auth.py +++ b/src/databricks/sql/auth/auth.py @@ -8,13 +8,17 @@ AzureServicePrincipalCredentialProvider, ) from databricks.sql.auth.common import AuthType, ClientContext +from databricks.sql.auth.token_federation import TokenFederationProvider, ExternalTokenProvider def get_auth_provider(cfg: ClientContext, http_client): + # Determine the base auth provider + base_provider = None + if cfg.credentials_provider: - return ExternalAuthProvider(cfg.credentials_provider) + base_provider = ExternalAuthProvider(cfg.credentials_provider) elif cfg.auth_type == AuthType.AZURE_SP_M2M.value: - return ExternalAuthProvider( + base_provider = ExternalAuthProvider( AzureServicePrincipalCredentialProvider( cfg.hostname, cfg.azure_client_id, @@ -29,7 +33,7 @@ def get_auth_provider(cfg: ClientContext, http_client): assert cfg.oauth_client_id is not None assert cfg.oauth_scopes is not None - return DatabricksOAuthProvider( + base_provider = DatabricksOAuthProvider( cfg.hostname, cfg.oauth_persistence, cfg.oauth_redirect_port_range, @@ -39,17 +43,17 @@ def get_auth_provider(cfg: ClientContext, http_client): cfg.auth_type, ) elif cfg.access_token is not None: - return AccessTokenAuthProvider(cfg.access_token) + base_provider = AccessTokenAuthProvider(cfg.access_token) elif cfg.use_cert_as_auth and cfg.tls_client_cert_file: # no op authenticator. authentication is performed using ssl certificate outside of headers - return AuthProvider() + base_provider = AuthProvider() else: if ( cfg.oauth_redirect_port_range is not None and cfg.oauth_client_id is not None and cfg.oauth_scopes is not None ): - return DatabricksOAuthProvider( + base_provider = DatabricksOAuthProvider( cfg.hostname, cfg.oauth_persistence, cfg.oauth_redirect_port_range, @@ -60,6 +64,17 @@ def get_auth_provider(cfg: ClientContext, http_client): ) else: raise RuntimeError("No valid authentication settings!") + + # Wrap with token federation if enabled + if cfg.enable_token_federation and base_provider: + return TokenFederationProvider( + hostname=cfg.hostname, + external_provider=base_provider, + http_client=http_client, + identity_federation_client_id=cfg.identity_federation_client_id, + ) + + return base_provider PYSQL_OAUTH_SCOPES = ["sql", "offline_access"] @@ -114,5 +129,8 @@ def get_python_sql_connector_auth_provider(hostname: str, http_client, **kwargs) else redirect_port_range, oauth_persistence=kwargs.get("experimental_oauth_persistence"), credentials_provider=kwargs.get("credentials_provider"), + # Token federation parameters + enable_token_federation=kwargs.get("enable_token_federation", False), + identity_federation_client_id=kwargs.get("identity_federation_client_id"), ) return get_auth_provider(cfg, http_client) diff --git a/src/databricks/sql/auth/common.py b/src/databricks/sql/auth/common.py index 679e353f1..c3a3c9c18 100644 --- a/src/databricks/sql/auth/common.py +++ b/src/databricks/sql/auth/common.py @@ -37,6 +37,9 @@ def __init__( tls_client_cert_file: Optional[str] = None, oauth_persistence=None, credentials_provider=None, + # Token federation parameters + enable_token_federation: bool = False, + identity_federation_client_id: Optional[str] = None, # HTTP client configuration parameters ssl_options=None, # SSLOptions type socket_timeout: Optional[float] = None, @@ -65,6 +68,9 @@ def __init__( self.tls_client_cert_file = tls_client_cert_file self.oauth_persistence = oauth_persistence self.credentials_provider = credentials_provider + # Token federation + self.enable_token_federation = enable_token_federation + self.identity_federation_client_id = identity_federation_client_id # HTTP client configuration self.ssl_options = ssl_options diff --git a/src/databricks/sql/auth/token_federation.py b/src/databricks/sql/auth/token_federation.py new file mode 100644 index 000000000..2154ca2a4 --- /dev/null +++ b/src/databricks/sql/auth/token_federation.py @@ -0,0 +1,226 @@ +import logging +import json +from datetime import datetime, timedelta +from typing import Optional, Dict, Tuple +from urllib.parse import urlparse, urlencode +import jwt +import requests + +from databricks.sql.auth.authenticators import AuthProvider +from databricks.sql.auth.common import AuthType +from databricks.sql.common.http import HttpMethod + +logger = logging.getLogger(__name__) + + +class TokenFederationProvider(AuthProvider): + """ + Implementation of Token Federation for Databricks SQL Python driver. + + This provider exchanges third-party access tokens for Databricks in-house tokens + when the token issuer is different from the Databricks host. + """ + + TOKEN_EXCHANGE_ENDPOINT = "/oidc/v1/token" + TOKEN_EXCHANGE_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:token-exchange" + TOKEN_EXCHANGE_SUBJECT_TYPE = "urn:ietf:params:oauth:token-type:jwt" + + def __init__( + self, + hostname: str, + external_provider: AuthProvider, + http_client=None, + identity_federation_client_id: Optional[str] = None, + ): + """ + Initialize the Token Federation Provider. + + Args: + hostname: The Databricks workspace hostname + external_provider: The external authentication provider + http_client: HTTP client for making requests + identity_federation_client_id: Optional client ID for token federation + """ + self.hostname = self._normalize_hostname(hostname) + self.external_provider = external_provider + self.http_client = http_client or requests.Session() + self.identity_federation_client_id = identity_federation_client_id + + self._cached_token = None + self._cached_token_expiry = None + self._external_headers = {} + + def add_headers(self, request_headers: Dict[str, str]): + """Add authentication headers to the request.""" + token_info = self._get_token() + request_headers["Authorization"] = f"{token_info['token_type']} {token_info['access_token']}" + + def _get_token(self) -> Dict[str, str]: + """Get or refresh the authentication token.""" + # Check if cached token is still valid + if self._is_token_valid(): + return self._cached_token + + # Get the external token + self._external_headers = {} + self.external_provider.add_headers(self._external_headers) + + # Extract token from Authorization header + auth_header = self._external_headers.get("Authorization", "") + token_type, access_token = self._extract_token_from_header(auth_header) + + # Check if token exchange is needed + if self._should_exchange_token(access_token): + try: + exchanged_token = self._exchange_token(access_token) + self._cache_token(exchanged_token) + return exchanged_token + except Exception as e: + logger.warning(f"Token exchange failed, using external token: {e}") + # Fall back to using the external token + + # Use external token directly + token_info = { + "access_token": access_token, + "token_type": token_type, + } + self._cache_token(token_info) + return token_info + + def _should_exchange_token(self, access_token: str) -> bool: + """Check if the token should be exchanged based on issuer.""" + try: + # Decode JWT without verification to check issuer + decoded = jwt.decode(access_token, options={"verify_signature": False}) + issuer = decoded.get("iss", "") + + # Check if issuer host is different from Databricks host + return not self._is_same_host(issuer, self.hostname) + except Exception as e: + logger.debug(f"Failed to decode JWT token: {e}") + return False + + def _exchange_token(self, access_token: str) -> Dict[str, str]: + """Exchange the external token for a Databricks token.""" + token_url = f"{self.hostname.rstrip('/')}{self.TOKEN_EXCHANGE_ENDPOINT}" + + # Prepare the token exchange request + data = { + "grant_type": self.TOKEN_EXCHANGE_GRANT_TYPE, + "subject_token": access_token, + "subject_token_type": self.TOKEN_EXCHANGE_SUBJECT_TYPE, + "scope": "sql", + "return_original_token_if_authenticated": "true", + } + + # Add client_id if provided + if self.identity_federation_client_id: + data["client_id"] = self.identity_federation_client_id + + headers = { + "Content-Type": "application/x-www-form-urlencoded", + "Accept": "*/*", + } + + # Encode data as URL-encoded form + body = urlencode(data) + + # Make the token exchange request using UnifiedHttpClient API + response = self.http_client.request( + HttpMethod.POST, url=token_url, body=body, headers=headers + ) + + # Parse the response + token_response = json.loads(response.data.decode()) + + return { + "access_token": token_response["access_token"], + "token_type": token_response.get("token_type", "Bearer"), + "expires_in": token_response.get("expires_in"), + } + + def _extract_token_from_header(self, auth_header: str) -> Tuple[str, str]: + """Extract token type and access token from Authorization header.""" + if not auth_header: + raise ValueError("Authorization header is missing") + + parts = auth_header.split(" ", 1) + if len(parts) != 2: + raise ValueError("Invalid Authorization header format") + + return parts[0], parts[1] + + def _is_same_host(self, url1: str, url2: str) -> bool: + """Check if two URLs have the same host.""" + try: + host1 = urlparse(url1).netloc + host2 = urlparse(url2).netloc + return host1 == host2 + except Exception as e: + logger.debug(f"Failed to parse URLs: {e}") + return False + + def _normalize_hostname(self, hostname: str) -> str: + """Normalize the hostname to include scheme and trailing slash.""" + if not hostname.startswith("http://") and not hostname.startswith("https://"): + hostname = f"https://{hostname}" + if not hostname.endswith("/"): + hostname = f"{hostname}/" + return hostname + + def _cache_token(self, token_info: Dict[str, str]): + """Cache the token with its expiry time.""" + self._cached_token = token_info + + # Calculate expiry time + if "expires_in" in token_info: + expires_in = int(token_info["expires_in"]) + # Set expiry with a 1-minute buffer + self._cached_token_expiry = datetime.now() + timedelta(seconds=expires_in - 60) + else: + # Try to get expiry from JWT + try: + decoded = jwt.decode( + token_info["access_token"], + options={"verify_signature": False} + ) + exp = decoded.get("exp") + if exp: + self._cached_token_expiry = datetime.fromtimestamp(exp) - timedelta(minutes=1) + else: + # Default to 1 hour if no expiry info + self._cached_token_expiry = datetime.now() + timedelta(hours=1) + except: + # Default to 1 hour if we can't decode + self._cached_token_expiry = datetime.now() + timedelta(hours=1) + + def _is_token_valid(self) -> bool: + """Check if the cached token is still valid.""" + if not self._cached_token or not self._cached_token_expiry: + return False + return datetime.now() < self._cached_token_expiry + + +class ExternalTokenProvider(AuthProvider): + """ + A simple provider that wraps an external credentials provider for token federation. + """ + + def __init__(self, credentials_provider): + """ + Initialize with an external credentials provider. + + Args: + credentials_provider: A callable that returns authentication headers + """ + self.credentials_provider = credentials_provider + self._header_factory = None + + def add_headers(self, request_headers: Dict[str, str]): + """Add headers from the external provider.""" + if self._header_factory is None: + self._header_factory = self.credentials_provider() + + headers = self._header_factory() + for key, value in headers.items(): + request_headers[key] = value \ No newline at end of file diff --git a/tests/unit/test_token_federation.py b/tests/unit/test_token_federation.py new file mode 100644 index 000000000..d4288b075 --- /dev/null +++ b/tests/unit/test_token_federation.py @@ -0,0 +1,325 @@ +import unittest +from unittest.mock import Mock, MagicMock, patch, call +import json +import jwt +from datetime import datetime, timedelta +import requests + +from databricks.sql.auth.token_federation import TokenFederationProvider, ExternalTokenProvider +from databricks.sql.auth.authenticators import AccessTokenAuthProvider + + +class TestTokenFederationProvider(unittest.TestCase): + + def setUp(self): + self.hostname = "https://test.databricks.com/" + self.external_provider = Mock() + self.http_client = Mock() + self.identity_federation_client_id = "test-client-id" + + self.provider = TokenFederationProvider( + hostname=self.hostname, + external_provider=self.external_provider, + http_client=self.http_client, + identity_federation_client_id=self.identity_federation_client_id, + ) + + def test_normalize_hostname(self): + """Test hostname normalization.""" + test_cases = [ + ("test.databricks.com", "https://test.databricks.com/"), + ("https://test.databricks.com", "https://test.databricks.com/"), + ("https://test.databricks.com/", "https://test.databricks.com/"), + ("test.databricks.com/", "https://test.databricks.com/"), + ] + + for input_hostname, expected in test_cases: + provider = TokenFederationProvider( + hostname=input_hostname, + external_provider=self.external_provider, + ) + self.assertEqual(provider.hostname, expected) + + def test_extract_token_from_header(self): + """Test extraction of token from Authorization header.""" + # Valid header + token_type, access_token = self.provider._extract_token_from_header( + "Bearer test-token-123" + ) + self.assertEqual(token_type, "Bearer") + self.assertEqual(access_token, "test-token-123") + + # Invalid header - missing token + with self.assertRaises(ValueError): + self.provider._extract_token_from_header("Bearer") + + # Invalid header - empty + with self.assertRaises(ValueError): + self.provider._extract_token_from_header("") + + def test_is_same_host(self): + """Test host comparison.""" + test_cases = [ + ("https://test.databricks.com", "https://test.databricks.com", True), + ("https://test.databricks.com", "https://test.databricks.com:443", False), + ("https://test1.databricks.com", "https://test2.databricks.com", False), + ("https://login.microsoftonline.com", "https://test.databricks.com", False), + ] + + for url1, url2, expected in test_cases: + result = self.provider._is_same_host(url1, url2) + self.assertEqual(result, expected) + + def test_should_exchange_token_different_issuer(self): + """Test that token exchange is triggered for different issuer.""" + # Create a mock JWT with different issuer + token_payload = { + "iss": "https://login.microsoftonline.com/tenant-id/", + "aud": "databricks", + "exp": int((datetime.now() + timedelta(hours=1)).timestamp()), + } + + with patch("jwt.decode", return_value=token_payload): + result = self.provider._should_exchange_token("mock-jwt-token") + self.assertTrue(result) + + def test_should_exchange_token_same_issuer(self): + """Test that token exchange is not triggered for same issuer.""" + # Create a mock JWT with same issuer + token_payload = { + "iss": "https://test.databricks.com", + "aud": "databricks", + "exp": int((datetime.now() + timedelta(hours=1)).timestamp()), + } + + with patch("jwt.decode", return_value=token_payload): + result = self.provider._should_exchange_token("mock-jwt-token") + self.assertFalse(result) + + def test_exchange_token_success(self): + """Test successful token exchange.""" + access_token = "external-token-123" + + # Mock successful response (UnifiedHttpClient style) + mock_response = Mock() + mock_response.data = json.dumps({ + "access_token": "databricks-token-456", + "token_type": "Bearer", + "expires_in": 3600, + }).encode('utf-8') + mock_response.status = 200 + + self.http_client.request.return_value = mock_response + + result = self.provider._exchange_token(access_token) + + # Verify the request + self.http_client.request.assert_called_once() + call_args = self.http_client.request.call_args + + # Check method and URL (HttpMethod.POST is first arg, URL is second) + from databricks.sql.common.http import HttpMethod + self.assertEqual(call_args[0][0], HttpMethod.POST) + self.assertEqual(call_args[1]["url"], f"{self.hostname}oidc/v1/token") + + # Check body contains expected parameters + from urllib.parse import parse_qs + body = call_args[1]["body"] + parsed_body = parse_qs(body) + + # Verify all expected params are in the body + self.assertEqual(parsed_body["grant_type"][0], "urn:ietf:params:oauth:grant-type:token-exchange") + self.assertEqual(parsed_body["subject_token"][0], access_token) + self.assertEqual(parsed_body["subject_token_type"][0], "urn:ietf:params:oauth:token-type:jwt") + self.assertEqual(parsed_body["scope"][0], "sql") + self.assertEqual(parsed_body["return_original_token_if_authenticated"][0], "true") + self.assertEqual(parsed_body["client_id"][0], self.identity_federation_client_id) + + # Check result + self.assertEqual(result["access_token"], "databricks-token-456") + self.assertEqual(result["token_type"], "Bearer") + self.assertEqual(result["expires_in"], 3600) + + def test_exchange_token_failure(self): + """Test token exchange failure handling.""" + access_token = "external-token-123" + + # Mock failed response + mock_response = Mock() + mock_response.data = b'{"error": "invalid_request"}' + mock_response.status = 400 + self.http_client.request.return_value = mock_response + + # Should not raise, but should return None or handle gracefully + # The actual implementation should handle this + with self.assertRaises(KeyError): # Will raise KeyError due to missing access_token + self.provider._exchange_token(access_token) + + def test_add_headers_with_token_exchange(self): + """Test adding headers with token exchange.""" + # Setup external provider to return a token + self.external_provider.add_headers = Mock( + side_effect=lambda headers: headers.update({ + "Authorization": "Bearer external-token-123" + }) + ) + + # Mock JWT decode to indicate different issuer + token_payload = { + "iss": "https://login.microsoftonline.com/tenant-id/", + "aud": "databricks", + "exp": int((datetime.now() + timedelta(hours=1)).timestamp()), + } + + # Mock successful token exchange + mock_response = Mock() + mock_response.data = json.dumps({ + "access_token": "databricks-token-456", + "token_type": "Bearer", + "expires_in": 3600, + }).encode('utf-8') + mock_response.status = 200 + self.http_client.request.return_value = mock_response + + with patch("jwt.decode", return_value=token_payload): + headers = {} + self.provider.add_headers(headers) + + # Should have the exchanged token + self.assertEqual(headers["Authorization"], "Bearer databricks-token-456") + + def test_add_headers_without_token_exchange(self): + """Test adding headers without token exchange (same issuer).""" + # Setup external provider to return a token + self.external_provider.add_headers = Mock( + side_effect=lambda headers: headers.update({ + "Authorization": "Bearer external-token-123" + }) + ) + + # Mock JWT decode to indicate same issuer + token_payload = { + "iss": "https://test.databricks.com", + "aud": "databricks", + "exp": int((datetime.now() + timedelta(hours=1)).timestamp()), + } + + with patch("jwt.decode", return_value=token_payload): + headers = {} + self.provider.add_headers(headers) + + # Should have the original external token + self.assertEqual(headers["Authorization"], "Bearer external-token-123") + + def test_token_caching(self): + """Test that tokens are cached and reused.""" + # Setup external provider + self.external_provider.add_headers = Mock( + side_effect=lambda headers: headers.update({ + "Authorization": "Bearer external-token-123" + }) + ) + + # Mock JWT decode + token_payload = { + "iss": "https://test.databricks.com", + "exp": int((datetime.now() + timedelta(hours=1)).timestamp()), + } + + with patch("jwt.decode", return_value=token_payload): + # First call + headers1 = {} + self.provider.add_headers(headers1) + + # Second call - should use cached token + headers2 = {} + self.provider.add_headers(headers2) + + # External provider should only be called once + self.assertEqual(self.external_provider.add_headers.call_count, 1) + + # Both headers should be the same + self.assertEqual(headers1["Authorization"], headers2["Authorization"]) + + def test_token_cache_expiry(self): + """Test that expired cached tokens are refreshed.""" + # Setup external provider + call_count = [0] + def add_headers_side_effect(headers): + call_count[0] += 1 + headers.update({ + "Authorization": f"Bearer external-token-{call_count[0]}" + }) + + self.external_provider.add_headers = Mock(side_effect=add_headers_side_effect) + + # Mock JWT decode with short expiry + token_payload = { + "iss": "https://test.databricks.com", + "exp": int((datetime.now() + timedelta(seconds=5)).timestamp()), + } + + with patch("jwt.decode", return_value=token_payload): + # First call + headers1 = {} + self.provider.add_headers(headers1) + self.assertEqual(headers1["Authorization"], "Bearer external-token-1") + + # Expire the cache + self.provider._cached_token_expiry = datetime.now() - timedelta(seconds=1) + + # Second call - should get new token + headers2 = {} + self.provider.add_headers(headers2) + self.assertEqual(headers2["Authorization"], "Bearer external-token-2") + + # External provider should be called twice + self.assertEqual(self.external_provider.add_headers.call_count, 2) + + +class TestExternalTokenProvider(unittest.TestCase): + + def test_add_headers(self): + """Test adding headers from external credentials provider.""" + # Create mock credentials provider + mock_headers = { + "Authorization": "Bearer test-token", + "X-Custom-Header": "custom-value", + } + credentials_provider = Mock(return_value=Mock(return_value=mock_headers)) + + provider = ExternalTokenProvider(credentials_provider) + + # Test adding headers + request_headers = {} + provider.add_headers(request_headers) + + # Verify headers were added + self.assertEqual(request_headers["Authorization"], "Bearer test-token") + self.assertEqual(request_headers["X-Custom-Header"], "custom-value") + + # Verify credentials provider was called once + credentials_provider.assert_called_once() + + def test_header_factory_cached(self): + """Test that header factory is cached.""" + mock_headers = {"Authorization": "Bearer test-token"} + header_factory = Mock(return_value=mock_headers) + credentials_provider = Mock(return_value=header_factory) + + provider = ExternalTokenProvider(credentials_provider) + + # Call add_headers multiple times + for _ in range(3): + request_headers = {} + provider.add_headers(request_headers) + + # Credentials provider should only be called once + credentials_provider.assert_called_once() + + # Header factory should be called 3 times + self.assertEqual(header_factory.call_count, 3) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From ae5ee50c83c0e60876bb5475b15c583e232c2aae Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Thu, 25 Sep 2025 19:40:15 +0000 Subject: [PATCH 2/7] address comment --- src/databricks/sql/auth/auth.py | 14 +- src/databricks/sql/auth/auth_utils.py | 117 ++++ src/databricks/sql/auth/common.py | 4 - src/databricks/sql/auth/token_federation.py | 203 ++----- tests/unit/test_token_federation.py | 614 ++++++++++---------- 5 files changed, 498 insertions(+), 454 deletions(-) create mode 100644 src/databricks/sql/auth/auth_utils.py diff --git a/src/databricks/sql/auth/auth.py b/src/databricks/sql/auth/auth.py index 122928b74..b50f93c9c 100755 --- a/src/databricks/sql/auth/auth.py +++ b/src/databricks/sql/auth/auth.py @@ -8,13 +8,13 @@ AzureServicePrincipalCredentialProvider, ) from databricks.sql.auth.common import AuthType, ClientContext -from databricks.sql.auth.token_federation import TokenFederationProvider, ExternalTokenProvider +from databricks.sql.auth.token_federation import TokenFederationProvider def get_auth_provider(cfg: ClientContext, http_client): # Determine the base auth provider base_provider = None - + if cfg.credentials_provider: base_provider = ExternalAuthProvider(cfg.credentials_provider) elif cfg.auth_type == AuthType.AZURE_SP_M2M.value: @@ -64,16 +64,16 @@ def get_auth_provider(cfg: ClientContext, http_client): ) else: raise RuntimeError("No valid authentication settings!") - - # Wrap with token federation if enabled - if cfg.enable_token_federation and base_provider: + + # Always wrap with token federation (falls back gracefully if not needed) + if base_provider: return TokenFederationProvider( hostname=cfg.hostname, external_provider=base_provider, http_client=http_client, identity_federation_client_id=cfg.identity_federation_client_id, ) - + return base_provider @@ -129,8 +129,6 @@ def get_python_sql_connector_auth_provider(hostname: str, http_client, **kwargs) else redirect_port_range, oauth_persistence=kwargs.get("experimental_oauth_persistence"), credentials_provider=kwargs.get("credentials_provider"), - # Token federation parameters - enable_token_federation=kwargs.get("enable_token_federation", False), identity_federation_client_id=kwargs.get("identity_federation_client_id"), ) return get_auth_provider(cfg, http_client) diff --git a/src/databricks/sql/auth/auth_utils.py b/src/databricks/sql/auth/auth_utils.py new file mode 100644 index 000000000..76233d96c --- /dev/null +++ b/src/databricks/sql/auth/auth_utils.py @@ -0,0 +1,117 @@ +import logging +import jwt +from datetime import datetime, timedelta +from typing import Optional, Dict, Tuple +from urllib.parse import urlparse + +logger = logging.getLogger(__name__) + + +def parse_hostname(hostname: str) -> str: + """ + Normalize the hostname to include scheme and trailing slash. + + Args: + hostname: The hostname to normalize + + Returns: + Normalized hostname with scheme and trailing slash + """ + if not hostname.startswith("http://") and not hostname.startswith("https://"): + hostname = f"https://{hostname}" + if not hostname.endswith("/"): + hostname = f"{hostname}/" + return hostname + + +def decode_token(access_token: str) -> Optional[Dict]: + """ + Decode a JWT token without verification to extract claims. + + Args: + access_token: The JWT access token to decode + + Returns: + Decoded token claims or None if decoding fails + """ + try: + return jwt.decode(access_token, options={"verify_signature": False}) + except Exception as e: + logger.debug("Failed to decode JWT token: %s", e) + return None + + +def is_same_host(url1: str, url2: str) -> bool: + """ + Check if two URLs have the same host. + + Args: + url1: First URL + url2: Second URL + + Returns: + True if hosts are the same, False otherwise + """ + try: + host1 = urlparse(url1).netloc + host2 = urlparse(url2).netloc + # Handle port differences (e.g., example.com vs example.com:443) + host1_without_port = host1.split(":")[0] + host2_without_port = host2.split(":")[0] + return host1_without_port == host2_without_port + except Exception as e: + logger.debug("Failed to parse URLs: %s", e) + return False + + +class Token: + """ + Represents an OAuth token with expiration management. + """ + + def __init__(self, access_token: str, token_type: str = "Bearer"): + """ + Initialize a token. + + Args: + access_token: The access token string + token_type: The token type (default: Bearer) + """ + self.access_token = access_token + self.token_type = token_type + self.expiry_time = self._calculate_expiry() + + def _calculate_expiry(self) -> datetime: + """ + Calculate the token expiry time from JWT claims. + + Returns: + The token expiry datetime + """ + decoded = decode_token(self.access_token) + if decoded and "exp" in decoded: + # Use JWT exp claim with 1 minute buffer + return datetime.fromtimestamp(decoded["exp"]) - timedelta(minutes=1) + # Default to 1 hour if no expiry info + return datetime.now() + timedelta(hours=1) + + def is_expired(self) -> bool: + """ + Check if the token is expired. + + Returns: + True if token is expired, False otherwise + """ + return datetime.now() >= self.expiry_time + + def to_dict(self) -> Dict[str, str]: + """ + Convert token to dictionary format. + + Returns: + Dictionary with access_token and token_type + """ + return { + "access_token": self.access_token, + "token_type": self.token_type, + } diff --git a/src/databricks/sql/auth/common.py b/src/databricks/sql/auth/common.py index c3a3c9c18..3e0be0d2b 100644 --- a/src/databricks/sql/auth/common.py +++ b/src/databricks/sql/auth/common.py @@ -37,8 +37,6 @@ def __init__( tls_client_cert_file: Optional[str] = None, oauth_persistence=None, credentials_provider=None, - # Token federation parameters - enable_token_federation: bool = False, identity_federation_client_id: Optional[str] = None, # HTTP client configuration parameters ssl_options=None, # SSLOptions type @@ -68,8 +66,6 @@ def __init__( self.tls_client_cert_file = tls_client_cert_file self.oauth_persistence = oauth_persistence self.credentials_provider = credentials_provider - # Token federation - self.enable_token_federation = enable_token_federation self.identity_federation_client_id = identity_federation_client_id # HTTP client configuration diff --git a/src/databricks/sql/auth/token_federation.py b/src/databricks/sql/auth/token_federation.py index 2154ca2a4..746ce0f86 100644 --- a/src/databricks/sql/auth/token_federation.py +++ b/src/databricks/sql/auth/token_federation.py @@ -1,13 +1,15 @@ import logging import json -from datetime import datetime, timedelta from typing import Optional, Dict, Tuple -from urllib.parse import urlparse, urlencode -import jwt -import requests +from urllib.parse import urlencode from databricks.sql.auth.authenticators import AuthProvider -from databricks.sql.auth.common import AuthType +from databricks.sql.auth.auth_utils import ( + Token, + parse_hostname, + decode_token, + is_same_host, +) from databricks.sql.common.http import HttpMethod logger = logging.getLogger(__name__) @@ -16,95 +18,89 @@ class TokenFederationProvider(AuthProvider): """ Implementation of Token Federation for Databricks SQL Python driver. - + This provider exchanges third-party access tokens for Databricks in-house tokens when the token issuer is different from the Databricks host. """ - + TOKEN_EXCHANGE_ENDPOINT = "/oidc/v1/token" TOKEN_EXCHANGE_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:token-exchange" TOKEN_EXCHANGE_SUBJECT_TYPE = "urn:ietf:params:oauth:token-type:jwt" - + def __init__( self, hostname: str, external_provider: AuthProvider, - http_client=None, + http_client, identity_federation_client_id: Optional[str] = None, ): """ Initialize the Token Federation Provider. - + Args: hostname: The Databricks workspace hostname external_provider: The external authentication provider - http_client: HTTP client for making requests + http_client: HTTP client for making requests (required) identity_federation_client_id: Optional client ID for token federation """ - self.hostname = self._normalize_hostname(hostname) + if not http_client: + raise ValueError("http_client is required for TokenFederationProvider") + + self.hostname = parse_hostname(hostname) self.external_provider = external_provider - self.http_client = http_client or requests.Session() + self.http_client = http_client self.identity_federation_client_id = identity_federation_client_id - - self._cached_token = None - self._cached_token_expiry = None + + self._cached_token: Optional[Token] = None self._external_headers = {} - + def add_headers(self, request_headers: Dict[str, str]): """Add authentication headers to the request.""" - token_info = self._get_token() - request_headers["Authorization"] = f"{token_info['token_type']} {token_info['access_token']}" - - def _get_token(self) -> Dict[str, str]: + token = self._get_token() + request_headers["Authorization"] = f"{token.token_type} {token.access_token}" + + def _get_token(self) -> Token: """Get or refresh the authentication token.""" # Check if cached token is still valid - if self._is_token_valid(): + if self._cached_token and not self._cached_token.is_expired(): return self._cached_token - + # Get the external token self._external_headers = {} self.external_provider.add_headers(self._external_headers) - + # Extract token from Authorization header auth_header = self._external_headers.get("Authorization", "") token_type, access_token = self._extract_token_from_header(auth_header) - + # Check if token exchange is needed if self._should_exchange_token(access_token): try: - exchanged_token = self._exchange_token(access_token) - self._cache_token(exchanged_token) - return exchanged_token + token = self._exchange_token(access_token) + self._cached_token = token + return token except Exception as e: - logger.warning(f"Token exchange failed, using external token: {e}") - # Fall back to using the external token - + logger.warning("Token exchange failed, using external token: %s", e) + # Use external token directly - token_info = { - "access_token": access_token, - "token_type": token_type, - } - self._cache_token(token_info) - return token_info - + token = Token(access_token, token_type) + self._cached_token = token + return token + def _should_exchange_token(self, access_token: str) -> bool: """Check if the token should be exchanged based on issuer.""" - try: - # Decode JWT without verification to check issuer - decoded = jwt.decode(access_token, options={"verify_signature": False}) - issuer = decoded.get("iss", "") - - # Check if issuer host is different from Databricks host - return not self._is_same_host(issuer, self.hostname) - except Exception as e: - logger.debug(f"Failed to decode JWT token: {e}") + decoded = decode_token(access_token) + if not decoded: return False - - def _exchange_token(self, access_token: str) -> Dict[str, str]: + + issuer = decoded.get("iss", "") + # Check if issuer host is different from Databricks host + return not is_same_host(issuer, self.hostname) + + def _exchange_token(self, access_token: str) -> Token: """Exchange the external token for a Databricks token.""" token_url = f"{self.hostname.rstrip('/')}{self.TOKEN_EXCHANGE_ENDPOINT}" - - # Prepare the token exchange request + data = { "grant_type": self.TOKEN_EXCHANGE_GRANT_TYPE, "subject_token": access_token, @@ -112,115 +108,34 @@ def _exchange_token(self, access_token: str) -> Dict[str, str]: "scope": "sql", "return_original_token_if_authenticated": "true", } - - # Add client_id if provided + if self.identity_federation_client_id: data["client_id"] = self.identity_federation_client_id - + headers = { "Content-Type": "application/x-www-form-urlencoded", "Accept": "*/*", } - - # Encode data as URL-encoded form + body = urlencode(data) - - # Make the token exchange request using UnifiedHttpClient API + response = self.http_client.request( HttpMethod.POST, url=token_url, body=body, headers=headers ) - - # Parse the response + token_response = json.loads(response.data.decode()) - - return { - "access_token": token_response["access_token"], - "token_type": token_response.get("token_type", "Bearer"), - "expires_in": token_response.get("expires_in"), - } - + + return Token( + token_response["access_token"], token_response.get("token_type", "Bearer") + ) + def _extract_token_from_header(self, auth_header: str) -> Tuple[str, str]: """Extract token type and access token from Authorization header.""" if not auth_header: raise ValueError("Authorization header is missing") - + parts = auth_header.split(" ", 1) if len(parts) != 2: raise ValueError("Invalid Authorization header format") - - return parts[0], parts[1] - - def _is_same_host(self, url1: str, url2: str) -> bool: - """Check if two URLs have the same host.""" - try: - host1 = urlparse(url1).netloc - host2 = urlparse(url2).netloc - return host1 == host2 - except Exception as e: - logger.debug(f"Failed to parse URLs: {e}") - return False - - def _normalize_hostname(self, hostname: str) -> str: - """Normalize the hostname to include scheme and trailing slash.""" - if not hostname.startswith("http://") and not hostname.startswith("https://"): - hostname = f"https://{hostname}" - if not hostname.endswith("/"): - hostname = f"{hostname}/" - return hostname - - def _cache_token(self, token_info: Dict[str, str]): - """Cache the token with its expiry time.""" - self._cached_token = token_info - - # Calculate expiry time - if "expires_in" in token_info: - expires_in = int(token_info["expires_in"]) - # Set expiry with a 1-minute buffer - self._cached_token_expiry = datetime.now() + timedelta(seconds=expires_in - 60) - else: - # Try to get expiry from JWT - try: - decoded = jwt.decode( - token_info["access_token"], - options={"verify_signature": False} - ) - exp = decoded.get("exp") - if exp: - self._cached_token_expiry = datetime.fromtimestamp(exp) - timedelta(minutes=1) - else: - # Default to 1 hour if no expiry info - self._cached_token_expiry = datetime.now() + timedelta(hours=1) - except: - # Default to 1 hour if we can't decode - self._cached_token_expiry = datetime.now() + timedelta(hours=1) - - def _is_token_valid(self) -> bool: - """Check if the cached token is still valid.""" - if not self._cached_token or not self._cached_token_expiry: - return False - return datetime.now() < self._cached_token_expiry - -class ExternalTokenProvider(AuthProvider): - """ - A simple provider that wraps an external credentials provider for token federation. - """ - - def __init__(self, credentials_provider): - """ - Initialize with an external credentials provider. - - Args: - credentials_provider: A callable that returns authentication headers - """ - self.credentials_provider = credentials_provider - self._header_factory = None - - def add_headers(self, request_headers: Dict[str, str]): - """Add headers from the external provider.""" - if self._header_factory is None: - self._header_factory = self.credentials_provider() - - headers = self._header_factory() - for key, value in headers.items(): - request_headers[key] = value \ No newline at end of file + return parts[0], parts[1] diff --git a/tests/unit/test_token_federation.py b/tests/unit/test_token_federation.py index d4288b075..e2b8b57eb 100644 --- a/tests/unit/test_token_federation.py +++ b/tests/unit/test_token_federation.py @@ -1,325 +1,343 @@ -import unittest -from unittest.mock import Mock, MagicMock, patch, call +import pytest +from unittest.mock import Mock, patch import json import jwt from datetime import datetime, timedelta -import requests - -from databricks.sql.auth.token_federation import TokenFederationProvider, ExternalTokenProvider -from databricks.sql.auth.authenticators import AccessTokenAuthProvider - - -class TestTokenFederationProvider(unittest.TestCase): - - def setUp(self): - self.hostname = "https://test.databricks.com/" - self.external_provider = Mock() - self.http_client = Mock() - self.identity_federation_client_id = "test-client-id" - - self.provider = TokenFederationProvider( - hostname=self.hostname, - external_provider=self.external_provider, - http_client=self.http_client, - identity_federation_client_id=self.identity_federation_client_id, - ) - - def test_normalize_hostname(self): - """Test hostname normalization.""" - test_cases = [ + +from databricks.sql.auth.token_federation import TokenFederationProvider +from databricks.sql.auth.auth_utils import ( + Token, + parse_hostname, + decode_token, + is_same_host, +) +from databricks.sql.common.http import HttpMethod + + +@pytest.fixture +def mock_http_client(): + """Fixture for mock HTTP client.""" + return Mock() + + +@pytest.fixture +def mock_external_provider(): + """Fixture for mock external provider.""" + return Mock() + + +@pytest.fixture +def token_federation_provider(mock_http_client, mock_external_provider): + """Fixture for TokenFederationProvider.""" + return TokenFederationProvider( + hostname="https://test.databricks.com/", + external_provider=mock_external_provider, + http_client=mock_http_client, + identity_federation_client_id="test-client-id", + ) + + +def create_mock_token_response( + access_token="databricks-token-456", token_type="Bearer", expires_in=3600 +): + """Helper function to create mock token exchange response.""" + mock_response = Mock() + mock_response.data = json.dumps( + { + "access_token": access_token, + "token_type": token_type, + "expires_in": expires_in, + } + ).encode("utf-8") + mock_response.status = 200 + return mock_response + + +def create_jwt_token(issuer="https://test.databricks.com", exp_hours=1, **kwargs): + """Helper function to create JWT tokens for testing.""" + payload = { + "iss": issuer, + "aud": "databricks", + "exp": int((datetime.now() + timedelta(hours=exp_hours)).timestamp()), + **kwargs, + } + return jwt.encode(payload, "secret", algorithm="HS256") + + +class TestTokenFederationProvider: + """Test TokenFederationProvider functionality.""" + + def test_init_requires_http_client(self, mock_external_provider): + """Test that http_client is required.""" + with pytest.raises(ValueError, match="http_client is required"): + TokenFederationProvider( + hostname="test.databricks.com", + external_provider=mock_external_provider, + http_client=None, + ) + + @pytest.mark.parametrize( + "input_hostname,expected", + [ ("test.databricks.com", "https://test.databricks.com/"), ("https://test.databricks.com", "https://test.databricks.com/"), ("https://test.databricks.com/", "https://test.databricks.com/"), ("test.databricks.com/", "https://test.databricks.com/"), - ] - - for input_hostname, expected in test_cases: - provider = TokenFederationProvider( - hostname=input_hostname, - external_provider=self.external_provider, - ) - self.assertEqual(provider.hostname, expected) - - def test_extract_token_from_header(self): - """Test extraction of token from Authorization header.""" - # Valid header - token_type, access_token = self.provider._extract_token_from_header( - "Bearer test-token-123" + ], + ) + def test_hostname_normalization( + self, input_hostname, expected, mock_http_client, mock_external_provider + ): + """Test hostname normalization during initialization.""" + provider = TokenFederationProvider( + hostname=input_hostname, + external_provider=mock_external_provider, + http_client=mock_http_client, ) - self.assertEqual(token_type, "Bearer") - self.assertEqual(access_token, "test-token-123") - - # Invalid header - missing token - with self.assertRaises(ValueError): - self.provider._extract_token_from_header("Bearer") - - # Invalid header - empty - with self.assertRaises(ValueError): - self.provider._extract_token_from_header("") - - def test_is_same_host(self): - """Test host comparison.""" - test_cases = [ - ("https://test.databricks.com", "https://test.databricks.com", True), - ("https://test.databricks.com", "https://test.databricks.com:443", False), - ("https://test1.databricks.com", "https://test2.databricks.com", False), - ("https://login.microsoftonline.com", "https://test.databricks.com", False), - ] - - for url1, url2, expected in test_cases: - result = self.provider._is_same_host(url1, url2) - self.assertEqual(result, expected) - - def test_should_exchange_token_different_issuer(self): - """Test that token exchange is triggered for different issuer.""" - # Create a mock JWT with different issuer - token_payload = { - "iss": "https://login.microsoftonline.com/tenant-id/", - "aud": "databricks", - "exp": int((datetime.now() + timedelta(hours=1)).timestamp()), - } - - with patch("jwt.decode", return_value=token_payload): - result = self.provider._should_exchange_token("mock-jwt-token") - self.assertTrue(result) - - def test_should_exchange_token_same_issuer(self): - """Test that token exchange is not triggered for same issuer.""" - # Create a mock JWT with same issuer - token_payload = { - "iss": "https://test.databricks.com", - "aud": "databricks", - "exp": int((datetime.now() + timedelta(hours=1)).timestamp()), - } - - with patch("jwt.decode", return_value=token_payload): - result = self.provider._should_exchange_token("mock-jwt-token") - self.assertFalse(result) - - def test_exchange_token_success(self): + assert provider.hostname == expected + + @pytest.mark.parametrize( + "auth_header,expected_type,expected_token", + [ + ("Bearer test-token-123", "Bearer", "test-token-123"), + ("Basic dGVzdDp0ZXN0", "Basic", "dGVzdDp0ZXN0"), + ], + ) + def test_extract_token_from_valid_header( + self, token_federation_provider, auth_header, expected_type, expected_token + ): + """Test extraction of token from valid Authorization header.""" + token_type, access_token = token_federation_provider._extract_token_from_header( + auth_header + ) + assert token_type == expected_type + assert access_token == expected_token + + @pytest.mark.parametrize( + "invalid_header", + [ + "Bearer", # Missing token + "", # Empty header + "InvalidFormat", # No space separator + ], + ) + def test_extract_token_from_invalid_header( + self, token_federation_provider, invalid_header + ): + """Test extraction fails for invalid Authorization headers.""" + with pytest.raises(ValueError): + token_federation_provider._extract_token_from_header(invalid_header) + + @pytest.mark.parametrize( + "issuer,hostname,should_exchange", + [ + ( + "https://login.microsoftonline.com/tenant-id/", + "https://test.databricks.com/", + True, + ), + ("https://test.databricks.com", "https://test.databricks.com/", False), + ("https://test.databricks.com:443", "https://test.databricks.com/", False), + ("https://accounts.google.com", "https://test.databricks.com/", True), + ], + ) + def test_should_exchange_token( + self, token_federation_provider, issuer, hostname, should_exchange + ): + """Test token exchange decision based on issuer.""" + token_federation_provider.hostname = hostname + jwt_token = create_jwt_token(issuer=issuer) + + result = token_federation_provider._should_exchange_token(jwt_token) + assert result == should_exchange + + def test_should_exchange_token_invalid_jwt(self, token_federation_provider): + """Test that invalid JWT returns False for exchange.""" + result = token_federation_provider._should_exchange_token("invalid-jwt-token") + assert result is False + + def test_exchange_token_success(self, token_federation_provider, mock_http_client): """Test successful token exchange.""" access_token = "external-token-123" - - # Mock successful response (UnifiedHttpClient style) - mock_response = Mock() - mock_response.data = json.dumps({ - "access_token": "databricks-token-456", - "token_type": "Bearer", - "expires_in": 3600, - }).encode('utf-8') - mock_response.status = 200 - - self.http_client.request.return_value = mock_response - - result = self.provider._exchange_token(access_token) - + mock_http_client.request.return_value = create_mock_token_response() + + result = token_federation_provider._exchange_token(access_token) + + # Verify result is a Token object + assert isinstance(result, Token) + assert result.access_token == "databricks-token-456" + assert result.token_type == "Bearer" + # Verify the request - self.http_client.request.assert_called_once() - call_args = self.http_client.request.call_args - - # Check method and URL (HttpMethod.POST is first arg, URL is second) - from databricks.sql.common.http import HttpMethod - self.assertEqual(call_args[0][0], HttpMethod.POST) - self.assertEqual(call_args[1]["url"], f"{self.hostname}oidc/v1/token") - + mock_http_client.request.assert_called_once() + call_args = mock_http_client.request.call_args + + # Check method and URL + assert call_args[0][0] == HttpMethod.POST + assert call_args[1]["url"] == "https://test.databricks.com/oidc/v1/token" + # Check body contains expected parameters from urllib.parse import parse_qs + body = call_args[1]["body"] parsed_body = parse_qs(body) - - # Verify all expected params are in the body - self.assertEqual(parsed_body["grant_type"][0], "urn:ietf:params:oauth:grant-type:token-exchange") - self.assertEqual(parsed_body["subject_token"][0], access_token) - self.assertEqual(parsed_body["subject_token_type"][0], "urn:ietf:params:oauth:token-type:jwt") - self.assertEqual(parsed_body["scope"][0], "sql") - self.assertEqual(parsed_body["return_original_token_if_authenticated"][0], "true") - self.assertEqual(parsed_body["client_id"][0], self.identity_federation_client_id) - - # Check result - self.assertEqual(result["access_token"], "databricks-token-456") - self.assertEqual(result["token_type"], "Bearer") - self.assertEqual(result["expires_in"], 3600) - - def test_exchange_token_failure(self): + + assert ( + parsed_body["grant_type"][0] + == "urn:ietf:params:oauth:grant-type:token-exchange" + ) + assert parsed_body["subject_token"][0] == access_token + assert ( + parsed_body["subject_token_type"][0] + == "urn:ietf:params:oauth:token-type:jwt" + ) + assert parsed_body["scope"][0] == "sql" + assert parsed_body["client_id"][0] == "test-client-id" + + def test_exchange_token_failure(self, token_federation_provider, mock_http_client): """Test token exchange failure handling.""" - access_token = "external-token-123" - - # Mock failed response mock_response = Mock() mock_response.data = b'{"error": "invalid_request"}' mock_response.status = 400 - self.http_client.request.return_value = mock_response - - # Should not raise, but should return None or handle gracefully - # The actual implementation should handle this - with self.assertRaises(KeyError): # Will raise KeyError due to missing access_token - self.provider._exchange_token(access_token) - - def test_add_headers_with_token_exchange(self): - """Test adding headers with token exchange.""" - # Setup external provider to return a token - self.external_provider.add_headers = Mock( - side_effect=lambda headers: headers.update({ - "Authorization": "Bearer external-token-123" - }) - ) - - # Mock JWT decode to indicate different issuer - token_payload = { - "iss": "https://login.microsoftonline.com/tenant-id/", - "aud": "databricks", - "exp": int((datetime.now() + timedelta(hours=1)).timestamp()), - } - - # Mock successful token exchange - mock_response = Mock() - mock_response.data = json.dumps({ - "access_token": "databricks-token-456", - "token_type": "Bearer", - "expires_in": 3600, - }).encode('utf-8') - mock_response.status = 200 - self.http_client.request.return_value = mock_response - - with patch("jwt.decode", return_value=token_payload): - headers = {} - self.provider.add_headers(headers) - - # Should have the exchanged token - self.assertEqual(headers["Authorization"], "Bearer databricks-token-456") - - def test_add_headers_without_token_exchange(self): - """Test adding headers without token exchange (same issuer).""" + mock_http_client.request.return_value = mock_response + + with pytest.raises(KeyError): # Will raise KeyError due to missing access_token + token_federation_provider._exchange_token("external-token-123") + + @pytest.mark.parametrize( + "external_issuer,should_exchange", + [ + ("https://login.microsoftonline.com/tenant-id/", True), + ("https://test.databricks.com", False), + ], + ) + def test_add_headers_token_exchange( + self, + token_federation_provider, + mock_external_provider, + mock_http_client, + external_issuer, + should_exchange, + ): + """Test adding headers with and without token exchange.""" # Setup external provider to return a token - self.external_provider.add_headers = Mock( - side_effect=lambda headers: headers.update({ - "Authorization": "Bearer external-token-123" - }) + external_token = create_jwt_token(issuer=external_issuer) + mock_external_provider.add_headers = Mock( + side_effect=lambda headers: headers.update( + {"Authorization": f"Bearer {external_token}"} + ) ) - - # Mock JWT decode to indicate same issuer - token_payload = { - "iss": "https://test.databricks.com", - "aud": "databricks", - "exp": int((datetime.now() + timedelta(hours=1)).timestamp()), - } - - with patch("jwt.decode", return_value=token_payload): - headers = {} - self.provider.add_headers(headers) - - # Should have the original external token - self.assertEqual(headers["Authorization"], "Bearer external-token-123") - - def test_token_caching(self): + + if should_exchange: + # Mock successful token exchange + mock_http_client.request.return_value = create_mock_token_response() + expected_token = "databricks-token-456" + else: + expected_token = external_token + + headers = {} + token_federation_provider.add_headers(headers) + + assert headers["Authorization"] == f"Bearer {expected_token}" + + def test_token_caching(self, token_federation_provider, mock_external_provider): """Test that tokens are cached and reused.""" - # Setup external provider - self.external_provider.add_headers = Mock( - side_effect=lambda headers: headers.update({ - "Authorization": "Bearer external-token-123" - }) + external_token = create_jwt_token() + mock_external_provider.add_headers = Mock( + side_effect=lambda headers: headers.update( + {"Authorization": f"Bearer {external_token}"} + ) ) - - # Mock JWT decode - token_payload = { - "iss": "https://test.databricks.com", - "exp": int((datetime.now() + timedelta(hours=1)).timestamp()), - } - - with patch("jwt.decode", return_value=token_payload): - # First call - headers1 = {} - self.provider.add_headers(headers1) - - # Second call - should use cached token - headers2 = {} - self.provider.add_headers(headers2) - - # External provider should only be called once - self.assertEqual(self.external_provider.add_headers.call_count, 1) - - # Both headers should be the same - self.assertEqual(headers1["Authorization"], headers2["Authorization"]) - - def test_token_cache_expiry(self): + + # First call + headers1 = {} + token_federation_provider.add_headers(headers1) + + # Second call - should use cached token + headers2 = {} + token_federation_provider.add_headers(headers2) + + # External provider should only be called once + assert mock_external_provider.add_headers.call_count == 1 + + # Both headers should be the same + assert headers1["Authorization"] == headers2["Authorization"] + + def test_token_cache_expiry( + self, token_federation_provider, mock_external_provider + ): """Test that expired cached tokens are refreshed.""" - # Setup external provider call_count = [0] + def add_headers_side_effect(headers): call_count[0] += 1 - headers.update({ - "Authorization": f"Bearer external-token-{call_count[0]}" - }) - - self.external_provider.add_headers = Mock(side_effect=add_headers_side_effect) - - # Mock JWT decode with short expiry - token_payload = { - "iss": "https://test.databricks.com", - "exp": int((datetime.now() + timedelta(seconds=5)).timestamp()), - } - - with patch("jwt.decode", return_value=token_payload): - # First call - headers1 = {} - self.provider.add_headers(headers1) - self.assertEqual(headers1["Authorization"], "Bearer external-token-1") - - # Expire the cache - self.provider._cached_token_expiry = datetime.now() - timedelta(seconds=1) - - # Second call - should get new token - headers2 = {} - self.provider.add_headers(headers2) - self.assertEqual(headers2["Authorization"], "Bearer external-token-2") - - # External provider should be called twice - self.assertEqual(self.external_provider.add_headers.call_count, 2) - - -class TestExternalTokenProvider(unittest.TestCase): - - def test_add_headers(self): - """Test adding headers from external credentials provider.""" - # Create mock credentials provider - mock_headers = { - "Authorization": "Bearer test-token", - "X-Custom-Header": "custom-value", - } - credentials_provider = Mock(return_value=Mock(return_value=mock_headers)) - - provider = ExternalTokenProvider(credentials_provider) - - # Test adding headers - request_headers = {} - provider.add_headers(request_headers) - - # Verify headers were added - self.assertEqual(request_headers["Authorization"], "Bearer test-token") - self.assertEqual(request_headers["X-Custom-Header"], "custom-value") - - # Verify credentials provider was called once - credentials_provider.assert_called_once() - - def test_header_factory_cached(self): - """Test that header factory is cached.""" - mock_headers = {"Authorization": "Bearer test-token"} - header_factory = Mock(return_value=mock_headers) - credentials_provider = Mock(return_value=header_factory) - - provider = ExternalTokenProvider(credentials_provider) - - # Call add_headers multiple times - for _ in range(3): - request_headers = {} - provider.add_headers(request_headers) - - # Credentials provider should only be called once - credentials_provider.assert_called_once() - - # Header factory should be called 3 times - self.assertEqual(header_factory.call_count, 3) - - -if __name__ == "__main__": - unittest.main() \ No newline at end of file + token = create_jwt_token( + exp_hours=0.001 if call_count[0] == 1 else 1 + ) # First token expires quickly + headers.update({"Authorization": f"Bearer {token}"}) + + mock_external_provider.add_headers = Mock(side_effect=add_headers_side_effect) + + # First call + headers1 = {} + token_federation_provider.add_headers(headers1) + first_token = headers1["Authorization"].split(" ")[1] + + # Force cache expiry + token_federation_provider._cached_token = Token(first_token) + token_federation_provider._cached_token.expiry_time = ( + datetime.now() - timedelta(seconds=1) + ) + + # Second call - should get new token + headers2 = {} + token_federation_provider.add_headers(headers2) + second_token = headers2["Authorization"].split(" ")[1] + + # External provider should be called twice + assert mock_external_provider.add_headers.call_count == 2 + # Tokens should be different + assert first_token != second_token + + +class TestUtilityFunctions: + """Test utility functions used by TokenFederationProvider.""" + + @pytest.mark.parametrize( + "input_hostname,expected", + [ + ("test.databricks.com", "https://test.databricks.com/"), + ("https://test.databricks.com", "https://test.databricks.com/"), + ("https://test.databricks.com/", "https://test.databricks.com/"), + ("test.databricks.com/", "https://test.databricks.com/"), + ], + ) + def test_parse_hostname(self, input_hostname, expected): + """Test hostname parsing.""" + assert parse_hostname(input_hostname) == expected + + @pytest.mark.parametrize( + "url1,url2,expected", + [ + ("https://test.databricks.com", "https://test.databricks.com", True), + ("https://test.databricks.com", "https://test.databricks.com:443", True), + ("https://test1.databricks.com", "https://test2.databricks.com", False), + ("https://login.microsoftonline.com", "https://test.databricks.com", False), + ], + ) + def test_is_same_host(self, url1, url2, expected): + """Test host comparison.""" + assert is_same_host(url1, url2) == expected + + def test_decode_token_valid(self): + """Test decoding a valid JWT token.""" + token = create_jwt_token() + result = decode_token(token) + assert result is not None + assert "iss" in result + assert "exp" in result + + def test_decode_token_invalid(self): + """Test decoding an invalid token.""" + result = decode_token("invalid-token") + assert result is None From 63a7f82d1bc93fb46716a3e99cdc406414baf470 Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Thu, 25 Sep 2025 20:41:07 +0000 Subject: [PATCH 3/7] address comments --- src/databricks/sql/auth/token_federation.py | 20 +++++++++++++++----- tests/unit/test_auth.py | 11 ++++++++--- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/src/databricks/sql/auth/token_federation.py b/src/databricks/sql/auth/token_federation.py index 746ce0f86..e8e2d4d43 100644 --- a/src/databricks/sql/auth/token_federation.py +++ b/src/databricks/sql/auth/token_federation.py @@ -56,6 +56,20 @@ def __init__( def add_headers(self, request_headers: Dict[str, str]): """Add authentication headers to the request.""" + + if self._cached_token and not self._cached_token.is_expired(): + request_headers["Authorization"] = f"{self._cached_token.token_type} {self._cached_token.access_token}" + return + + # Get the external headers first to check if we need token federation + self._external_headers = {} + self.external_provider.add_headers(self._external_headers) + + # If no Authorization header from external provider, pass through all headers + if "Authorization" not in self._external_headers: + request_headers.update(self._external_headers) + return + token = self._get_token() request_headers["Authorization"] = f"{token.token_type} {token.access_token}" @@ -65,11 +79,7 @@ def _get_token(self) -> Token: if self._cached_token and not self._cached_token.is_expired(): return self._cached_token - # Get the external token - self._external_headers = {} - self.external_provider.add_headers(self._external_headers) - - # Extract token from Authorization header + # Extract token from already-fetched headers auth_header = self._external_headers.get("Authorization", "") token_type, access_token = self._extract_token_from_header(auth_header) diff --git a/tests/unit/test_auth.py b/tests/unit/test_auth.py index a5ad7562e..d1b941208 100644 --- a/tests/unit/test_auth.py +++ b/tests/unit/test_auth.py @@ -164,7 +164,9 @@ def __call__(self, *args, **kwargs) -> HeaderFactory: kwargs = {"credentials_provider": MyProvider()} mock_http_client = MagicMock() auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client, **kwargs) - self.assertTrue(type(auth_provider).__name__, "ExternalAuthProvider") + + self.assertEqual(type(auth_provider).__name__, "TokenFederationProvider") + self.assertEqual(type(auth_provider.external_provider).__name__, "ExternalAuthProvider") headers = {} auth_provider.add_headers(headers) @@ -199,8 +201,11 @@ def test_get_python_sql_connector_default_auth(self, mock__initial_get_token): hostname = "foo.cloud.databricks.com" mock_http_client = MagicMock() auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client) - self.assertTrue(type(auth_provider).__name__, "DatabricksOAuthProvider") - self.assertTrue(auth_provider._client_id, PYSQL_OAUTH_CLIENT_ID) + + self.assertEqual(type(auth_provider).__name__, "TokenFederationProvider") + self.assertEqual(type(auth_provider.external_provider).__name__, "DatabricksOAuthProvider") + + self.assertEqual(auth_provider.external_provider._client_id, PYSQL_OAUTH_CLIENT_ID) class TestClientCredentialsTokenSource: From 1e604a158d81edc9138b8b375f12355ebb511c80 Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Fri, 26 Sep 2025 06:28:50 +0000 Subject: [PATCH 4/7] lint --- src/databricks/sql/auth/token_federation.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/databricks/sql/auth/token_federation.py b/src/databricks/sql/auth/token_federation.py index e8e2d4d43..40a3ef52e 100644 --- a/src/databricks/sql/auth/token_federation.py +++ b/src/databricks/sql/auth/token_federation.py @@ -58,7 +58,9 @@ def add_headers(self, request_headers: Dict[str, str]): """Add authentication headers to the request.""" if self._cached_token and not self._cached_token.is_expired(): - request_headers["Authorization"] = f"{self._cached_token.token_type} {self._cached_token.access_token}" + request_headers[ + "Authorization" + ] = f"{self._cached_token.token_type} {self._cached_token.access_token}" return # Get the external headers first to check if we need token federation From b79af036aa36347f0065358395c988f4c4cdcf8e Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Fri, 26 Sep 2025 06:35:36 +0000 Subject: [PATCH 5/7] lint fix --- src/databricks/sql/auth/auth.py | 2 +- src/databricks/sql/auth/token_federation.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/databricks/sql/auth/auth.py b/src/databricks/sql/auth/auth.py index b50f93c9c..a5de0d622 100755 --- a/src/databricks/sql/auth/auth.py +++ b/src/databricks/sql/auth/auth.py @@ -13,7 +13,7 @@ def get_auth_provider(cfg: ClientContext, http_client): # Determine the base auth provider - base_provider = None + base_provider: Optional[AuthProvider] = None if cfg.credentials_provider: base_provider = ExternalAuthProvider(cfg.credentials_provider) diff --git a/src/databricks/sql/auth/token_federation.py b/src/databricks/sql/auth/token_federation.py index 40a3ef52e..660935425 100644 --- a/src/databricks/sql/auth/token_federation.py +++ b/src/databricks/sql/auth/token_federation.py @@ -52,7 +52,7 @@ def __init__( self.identity_federation_client_id = identity_federation_client_id self._cached_token: Optional[Token] = None - self._external_headers = {} + self._external_headers: Dict[str, str] = {} def add_headers(self, request_headers: Dict[str, str]): """Add authentication headers to the request.""" From 010b58e424572250fc053d27435abad2402a1ea2 Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Fri, 26 Sep 2025 09:14:35 +0000 Subject: [PATCH 6/7] nit --- src/databricks/sql/auth/auth_utils.py | 53 -------------------- src/databricks/sql/auth/token_federation.py | 55 ++++++++++++++++++++- 2 files changed, 54 insertions(+), 54 deletions(-) diff --git a/src/databricks/sql/auth/auth_utils.py b/src/databricks/sql/auth/auth_utils.py index 76233d96c..439aabc51 100644 --- a/src/databricks/sql/auth/auth_utils.py +++ b/src/databricks/sql/auth/auth_utils.py @@ -62,56 +62,3 @@ def is_same_host(url1: str, url2: str) -> bool: except Exception as e: logger.debug("Failed to parse URLs: %s", e) return False - - -class Token: - """ - Represents an OAuth token with expiration management. - """ - - def __init__(self, access_token: str, token_type: str = "Bearer"): - """ - Initialize a token. - - Args: - access_token: The access token string - token_type: The token type (default: Bearer) - """ - self.access_token = access_token - self.token_type = token_type - self.expiry_time = self._calculate_expiry() - - def _calculate_expiry(self) -> datetime: - """ - Calculate the token expiry time from JWT claims. - - Returns: - The token expiry datetime - """ - decoded = decode_token(self.access_token) - if decoded and "exp" in decoded: - # Use JWT exp claim with 1 minute buffer - return datetime.fromtimestamp(decoded["exp"]) - timedelta(minutes=1) - # Default to 1 hour if no expiry info - return datetime.now() + timedelta(hours=1) - - def is_expired(self) -> bool: - """ - Check if the token is expired. - - Returns: - True if token is expired, False otherwise - """ - return datetime.now() >= self.expiry_time - - def to_dict(self) -> Dict[str, str]: - """ - Convert token to dictionary format. - - Returns: - Dictionary with access_token and token_type - """ - return { - "access_token": self.access_token, - "token_type": self.token_type, - } diff --git a/src/databricks/sql/auth/token_federation.py b/src/databricks/sql/auth/token_federation.py index 660935425..7b62f6762 100644 --- a/src/databricks/sql/auth/token_federation.py +++ b/src/databricks/sql/auth/token_federation.py @@ -1,11 +1,11 @@ import logging import json +from datetime import datetime, timedelta from typing import Optional, Dict, Tuple from urllib.parse import urlencode from databricks.sql.auth.authenticators import AuthProvider from databricks.sql.auth.auth_utils import ( - Token, parse_hostname, decode_token, is_same_host, @@ -15,6 +15,59 @@ logger = logging.getLogger(__name__) +class Token: + """ + Represents an OAuth token with expiration management. + """ + + def __init__(self, access_token: str, token_type: str = "Bearer"): + """ + Initialize a token. + + Args: + access_token: The access token string + token_type: The token type (default: Bearer) + """ + self.access_token = access_token + self.token_type = token_type + self.expiry_time = self._calculate_expiry() + + def _calculate_expiry(self) -> datetime: + """ + Calculate the token expiry time from JWT claims. + + Returns: + The token expiry datetime + """ + decoded = decode_token(self.access_token) + if decoded and "exp" in decoded: + # Use JWT exp claim with 1 minute buffer + return datetime.fromtimestamp(decoded["exp"]) - timedelta(minutes=1) + # Default to 1 hour if no expiry info + return datetime.now() + timedelta(hours=1) + + def is_expired(self) -> bool: + """ + Check if the token is expired. + + Returns: + True if token is expired, False otherwise + """ + return datetime.now() >= self.expiry_time + + def to_dict(self) -> Dict[str, str]: + """ + Convert token to dictionary format. + + Returns: + Dictionary with access_token and token_type + """ + return { + "access_token": self.access_token, + "token_type": self.token_type, + } + + class TokenFederationProvider(AuthProvider): """ Implementation of Token Federation for Databricks SQL Python driver. From 3398ce6bbbedbfa9f1242e1e9e3957947311e2c4 Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Fri, 26 Sep 2025 09:15:44 +0000 Subject: [PATCH 7/7] change import --- tests/unit/test_token_federation.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/unit/test_token_federation.py b/tests/unit/test_token_federation.py index e2b8b57eb..2e671c33e 100644 --- a/tests/unit/test_token_federation.py +++ b/tests/unit/test_token_federation.py @@ -4,9 +4,8 @@ import jwt from datetime import datetime, timedelta -from databricks.sql.auth.token_federation import TokenFederationProvider +from databricks.sql.auth.token_federation import TokenFederationProvider, Token from databricks.sql.auth.auth_utils import ( - Token, parse_hostname, decode_token, is_same_host,