Skip to content

Commit c23d540

Browse files
Merge branch 'main' into normalise-code
2 parents a636bdc + 59d28b0 commit c23d540

File tree

10 files changed

+254
-23
lines changed

10 files changed

+254
-23
lines changed

src/databricks/sql/backend/thrift_backend.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,7 +1040,6 @@ def execute_command(
10401040
max_download_threads=self.max_download_threads,
10411041
ssl_options=self._ssl_options,
10421042
has_more_rows=has_more_rows,
1043-
session_id_hex=self._session_id_hex,
10441043
)
10451044

10461045
def get_catalogs(
@@ -1079,7 +1078,6 @@ def get_catalogs(
10791078
max_download_threads=self.max_download_threads,
10801079
ssl_options=self._ssl_options,
10811080
has_more_rows=has_more_rows,
1082-
session_id_hex=self._session_id_hex,
10831081
)
10841082

10851083
def get_schemas(
@@ -1124,7 +1122,6 @@ def get_schemas(
11241122
max_download_threads=self.max_download_threads,
11251123
ssl_options=self._ssl_options,
11261124
has_more_rows=has_more_rows,
1127-
session_id_hex=self._session_id_hex,
11281125
)
11291126

11301127
def get_tables(
@@ -1173,7 +1170,6 @@ def get_tables(
11731170
max_download_threads=self.max_download_threads,
11741171
ssl_options=self._ssl_options,
11751172
has_more_rows=has_more_rows,
1176-
session_id_hex=self._session_id_hex,
11771173
)
11781174

11791175
def get_columns(
@@ -1222,7 +1218,6 @@ def get_columns(
12221218
max_download_threads=self.max_download_threads,
12231219
ssl_options=self._ssl_options,
12241220
has_more_rows=has_more_rows,
1225-
session_id_hex=self._session_id_hex,
12261221
)
12271222

12281223
def _handle_execute_response(self, resp, cursor):

src/databricks/sql/cloudfetch/downloader.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,14 @@ class DownloadableResultSettings:
5454
link_expiry_buffer_secs (int): Time in seconds to prevent download of a link before it expires. Default 0 secs.
5555
download_timeout (int): Timeout for download requests. Default 60 secs.
5656
max_consecutive_file_download_retries (int): Number of consecutive download retries before shutting down.
57+
min_cloudfetch_download_speed (float): Threshold in MB/s below which to log warning. Default 0.1 MB/s.
5758
"""
5859

5960
is_lz4_compressed: bool
6061
link_expiry_buffer_secs: int = 0
6162
download_timeout: int = 60
6263
max_consecutive_file_download_retries: int = 0
64+
min_cloudfetch_download_speed: float = 0.1
6365

6466

6567
class ResultSetDownloadHandler:
@@ -100,6 +102,8 @@ def run(self) -> DownloadedFile:
100102
self.link, self.settings.link_expiry_buffer_secs
101103
)
102104

105+
start_time = time.time()
106+
103107
with self._http_client.execute(
104108
method=HttpMethod.GET,
105109
url=self.link.fileLink,
@@ -112,6 +116,13 @@ def run(self) -> DownloadedFile:
112116

113117
# Save (and decompress if needed) the downloaded file
114118
compressed_data = response.content
119+
120+
# Log download metrics
121+
download_duration = time.time() - start_time
122+
self._log_download_metrics(
123+
self.link.fileLink, len(compressed_data), download_duration
124+
)
125+
115126
decompressed_data = (
116127
ResultSetDownloadHandler._decompress_data(compressed_data)
117128
if self.settings.is_lz4_compressed
@@ -138,6 +149,32 @@ def run(self) -> DownloadedFile:
138149
self.link.rowCount,
139150
)
140151

152+
def _log_download_metrics(
153+
self, url: str, bytes_downloaded: int, duration_seconds: float
154+
):
155+
"""Log download speed metrics at INFO/WARN levels."""
156+
# Calculate speed in MB/s (ensure float division for precision)
157+
speed_mbps = (float(bytes_downloaded) / (1024 * 1024)) / duration_seconds
158+
159+
urlEndpoint = url.split("?")[0]
160+
# INFO level logging
161+
logger.info(
162+
"CloudFetch download completed: %.4f MB/s, %d bytes in %.3fs from %s",
163+
speed_mbps,
164+
bytes_downloaded,
165+
duration_seconds,
166+
urlEndpoint,
167+
)
168+
169+
# WARN level logging if below threshold
170+
if speed_mbps < self.settings.min_cloudfetch_download_speed:
171+
logger.warning(
172+
"CloudFetch download slower than threshold: %.4f MB/s (threshold: %.1f MB/s) from %s",
173+
speed_mbps,
174+
self.settings.min_cloudfetch_download_speed,
175+
url,
176+
)
177+
141178
@staticmethod
142179
def _validate_link(link: TSparkArrowResultLink, expiry_buffer_secs: int):
143180
"""

src/databricks/sql/common/http.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
import threading
66
from dataclasses import dataclass
77
from contextlib import contextmanager
8-
from typing import Generator
8+
from typing import Generator, Optional
99
import logging
10+
from requests.adapters import HTTPAdapter
11+
from databricks.sql.auth.retry import DatabricksRetryPolicy, CommandType
1012

1113
logger = logging.getLogger(__name__)
1214

@@ -81,3 +83,70 @@ def execute(
8183

8284
def close(self):
8385
self.session.close()
86+
87+
88+
class TelemetryHTTPAdapter(HTTPAdapter):
89+
"""
90+
Custom HTTP adapter to prepare our DatabricksRetryPolicy before each request.
91+
This ensures the retry timer is started and the command type is set correctly,
92+
allowing the policy to manage its state for the duration of the request retries.
93+
"""
94+
95+
def send(self, request, **kwargs):
96+
self.max_retries.command_type = CommandType.OTHER
97+
self.max_retries.start_retry_timer()
98+
return super().send(request, **kwargs)
99+
100+
101+
class TelemetryHttpClient: # TODO: Unify all the http clients in the PySQL Connector
102+
"""Singleton HTTP client for sending telemetry data."""
103+
104+
_instance: Optional["TelemetryHttpClient"] = None
105+
_lock = threading.Lock()
106+
107+
TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_COUNT = 3
108+
TELEMETRY_RETRY_DELAY_MIN = 1.0
109+
TELEMETRY_RETRY_DELAY_MAX = 10.0
110+
TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_DURATION = 30.0
111+
112+
def __init__(self):
113+
"""Initializes the session and mounts the custom retry adapter."""
114+
retry_policy = DatabricksRetryPolicy(
115+
delay_min=self.TELEMETRY_RETRY_DELAY_MIN,
116+
delay_max=self.TELEMETRY_RETRY_DELAY_MAX,
117+
stop_after_attempts_count=self.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_COUNT,
118+
stop_after_attempts_duration=self.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_DURATION,
119+
delay_default=1.0,
120+
force_dangerous_codes=[],
121+
)
122+
adapter = TelemetryHTTPAdapter(max_retries=retry_policy)
123+
self.session = requests.Session()
124+
self.session.mount("https://", adapter)
125+
self.session.mount("http://", adapter)
126+
127+
@classmethod
128+
def get_instance(cls) -> "TelemetryHttpClient":
129+
"""Get the singleton instance of the TelemetryHttpClient."""
130+
if cls._instance is None:
131+
with cls._lock:
132+
if cls._instance is None:
133+
logger.debug("Initializing singleton TelemetryHttpClient")
134+
cls._instance = TelemetryHttpClient()
135+
return cls._instance
136+
137+
def post(self, url: str, **kwargs) -> requests.Response:
138+
"""
139+
Executes a POST request using the configured session.
140+
141+
This is a blocking call intended to be run in a background thread.
142+
"""
143+
logger.debug("Executing telemetry POST request to: %s", url)
144+
return self.session.post(url, **kwargs)
145+
146+
def close(self):
147+
"""Closes the underlying requests.Session."""
148+
logger.debug("Closing TelemetryHttpClient session.")
149+
self.session.close()
150+
# Clear the instance to allow for re-initialization if needed
151+
with TelemetryHttpClient._lock:
152+
TelemetryHttpClient._instance = None

src/databricks/sql/exc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
import logging
33

44
logger = logging.getLogger(__name__)
5-
from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory
6-
75

86
### PEP-249 Mandated ###
97
# https://peps.python.org/pep-0249/#exceptions
@@ -22,6 +20,8 @@ def __init__(
2220

2321
error_name = self.__class__.__name__
2422
if session_id_hex:
23+
from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory
24+
2525
telemetry_client = TelemetryClientFactory.get_telemetry_client(
2626
session_id_hex
2727
)

src/databricks/sql/result_set.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,11 @@ def close(self) -> None:
170170
been closed on the server for some other reason, issue a request to the server to close it.
171171
"""
172172
try:
173-
self.results.close()
173+
if self.results is not None:
174+
self.results.close()
175+
else:
176+
logger.warning("result set close: queue not initialized")
177+
174178
if (
175179
self.status != CommandState.CLOSED
176180
and not self.has_been_closed_server_side
@@ -193,7 +197,6 @@ def __init__(
193197
connection: Connection,
194198
execute_response: ExecuteResponse,
195199
thrift_client: ThriftDatabricksClient,
196-
session_id_hex: Optional[str],
197200
buffer_size_bytes: int = 104857600,
198201
arraysize: int = 10000,
199202
use_cloud_fetch: bool = True,
@@ -217,7 +220,7 @@ def __init__(
217220
:param ssl_options: SSL options for cloud fetch
218221
:param has_more_rows: Whether there are more rows to fetch
219222
"""
220-
self.num_downloaded_chunks = 0
223+
self.num_chunks = 0
221224

222225
# Initialize ThriftResultSet-specific attributes
223226
self._use_cloud_fetch = use_cloud_fetch
@@ -237,12 +240,12 @@ def __init__(
237240
lz4_compressed=execute_response.lz4_compressed,
238241
description=execute_response.description,
239242
ssl_options=ssl_options,
240-
session_id_hex=session_id_hex,
243+
session_id_hex=connection.get_session_id_hex(),
241244
statement_id=execute_response.command_id.to_hex_guid(),
242-
chunk_id=self.num_downloaded_chunks,
245+
chunk_id=self.num_chunks,
243246
)
244247
if t_row_set.resultLinks:
245-
self.num_downloaded_chunks += len(t_row_set.resultLinks)
248+
self.num_chunks += len(t_row_set.resultLinks)
246249

247250
# Call parent constructor with common attributes
248251
super().__init__(
@@ -275,11 +278,11 @@ def _fill_results_buffer(self):
275278
arrow_schema_bytes=self._arrow_schema_bytes,
276279
description=self.description,
277280
use_cloud_fetch=self._use_cloud_fetch,
278-
chunk_id=self.num_downloaded_chunks,
281+
chunk_id=self.num_chunks,
279282
)
280283
self.results = results
281284
self.has_more_rows = has_more_rows
282-
self.num_downloaded_chunks += result_links_count
285+
self.num_chunks += result_links_count
283286

284287
def _convert_columnar_table(self, table):
285288
column_names = [c[0] for c in self.description]

src/databricks/sql/telemetry/telemetry_client.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import threading
22
import time
3-
import requests
43
import logging
54
from concurrent.futures import ThreadPoolExecutor
65
from typing import Dict, Optional
6+
from databricks.sql.common.http import TelemetryHttpClient
77
from databricks.sql.telemetry.models.event import (
88
TelemetryEvent,
99
DriverSystemConfiguration,
@@ -159,6 +159,7 @@ def __init__(
159159
self._driver_connection_params = None
160160
self._host_url = host_url
161161
self._executor = executor
162+
self._http_client = TelemetryHttpClient.get_instance()
162163

163164
def _export_event(self, event):
164165
"""Add an event to the batch queue and flush if batch is full"""
@@ -207,7 +208,7 @@ def _send_telemetry(self, events):
207208
try:
208209
logger.debug("Submitting telemetry request to thread pool")
209210
future = self._executor.submit(
210-
requests.post,
211+
self._http_client.post,
211212
url,
212213
data=request.to_json(),
213214
headers=headers,
@@ -433,6 +434,7 @@ def close(session_id_hex):
433434
)
434435
try:
435436
TelemetryClientFactory._executor.shutdown(wait=True)
437+
TelemetryHttpClient.close()
436438
except Exception as e:
437439
logger.debug("Failed to shutdown thread pool executor: %s", e)
438440
TelemetryClientFactory._executor = None

tests/unit/test_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ def test_negative_fetch_throws_exception(self):
266266
mock_backend.fetch_results.return_value = (Mock(), False, 0)
267267

268268
result_set = ThriftResultSet(
269-
Mock(), Mock(), mock_backend, session_id_hex=Mock()
269+
Mock(), Mock(), mock_backend
270270
)
271271

272272
with self.assertRaises(ValueError) as e:

tests/unit/test_downloader.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,17 @@ class DownloaderTests(unittest.TestCase):
2323
Unit tests for checking downloader logic.
2424
"""
2525

26+
def _setup_time_mock_for_download(self, mock_time, end_time):
27+
"""Helper to setup time mock that handles logging system calls."""
28+
call_count = [0]
29+
def time_side_effect():
30+
call_count[0] += 1
31+
if call_count[0] <= 2: # First two calls (validation, start_time)
32+
return 1000
33+
else: # All subsequent calls (logging, duration calculation)
34+
return end_time
35+
mock_time.side_effect = time_side_effect
36+
2637
@patch("time.time", return_value=1000)
2738
def test_run_link_expired(self, mock_time):
2839
settings = Mock()
@@ -90,13 +101,17 @@ def test_run_get_response_not_ok(self, mock_time):
90101
d.run()
91102
self.assertTrue("404" in str(context.exception))
92103

93-
@patch("time.time", return_value=1000)
104+
@patch("time.time")
94105
def test_run_uncompressed_successful(self, mock_time):
106+
self._setup_time_mock_for_download(mock_time, 1000.5)
107+
95108
http_client = DatabricksHttpClient.get_instance()
96109
file_bytes = b"1234567890" * 10
97110
settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False)
98111
settings.is_lz4_compressed = False
112+
settings.min_cloudfetch_download_speed = 1.0
99113
result_link = Mock(bytesNum=100, expiryTime=1001)
114+
result_link.fileLink = "https://s3.amazonaws.com/bucket/file.arrow?token=abc123"
100115

101116
with patch.object(
102117
http_client,
@@ -115,15 +130,19 @@ def test_run_uncompressed_successful(self, mock_time):
115130

116131
assert file.file_bytes == b"1234567890" * 10
117132

118-
@patch("time.time", return_value=1000)
133+
@patch("time.time")
119134
def test_run_compressed_successful(self, mock_time):
135+
self._setup_time_mock_for_download(mock_time, 1000.2)
136+
120137
http_client = DatabricksHttpClient.get_instance()
121138
file_bytes = b"1234567890" * 10
122139
compressed_bytes = b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00'
123140

124141
settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False)
125142
settings.is_lz4_compressed = True
143+
settings.min_cloudfetch_download_speed = 1.0
126144
result_link = Mock(bytesNum=100, expiryTime=1001)
145+
result_link.fileLink = "https://s3.amazonaws.com/bucket/file.arrow?token=xyz789"
127146
with patch.object(
128147
http_client,
129148
"execute",

tests/unit/test_telemetry.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import uuid
22
import pytest
3-
import requests
43
from unittest.mock import patch, MagicMock
54

65
from databricks.sql.telemetry.telemetry_client import (
@@ -91,7 +90,7 @@ def test_network_request_flow(self, mock_post, mock_telemetry_client):
9190
args, kwargs = client._executor.submit.call_args
9291

9392
# Verify correct function and URL
94-
assert args[0] == requests.post
93+
assert args[0] == client._http_client.post
9594
assert args[1] == "https://test-host.com/telemetry-ext"
9695
assert kwargs["headers"]["Authorization"] == "Bearer test-token"
9796

0 commit comments

Comments
 (0)