From 4437a2ac0b46b700a65fb8ad946c97b8f212c52e Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Fri, 8 Aug 2025 12:58:53 +0530 Subject: [PATCH 1/4] Refactor codebase to use a unified http client Signed-off-by: Vikrant Puppala --- src/databricks/sql/auth/auth.py | 4 +- src/databricks/sql/auth/authenticators.py | 2 + src/databricks/sql/auth/common.py | 61 +++-- src/databricks/sql/auth/oauth.py | 28 ++- src/databricks/sql/backend/sea/queue.py | 4 + src/databricks/sql/backend/sea/result_set.py | 1 + src/databricks/sql/client.py | 38 ++- .../sql/cloudfetch/download_manager.py | 3 + src/databricks/sql/cloudfetch/downloader.py | 79 +++--- src/databricks/sql/common/feature_flag.py | 16 +- src/databricks/sql/common/http.py | 112 --------- .../sql/common/unified_http_client.py | 226 ++++++++++++++++++ src/databricks/sql/result_set.py | 1 + src/databricks/sql/session.py | 39 ++- .../sql/telemetry/telemetry_client.py | 22 +- src/databricks/sql/utils.py | 15 +- 16 files changed, 440 insertions(+), 211 deletions(-) create mode 100644 src/databricks/sql/common/unified_http_client.py diff --git a/src/databricks/sql/auth/auth.py b/src/databricks/sql/auth/auth.py index 3792d6d05..a8d0671b0 100755 --- a/src/databricks/sql/auth/auth.py +++ b/src/databricks/sql/auth/auth.py @@ -35,6 +35,7 @@ def get_auth_provider(cfg: ClientContext): cfg.oauth_client_id, cfg.oauth_scopes, cfg.auth_type, + http_client=http_client, ) elif cfg.access_token is not None: return AccessTokenAuthProvider(cfg.access_token) @@ -53,6 +54,7 @@ def get_auth_provider(cfg: ClientContext): cfg.oauth_redirect_port_range, cfg.oauth_client_id, cfg.oauth_scopes, + http_client=http_client, ) else: raise RuntimeError("No valid authentication settings!") @@ -79,7 +81,7 @@ def get_client_id_and_redirect_port(use_azure_auth: bool): ) -def get_python_sql_connector_auth_provider(hostname: str, **kwargs): +def get_python_sql_connector_auth_provider(hostname: str, http_client, **kwargs): # TODO : unify all the auth mechanisms with the Python SDK auth_type = kwargs.get("auth_type") diff --git a/src/databricks/sql/auth/authenticators.py b/src/databricks/sql/auth/authenticators.py index 26c1f3708..80f44812c 100644 --- a/src/databricks/sql/auth/authenticators.py +++ b/src/databricks/sql/auth/authenticators.py @@ -63,6 +63,7 @@ def __init__( redirect_port_range: List[int], client_id: str, scopes: List[str], + http_client, auth_type: str = "databricks-oauth", ): try: @@ -79,6 +80,7 @@ def __init__( port_range=redirect_port_range, client_id=client_id, idp_endpoint=idp_endpoint, + http_client=http_client, ) self._hostname = hostname self._scopes_as_str = DatabricksOAuthProvider.SCOPE_DELIM.join(cloud_scopes) diff --git a/src/databricks/sql/auth/common.py b/src/databricks/sql/auth/common.py index 5cfbc37c0..262166a52 100644 --- a/src/databricks/sql/auth/common.py +++ b/src/databricks/sql/auth/common.py @@ -2,7 +2,6 @@ import logging from typing import Optional, List from urllib.parse import urlparse -from databricks.sql.common.http import DatabricksHttpClient, HttpMethod logger = logging.getLogger(__name__) @@ -36,6 +35,21 @@ def __init__( tls_client_cert_file: Optional[str] = None, oauth_persistence=None, credentials_provider=None, + # HTTP client configuration parameters + ssl_options=None, # SSLOptions type + socket_timeout: Optional[float] = None, + retry_stop_after_attempts_count: Optional[int] = None, + retry_delay_min: Optional[float] = None, + retry_delay_max: Optional[float] = None, + retry_stop_after_attempts_duration: Optional[float] = None, + retry_delay_default: Optional[float] = None, + retry_dangerous_codes: Optional[List[int]] = None, + http_proxy: Optional[str] = None, + proxy_username: Optional[str] = None, + proxy_password: Optional[str] = None, + pool_connections: Optional[int] = None, + pool_maxsize: Optional[int] = None, + user_agent: Optional[str] = None, ): self.hostname = hostname self.access_token = access_token @@ -51,6 +65,22 @@ def __init__( self.tls_client_cert_file = tls_client_cert_file self.oauth_persistence = oauth_persistence self.credentials_provider = credentials_provider + + # HTTP client configuration + self.ssl_options = ssl_options + self.socket_timeout = socket_timeout + self.retry_stop_after_attempts_count = retry_stop_after_attempts_count or 30 + self.retry_delay_min = retry_delay_min or 1.0 + self.retry_delay_max = retry_delay_max or 60.0 + self.retry_stop_after_attempts_duration = retry_stop_after_attempts_duration or 900.0 + self.retry_delay_default = retry_delay_default or 5.0 + self.retry_dangerous_codes = retry_dangerous_codes or [] + self.http_proxy = http_proxy + self.proxy_username = proxy_username + self.proxy_password = proxy_password + self.pool_connections = pool_connections or 10 + self.pool_maxsize = pool_maxsize or 20 + self.user_agent = user_agent def get_effective_azure_login_app_id(hostname) -> str: @@ -69,7 +99,7 @@ def get_effective_azure_login_app_id(hostname) -> str: return AzureAppId.PROD.value[1] -def get_azure_tenant_id_from_host(host: str, http_client=None) -> str: +def get_azure_tenant_id_from_host(host: str, http_client) -> str: """ Load the Azure tenant ID from the Azure Databricks login page. @@ -78,23 +108,22 @@ def get_azure_tenant_id_from_host(host: str, http_client=None) -> str: the Azure login page, and the tenant ID is extracted from the redirect URL. """ - if http_client is None: - http_client = DatabricksHttpClient.get_instance() - login_url = f"{host}/aad/auth" logger.debug("Loading tenant ID from %s", login_url) - with http_client.execute(HttpMethod.GET, login_url, allow_redirects=False) as resp: - if resp.status_code // 100 != 3: + + with http_client.request_context('GET', login_url, allow_redirects=False) as resp: + if resp.status // 100 != 3: raise ValueError( - f"Failed to get tenant ID from {login_url}: expected status code 3xx, got {resp.status_code}" + f"Failed to get tenant ID from {login_url}: expected status code 3xx, got {resp.status}" ) - entra_id_endpoint = resp.headers.get("Location") + entra_id_endpoint = dict(resp.headers).get("Location") if entra_id_endpoint is None: raise ValueError(f"No Location header in response from {login_url}") - # The Location header has the following form: https://login.microsoftonline.com//oauth2/authorize?... - # The domain may change depending on the Azure cloud (e.g. login.microsoftonline.us for US Government cloud). - url = urlparse(entra_id_endpoint) - path_segments = url.path.split("/") - if len(path_segments) < 2: - raise ValueError(f"Invalid path in Location header: {url.path}") - return path_segments[1] + + # The Location header has the following form: https://login.microsoftonline.com//oauth2/authorize?... + # The domain may change depending on the Azure cloud (e.g. login.microsoftonline.us for US Government cloud). + url = urlparse(entra_id_endpoint) + path_segments = url.path.split("/") + if len(path_segments) < 2: + raise ValueError(f"Invalid path in Location header: {url.path}") + return path_segments[1] diff --git a/src/databricks/sql/auth/oauth.py b/src/databricks/sql/auth/oauth.py index aa3184d88..0d67929a3 100644 --- a/src/databricks/sql/auth/oauth.py +++ b/src/databricks/sql/auth/oauth.py @@ -9,10 +9,8 @@ from typing import List, Optional import oauthlib.oauth2 -import requests from oauthlib.oauth2.rfc6749.errors import OAuth2Error -from requests.exceptions import RequestException -from databricks.sql.common.http import HttpMethod, DatabricksHttpClient, HttpHeader +from databricks.sql.common.http import HttpMethod, HttpHeader from databricks.sql.common.http import OAuthResponse from databricks.sql.auth.oauth_http_handler import OAuthHttpSingleRequestHandler from databricks.sql.auth.endpoint import OAuthEndpointCollection @@ -85,11 +83,13 @@ def __init__( port_range: List[int], client_id: str, idp_endpoint: OAuthEndpointCollection, + http_client, ): self.port_range = port_range self.client_id = client_id self.redirect_port = None self.idp_endpoint = idp_endpoint + self.http_client = http_client @staticmethod def __token_urlsafe(nbytes=32): @@ -103,8 +103,12 @@ def __fetch_well_known_config(self, hostname: str): known_config_url = self.idp_endpoint.get_openid_config_url(hostname) try: - response = requests.get(url=known_config_url, auth=IgnoreNetrcAuth()) - except RequestException as e: + from databricks.sql.common.unified_http_client import IgnoreNetrcAuth + response = self.http_client.request('GET', url=known_config_url) + # Convert urllib3 response to requests-like response for compatibility + response.status_code = response.status + response.json = lambda: json.loads(response.data.decode()) + except Exception as e: logger.error( f"Unable to fetch OAuth configuration from {known_config_url}.\n" "Verify it is a valid workspace URL and that OAuth is " @@ -122,7 +126,7 @@ def __fetch_well_known_config(self, hostname: str): raise RuntimeError(msg) try: return response.json() - except requests.exceptions.JSONDecodeError as e: + except Exception as e: logger.error( f"Unable to decode OAuth configuration from {known_config_url}.\n" "Verify it is a valid workspace URL and that OAuth is " @@ -209,10 +213,13 @@ def __send_token_request(token_request_url, data): "Accept": "application/json", "Content-Type": "application/x-www-form-urlencoded", } - response = requests.post( - url=token_request_url, data=data, headers=headers, auth=IgnoreNetrcAuth() + # Use unified HTTP client + from databricks.sql.common.unified_http_client import IgnoreNetrcAuth + response = self.http_client.request( + 'POST', url=token_request_url, body=data, headers=headers ) - return response.json() + # Convert urllib3 response to dict for compatibility + return json.loads(response.data.decode()) def __send_refresh_token_request(self, hostname, refresh_token): oauth_config = self.__fetch_well_known_config(hostname) @@ -320,6 +327,7 @@ def __init__( token_url, client_id, client_secret, + http_client, extra_params: dict = {}, ): self.client_id = client_id @@ -327,7 +335,7 @@ def __init__( self.token_url = token_url self.extra_params = extra_params self.token: Optional[Token] = None - self._http_client = DatabricksHttpClient.get_instance() + self._http_client = http_client def get_token(self) -> Token: if self.token is None or self.token.is_expired(): diff --git a/src/databricks/sql/backend/sea/queue.py b/src/databricks/sql/backend/sea/queue.py index 130f0c5bf..4a319c442 100644 --- a/src/databricks/sql/backend/sea/queue.py +++ b/src/databricks/sql/backend/sea/queue.py @@ -50,6 +50,7 @@ def build_queue( max_download_threads: int, sea_client: SeaDatabricksClient, lz4_compressed: bool, + http_client, ) -> ResultSetQueue: """ Factory method to build a result set queue for SEA backend. @@ -94,6 +95,7 @@ def build_queue( total_chunk_count=manifest.total_chunk_count, lz4_compressed=lz4_compressed, description=description, + http_client=http_client, ) raise ProgrammingError("Invalid result format") @@ -309,6 +311,7 @@ def __init__( sea_client: SeaDatabricksClient, statement_id: str, total_chunk_count: int, + http_client, lz4_compressed: bool = False, description: List[Tuple] = [], ): @@ -337,6 +340,7 @@ def __init__( # TODO: fix these arguments when telemetry is implemented in SEA session_id_hex=None, chunk_id=0, + http_client=http_client, ) logger.debug( diff --git a/src/databricks/sql/backend/sea/result_set.py b/src/databricks/sql/backend/sea/result_set.py index afa70bc89..17838ed81 100644 --- a/src/databricks/sql/backend/sea/result_set.py +++ b/src/databricks/sql/backend/sea/result_set.py @@ -64,6 +64,7 @@ def __init__( max_download_threads=sea_client.max_download_threads, sea_client=sea_client, lz4_compressed=execute_response.lz4_compressed, + http_client=connection.session.http_client, ) # Call parent constructor with common attributes diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 73ee0e03c..295be29dc 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -6,7 +6,6 @@ import pyarrow except ImportError: pyarrow = None -import requests import json import os import decimal @@ -292,6 +291,7 @@ def read(self) -> Optional[OAuthToken]: auth_provider=self.session.auth_provider, host_url=self.session.host, batch_size=self.telemetry_batch_size, + http_client=self.session.http_client, ) self._telemetry_client = TelemetryClientFactory.get_telemetry_client( @@ -744,16 +744,20 @@ def _handle_staging_put( ) with open(local_file, "rb") as fh: - r = requests.put(url=presigned_url, data=fh, headers=headers) + r = self.connection.session.http_client.request('PUT', presigned_url, body=fh.read(), headers=headers) + # Add compatibility attributes for urllib3 response + r.status_code = r.status + if hasattr(r, 'data'): + r.content = r.data + r.ok = r.status < 400 + r.text = r.data.decode() if r.data else "" # fmt: off - # Design borrowed from: https://stackoverflow.com/a/2342589/5093960 - - OK = requests.codes.ok # 200 - CREATED = requests.codes.created # 201 - ACCEPTED = requests.codes.accepted # 202 - NO_CONTENT = requests.codes.no_content # 204 - + # HTTP status codes + OK = 200 + CREATED = 201 + ACCEPTED = 202 + NO_CONTENT = 204 # fmt: on if r.status_code not in [OK, CREATED, NO_CONTENT, ACCEPTED]: @@ -783,7 +787,13 @@ def _handle_staging_get( session_id_hex=self.connection.get_session_id_hex(), ) - r = requests.get(url=presigned_url, headers=headers) + r = self.connection.session.http_client.request('GET', presigned_url, headers=headers) + # Add compatibility attributes for urllib3 response + r.status_code = r.status + if hasattr(r, 'data'): + r.content = r.data + r.ok = r.status < 400 + r.text = r.data.decode() if r.data else "" # response.ok verifies the status code is not between 400-600. # Any 2xx or 3xx will evaluate r.ok == True @@ -802,7 +812,13 @@ def _handle_staging_remove( ): """Make an HTTP DELETE request to the presigned_url""" - r = requests.delete(url=presigned_url, headers=headers) + r = self.connection.session.http_client.request('DELETE', presigned_url, headers=headers) + # Add compatibility attributes for urllib3 response + r.status_code = r.status + if hasattr(r, 'data'): + r.content = r.data + r.ok = r.status < 400 + r.text = r.data.decode() if r.data else "" if not r.ok: raise OperationalError( diff --git a/src/databricks/sql/cloudfetch/download_manager.py b/src/databricks/sql/cloudfetch/download_manager.py index 32b698bed..27265720f 100644 --- a/src/databricks/sql/cloudfetch/download_manager.py +++ b/src/databricks/sql/cloudfetch/download_manager.py @@ -25,6 +25,7 @@ def __init__( session_id_hex: Optional[str], statement_id: str, chunk_id: int, + http_client, ): self._pending_links: List[Tuple[int, TSparkArrowResultLink]] = [] self.chunk_id = chunk_id @@ -47,6 +48,7 @@ def __init__( self._ssl_options = ssl_options self.session_id_hex = session_id_hex self.statement_id = statement_id + self._http_client = http_client def get_next_downloaded_file( self, next_row_offset: int @@ -109,6 +111,7 @@ def _schedule_downloads(self): chunk_id=chunk_id, session_id_hex=self.session_id_hex, statement_id=self.statement_id, + http_client=self._http_client, ) task = self._thread_pool.submit(handler.run) self._download_tasks.append(task) diff --git a/src/databricks/sql/cloudfetch/downloader.py b/src/databricks/sql/cloudfetch/downloader.py index 1331fa203..ea375fbbb 100644 --- a/src/databricks/sql/cloudfetch/downloader.py +++ b/src/databricks/sql/cloudfetch/downloader.py @@ -2,10 +2,9 @@ from dataclasses import dataclass from typing import Optional -from requests.adapters import Retry import lz4.frame import time -from databricks.sql.common.http import DatabricksHttpClient, HttpMethod +from databricks.sql.common.http import HttpMethod from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink from databricks.sql.exc import Error from databricks.sql.types import SSLOptions @@ -16,16 +15,6 @@ # TODO: Ideally, we should use a common retry policy (DatabricksRetryPolicy) for all the requests across the library. # But DatabricksRetryPolicy should be updated first - currently it can work only with Thrift requests -retryPolicy = Retry( - total=5, # max retry attempts - backoff_factor=1, # min delay, 1 second - # TODO: `backoff_max` is supported since `urllib3` v2.0.0, but we allow >= 1.26. - # The default value (120 seconds) used since v1.26 looks reasonable enough - # backoff_max=60, # max delay, 60 seconds - # retry all status codes below 100, 429 (Too Many Requests), and all codes above 500, - # excluding 501 Not implemented - status_forcelist=[*range(0, 101), 429, 500, *range(502, 1000)], -) @dataclass @@ -73,11 +62,12 @@ def __init__( chunk_id: int, session_id_hex: Optional[str], statement_id: str, + http_client, ): self.settings = settings self.link = link self._ssl_options = ssl_options - self._http_client = DatabricksHttpClient.get_instance() + self._http_client = http_client self.chunk_id = chunk_id self.session_id_hex = session_id_hex self.statement_id = statement_id @@ -104,50 +94,47 @@ def run(self) -> DownloadedFile: start_time = time.time() - with self._http_client.execute( - method=HttpMethod.GET, + with self._http_client.request_context( + method='GET', url=self.link.fileLink, timeout=self.settings.download_timeout, - verify=self._ssl_options.tls_verify, headers=self.link.httpHeaders - # TODO: Pass cert from `self._ssl_options` ) as response: - response.raise_for_status() - - # Save (and decompress if needed) the downloaded file - compressed_data = response.content - - # Log download metrics - download_duration = time.time() - start_time - self._log_download_metrics( - self.link.fileLink, len(compressed_data), download_duration - ) - - decompressed_data = ( - ResultSetDownloadHandler._decompress_data(compressed_data) - if self.settings.is_lz4_compressed - else compressed_data - ) + if response.status >= 400: + raise Exception(f"HTTP {response.status}: {response.data.decode()}") + compressed_data = response.data + + # Log download metrics + download_duration = time.time() - start_time + self._log_download_metrics( + self.link.fileLink, len(compressed_data), download_duration + ) - # The size of the downloaded file should match the size specified from TSparkArrowResultLink - if len(decompressed_data) != self.link.bytesNum: - logger.debug( - "ResultSetDownloadHandler: downloaded file size {} does not match the expected value {}".format( - len(decompressed_data), self.link.bytesNum - ) - ) + decompressed_data = ( + ResultSetDownloadHandler._decompress_data(compressed_data) + if self.settings.is_lz4_compressed + else compressed_data + ) + # The size of the downloaded file should match the size specified from TSparkArrowResultLink + if len(decompressed_data) != self.link.bytesNum: logger.debug( - "ResultSetDownloadHandler: successfully downloaded file, offset {}, row count {}".format( - self.link.startRowOffset, self.link.rowCount + "ResultSetDownloadHandler: downloaded file size {} does not match the expected value {}".format( + len(decompressed_data), self.link.bytesNum ) ) - return DownloadedFile( - decompressed_data, - self.link.startRowOffset, - self.link.rowCount, + logger.debug( + "ResultSetDownloadHandler: successfully downloaded file, offset {}, row count {}".format( + self.link.startRowOffset, self.link.rowCount ) + ) + + return DownloadedFile( + decompressed_data, + self.link.startRowOffset, + self.link.rowCount, + ) def _log_download_metrics( self, url: str, bytes_downloaded: int, duration_seconds: float diff --git a/src/databricks/sql/common/feature_flag.py b/src/databricks/sql/common/feature_flag.py index 53add9253..8e7029805 100644 --- a/src/databricks/sql/common/feature_flag.py +++ b/src/databricks/sql/common/feature_flag.py @@ -1,6 +1,6 @@ +import json import threading import time -import requests from dataclasses import dataclass, field from concurrent.futures import ThreadPoolExecutor from typing import Dict, Optional, List, Any, TYPE_CHECKING @@ -49,7 +49,7 @@ class FeatureFlagsContext: in the background, returning stale data until the refresh completes. """ - def __init__(self, connection: "Connection", executor: ThreadPoolExecutor): + def __init__(self, connection: "Connection", executor: ThreadPoolExecutor, http_client): from databricks.sql import __version__ self._connection = connection @@ -65,6 +65,9 @@ def __init__(self, connection: "Connection", executor: ThreadPoolExecutor): self._feature_flag_endpoint = ( f"https://{self._connection.session.host}{endpoint_suffix}" ) + + # Use the provided HTTP client + self._http_client = http_client def _is_refresh_needed(self) -> bool: """Checks if the cache is due for a proactive background refresh.""" @@ -105,9 +108,12 @@ def _refresh_flags(self): self._connection.session.auth_provider.add_headers(headers) headers["User-Agent"] = self._connection.session.useragent_header - response = requests.get( - self._feature_flag_endpoint, headers=headers, timeout=30 + response = self._http_client.request( + 'GET', self._feature_flag_endpoint, headers=headers, timeout=30 ) + # Add compatibility attributes for urllib3 response + response.status_code = response.status + response.json = lambda: json.loads(response.data.decode()) if response.status_code == 200: ff_response = FeatureFlagsResponse.from_dict(response.json()) @@ -159,7 +165,7 @@ def get_instance(cls, connection: "Connection") -> FeatureFlagsContext: # Use the unique session ID as the key key = connection.get_session_id_hex() if key not in cls._context_map: - cls._context_map[key] = FeatureFlagsContext(connection, cls._executor) + cls._context_map[key] = FeatureFlagsContext(connection, cls._executor, connection.session.http_client) return cls._context_map[key] @classmethod diff --git a/src/databricks/sql/common/http.py b/src/databricks/sql/common/http.py index 0cd2919c0..cf76a5fba 100644 --- a/src/databricks/sql/common/http.py +++ b/src/databricks/sql/common/http.py @@ -38,115 +38,3 @@ class OAuthResponse: resource: str = "" access_token: str = "" refresh_token: str = "" - - -# Singleton class for common Http Client -class DatabricksHttpClient: - ## TODO: Unify all the http clients in the PySQL Connector - - _instance = None - _lock = threading.Lock() - - def __init__(self): - self.session = requests.Session() - adapter = HTTPAdapter( - pool_connections=5, - pool_maxsize=10, - max_retries=Retry(total=10, backoff_factor=0.1), - ) - self.session.mount("https://", adapter) - self.session.mount("http://", adapter) - - @classmethod - def get_instance(cls) -> "DatabricksHttpClient": - if cls._instance is None: - with cls._lock: - if cls._instance is None: - cls._instance = DatabricksHttpClient() - return cls._instance - - @contextmanager - def execute( - self, method: HttpMethod, url: str, **kwargs - ) -> Generator[requests.Response, None, None]: - logger.info("Executing HTTP request: %s with url: %s", method.value, url) - response = None - try: - response = self.session.request(method.value, url, **kwargs) - yield response - except Exception as e: - logger.error("Error executing HTTP request in DatabricksHttpClient: %s", e) - raise e - finally: - if response is not None: - response.close() - - def close(self): - self.session.close() - - -class TelemetryHTTPAdapter(HTTPAdapter): - """ - Custom HTTP adapter to prepare our DatabricksRetryPolicy before each request. - This ensures the retry timer is started and the command type is set correctly, - allowing the policy to manage its state for the duration of the request retries. - """ - - def send(self, request, **kwargs): - self.max_retries.command_type = CommandType.OTHER - self.max_retries.start_retry_timer() - return super().send(request, **kwargs) - - -class TelemetryHttpClient: # TODO: Unify all the http clients in the PySQL Connector - """Singleton HTTP client for sending telemetry data.""" - - _instance: Optional["TelemetryHttpClient"] = None - _lock = threading.Lock() - - TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_COUNT = 3 - TELEMETRY_RETRY_DELAY_MIN = 1.0 - TELEMETRY_RETRY_DELAY_MAX = 10.0 - TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_DURATION = 30.0 - - def __init__(self): - """Initializes the session and mounts the custom retry adapter.""" - retry_policy = DatabricksRetryPolicy( - delay_min=self.TELEMETRY_RETRY_DELAY_MIN, - delay_max=self.TELEMETRY_RETRY_DELAY_MAX, - stop_after_attempts_count=self.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_COUNT, - stop_after_attempts_duration=self.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_DURATION, - delay_default=1.0, - force_dangerous_codes=[], - ) - adapter = TelemetryHTTPAdapter(max_retries=retry_policy) - self.session = requests.Session() - self.session.mount("https://", adapter) - self.session.mount("http://", adapter) - - @classmethod - def get_instance(cls) -> "TelemetryHttpClient": - """Get the singleton instance of the TelemetryHttpClient.""" - if cls._instance is None: - with cls._lock: - if cls._instance is None: - logger.debug("Initializing singleton TelemetryHttpClient") - cls._instance = TelemetryHttpClient() - return cls._instance - - def post(self, url: str, **kwargs) -> requests.Response: - """ - Executes a POST request using the configured session. - - This is a blocking call intended to be run in a background thread. - """ - logger.debug("Executing telemetry POST request to: %s", url) - return self.session.post(url, **kwargs) - - def close(self): - """Closes the underlying requests.Session.""" - logger.debug("Closing TelemetryHttpClient session.") - self.session.close() - # Clear the instance to allow for re-initialization if needed - with TelemetryHttpClient._lock: - TelemetryHttpClient._instance = None diff --git a/src/databricks/sql/common/unified_http_client.py b/src/databricks/sql/common/unified_http_client.py new file mode 100644 index 000000000..8c3be2bfd --- /dev/null +++ b/src/databricks/sql/common/unified_http_client.py @@ -0,0 +1,226 @@ +import logging +import ssl +import urllib.parse +from contextlib import contextmanager +from typing import Dict, Any, Optional, Generator, Union + +import urllib3 +from urllib3 import PoolManager, ProxyManager +from urllib3.util import make_headers +from urllib3.exceptions import MaxRetryError + +from databricks.sql.auth.retry import DatabricksRetryPolicy +from databricks.sql.exc import RequestError + +logger = logging.getLogger(__name__) + + +class UnifiedHttpClient: + """ + Unified HTTP client for all Databricks SQL connector HTTP operations. + + This client uses urllib3 for robust HTTP communication with retry policies, + connection pooling, SSL support, and proxy support. It replaces the various + singleton HTTP clients and direct requests usage throughout the codebase. + """ + + def __init__(self, client_context): + """ + Initialize the unified HTTP client. + + Args: + client_context: ClientContext instance containing HTTP configuration + """ + self.config = client_context + self._pool_manager = None + self._setup_pool_manager() + + def _setup_pool_manager(self): + """Set up the urllib3 PoolManager with configuration from ClientContext.""" + + # SSL context setup + ssl_context = None + if self.config.ssl_options: + ssl_context = ssl.create_default_context() + + # Configure SSL verification + if not self.config.ssl_options.tls_verify: + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + elif not self.config.ssl_options.tls_verify_hostname: + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_REQUIRED + + # Load custom CA file if specified + if self.config.ssl_options.tls_trusted_ca_file: + ssl_context.load_verify_locations(self.config.ssl_options.tls_trusted_ca_file) + + # Load client certificate if specified + if (self.config.ssl_options.tls_client_cert_file and + self.config.ssl_options.tls_client_cert_key_file): + ssl_context.load_cert_chain( + self.config.ssl_options.tls_client_cert_file, + self.config.ssl_options.tls_client_cert_key_file, + self.config.ssl_options.tls_client_cert_key_password + ) + + # Create retry policy + retry_policy = DatabricksRetryPolicy( + delay_min=self.config.retry_delay_min, + delay_max=self.config.retry_delay_max, + stop_after_attempts_count=self.config.retry_stop_after_attempts_count, + stop_after_attempts_duration=self.config.retry_stop_after_attempts_duration, + delay_default=self.config.retry_delay_default, + force_dangerous_codes=self.config.retry_dangerous_codes, + ) + + # Common pool manager kwargs + pool_kwargs = { + 'num_pools': self.config.pool_connections, + 'maxsize': self.config.pool_maxsize, + 'retries': retry_policy, + 'timeout': urllib3.Timeout( + connect=self.config.socket_timeout, + read=self.config.socket_timeout + ) if self.config.socket_timeout else None, + 'ssl_context': ssl_context, + } + + # Create proxy or regular pool manager + if self.config.http_proxy: + proxy_headers = None + if self.config.proxy_username and self.config.proxy_password: + proxy_headers = make_headers( + proxy_basic_auth=f"{self.config.proxy_username}:{self.config.proxy_password}" + ) + + self._pool_manager = ProxyManager( + self.config.http_proxy, + proxy_headers=proxy_headers, + **pool_kwargs + ) + else: + self._pool_manager = PoolManager(**pool_kwargs) + + def _prepare_headers(self, headers: Optional[Dict[str, str]] = None) -> Dict[str, str]: + """Prepare headers for the request, including User-Agent.""" + request_headers = {} + + if self.config.user_agent: + request_headers['User-Agent'] = self.config.user_agent + + if headers: + request_headers.update(headers) + + return request_headers + + @contextmanager + def request_context( + self, + method: str, + url: str, + headers: Optional[Dict[str, str]] = None, + **kwargs + ) -> Generator[urllib3.HTTPResponse, None, None]: + """ + Context manager for making HTTP requests with proper resource cleanup. + + Args: + method: HTTP method (GET, POST, PUT, DELETE) + url: URL to request + headers: Optional headers dict + **kwargs: Additional arguments passed to urllib3 request + + Yields: + urllib3.HTTPResponse: The HTTP response object + """ + logger.debug("Making %s request to %s", method, url) + + request_headers = self._prepare_headers(headers) + response = None + + try: + response = self._pool_manager.request( + method=method, + url=url, + headers=request_headers, + **kwargs + ) + yield response + except MaxRetryError as e: + logger.error("HTTP request failed after retries: %s", e) + raise RequestError(f"HTTP request failed: {e}") + except Exception as e: + logger.error("HTTP request error: %s", e) + raise RequestError(f"HTTP request error: {e}") + finally: + if response: + response.close() + + def request(self, method: str, url: str, headers: Optional[Dict[str, str]] = None, **kwargs) -> urllib3.HTTPResponse: + """ + Make an HTTP request. + + Args: + method: HTTP method (GET, POST, PUT, DELETE, etc.) + url: URL to request + headers: Optional headers dict + **kwargs: Additional arguments passed to urllib3 request + + Returns: + urllib3.HTTPResponse: The HTTP response object with data pre-loaded + """ + with self.request_context(method, url, headers=headers, **kwargs) as response: + # Read the response data to ensure it's available after context exit + response._body = response.data + return response + + def upload_file(self, url: str, file_path: str, headers: Optional[Dict[str, str]] = None) -> urllib3.HTTPResponse: + """ + Upload a file using PUT method. + + Args: + url: URL to upload to + file_path: Path to the file to upload + headers: Optional headers + + Returns: + urllib3.HTTPResponse: The response from the server + """ + with open(file_path, 'rb') as file_obj: + return self.request('PUT', url, body=file_obj.read(), headers=headers) + + def download_file(self, url: str, file_path: str, headers: Optional[Dict[str, str]] = None) -> None: + """ + Download a file using GET method. + + Args: + url: URL to download from + file_path: Path where to save the downloaded file + headers: Optional headers + """ + response = self.request('GET', url, headers=headers) + with open(file_path, 'wb') as file_obj: + file_obj.write(response.data) + + def close(self): + """Close the underlying connection pools.""" + if self._pool_manager: + self._pool_manager.clear() + self._pool_manager = None + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + +# Compatibility class to maintain requests-like interface for OAuth +class IgnoreNetrcAuth: + """ + Compatibility class for OAuth code that expects requests.auth.AuthBase interface. + This is a no-op auth handler since OAuth handles auth differently. + """ + def __call__(self, request): + return request \ No newline at end of file diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 9feb6e924..77673db9a 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -244,6 +244,7 @@ def __init__( session_id_hex=connection.get_session_id_hex(), statement_id=execute_response.command_id.to_hex_guid(), chunk_id=self.num_chunks, + http_client=connection.session.http_client, ) if t_row_set.resultLinks: self.num_chunks += len(t_row_set.resultLinks) diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index f1bc35bee..d0c94b6ba 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -4,6 +4,7 @@ from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import SSLOptions from databricks.sql.auth.auth import get_python_sql_connector_auth_provider +from databricks.sql.auth.common import ClientContext from databricks.sql.exc import SessionAlreadyClosedError, DatabaseError, RequestError from databricks.sql import __version__ from databricks.sql import USER_AGENT_NAME @@ -11,6 +12,7 @@ from databricks.sql.backend.sea.backend import SeaDatabricksClient from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import SessionId, BackendType +from databricks.sql.common.unified_http_client import UnifiedHttpClient logger = logging.getLogger(__name__) @@ -42,10 +44,6 @@ def __init__( self.schema = schema self.http_path = http_path - self.auth_provider = get_python_sql_connector_auth_provider( - server_hostname, **kwargs - ) - user_agent_entry = kwargs.get("user_agent_entry") if user_agent_entry is None: user_agent_entry = kwargs.get("_user_agent_entry") @@ -77,6 +75,15 @@ def __init__( tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"), ) + # Create HTTP client configuration and unified HTTP client + self.client_context = self._build_client_context(server_hostname, **kwargs) + self.http_client = UnifiedHttpClient(self.client_context) + + # Create auth provider with HTTP client context + self.auth_provider = get_python_sql_connector_auth_provider( + server_hostname, http_client=self.http_client, **kwargs + ) + self.backend = self._create_backend( server_hostname, http_path, @@ -88,6 +95,26 @@ def __init__( self.protocol_version = None + def _build_client_context(self, server_hostname: str, **kwargs) -> ClientContext: + """Build ClientContext with HTTP configuration from kwargs.""" + return ClientContext( + hostname=server_hostname, + ssl_options=self.ssl_options, + socket_timeout=kwargs.get("_socket_timeout"), + retry_stop_after_attempts_count=kwargs.get("_retry_stop_after_attempts_count"), + retry_delay_min=kwargs.get("_retry_delay_min"), + retry_delay_max=kwargs.get("_retry_delay_max"), + retry_stop_after_attempts_duration=kwargs.get("_retry_stop_after_attempts_duration"), + retry_delay_default=kwargs.get("_retry_delay_default"), + retry_dangerous_codes=kwargs.get("_retry_dangerous_codes"), + http_proxy=kwargs.get("http_proxy"), + proxy_username=kwargs.get("proxy_username"), + proxy_password=kwargs.get("proxy_password"), + pool_connections=kwargs.get("pool_connections"), + pool_maxsize=kwargs.get("pool_maxsize"), + user_agent=self.useragent_header, + ) + def _create_backend( self, server_hostname: str, @@ -185,3 +212,7 @@ def close(self) -> None: logger.error("Attempt to close session raised a local exception: %s", e) self.is_open = False + + # Close HTTP client if it exists + if hasattr(self, 'http_client') and self.http_client: + self.http_client.close() diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 55f06c8df..93cef3600 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -168,6 +168,7 @@ def __init__( host_url, executor, batch_size, + http_client, ): logger.debug("Initializing TelemetryClient for connection: %s", session_id_hex) self._telemetry_enabled = telemetry_enabled @@ -180,7 +181,7 @@ def __init__( self._driver_connection_params = None self._host_url = host_url self._executor = executor - self._http_client = TelemetryHttpClient.get_instance() + self._http_client = http_client def _export_event(self, event): """Add an event to the batch queue and flush if batch is full""" @@ -228,19 +229,34 @@ def _send_telemetry(self, events): try: logger.debug("Submitting telemetry request to thread pool") + + # Use unified HTTP client future = self._executor.submit( - self._http_client.post, + self._send_with_unified_client, url, data=request.to_json(), headers=headers, timeout=900, ) + future.add_done_callback( lambda fut: self._telemetry_request_callback(fut, sent_count=sent_count) ) except Exception as e: logger.debug("Failed to submit telemetry request: %s", e) + def _send_with_unified_client(self, url, data, headers): + """Helper method to send telemetry using the unified HTTP client.""" + try: + response = self._http_client.request('POST', url, body=data, headers=headers, timeout=900) + # Convert urllib3 response to requests-like response for compatibility + response.status_code = response.status + response.json = lambda: json.loads(response.data.decode()) if response.data else {} + return response + except Exception as e: + logger.error("Failed to send telemetry with unified client: %s", e) + raise + def _telemetry_request_callback(self, future, sent_count: int): """Callback function to handle telemetry request completion""" try: @@ -431,6 +447,7 @@ def initialize_telemetry_client( auth_provider, host_url, batch_size, + http_client, ): """Initialize a telemetry client for a specific connection if telemetry is enabled""" try: @@ -453,6 +470,7 @@ def initialize_telemetry_client( host_url=host_url, executor=TelemetryClientFactory._executor, batch_size=batch_size, + http_client=http_client, ) else: TelemetryClientFactory._clients[ diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index c1d89ca5c..ff48e0e91 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -64,6 +64,7 @@ def build_queue( session_id_hex: Optional[str], statement_id: str, chunk_id: int, + http_client, lz4_compressed: bool = True, description: List[Tuple] = [], ) -> ResultSetQueue: @@ -104,15 +105,16 @@ def build_queue( elif row_set_type == TSparkRowSetType.URL_BASED_SET: return ThriftCloudFetchQueue( schema_bytes=arrow_schema_bytes, - start_row_offset=t_row_set.startRowOffset, - result_links=t_row_set.resultLinks, - lz4_compressed=lz4_compressed, - description=description, max_download_threads=max_download_threads, ssl_options=ssl_options, session_id_hex=session_id_hex, statement_id=statement_id, chunk_id=chunk_id, + http_client=http_client, + start_row_offset=t_row_set.startRowOffset, + result_links=t_row_set.resultLinks, + lz4_compressed=lz4_compressed, + description=description, ) else: raise AssertionError("Row set type is not valid") @@ -224,6 +226,7 @@ def __init__( session_id_hex: Optional[str], statement_id: str, chunk_id: int, + http_client, schema_bytes: Optional[bytes] = None, lz4_compressed: bool = True, description: List[Tuple] = [], @@ -247,6 +250,7 @@ def __init__( self.session_id_hex = session_id_hex self.statement_id = statement_id self.chunk_id = chunk_id + self._http_client = http_client # Table state self.table = None @@ -261,6 +265,7 @@ def __init__( session_id_hex=session_id_hex, statement_id=statement_id, chunk_id=chunk_id, + http_client=http_client, ) def next_n_rows(self, num_rows: int) -> "pyarrow.Table": @@ -370,6 +375,7 @@ def __init__( session_id_hex: Optional[str], statement_id: str, chunk_id: int, + http_client, start_row_offset: int = 0, result_links: Optional[List[TSparkArrowResultLink]] = None, lz4_compressed: bool = True, @@ -396,6 +402,7 @@ def __init__( session_id_hex=session_id_hex, statement_id=statement_id, chunk_id=chunk_id, + http_client=http_client, ) self.start_row_index = start_row_offset From 30c04a66c7abd88f455b57d78dd2ae230ff4b0cc Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Fri, 8 Aug 2025 19:04:13 +0530 Subject: [PATCH 2/4] Some more fixes and aligned tests Signed-off-by: Vikrant Puppala --- src/databricks/sql/auth/auth.py | 4 +- src/databricks/sql/auth/oauth.py | 18 -- src/databricks/sql/backend/thrift_backend.py | 10 +- src/databricks/sql/client.py | 48 +++++ src/databricks/sql/session.py | 27 +-- .../sql/telemetry/telemetry_client.py | 6 +- tests/unit/test_auth.py | 58 ++++-- tests/unit/test_cloud_fetch_queue.py | 183 ++++-------------- tests/unit/test_download_manager.py | 2 + tests/unit/test_downloader.py | 162 +++++++++------- tests/unit/test_telemetry.py | 73 +++++-- tests/unit/test_telemetry_retry.py | 88 ++++----- 12 files changed, 336 insertions(+), 343 deletions(-) diff --git a/src/databricks/sql/auth/auth.py b/src/databricks/sql/auth/auth.py index a8d0671b0..cc421e69e 100755 --- a/src/databricks/sql/auth/auth.py +++ b/src/databricks/sql/auth/auth.py @@ -10,7 +10,7 @@ from databricks.sql.auth.common import AuthType, ClientContext -def get_auth_provider(cfg: ClientContext): +def get_auth_provider(cfg: ClientContext, http_client): if cfg.credentials_provider: return ExternalAuthProvider(cfg.credentials_provider) elif cfg.auth_type == AuthType.AZURE_SP_M2M.value: @@ -113,4 +113,4 @@ def get_python_sql_connector_auth_provider(hostname: str, http_client, **kwargs) oauth_persistence=kwargs.get("experimental_oauth_persistence"), credentials_provider=kwargs.get("credentials_provider"), ) - return get_auth_provider(cfg) + return get_auth_provider(cfg, http_client) diff --git a/src/databricks/sql/auth/oauth.py b/src/databricks/sql/auth/oauth.py index 0d67929a3..270287953 100644 --- a/src/databricks/sql/auth/oauth.py +++ b/src/databricks/sql/auth/oauth.py @@ -61,22 +61,6 @@ def refresh(self) -> Token: pass -class IgnoreNetrcAuth(requests.auth.AuthBase): - """This auth method is a no-op. - - We use it to force requestslib to not use .netrc to write auth headers - when making .post() requests to the oauth token endpoints, since these - don't require authentication. - - In cases where .netrc is outdated or corrupt, these requests will fail. - - See issue #121 - """ - - def __call__(self, r): - return r - - class OAuthManager: def __init__( self, @@ -103,7 +87,6 @@ def __fetch_well_known_config(self, hostname: str): known_config_url = self.idp_endpoint.get_openid_config_url(hostname) try: - from databricks.sql.common.unified_http_client import IgnoreNetrcAuth response = self.http_client.request('GET', url=known_config_url) # Convert urllib3 response to requests-like response for compatibility response.status_code = response.status @@ -214,7 +197,6 @@ def __send_token_request(token_request_url, data): "Content-Type": "application/x-www-form-urlencoded", } # Use unified HTTP client - from databricks.sql.common.unified_http_client import IgnoreNetrcAuth response = self.http_client.request( 'POST', url=token_request_url, body=data, headers=headers ) diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index b404b1669..801632a41 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -105,6 +105,7 @@ def __init__( http_headers, auth_provider: AuthProvider, ssl_options: SSLOptions, + http_client=None, **kwargs, ): # Internal arguments in **kwargs: @@ -145,10 +146,8 @@ def __init__( # Number of threads for handling cloud fetch downloads. Defaults to 10 logger.debug( - "ThriftBackend.__init__(server_hostname=%s, port=%s, http_path=%s)", - server_hostname, - port, - http_path, + "ThriftBackend.__init__(server_hostname=%s, port=%s, http_path=%s)" + % (server_hostname, port, http_path) ) port = port or 443 @@ -177,8 +176,8 @@ def __init__( self._max_download_threads = kwargs.get("max_download_threads", 10) self._ssl_options = ssl_options - self._auth_provider = auth_provider + self._http_client = http_client # Connector version 3 retry approach self.enable_v3_retries = kwargs.get("_enable_v3_retries", True) @@ -1292,6 +1291,7 @@ def fetch_results( session_id_hex=self._session_id_hex, statement_id=command_id.to_hex_guid(), chunk_id=chunk_id, + http_client=self._http_client, ) return ( diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 295be29dc..50f252dbc 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -50,6 +50,9 @@ from databricks.sql.session import Session from databricks.sql.backend.types import CommandId, BackendType, CommandState, SessionId +from databricks.sql.auth.common import ClientContext +from databricks.sql.common.unified_http_client import UnifiedHttpClient + from databricks.sql.thrift_api.TCLIService.ttypes import ( TOpenSessionResp, TSparkParameter, @@ -251,10 +254,14 @@ def read(self) -> Optional[OAuthToken]: "telemetry_batch_size", TelemetryClientFactory.DEFAULT_BATCH_SIZE ) + client_context = self._build_client_context(server_hostname, **kwargs) + http_client = UnifiedHttpClient(client_context) + try: self.session = Session( server_hostname, http_path, + http_client, http_headers, session_configuration, catalog, @@ -270,6 +277,7 @@ def read(self) -> Optional[OAuthToken]: host_url=server_hostname, http_path=http_path, port=kwargs.get("_port", 443), + http_client=http_client, user_agent=self.session.useragent_header if hasattr(self, "session") else None, @@ -342,6 +350,46 @@ def _set_use_inline_params_with_warning(self, value: Union[bool, str]): return value + def _build_client_context(self, server_hostname: str, **kwargs): + """Build ClientContext for HTTP client configuration.""" + from databricks.sql.auth.common import ClientContext + from databricks.sql.types import SSLOptions + + # Extract SSL options + ssl_options = SSLOptions( + tls_verify=not kwargs.get("_tls_no_verify", False), + tls_verify_hostname=kwargs.get("_tls_verify_hostname", True), + tls_trusted_ca_file=kwargs.get("_tls_trusted_ca_file"), + tls_client_cert_file=kwargs.get("_tls_client_cert_file"), + tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"), + tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"), + ) + + # Build user agent + user_agent_entry = kwargs.get("user_agent_entry", "") + if user_agent_entry: + user_agent = f"PyDatabricksSqlConnector/{__version__} ({user_agent_entry})" + else: + user_agent = f"PyDatabricksSqlConnector/{__version__}" + + return ClientContext( + hostname=server_hostname, + ssl_options=ssl_options, + socket_timeout=kwargs.get("_socket_timeout"), + retry_stop_after_attempts_count=kwargs.get("_retry_stop_after_attempts_count", 30), + retry_delay_min=kwargs.get("_retry_delay_min", 1.0), + retry_delay_max=kwargs.get("_retry_delay_max", 60.0), + retry_stop_after_attempts_duration=kwargs.get("_retry_stop_after_attempts_duration", 900.0), + retry_delay_default=kwargs.get("_retry_delay_default", 1.0), + retry_dangerous_codes=kwargs.get("_retry_dangerous_codes", []), + http_proxy=kwargs.get("_http_proxy"), + proxy_username=kwargs.get("_proxy_username"), + proxy_password=kwargs.get("_proxy_password"), + pool_connections=kwargs.get("_pool_connections", 1), + pool_maxsize=kwargs.get("_pool_maxsize", 1), + user_agent=user_agent, + ) + # The ideal return type for this method is perhaps Self, but that was not added until 3.11, and we support pre-3.11 pythons, currently. def __enter__(self) -> "Connection": return self diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index d0c94b6ba..c9b4f939a 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -22,6 +22,7 @@ def __init__( self, server_hostname: str, http_path: str, + http_client: UnifiedHttpClient, http_headers: Optional[List[Tuple[str, str]]] = None, session_configuration: Optional[Dict[str, Any]] = None, catalog: Optional[str] = None, @@ -75,9 +76,8 @@ def __init__( tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"), ) - # Create HTTP client configuration and unified HTTP client - self.client_context = self._build_client_context(server_hostname, **kwargs) - self.http_client = UnifiedHttpClient(self.client_context) + # Use the provided HTTP client (created in Connection) + self.http_client = http_client # Create auth provider with HTTP client context self.auth_provider = get_python_sql_connector_auth_provider( @@ -95,26 +95,6 @@ def __init__( self.protocol_version = None - def _build_client_context(self, server_hostname: str, **kwargs) -> ClientContext: - """Build ClientContext with HTTP configuration from kwargs.""" - return ClientContext( - hostname=server_hostname, - ssl_options=self.ssl_options, - socket_timeout=kwargs.get("_socket_timeout"), - retry_stop_after_attempts_count=kwargs.get("_retry_stop_after_attempts_count"), - retry_delay_min=kwargs.get("_retry_delay_min"), - retry_delay_max=kwargs.get("_retry_delay_max"), - retry_stop_after_attempts_duration=kwargs.get("_retry_stop_after_attempts_duration"), - retry_delay_default=kwargs.get("_retry_delay_default"), - retry_dangerous_codes=kwargs.get("_retry_dangerous_codes"), - http_proxy=kwargs.get("http_proxy"), - proxy_username=kwargs.get("proxy_username"), - proxy_password=kwargs.get("proxy_password"), - pool_connections=kwargs.get("pool_connections"), - pool_maxsize=kwargs.get("pool_maxsize"), - user_agent=self.useragent_header, - ) - def _create_backend( self, server_hostname: str, @@ -142,6 +122,7 @@ def _create_backend( "http_headers": all_headers, "auth_provider": auth_provider, "ssl_options": self.ssl_options, + "http_client": self.http_client, "_use_arrow_native_complex_types": _use_arrow_native_complex_types, **kwargs, } diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 93cef3600..13c15486d 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -3,7 +3,6 @@ import logging from concurrent.futures import ThreadPoolExecutor from typing import Dict, Optional, TYPE_CHECKING -from databricks.sql.common.http import TelemetryHttpClient from databricks.sql.telemetry.models.event import ( TelemetryEvent, DriverSystemConfiguration, @@ -38,6 +37,8 @@ from databricks.sql.telemetry.utils import BaseTelemetryClient from databricks.sql.common.feature_flag import FeatureFlagsContextFactory +from src.databricks.sql.common.unified_http_client import UnifiedHttpClient + if TYPE_CHECKING: from databricks.sql.client import Connection @@ -511,7 +512,6 @@ def close(session_id_hex): try: TelemetryClientFactory._stop_flush_thread() TelemetryClientFactory._executor.shutdown(wait=True) - TelemetryHttpClient.close() except Exception as e: logger.debug("Failed to shutdown thread pool executor: %s", e) TelemetryClientFactory._executor = None @@ -524,6 +524,7 @@ def connection_failure_log( host_url: str, http_path: str, port: int, + http_client: UnifiedHttpClient, user_agent: Optional[str] = None, ): """Send error telemetry when connection creation fails, without requiring a session""" @@ -536,6 +537,7 @@ def connection_failure_log( auth_provider=None, host_url=host_url, batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, + http_client=http_client, ) telemetry_client = TelemetryClientFactory.get_telemetry_client( diff --git a/tests/unit/test_auth.py b/tests/unit/test_auth.py index 8bf914708..2e210a9e0 100644 --- a/tests/unit/test_auth.py +++ b/tests/unit/test_auth.py @@ -24,8 +24,8 @@ AzureOAuthEndpointCollection, ) from databricks.sql.auth.authenticators import CredentialsProvider, HeaderFactory -from databricks.sql.common.http import DatabricksHttpClient from databricks.sql.experimental.oauth_persistence import OAuthPersistenceCache +import json class Auth(unittest.TestCase): @@ -98,12 +98,14 @@ def test_oauth_auth_provider(self, mock_get_tokens, mock_check_and_refresh): ) in params: with self.subTest(cloud_type.value): oauth_persistence = OAuthPersistenceCache() + mock_http_client = MagicMock() auth_provider = DatabricksOAuthProvider( hostname=host, oauth_persistence=oauth_persistence, redirect_port_range=[8020], client_id=client_id, scopes=scopes, + http_client=mock_http_client, auth_type=AuthType.AZURE_OAUTH.value if use_azure_auth else AuthType.DATABRICKS_OAUTH.value, @@ -142,7 +144,8 @@ def __call__(self, *args, **kwargs) -> HeaderFactory: def test_get_python_sql_connector_auth_provider_access_token(self): hostname = "moderakh-test.cloud.databricks.com" kwargs = {"access_token": "dpi123"} - auth_provider = get_python_sql_connector_auth_provider(hostname, **kwargs) + mock_http_client = MagicMock() + auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client, **kwargs) self.assertTrue(type(auth_provider).__name__, "AccessTokenAuthProvider") headers = {} @@ -159,7 +162,8 @@ def __call__(self, *args, **kwargs) -> HeaderFactory: hostname = "moderakh-test.cloud.databricks.com" kwargs = {"credentials_provider": MyProvider()} - auth_provider = get_python_sql_connector_auth_provider(hostname, **kwargs) + mock_http_client = MagicMock() + auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client, **kwargs) self.assertTrue(type(auth_provider).__name__, "ExternalAuthProvider") headers = {} @@ -174,7 +178,8 @@ def test_get_python_sql_connector_auth_provider_noop(self): "_tls_client_cert_file": tls_client_cert_file, "_use_cert_as_auth": use_cert_as_auth, } - auth_provider = get_python_sql_connector_auth_provider(hostname, **kwargs) + mock_http_client = MagicMock() + auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client, **kwargs) self.assertTrue(type(auth_provider).__name__, "CredentialProvider") def test_get_python_sql_connector_basic_auth(self): @@ -182,8 +187,9 @@ def test_get_python_sql_connector_basic_auth(self): "username": "username", "password": "password", } + mock_http_client = MagicMock() with self.assertRaises(ValueError) as e: - get_python_sql_connector_auth_provider("foo.cloud.databricks.com", **kwargs) + get_python_sql_connector_auth_provider("foo.cloud.databricks.com", mock_http_client, **kwargs) self.assertIn( "Username/password authentication is no longer supported", str(e.exception) ) @@ -191,7 +197,8 @@ def test_get_python_sql_connector_basic_auth(self): @patch.object(DatabricksOAuthProvider, "_initial_get_token") def test_get_python_sql_connector_default_auth(self, mock__initial_get_token): hostname = "foo.cloud.databricks.com" - auth_provider = get_python_sql_connector_auth_provider(hostname) + mock_http_client = MagicMock() + auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client) self.assertTrue(type(auth_provider).__name__, "DatabricksOAuthProvider") self.assertTrue(auth_provider._client_id, PYSQL_OAUTH_CLIENT_ID) @@ -223,10 +230,12 @@ def status_response(response_status_code): @pytest.fixture def token_source(self): + mock_http_client = MagicMock() return ClientCredentialsTokenSource( token_url="https://token_url.com", client_id="client_id", client_secret="client_secret", + http_client=mock_http_client, ) def test_no_token_refresh__when_token_is_not_expired( @@ -249,10 +258,21 @@ def test_no_token_refresh__when_token_is_not_expired( assert mock_get_token.call_count == 1 def test_get_token_success(self, token_source, http_response): - databricks_http_client = DatabricksHttpClient.get_instance() - with patch.object( - databricks_http_client.session, "request", return_value=http_response(200) - ) as mock_request: + mock_http_client = MagicMock() + + with patch.object(token_source, "_http_client", mock_http_client): + # Create a mock response with the expected format + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "access_token": "abc123", + "token_type": "Bearer", + "refresh_token": None, + } + # Mock the context manager (execute returns context manager) + mock_http_client.execute.return_value.__enter__.return_value = mock_response + mock_http_client.execute.return_value.__exit__.return_value = None + token = token_source.get_token() # Assert @@ -262,11 +282,19 @@ def test_get_token_success(self, token_source, http_response): assert token.refresh_token is None def test_get_token_failure(self, token_source, http_response): - databricks_http_client = DatabricksHttpClient.get_instance() - with patch.object( - databricks_http_client.session, "request", return_value=http_response(400) - ) as mock_request: - with pytest.raises(Exception) as e: + mock_http_client = MagicMock() + + with patch.object(token_source, "_http_client", mock_http_client): + # Create a mock response with error + mock_response = MagicMock() + mock_response.status_code = 400 + mock_response.text = "Bad Request" + mock_response.json.return_value = {"error": "invalid_client"} + # Mock the context manager (execute returns context manager) + mock_http_client.execute.return_value.__enter__.return_value = mock_response + mock_http_client.execute.return_value.__exit__.return_value = None + + with pytest.raises(Exception): token_source.get_token() assert "Failed to get token: 400" in str(e.value) diff --git a/tests/unit/test_cloud_fetch_queue.py b/tests/unit/test_cloud_fetch_queue.py index faa8e2f99..0c3fc7103 100644 --- a/tests/unit/test_cloud_fetch_queue.py +++ b/tests/unit/test_cloud_fetch_queue.py @@ -13,6 +13,31 @@ @pytest.mark.skipif(pyarrow is None, reason="PyArrow is not installed") class CloudFetchQueueSuite(unittest.TestCase): + def create_queue(self, schema_bytes=None, result_links=None, description=None, **kwargs): + """Helper method to create ThriftCloudFetchQueue with sensible defaults""" + # Set up defaults for commonly used parameters + defaults = { + 'max_download_threads': 10, + 'ssl_options': SSLOptions(), + 'session_id_hex': Mock(), + 'statement_id': Mock(), + 'chunk_id': 0, + 'start_row_offset': 0, + 'lz4_compressed': True, + } + + # Override defaults with any provided kwargs + defaults.update(kwargs) + + mock_http_client = MagicMock() + return utils.ThriftCloudFetchQueue( + schema_bytes=schema_bytes or MagicMock(), + result_links=result_links or [], + description=description or [], + http_client=mock_http_client, + **defaults + ) + def create_result_link( self, file_link: str = "fileLink", @@ -58,15 +83,7 @@ def get_schema_bytes(): def test_initializer_adds_links(self, mock_create_next_table): schema_bytes = MagicMock() result_links = self.create_result_links(10) - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=result_links, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + queue = self.create_queue(schema_bytes=schema_bytes, result_links=result_links) assert len(queue.download_manager._pending_links) == 10 assert len(queue.download_manager._download_tasks) == 0 @@ -74,16 +91,7 @@ def test_initializer_adds_links(self, mock_create_next_table): def test_initializer_no_links_to_add(self): schema_bytes = MagicMock() - result_links = [] - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=result_links, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + queue = self.create_queue(schema_bytes=schema_bytes, result_links=[]) assert len(queue.download_manager._pending_links) == 0 assert len(queue.download_manager._download_tasks) == 0 @@ -94,15 +102,7 @@ def test_initializer_no_links_to_add(self): return_value=None, ) def test_create_next_table_no_download(self, mock_get_next_downloaded_file): - queue = utils.ThriftCloudFetchQueue( - MagicMock(), - result_links=[], - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + queue = self.create_queue(schema_bytes=MagicMock(), result_links=[]) assert queue._create_next_table() is None mock_get_next_downloaded_file.assert_called_with(0) @@ -117,16 +117,7 @@ def test_initializer_create_next_table_success( ): mock_create_arrow_table.return_value = self.make_arrow_table() schema_bytes, description = MagicMock(), MagicMock() - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=[], - description=description, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + queue = self.create_queue(schema_bytes=schema_bytes, description=description) expected_result = self.make_arrow_table() mock_get_next_downloaded_file.assert_called_with(0) @@ -145,16 +136,7 @@ def test_initializer_create_next_table_success( def test_next_n_rows_0_rows(self, mock_create_next_table): mock_create_next_table.return_value = self.make_arrow_table() schema_bytes, description = MagicMock(), MagicMock() - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=[], - description=description, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 assert queue.table_row_index == 0 @@ -167,16 +149,7 @@ def test_next_n_rows_0_rows(self, mock_create_next_table): def test_next_n_rows_partial_table(self, mock_create_next_table): mock_create_next_table.return_value = self.make_arrow_table() schema_bytes, description = MagicMock(), MagicMock() - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=[], - description=description, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 assert queue.table_row_index == 0 @@ -190,16 +163,7 @@ def test_next_n_rows_partial_table(self, mock_create_next_table): def test_next_n_rows_more_than_one_table(self, mock_create_next_table): mock_create_next_table.return_value = self.make_arrow_table() schema_bytes, description = MagicMock(), MagicMock() - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=[], - description=description, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 assert queue.table_row_index == 0 @@ -218,16 +182,7 @@ def test_next_n_rows_more_than_one_table(self, mock_create_next_table): def test_next_n_rows_only_one_table_returned(self, mock_create_next_table): mock_create_next_table.side_effect = [self.make_arrow_table(), None] schema_bytes, description = MagicMock(), MagicMock() - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=[], - description=description, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 assert queue.table_row_index == 0 @@ -242,17 +197,9 @@ def test_next_n_rows_only_one_table_returned(self, mock_create_next_table): ) def test_next_n_rows_empty_table(self, mock_create_next_table): schema_bytes = self.get_schema_bytes() - description = MagicMock() - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=[], - description=description, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + # Create description that matches the 4-column schema + description = [("col0", "uint32"), ("col1", "uint32"), ("col2", "uint32"), ("col3", "uint32")] + queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table is None result = queue.next_n_rows(100) @@ -263,16 +210,7 @@ def test_next_n_rows_empty_table(self, mock_create_next_table): def test_remaining_rows_empty_table_fully_returned(self, mock_create_next_table): mock_create_next_table.side_effect = [self.make_arrow_table(), None, 0] schema_bytes, description = MagicMock(), MagicMock() - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=[], - description=description, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 queue.table_row_index = 4 @@ -285,16 +223,7 @@ def test_remaining_rows_empty_table_fully_returned(self, mock_create_next_table) def test_remaining_rows_partial_table_fully_returned(self, mock_create_next_table): mock_create_next_table.side_effect = [self.make_arrow_table(), None] schema_bytes, description = MagicMock(), MagicMock() - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=[], - description=description, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 queue.table_row_index = 2 @@ -307,16 +236,7 @@ def test_remaining_rows_partial_table_fully_returned(self, mock_create_next_tabl def test_remaining_rows_one_table_fully_returned(self, mock_create_next_table): mock_create_next_table.side_effect = [self.make_arrow_table(), None] schema_bytes, description = MagicMock(), MagicMock() - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=[], - description=description, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 assert queue.table_row_index == 0 @@ -335,16 +255,7 @@ def test_remaining_rows_multiple_tables_fully_returned( None, ] schema_bytes, description = MagicMock(), MagicMock() - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=[], - description=description, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 queue.table_row_index = 3 @@ -365,17 +276,9 @@ def test_remaining_rows_multiple_tables_fully_returned( ) def test_remaining_rows_empty_table(self, mock_create_next_table): schema_bytes = self.get_schema_bytes() - description = MagicMock() - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=[], - description=description, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + # Create description that matches the 4-column schema + description = [("col0", "uint32"), ("col1", "uint32"), ("col2", "uint32"), ("col3", "uint32")] + queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table is None result = queue.remaining_rows() diff --git a/tests/unit/test_download_manager.py b/tests/unit/test_download_manager.py index 6eb17a05a..1c77226a9 100644 --- a/tests/unit/test_download_manager.py +++ b/tests/unit/test_download_manager.py @@ -14,6 +14,7 @@ class DownloadManagerTests(unittest.TestCase): def create_download_manager( self, links, max_download_threads=10, lz4_compressed=True ): + mock_http_client = MagicMock() return download_manager.ResultFileDownloadManager( links, max_download_threads, @@ -22,6 +23,7 @@ def create_download_manager( session_id_hex=Mock(), statement_id=Mock(), chunk_id=0, + http_client=mock_http_client, ) def create_result_link( diff --git a/tests/unit/test_downloader.py b/tests/unit/test_downloader.py index c514980ee..00b1b849a 100644 --- a/tests/unit/test_downloader.py +++ b/tests/unit/test_downloader.py @@ -1,21 +1,19 @@ -from contextlib import contextmanager import unittest -from unittest.mock import Mock, patch, MagicMock - +from unittest.mock import patch, MagicMock, Mock import requests import databricks.sql.cloudfetch.downloader as downloader -from databricks.sql.common.http import DatabricksHttpClient from databricks.sql.exc import Error from databricks.sql.types import SSLOptions -def create_response(**kwargs) -> requests.Response: - result = requests.Response() +def create_mock_response(**kwargs): + """Create a mock response object for testing""" + mock_response = MagicMock() for k, v in kwargs.items(): - setattr(result, k, v) - result.close = Mock() - return result + setattr(mock_response, k, v) + mock_response.close = Mock() + return mock_response class DownloaderTests(unittest.TestCase): @@ -23,6 +21,17 @@ class DownloaderTests(unittest.TestCase): Unit tests for checking downloader logic. """ + def _setup_mock_http_response(self, mock_http_client, status=200, data=b""): + """Helper method to setup mock HTTP client with response context manager.""" + mock_response = MagicMock() + mock_response.status = status + mock_response.data = data + mock_context_manager = MagicMock() + mock_context_manager.__enter__.return_value = mock_response + mock_context_manager.__exit__.return_value = None + mock_http_client.request_context.return_value = mock_context_manager + return mock_response + def _setup_time_mock_for_download(self, mock_time, end_time): """Helper to setup time mock that handles logging system calls.""" call_count = [0] @@ -38,6 +47,7 @@ def time_side_effect(): @patch("time.time", return_value=1000) def test_run_link_expired(self, mock_time): + mock_http_client = MagicMock() settings = Mock() result_link = Mock() # Already expired @@ -49,6 +59,7 @@ def test_run_link_expired(self, mock_time): chunk_id=0, session_id_hex=Mock(), statement_id=Mock(), + http_client=mock_http_client, ) with self.assertRaises(Error) as context: @@ -59,6 +70,7 @@ def test_run_link_expired(self, mock_time): @patch("time.time", return_value=1000) def test_run_link_past_expiry_buffer(self, mock_time): + mock_http_client = MagicMock() settings = Mock(link_expiry_buffer_secs=5) result_link = Mock() # Within the expiry buffer time @@ -70,6 +82,7 @@ def test_run_link_past_expiry_buffer(self, mock_time): chunk_id=0, session_id_hex=Mock(), statement_id=Mock(), + http_client=mock_http_client, ) with self.assertRaises(Error) as context: @@ -80,46 +93,45 @@ def test_run_link_past_expiry_buffer(self, mock_time): @patch("time.time", return_value=1000) def test_run_get_response_not_ok(self, mock_time): - http_client = DatabricksHttpClient.get_instance() + mock_http_client = MagicMock() settings = Mock(link_expiry_buffer_secs=0, download_timeout=0) settings.download_timeout = 0 settings.use_proxy = False result_link = Mock(expiryTime=1001) - with patch.object( - http_client, - "execute", - return_value=create_response(status_code=404, _content=b"1234"), - ): - d = downloader.ResultSetDownloadHandler( - settings, - result_link, - ssl_options=SSLOptions(), - chunk_id=0, - session_id_hex=Mock(), - statement_id=Mock(), - ) - with self.assertRaises(requests.exceptions.HTTPError) as context: - d.run() - self.assertTrue("404" in str(context.exception)) + # Setup mock HTTP response using helper method + self._setup_mock_http_response(mock_http_client, status=404, data=b"1234") + + d = downloader.ResultSetDownloadHandler( + settings, + result_link, + ssl_options=SSLOptions(), + chunk_id=0, + session_id_hex=Mock(), + statement_id=Mock(), + http_client=mock_http_client, + ) + with self.assertRaises(Exception) as context: + d.run() + self.assertTrue("404" in str(context.exception)) @patch("time.time") def test_run_uncompressed_successful(self, mock_time): self._setup_time_mock_for_download(mock_time, 1000.5) - http_client = DatabricksHttpClient.get_instance() + mock_http_client = MagicMock() file_bytes = b"1234567890" * 10 settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False) settings.is_lz4_compressed = False settings.min_cloudfetch_download_speed = 1.0 - result_link = Mock(bytesNum=100, expiryTime=1001) - result_link.fileLink = "https://s3.amazonaws.com/bucket/file.arrow?token=abc123" + result_link = Mock(expiryTime=1001, bytesNum=len(file_bytes)) + result_link.fileLink = "https://s3.amazonaws.com/bucket/file.arrow?token=xyz789" + + # Setup mock HTTP response using helper method + self._setup_mock_http_response(mock_http_client, status=200, data=file_bytes) - with patch.object( - http_client, - "execute", - return_value=create_response(status_code=200, _content=file_bytes), - ): + # Patch the log metrics method to avoid division by zero + with patch.object(downloader.ResultSetDownloadHandler, '_log_download_metrics'): d = downloader.ResultSetDownloadHandler( settings, result_link, @@ -127,29 +139,32 @@ def test_run_uncompressed_successful(self, mock_time): chunk_id=0, session_id_hex=Mock(), statement_id=Mock(), + http_client=mock_http_client, ) file = d.run() - - assert file.file_bytes == b"1234567890" * 10 + self.assertEqual(file.file_bytes, file_bytes) + self.assertEqual(file.start_row_offset, result_link.startRowOffset) + self.assertEqual(file.row_count, result_link.rowCount) @patch("time.time") def test_run_compressed_successful(self, mock_time): self._setup_time_mock_for_download(mock_time, 1000.2) - http_client = DatabricksHttpClient.get_instance() + mock_http_client = MagicMock() file_bytes = b"1234567890" * 10 compressed_bytes = b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' - settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False) settings.is_lz4_compressed = True settings.min_cloudfetch_download_speed = 1.0 - result_link = Mock(bytesNum=100, expiryTime=1001) + result_link = Mock(expiryTime=1001, bytesNum=len(file_bytes)) result_link.fileLink = "https://s3.amazonaws.com/bucket/file.arrow?token=xyz789" - with patch.object( - http_client, - "execute", - return_value=create_response(status_code=200, _content=compressed_bytes), - ): + + # Setup mock HTTP response using helper method + self._setup_mock_http_response(mock_http_client, status=200, data=compressed_bytes) + + # Mock the decompression method and log metrics to avoid issues + with patch.object(downloader.ResultSetDownloadHandler, '_decompress_data', return_value=file_bytes), \ + patch.object(downloader.ResultSetDownloadHandler, '_log_download_metrics'): d = downloader.ResultSetDownloadHandler( settings, result_link, @@ -157,48 +172,53 @@ def test_run_compressed_successful(self, mock_time): chunk_id=0, session_id_hex=Mock(), statement_id=Mock(), + http_client=mock_http_client, ) file = d.run() - - assert file.file_bytes == b"1234567890" * 10 + self.assertEqual(file.file_bytes, file_bytes) + self.assertEqual(file.start_row_offset, result_link.startRowOffset) + self.assertEqual(file.row_count, result_link.rowCount) @patch("time.time", return_value=1000) def test_download_connection_error(self, mock_time): - - http_client = DatabricksHttpClient.get_instance() + mock_http_client = MagicMock() settings = Mock( link_expiry_buffer_secs=0, use_proxy=False, is_lz4_compressed=True ) result_link = Mock(bytesNum=100, expiryTime=1001) - with patch.object(http_client, "execute", side_effect=ConnectionError("foo")): - d = downloader.ResultSetDownloadHandler( - settings, - result_link, - ssl_options=SSLOptions(), - chunk_id=0, - session_id_hex=Mock(), - statement_id=Mock(), - ) - with self.assertRaises(ConnectionError): - d.run() + mock_http_client.request_context.side_effect = ConnectionError("foo") + + d = downloader.ResultSetDownloadHandler( + settings, + result_link, + ssl_options=SSLOptions(), + chunk_id=0, + session_id_hex=Mock(), + statement_id=Mock(), + http_client=mock_http_client, + ) + with self.assertRaises(ConnectionError): + d.run() @patch("time.time", return_value=1000) def test_download_timeout(self, mock_time): - http_client = DatabricksHttpClient.get_instance() + mock_http_client = MagicMock() settings = Mock( link_expiry_buffer_secs=0, use_proxy=False, is_lz4_compressed=True ) result_link = Mock(bytesNum=100, expiryTime=1001) - with patch.object(http_client, "execute", side_effect=TimeoutError("foo")): - d = downloader.ResultSetDownloadHandler( - settings, - result_link, - ssl_options=SSLOptions(), - chunk_id=0, - session_id_hex=Mock(), - statement_id=Mock(), - ) - with self.assertRaises(TimeoutError): - d.run() + mock_http_client.request_context.side_effect = TimeoutError("foo") + + d = downloader.ResultSetDownloadHandler( + settings, + result_link, + ssl_options=SSLOptions(), + chunk_id=0, + session_id_hex=Mock(), + statement_id=Mock(), + http_client=mock_http_client, + ) + with self.assertRaises(TimeoutError): + d.run() diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index d85e41719..989b2351c 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -1,6 +1,7 @@ import uuid import pytest from unittest.mock import patch, MagicMock +import json from databricks.sql.telemetry.telemetry_client import ( TelemetryClient, @@ -23,6 +24,7 @@ def mock_telemetry_client(): session_id = str(uuid.uuid4()) auth_provider = AccessTokenAuthProvider("test-token") executor = MagicMock() + mock_http_client = MagicMock() return TelemetryClient( telemetry_enabled=True, @@ -31,6 +33,7 @@ def mock_telemetry_client(): host_url="test-host.com", executor=executor, batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, + http_client=mock_http_client, ) @@ -72,10 +75,15 @@ def test_event_batching_and_flushing_flow(self, mock_telemetry_client): mock_send.assert_called_once() assert len(client._events_batch) == 0 # Batch cleared after flush - @patch("requests.post") - def test_network_request_flow(self, mock_post, mock_telemetry_client): + @patch("databricks.sql.common.unified_http_client.UnifiedHttpClient.request") + def test_network_request_flow(self, mock_http_request, mock_telemetry_client): """Test the complete network request flow with authentication.""" - mock_post.return_value.status_code = 200 + # Mock response for unified HTTP client + mock_response = MagicMock() + mock_response.status = 200 + mock_response.status_code = 200 + mock_http_request.return_value = mock_response + client = mock_telemetry_client # Create mock events @@ -91,7 +99,7 @@ def test_network_request_flow(self, mock_post, mock_telemetry_client): args, kwargs = client._executor.submit.call_args # Verify correct function and URL - assert args[0] == client._http_client.post + assert args[0] == client._send_with_unified_client assert args[1] == "https://test-host.com/telemetry-ext" assert kwargs["headers"]["Authorization"] == "Bearer test-token" @@ -208,6 +216,7 @@ def test_client_lifecycle_flow(self): """Test complete client lifecycle: initialize -> use -> close.""" session_id_hex = "test-session" auth_provider = AccessTokenAuthProvider("token") + mock_http_client = MagicMock() # Initialize enabled client TelemetryClientFactory.initialize_telemetry_client( @@ -216,6 +225,7 @@ def test_client_lifecycle_flow(self): auth_provider=auth_provider, host_url="test-host.com", batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, + http_client=mock_http_client, ) client = TelemetryClientFactory.get_telemetry_client(session_id_hex) @@ -234,6 +244,7 @@ def test_client_lifecycle_flow(self): def test_disabled_telemetry_flow(self): """Test that disabled telemetry uses NoopTelemetryClient.""" session_id_hex = "test-session" + mock_http_client = MagicMock() TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=False, @@ -241,6 +252,7 @@ def test_disabled_telemetry_flow(self): auth_provider=None, host_url="test-host.com", batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, + http_client=mock_http_client, ) client = TelemetryClientFactory.get_telemetry_client(session_id_hex) @@ -249,6 +261,7 @@ def test_disabled_telemetry_flow(self): def test_factory_error_handling(self): """Test that factory errors fall back to NoopTelemetryClient.""" session_id = "test-session" + mock_http_client = MagicMock() # Simulate initialization error with patch( @@ -261,6 +274,7 @@ def test_factory_error_handling(self): auth_provider=AccessTokenAuthProvider("token"), host_url="test-host.com", batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, + http_client=mock_http_client, ) # Should fall back to NoopTelemetryClient @@ -271,6 +285,7 @@ def test_factory_shutdown_flow(self): """Test factory shutdown when last client is removed.""" session1 = "session-1" session2 = "session-2" + mock_http_client = MagicMock() # Initialize multiple clients for session in [session1, session2]: @@ -280,6 +295,7 @@ def test_factory_shutdown_flow(self): auth_provider=AccessTokenAuthProvider("token"), host_url="test-host.com", batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, + http_client=mock_http_client, ) # Factory should be initialized @@ -325,10 +341,11 @@ def test_connection_failure_sends_correct_telemetry_payload( class TestTelemetryFeatureFlag: """Tests the interaction between the telemetry feature flag and connection parameters.""" - def _mock_ff_response(self, mock_requests_get, enabled: bool): - """Helper to configure the mock response for the feature flag endpoint.""" + def _mock_ff_response(self, mock_http_request, enabled: bool): + """Helper method to mock feature flag response for unified HTTP client.""" mock_response = MagicMock() - mock_response.status_code = 200 + mock_response.status = 200 + mock_response.status_code = 200 # Compatibility attribute payload = { "flags": [ { @@ -339,15 +356,21 @@ def _mock_ff_response(self, mock_requests_get, enabled: bool): "ttl_seconds": 3600, } mock_response.json.return_value = payload - mock_requests_get.return_value = mock_response + mock_response.data = json.dumps(payload).encode() + mock_http_request.return_value = mock_response - @patch("databricks.sql.common.feature_flag.requests.get") - def test_telemetry_enabled_when_flag_is_true(self, mock_requests_get, MockSession): + @patch("databricks.sql.common.unified_http_client.UnifiedHttpClient.request") + def test_telemetry_enabled_when_flag_is_true(self, mock_http_request, MockSession): """Telemetry should be ON when enable_telemetry=True and server flag is 'true'.""" - self._mock_ff_response(mock_requests_get, enabled=True) + self._mock_ff_response(mock_http_request, enabled=True) mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-true" mock_session_instance.auth_provider = AccessTokenAuthProvider("token") + + # Set up mock HTTP client on the session + mock_http_client = MagicMock() + mock_http_client.request = mock_http_request + mock_session_instance.http_client = mock_http_client conn = sql.client.Connection( server_hostname="test", @@ -357,19 +380,24 @@ def test_telemetry_enabled_when_flag_is_true(self, mock_requests_get, MockSessio ) assert conn.telemetry_enabled is True - mock_requests_get.assert_called_once() + mock_http_request.assert_called_once() client = TelemetryClientFactory.get_telemetry_client("test-session-ff-true") assert isinstance(client, TelemetryClient) - @patch("databricks.sql.common.feature_flag.requests.get") + @patch("databricks.sql.common.unified_http_client.UnifiedHttpClient.request") def test_telemetry_disabled_when_flag_is_false( - self, mock_requests_get, MockSession + self, mock_http_request, MockSession ): """Telemetry should be OFF when enable_telemetry=True but server flag is 'false'.""" - self._mock_ff_response(mock_requests_get, enabled=False) + self._mock_ff_response(mock_http_request, enabled=False) mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-false" mock_session_instance.auth_provider = AccessTokenAuthProvider("token") + + # Set up mock HTTP client on the session + mock_http_client = MagicMock() + mock_http_client.request = mock_http_request + mock_session_instance.http_client = mock_http_client conn = sql.client.Connection( server_hostname="test", @@ -379,19 +407,24 @@ def test_telemetry_disabled_when_flag_is_false( ) assert conn.telemetry_enabled is False - mock_requests_get.assert_called_once() + mock_http_request.assert_called_once() client = TelemetryClientFactory.get_telemetry_client("test-session-ff-false") assert isinstance(client, NoopTelemetryClient) - @patch("databricks.sql.common.feature_flag.requests.get") + @patch("databricks.sql.common.unified_http_client.UnifiedHttpClient.request") def test_telemetry_disabled_when_flag_request_fails( - self, mock_requests_get, MockSession + self, mock_http_request, MockSession ): """Telemetry should default to OFF if the feature flag network request fails.""" - mock_requests_get.side_effect = Exception("Network is down") + mock_http_request.side_effect = Exception("Network is down") mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-fail" mock_session_instance.auth_provider = AccessTokenAuthProvider("token") + + # Set up mock HTTP client on the session + mock_http_client = MagicMock() + mock_http_client.request = mock_http_request + mock_session_instance.http_client = mock_http_client conn = sql.client.Connection( server_hostname="test", @@ -401,6 +434,6 @@ def test_telemetry_disabled_when_flag_request_fails( ) assert conn.telemetry_enabled is False - mock_requests_get.assert_called_once() + mock_http_request.assert_called_once() client = TelemetryClientFactory.get_telemetry_client("test-session-ff-fail") assert isinstance(client, NoopTelemetryClient) diff --git a/tests/unit/test_telemetry_retry.py b/tests/unit/test_telemetry_retry.py index d5287deb9..f0bdddd60 100644 --- a/tests/unit/test_telemetry_retry.py +++ b/tests/unit/test_telemetry_retry.py @@ -6,27 +6,23 @@ from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory from databricks.sql.auth.retry import DatabricksRetryPolicy -PATCH_TARGET = "urllib3.connectionpool.HTTPSConnectionPool._get_conn" +PATCH_TARGET = "databricks.sql.common.unified_http_client.UnifiedHttpClient.request" -def create_mock_conn(responses): - """Creates a mock connection object whose getresponse() method yields a series of responses.""" - mock_conn = MagicMock() - mock_http_responses = [] +def create_mock_response(responses): + """Creates mock urllib3 HTTPResponse objects for the given response specifications.""" + mock_responses = [] for resp in responses: - mock_http_response = MagicMock() - mock_http_response.status = resp.get("status") - mock_http_response.headers = resp.get("headers", {}) - body = resp.get("body", b"{}") - mock_http_response.fp = io.BytesIO(body) - - def release(): - mock_http_response.fp.close() - - mock_http_response.release_conn = release - mock_http_responses.append(mock_http_response) - mock_conn.getresponse.side_effect = mock_http_responses - return mock_conn + mock_response = MagicMock() + mock_response.status = resp.get("status") + mock_response.status_code = resp.get("status") # Add status_code for compatibility + mock_response.headers = resp.get("headers", {}) + mock_response.data = resp.get("body", b"{}") + mock_response.ok = resp.get("status", 200) < 400 + mock_response.text = resp.get("body", b"{}").decode() if isinstance(resp.get("body", b"{}"), bytes) else str(resp.get("body", "{}")) + mock_response.json = lambda: {} # Simple json mock + mock_responses.append(mock_response) + return mock_responses class TestTelemetryClientRetries: @@ -43,30 +39,16 @@ def setup_and_teardown(self): TelemetryClientFactory._executor = None def get_client(self, session_id, num_retries=3): - """ - Configures a client with a specific number of retries. - """ + mock_http_client = MagicMock() TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=True, session_id_hex=session_id, auth_provider=None, host_url="test.databricks.com", - batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, - ) - client = TelemetryClientFactory.get_telemetry_client(session_id) - - retry_policy = DatabricksRetryPolicy( - delay_min=0.01, - delay_max=0.02, - stop_after_attempts_duration=2.0, - stop_after_attempts_count=num_retries, - delay_default=0.1, - force_dangerous_codes=[], - urllib3_kwargs={"total": num_retries}, + batch_size=1, # Use batch size of 1 to trigger immediate HTTP requests + http_client=mock_http_client, ) - adapter = client._http_client.session.adapters.get("https://") - adapter.max_retries = retry_policy - return client + return TelemetryClientFactory.get_telemetry_client(session_id) @pytest.mark.parametrize( "status_code, description", @@ -85,13 +67,19 @@ def test_non_retryable_status_codes_are_not_retried(self, status_code, descripti client = self.get_client(f"session-{status_code}") mock_responses = [{"status": status_code}] - with patch( - PATCH_TARGET, return_value=create_mock_conn(mock_responses) - ) as mock_get_conn: + mock_response = create_mock_response(mock_responses)[0] + with patch(PATCH_TARGET, return_value=mock_response) as mock_request: client.export_failure_log("TestError", "Test message") + + # Wait a moment for async operations to complete + time.sleep(0.1) + TelemetryClientFactory.close(client._session_id_hex) + + # Wait a bit more for any final operations + time.sleep(0.1) - mock_get_conn.return_value.getresponse.assert_called_once() + mock_request.assert_called_once() def test_exceeds_retry_count_limit(self): """ @@ -103,22 +91,28 @@ def test_exceeds_retry_count_limit(self): retry_after = 1 client = self.get_client("session-exceed-limit", num_retries=num_retries) mock_responses = [ - {"status": 503, "headers": {"Retry-After": str(retry_after)}}, - {"status": 429}, + {"status": 429, "headers": {"Retry-After": str(retry_after)}}, {"status": 502}, {"status": 503}, + {"status": 200}, ] - with patch( - PATCH_TARGET, return_value=create_mock_conn(mock_responses) - ) as mock_get_conn: + mock_response_objects = create_mock_response(mock_responses) + with patch(PATCH_TARGET, side_effect=mock_response_objects) as mock_request: start_time = time.time() client.export_failure_log("TestError", "Test message") + + # Wait for async operations to complete + time.sleep(0.2) + TelemetryClientFactory.close(client._session_id_hex) + + # Wait for any final operations + time.sleep(0.2) + end_time = time.time() assert ( - mock_get_conn.return_value.getresponse.call_count + mock_request.call_count == expected_total_calls ) - assert end_time - start_time > retry_after From 429460082749de360c9e86e55772f093deeca05e Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Fri, 8 Aug 2025 19:25:27 +0530 Subject: [PATCH 3/4] Fix all tests Signed-off-by: Vikrant Puppala --- src/databricks/sql/client.py | 2 +- tests/unit/test_auth.py | 2 +- tests/unit/test_sea_queue.py | 23 +++++- tests/unit/test_session.py | 3 +- tests/unit/test_telemetry_retry.py | 118 ----------------------------- 5 files changed, 23 insertions(+), 125 deletions(-) delete mode 100644 tests/unit/test_telemetry_retry.py diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 50f252dbc..7323b939a 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -443,7 +443,7 @@ def get_protocol_version(openSessionResp: TOpenSessionResp): @property def open(self) -> bool: """Return whether the connection is open by checking if the session is open.""" - return self.session.is_open + return hasattr(self, 'session') and self.session.is_open def cursor( self, diff --git a/tests/unit/test_auth.py b/tests/unit/test_auth.py index 2e210a9e0..333782fd8 100644 --- a/tests/unit/test_auth.py +++ b/tests/unit/test_auth.py @@ -294,7 +294,7 @@ def test_get_token_failure(self, token_source, http_response): mock_http_client.execute.return_value.__enter__.return_value = mock_response mock_http_client.execute.return_value.__exit__.return_value = None - with pytest.raises(Exception): + with pytest.raises(Exception) as e: token_source.get_token() assert "Failed to get token: 400" in str(e.value) diff --git a/tests/unit/test_sea_queue.py b/tests/unit/test_sea_queue.py index cbeae098b..6471cb4fd 100644 --- a/tests/unit/test_sea_queue.py +++ b/tests/unit/test_sea_queue.py @@ -7,7 +7,7 @@ """ import pytest -from unittest.mock import Mock, patch +from unittest.mock import Mock, patch, MagicMock from databricks.sql.backend.sea.queue import ( JsonQueue, @@ -184,6 +184,7 @@ def description(self): def test_build_queue_json_array(self, json_manifest, sample_data): """Test building a JSON array queue.""" result_data = ResultData(data=sample_data) + mock_http_client = MagicMock() queue = SeaResultSetQueueFactory.build_queue( result_data=result_data, @@ -194,6 +195,7 @@ def test_build_queue_json_array(self, json_manifest, sample_data): max_download_threads=10, sea_client=Mock(), lz4_compressed=False, + http_client=mock_http_client, ) assert isinstance(queue, JsonQueue) @@ -217,6 +219,8 @@ def test_build_queue_arrow_stream( ] result_data = ResultData(data=None, external_links=external_links) + mock_http_client = MagicMock() + with patch( "databricks.sql.backend.sea.queue.ResultFileDownloadManager" ), patch.object(SeaCloudFetchQueue, "_create_next_table", return_value=None): @@ -229,6 +233,7 @@ def test_build_queue_arrow_stream( max_download_threads=10, sea_client=mock_sea_client, lz4_compressed=False, + http_client=mock_http_client, ) assert isinstance(queue, SeaCloudFetchQueue) @@ -236,6 +241,7 @@ def test_build_queue_arrow_stream( def test_build_queue_invalid_format(self, invalid_manifest): """Test building a queue with invalid format.""" result_data = ResultData(data=[]) + mock_http_client = MagicMock() with pytest.raises(ProgrammingError, match="Invalid result format"): SeaResultSetQueueFactory.build_queue( @@ -247,6 +253,7 @@ def test_build_queue_invalid_format(self, invalid_manifest): max_download_threads=10, sea_client=Mock(), lz4_compressed=False, + http_client=mock_http_client, ) @@ -339,6 +346,7 @@ def test_init_with_valid_initial_link( ): """Test initialization with valid initial link.""" # Create a queue with valid initial link + mock_http_client = MagicMock() with patch.object(SeaCloudFetchQueue, "_create_next_table", return_value=None): queue = SeaCloudFetchQueue( result_data=ResultData(external_links=[sample_external_link]), @@ -349,6 +357,7 @@ def test_init_with_valid_initial_link( total_chunk_count=1, lz4_compressed=False, description=description, + http_client=mock_http_client, ) # Verify attributes @@ -367,6 +376,7 @@ def test_init_no_initial_links( ): """Test initialization with no initial links.""" # Create a queue with empty initial links + mock_http_client = MagicMock() queue = SeaCloudFetchQueue( result_data=ResultData(external_links=[]), max_download_threads=5, @@ -376,6 +386,7 @@ def test_init_no_initial_links( total_chunk_count=0, lz4_compressed=False, description=description, + http_client=mock_http_client, ) assert queue.table is None @@ -462,7 +473,7 @@ def test_hybrid_disposition_with_attachment( # Create result data with attachment attachment_data = b"mock_arrow_data" result_data = ResultData(attachment=attachment_data) - + mock_http_client = MagicMock() # Build queue queue = SeaResultSetQueueFactory.build_queue( result_data=result_data, @@ -473,6 +484,7 @@ def test_hybrid_disposition_with_attachment( max_download_threads=10, sea_client=mock_sea_client, lz4_compressed=False, + http_client=mock_http_client, ) # Verify ArrowQueue was created @@ -508,7 +520,8 @@ def test_hybrid_disposition_with_external_links( # Create result data with external links but no attachment result_data = ResultData(external_links=external_links, attachment=None) - # Build queue + # Build queue + mock_http_client = MagicMock() queue = SeaResultSetQueueFactory.build_queue( result_data=result_data, manifest=arrow_manifest, @@ -518,6 +531,7 @@ def test_hybrid_disposition_with_external_links( max_download_threads=10, sea_client=mock_sea_client, lz4_compressed=False, + http_client=mock_http_client, ) # Verify SeaCloudFetchQueue was created @@ -548,7 +562,7 @@ def test_hybrid_disposition_with_compressed_attachment( # Create result data with attachment result_data = ResultData(attachment=compressed_data) - + mock_http_client = MagicMock() # Build queue with lz4_compressed=True queue = SeaResultSetQueueFactory.build_queue( result_data=result_data, @@ -559,6 +573,7 @@ def test_hybrid_disposition_with_compressed_attachment( max_download_threads=10, sea_client=mock_sea_client, lz4_compressed=True, + http_client=mock_http_client, ) # Verify ArrowQueue was created with decompressed data diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 6823b1b33..e019e05a2 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -75,8 +75,9 @@ def test_http_header_passthrough(self, mock_client_class): call_kwargs = mock_client_class.call_args[1] assert ("foo", "bar") in call_kwargs["http_headers"] + @patch("%s.client.UnifiedHttpClient" % PACKAGE_NAME) @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) - def test_tls_arg_passthrough(self, mock_client_class): + def test_tls_arg_passthrough(self, mock_client_class, mock_http_client): databricks.sql.connect( **self.DUMMY_CONNECTION_ARGS, _tls_verify_hostname="hostname", diff --git a/tests/unit/test_telemetry_retry.py b/tests/unit/test_telemetry_retry.py deleted file mode 100644 index f0bdddd60..000000000 --- a/tests/unit/test_telemetry_retry.py +++ /dev/null @@ -1,118 +0,0 @@ -import pytest -from unittest.mock import patch, MagicMock -import io -import time - -from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory -from databricks.sql.auth.retry import DatabricksRetryPolicy - -PATCH_TARGET = "databricks.sql.common.unified_http_client.UnifiedHttpClient.request" - - -def create_mock_response(responses): - """Creates mock urllib3 HTTPResponse objects for the given response specifications.""" - mock_responses = [] - for resp in responses: - mock_response = MagicMock() - mock_response.status = resp.get("status") - mock_response.status_code = resp.get("status") # Add status_code for compatibility - mock_response.headers = resp.get("headers", {}) - mock_response.data = resp.get("body", b"{}") - mock_response.ok = resp.get("status", 200) < 400 - mock_response.text = resp.get("body", b"{}").decode() if isinstance(resp.get("body", b"{}"), bytes) else str(resp.get("body", "{}")) - mock_response.json = lambda: {} # Simple json mock - mock_responses.append(mock_response) - return mock_responses - - -class TestTelemetryClientRetries: - @pytest.fixture(autouse=True) - def setup_and_teardown(self): - TelemetryClientFactory._initialized = False - TelemetryClientFactory._clients = {} - TelemetryClientFactory._executor = None - yield - if TelemetryClientFactory._executor: - TelemetryClientFactory._executor.shutdown(wait=True) - TelemetryClientFactory._initialized = False - TelemetryClientFactory._clients = {} - TelemetryClientFactory._executor = None - - def get_client(self, session_id, num_retries=3): - mock_http_client = MagicMock() - TelemetryClientFactory.initialize_telemetry_client( - telemetry_enabled=True, - session_id_hex=session_id, - auth_provider=None, - host_url="test.databricks.com", - batch_size=1, # Use batch size of 1 to trigger immediate HTTP requests - http_client=mock_http_client, - ) - return TelemetryClientFactory.get_telemetry_client(session_id) - - @pytest.mark.parametrize( - "status_code, description", - [ - (401, "Unauthorized"), - (403, "Forbidden"), - (501, "Not Implemented"), - (200, "Success"), - ], - ) - def test_non_retryable_status_codes_are_not_retried(self, status_code, description): - """ - Verifies that terminal error codes (401, 403, 501) and success codes (200) are not retried. - """ - # Use the status code in the session ID for easier debugging if it fails - client = self.get_client(f"session-{status_code}") - mock_responses = [{"status": status_code}] - - mock_response = create_mock_response(mock_responses)[0] - with patch(PATCH_TARGET, return_value=mock_response) as mock_request: - client.export_failure_log("TestError", "Test message") - - # Wait a moment for async operations to complete - time.sleep(0.1) - - TelemetryClientFactory.close(client._session_id_hex) - - # Wait a bit more for any final operations - time.sleep(0.1) - - mock_request.assert_called_once() - - def test_exceeds_retry_count_limit(self): - """ - Verifies that the client retries up to the specified number of times before giving up. - Verifies that the client respects the Retry-After header and retries on 429, 502, 503. - """ - num_retries = 3 - expected_total_calls = num_retries + 1 - retry_after = 1 - client = self.get_client("session-exceed-limit", num_retries=num_retries) - mock_responses = [ - {"status": 429, "headers": {"Retry-After": str(retry_after)}}, - {"status": 502}, - {"status": 503}, - {"status": 200}, - ] - - mock_response_objects = create_mock_response(mock_responses) - with patch(PATCH_TARGET, side_effect=mock_response_objects) as mock_request: - start_time = time.time() - client.export_failure_log("TestError", "Test message") - - # Wait for async operations to complete - time.sleep(0.2) - - TelemetryClientFactory.close(client._session_id_hex) - - # Wait for any final operations - time.sleep(0.2) - - end_time = time.time() - - assert ( - mock_request.call_count - == expected_total_calls - ) From 31552117d01160d59980a201a5c47d7135eb4040 Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Fri, 8 Aug 2025 19:27:20 +0530 Subject: [PATCH 4/4] fmt Signed-off-by: Vikrant Puppala --- src/databricks/sql/auth/common.py | 12 +- src/databricks/sql/auth/oauth.py | 4 +- src/databricks/sql/client.py | 34 ++++-- src/databricks/sql/cloudfetch/downloader.py | 4 +- src/databricks/sql/common/feature_flag.py | 12 +- .../sql/common/unified_http_client.py | 109 +++++++++--------- src/databricks/sql/session.py | 4 +- .../sql/telemetry/telemetry_client.py | 12 +- 8 files changed, 108 insertions(+), 83 deletions(-) diff --git a/src/databricks/sql/auth/common.py b/src/databricks/sql/auth/common.py index 262166a52..61b07cb91 100644 --- a/src/databricks/sql/auth/common.py +++ b/src/databricks/sql/auth/common.py @@ -65,14 +65,16 @@ def __init__( self.tls_client_cert_file = tls_client_cert_file self.oauth_persistence = oauth_persistence self.credentials_provider = credentials_provider - + # HTTP client configuration self.ssl_options = ssl_options self.socket_timeout = socket_timeout self.retry_stop_after_attempts_count = retry_stop_after_attempts_count or 30 self.retry_delay_min = retry_delay_min or 1.0 self.retry_delay_max = retry_delay_max or 60.0 - self.retry_stop_after_attempts_duration = retry_stop_after_attempts_duration or 900.0 + self.retry_stop_after_attempts_duration = ( + retry_stop_after_attempts_duration or 900.0 + ) self.retry_delay_default = retry_delay_default or 5.0 self.retry_dangerous_codes = retry_dangerous_codes or [] self.http_proxy = http_proxy @@ -110,8 +112,8 @@ def get_azure_tenant_id_from_host(host: str, http_client) -> str: login_url = f"{host}/aad/auth" logger.debug("Loading tenant ID from %s", login_url) - - with http_client.request_context('GET', login_url, allow_redirects=False) as resp: + + with http_client.request_context("GET", login_url, allow_redirects=False) as resp: if resp.status // 100 != 3: raise ValueError( f"Failed to get tenant ID from {login_url}: expected status code 3xx, got {resp.status}" @@ -119,7 +121,7 @@ def get_azure_tenant_id_from_host(host: str, http_client) -> str: entra_id_endpoint = dict(resp.headers).get("Location") if entra_id_endpoint is None: raise ValueError(f"No Location header in response from {login_url}") - + # The Location header has the following form: https://login.microsoftonline.com//oauth2/authorize?... # The domain may change depending on the Azure cloud (e.g. login.microsoftonline.us for US Government cloud). url = urlparse(entra_id_endpoint) diff --git a/src/databricks/sql/auth/oauth.py b/src/databricks/sql/auth/oauth.py index 270287953..7f96a2303 100644 --- a/src/databricks/sql/auth/oauth.py +++ b/src/databricks/sql/auth/oauth.py @@ -87,7 +87,7 @@ def __fetch_well_known_config(self, hostname: str): known_config_url = self.idp_endpoint.get_openid_config_url(hostname) try: - response = self.http_client.request('GET', url=known_config_url) + response = self.http_client.request("GET", url=known_config_url) # Convert urllib3 response to requests-like response for compatibility response.status_code = response.status response.json = lambda: json.loads(response.data.decode()) @@ -198,7 +198,7 @@ def __send_token_request(token_request_url, data): } # Use unified HTTP client response = self.http_client.request( - 'POST', url=token_request_url, body=data, headers=headers + "POST", url=token_request_url, body=data, headers=headers ) # Convert urllib3 response to dict for compatibility return json.loads(response.data.decode()) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 7323b939a..1a35f97da 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -354,7 +354,7 @@ def _build_client_context(self, server_hostname: str, **kwargs): """Build ClientContext for HTTP client configuration.""" from databricks.sql.auth.common import ClientContext from databricks.sql.types import SSLOptions - + # Extract SSL options ssl_options = SSLOptions( tls_verify=not kwargs.get("_tls_no_verify", False), @@ -364,22 +364,26 @@ def _build_client_context(self, server_hostname: str, **kwargs): tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"), tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"), ) - + # Build user agent user_agent_entry = kwargs.get("user_agent_entry", "") if user_agent_entry: user_agent = f"PyDatabricksSqlConnector/{__version__} ({user_agent_entry})" else: user_agent = f"PyDatabricksSqlConnector/{__version__}" - + return ClientContext( hostname=server_hostname, ssl_options=ssl_options, socket_timeout=kwargs.get("_socket_timeout"), - retry_stop_after_attempts_count=kwargs.get("_retry_stop_after_attempts_count", 30), + retry_stop_after_attempts_count=kwargs.get( + "_retry_stop_after_attempts_count", 30 + ), retry_delay_min=kwargs.get("_retry_delay_min", 1.0), retry_delay_max=kwargs.get("_retry_delay_max", 60.0), - retry_stop_after_attempts_duration=kwargs.get("_retry_stop_after_attempts_duration", 900.0), + retry_stop_after_attempts_duration=kwargs.get( + "_retry_stop_after_attempts_duration", 900.0 + ), retry_delay_default=kwargs.get("_retry_delay_default", 1.0), retry_dangerous_codes=kwargs.get("_retry_dangerous_codes", []), http_proxy=kwargs.get("_http_proxy"), @@ -443,7 +447,7 @@ def get_protocol_version(openSessionResp: TOpenSessionResp): @property def open(self) -> bool: """Return whether the connection is open by checking if the session is open.""" - return hasattr(self, 'session') and self.session.is_open + return hasattr(self, "session") and self.session.is_open def cursor( self, @@ -792,10 +796,12 @@ def _handle_staging_put( ) with open(local_file, "rb") as fh: - r = self.connection.session.http_client.request('PUT', presigned_url, body=fh.read(), headers=headers) + r = self.connection.session.http_client.request( + "PUT", presigned_url, body=fh.read(), headers=headers + ) # Add compatibility attributes for urllib3 response r.status_code = r.status - if hasattr(r, 'data'): + if hasattr(r, "data"): r.content = r.data r.ok = r.status < 400 r.text = r.data.decode() if r.data else "" @@ -835,10 +841,12 @@ def _handle_staging_get( session_id_hex=self.connection.get_session_id_hex(), ) - r = self.connection.session.http_client.request('GET', presigned_url, headers=headers) + r = self.connection.session.http_client.request( + "GET", presigned_url, headers=headers + ) # Add compatibility attributes for urllib3 response r.status_code = r.status - if hasattr(r, 'data'): + if hasattr(r, "data"): r.content = r.data r.ok = r.status < 400 r.text = r.data.decode() if r.data else "" @@ -860,10 +868,12 @@ def _handle_staging_remove( ): """Make an HTTP DELETE request to the presigned_url""" - r = self.connection.session.http_client.request('DELETE', presigned_url, headers=headers) + r = self.connection.session.http_client.request( + "DELETE", presigned_url, headers=headers + ) # Add compatibility attributes for urllib3 response r.status_code = r.status - if hasattr(r, 'data'): + if hasattr(r, "data"): r.content = r.data r.ok = r.status < 400 r.text = r.data.decode() if r.data else "" diff --git a/src/databricks/sql/cloudfetch/downloader.py b/src/databricks/sql/cloudfetch/downloader.py index ea375fbbb..cef4ca274 100644 --- a/src/databricks/sql/cloudfetch/downloader.py +++ b/src/databricks/sql/cloudfetch/downloader.py @@ -95,10 +95,10 @@ def run(self) -> DownloadedFile: start_time = time.time() with self._http_client.request_context( - method='GET', + method="GET", url=self.link.fileLink, timeout=self.settings.download_timeout, - headers=self.link.httpHeaders + headers=self.link.httpHeaders, ) as response: if response.status >= 400: raise Exception(f"HTTP {response.status}: {response.data.decode()}") diff --git a/src/databricks/sql/common/feature_flag.py b/src/databricks/sql/common/feature_flag.py index 8e7029805..1b920b008 100644 --- a/src/databricks/sql/common/feature_flag.py +++ b/src/databricks/sql/common/feature_flag.py @@ -49,7 +49,9 @@ class FeatureFlagsContext: in the background, returning stale data until the refresh completes. """ - def __init__(self, connection: "Connection", executor: ThreadPoolExecutor, http_client): + def __init__( + self, connection: "Connection", executor: ThreadPoolExecutor, http_client + ): from databricks.sql import __version__ self._connection = connection @@ -65,7 +67,7 @@ def __init__(self, connection: "Connection", executor: ThreadPoolExecutor, http_ self._feature_flag_endpoint = ( f"https://{self._connection.session.host}{endpoint_suffix}" ) - + # Use the provided HTTP client self._http_client = http_client @@ -109,7 +111,7 @@ def _refresh_flags(self): headers["User-Agent"] = self._connection.session.useragent_header response = self._http_client.request( - 'GET', self._feature_flag_endpoint, headers=headers, timeout=30 + "GET", self._feature_flag_endpoint, headers=headers, timeout=30 ) # Add compatibility attributes for urllib3 response response.status_code = response.status @@ -165,7 +167,9 @@ def get_instance(cls, connection: "Connection") -> FeatureFlagsContext: # Use the unique session ID as the key key = connection.get_session_id_hex() if key not in cls._context_map: - cls._context_map[key] = FeatureFlagsContext(connection, cls._executor, connection.session.http_client) + cls._context_map[key] = FeatureFlagsContext( + connection, cls._executor, connection.session.http_client + ) return cls._context_map[key] @classmethod diff --git a/src/databricks/sql/common/unified_http_client.py b/src/databricks/sql/common/unified_http_client.py index 8c3be2bfd..a296704b4 100644 --- a/src/databricks/sql/common/unified_http_client.py +++ b/src/databricks/sql/common/unified_http_client.py @@ -18,7 +18,7 @@ class UnifiedHttpClient: """ Unified HTTP client for all Databricks SQL connector HTTP operations. - + This client uses urllib3 for robust HTTP communication with retry policies, connection pooling, SSL support, and proxy support. It replaces the various singleton HTTP clients and direct requests usage throughout the codebase. @@ -37,12 +37,12 @@ def __init__(self, client_context): def _setup_pool_manager(self): """Set up the urllib3 PoolManager with configuration from ClientContext.""" - + # SSL context setup ssl_context = None if self.config.ssl_options: ssl_context = ssl.create_default_context() - + # Configure SSL verification if not self.config.ssl_options.tls_verify: ssl_context.check_hostname = False @@ -50,18 +50,22 @@ def _setup_pool_manager(self): elif not self.config.ssl_options.tls_verify_hostname: ssl_context.check_hostname = False ssl_context.verify_mode = ssl.CERT_REQUIRED - + # Load custom CA file if specified if self.config.ssl_options.tls_trusted_ca_file: - ssl_context.load_verify_locations(self.config.ssl_options.tls_trusted_ca_file) - + ssl_context.load_verify_locations( + self.config.ssl_options.tls_trusted_ca_file + ) + # Load client certificate if specified - if (self.config.ssl_options.tls_client_cert_file and - self.config.ssl_options.tls_client_cert_key_file): + if ( + self.config.ssl_options.tls_client_cert_file + and self.config.ssl_options.tls_client_cert_key_file + ): ssl_context.load_cert_chain( self.config.ssl_options.tls_client_cert_file, self.config.ssl_options.tls_client_cert_key_file, - self.config.ssl_options.tls_client_cert_key_password + self.config.ssl_options.tls_client_cert_key_password, ) # Create retry policy @@ -76,14 +80,15 @@ def _setup_pool_manager(self): # Common pool manager kwargs pool_kwargs = { - 'num_pools': self.config.pool_connections, - 'maxsize': self.config.pool_maxsize, - 'retries': retry_policy, - 'timeout': urllib3.Timeout( - connect=self.config.socket_timeout, - read=self.config.socket_timeout - ) if self.config.socket_timeout else None, - 'ssl_context': ssl_context, + "num_pools": self.config.pool_connections, + "maxsize": self.config.pool_maxsize, + "retries": retry_policy, + "timeout": urllib3.Timeout( + connect=self.config.socket_timeout, read=self.config.socket_timeout + ) + if self.config.socket_timeout + else None, + "ssl_context": ssl_context, } # Create proxy or regular pool manager @@ -93,58 +98,51 @@ def _setup_pool_manager(self): proxy_headers = make_headers( proxy_basic_auth=f"{self.config.proxy_username}:{self.config.proxy_password}" ) - + self._pool_manager = ProxyManager( - self.config.http_proxy, - proxy_headers=proxy_headers, - **pool_kwargs + self.config.http_proxy, proxy_headers=proxy_headers, **pool_kwargs ) else: self._pool_manager = PoolManager(**pool_kwargs) - def _prepare_headers(self, headers: Optional[Dict[str, str]] = None) -> Dict[str, str]: + def _prepare_headers( + self, headers: Optional[Dict[str, str]] = None + ) -> Dict[str, str]: """Prepare headers for the request, including User-Agent.""" request_headers = {} - + if self.config.user_agent: - request_headers['User-Agent'] = self.config.user_agent - + request_headers["User-Agent"] = self.config.user_agent + if headers: request_headers.update(headers) - + return request_headers @contextmanager def request_context( - self, - method: str, - url: str, - headers: Optional[Dict[str, str]] = None, - **kwargs + self, method: str, url: str, headers: Optional[Dict[str, str]] = None, **kwargs ) -> Generator[urllib3.HTTPResponse, None, None]: """ Context manager for making HTTP requests with proper resource cleanup. - + Args: method: HTTP method (GET, POST, PUT, DELETE) url: URL to request headers: Optional headers dict **kwargs: Additional arguments passed to urllib3 request - + Yields: urllib3.HTTPResponse: The HTTP response object """ logger.debug("Making %s request to %s", method, url) - + request_headers = self._prepare_headers(headers) response = None - + try: response = self._pool_manager.request( - method=method, - url=url, - headers=request_headers, - **kwargs + method=method, url=url, headers=request_headers, **kwargs ) yield response except MaxRetryError as e: @@ -157,16 +155,18 @@ def request_context( if response: response.close() - def request(self, method: str, url: str, headers: Optional[Dict[str, str]] = None, **kwargs) -> urllib3.HTTPResponse: + def request( + self, method: str, url: str, headers: Optional[Dict[str, str]] = None, **kwargs + ) -> urllib3.HTTPResponse: """ Make an HTTP request. - + Args: method: HTTP method (GET, POST, PUT, DELETE, etc.) url: URL to request headers: Optional headers dict **kwargs: Additional arguments passed to urllib3 request - + Returns: urllib3.HTTPResponse: The HTTP response object with data pre-loaded """ @@ -175,32 +175,36 @@ def request(self, method: str, url: str, headers: Optional[Dict[str, str]] = Non response._body = response.data return response - def upload_file(self, url: str, file_path: str, headers: Optional[Dict[str, str]] = None) -> urllib3.HTTPResponse: + def upload_file( + self, url: str, file_path: str, headers: Optional[Dict[str, str]] = None + ) -> urllib3.HTTPResponse: """ Upload a file using PUT method. - + Args: url: URL to upload to file_path: Path to the file to upload headers: Optional headers - + Returns: urllib3.HTTPResponse: The response from the server """ - with open(file_path, 'rb') as file_obj: - return self.request('PUT', url, body=file_obj.read(), headers=headers) + with open(file_path, "rb") as file_obj: + return self.request("PUT", url, body=file_obj.read(), headers=headers) - def download_file(self, url: str, file_path: str, headers: Optional[Dict[str, str]] = None) -> None: + def download_file( + self, url: str, file_path: str, headers: Optional[Dict[str, str]] = None + ) -> None: """ Download a file using GET method. - + Args: url: URL to download from file_path: Path where to save the downloaded file headers: Optional headers """ - response = self.request('GET', url, headers=headers) - with open(file_path, 'wb') as file_obj: + response = self.request("GET", url, headers=headers) + with open(file_path, "wb") as file_obj: file_obj.write(response.data) def close(self): @@ -222,5 +226,6 @@ class IgnoreNetrcAuth: Compatibility class for OAuth code that expects requests.auth.AuthBase interface. This is a no-op auth handler since OAuth handles auth differently. """ + def __call__(self, request): - return request \ No newline at end of file + return request diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index c9b4f939a..0cba8be48 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -193,7 +193,7 @@ def close(self) -> None: logger.error("Attempt to close session raised a local exception: %s", e) self.is_open = False - + # Close HTTP client if it exists - if hasattr(self, 'http_client') and self.http_client: + if hasattr(self, "http_client") and self.http_client: self.http_client.close() diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 13c15486d..2785d3cca 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -230,7 +230,7 @@ def _send_telemetry(self, events): try: logger.debug("Submitting telemetry request to thread pool") - + # Use unified HTTP client future = self._executor.submit( self._send_with_unified_client, @@ -239,7 +239,7 @@ def _send_telemetry(self, events): headers=headers, timeout=900, ) - + future.add_done_callback( lambda fut: self._telemetry_request_callback(fut, sent_count=sent_count) ) @@ -249,10 +249,14 @@ def _send_telemetry(self, events): def _send_with_unified_client(self, url, data, headers): """Helper method to send telemetry using the unified HTTP client.""" try: - response = self._http_client.request('POST', url, body=data, headers=headers, timeout=900) + response = self._http_client.request( + "POST", url, body=data, headers=headers, timeout=900 + ) # Convert urllib3 response to requests-like response for compatibility response.status_code = response.status - response.json = lambda: json.loads(response.data.decode()) if response.data else {} + response.json = ( + lambda: json.loads(response.data.decode()) if response.data else {} + ) return response except Exception as e: logger.error("Failed to send telemetry with unified client: %s", e)