Skip to content

Commit 9a8cbbf

Browse files
committed
Merge branch 'main' into telemetry-server-flag
Signed-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com>
2 parents 0d300ab + a0d7cd1 commit 9a8cbbf

File tree

4 files changed

+52
-4
lines changed

4 files changed

+52
-4
lines changed

src/databricks/sql/client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,9 @@ def read(self) -> Optional[OAuthToken]:
248248
self.lz4_compression = kwargs.get("enable_query_result_lz4_compression", True)
249249
self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True)
250250
self._cursors = [] # type: List[Cursor]
251+
self.telemetry_batch_size = kwargs.get(
252+
"telemetry_batch_size", TelemetryClientFactory.DEFAULT_BATCH_SIZE
253+
)
251254

252255
try:
253256
self.session = Session(
@@ -288,6 +291,7 @@ def read(self) -> Optional[OAuthToken]:
288291
session_id_hex=self.get_session_id_hex(),
289292
auth_provider=self.session.auth_provider,
290293
host_url=self.session.host,
294+
batch_size=self.telemetry_batch_size,
291295
)
292296

293297
self._telemetry_client = TelemetryClientFactory.get_telemetry_client(

src/databricks/sql/telemetry/telemetry_client.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,19 +158,18 @@ class TelemetryClient(BaseTelemetryClient):
158158
TELEMETRY_AUTHENTICATED_PATH = "/telemetry-ext"
159159
TELEMETRY_UNAUTHENTICATED_PATH = "/telemetry-unauth"
160160

161-
DEFAULT_BATCH_SIZE = 100
162-
163161
def __init__(
164162
self,
165163
telemetry_enabled,
166164
session_id_hex,
167165
auth_provider,
168166
host_url,
169167
executor,
168+
batch_size,
170169
):
171170
logger.debug("Initializing TelemetryClient for connection: %s", session_id_hex)
172171
self._telemetry_enabled = telemetry_enabled
173-
self._batch_size = self.DEFAULT_BATCH_SIZE
172+
self._batch_size = batch_size
174173
self._session_id_hex = session_id_hex
175174
self._auth_provider = auth_provider
176175
self._user_agent = None
@@ -338,7 +337,7 @@ def close(self):
338337
class TelemetryClientFactory:
339338
"""
340339
Static factory class for creating and managing telemetry clients.
341-
It uses a thread pool to handle asynchronous operations.
340+
It uses a thread pool to handle asynchronous operations and a single flush thread for all clients.
342341
"""
343342

344343
_clients: Dict[
@@ -351,6 +350,13 @@ class TelemetryClientFactory:
351350
_original_excepthook = None
352351
_excepthook_installed = False
353352

353+
# Shared flush thread for all clients
354+
_flush_thread = None
355+
_flush_event = threading.Event()
356+
_flush_interval_seconds = 90
357+
358+
DEFAULT_BATCH_SIZE = 100
359+
354360
@classmethod
355361
def _initialize(cls):
356362
"""Initialize the factory if not already initialized"""
@@ -361,11 +367,39 @@ def _initialize(cls):
361367
max_workers=10
362368
) # Thread pool for async operations
363369
cls._install_exception_hook()
370+
cls._start_flush_thread()
364371
cls._initialized = True
365372
logger.debug(
366373
"TelemetryClientFactory initialized with thread pool (max_workers=10)"
367374
)
368375

376+
@classmethod
377+
def _start_flush_thread(cls):
378+
"""Start the shared background thread for periodic flushing of all clients"""
379+
cls._flush_event.clear()
380+
cls._flush_thread = threading.Thread(target=cls._flush_worker, daemon=True)
381+
cls._flush_thread.start()
382+
383+
@classmethod
384+
def _flush_worker(cls):
385+
"""Background worker thread for periodic flushing of all clients"""
386+
while not cls._flush_event.wait(cls._flush_interval_seconds):
387+
logger.debug("Performing periodic flush for all telemetry clients")
388+
389+
with cls._lock:
390+
clients_to_flush = list(cls._clients.values())
391+
392+
for client in clients_to_flush:
393+
client._flush()
394+
395+
@classmethod
396+
def _stop_flush_thread(cls):
397+
"""Stop the shared background flush thread"""
398+
if cls._flush_thread is not None:
399+
cls._flush_event.set()
400+
cls._flush_thread.join(timeout=1.0)
401+
cls._flush_thread = None
402+
369403
@classmethod
370404
def _install_exception_hook(cls):
371405
"""Install global exception handler for unhandled exceptions"""
@@ -394,6 +428,7 @@ def initialize_telemetry_client(
394428
session_id_hex,
395429
auth_provider,
396430
host_url,
431+
batch_size,
397432
):
398433
"""Initialize a telemetry client for a specific connection if telemetry is enabled"""
399434
try:
@@ -415,6 +450,7 @@ def initialize_telemetry_client(
415450
auth_provider=auth_provider,
416451
host_url=host_url,
417452
executor=TelemetryClientFactory._executor,
453+
batch_size=batch_size,
418454
)
419455
else:
420456
TelemetryClientFactory._clients[
@@ -453,6 +489,7 @@ def close(session_id_hex):
453489
"No more telemetry clients, shutting down thread pool executor"
454490
)
455491
try:
492+
TelemetryClientFactory._stop_flush_thread()
456493
TelemetryClientFactory._executor.shutdown(wait=True)
457494
TelemetryHttpClient.close()
458495
except Exception as e:
@@ -478,6 +515,7 @@ def connection_failure_log(
478515
session_id_hex=UNAUTH_DUMMY_SESSION_ID,
479516
auth_provider=None,
480517
host_url=host_url,
518+
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE,
481519
)
482520

483521
telemetry_client = TelemetryClientFactory.get_telemetry_client(

tests/unit/test_telemetry.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def mock_telemetry_client():
3030
auth_provider=auth_provider,
3131
host_url="test-host.com",
3232
executor=executor,
33+
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE
3334
)
3435

3536

@@ -214,6 +215,7 @@ def test_client_lifecycle_flow(self):
214215
session_id_hex=session_id_hex,
215216
auth_provider=auth_provider,
216217
host_url="test-host.com",
218+
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE
217219
)
218220

219221
client = TelemetryClientFactory.get_telemetry_client(session_id_hex)
@@ -238,6 +240,7 @@ def test_disabled_telemetry_flow(self):
238240
session_id_hex=session_id_hex,
239241
auth_provider=None,
240242
host_url="test-host.com",
243+
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE
241244
)
242245

243246
client = TelemetryClientFactory.get_telemetry_client(session_id_hex)
@@ -257,6 +260,7 @@ def test_factory_error_handling(self):
257260
session_id_hex=session_id,
258261
auth_provider=AccessTokenAuthProvider("token"),
259262
host_url="test-host.com",
263+
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE
260264
)
261265

262266
# Should fall back to NoopTelemetryClient
@@ -275,6 +279,7 @@ def test_factory_shutdown_flow(self):
275279
session_id_hex=session,
276280
auth_provider=AccessTokenAuthProvider("token"),
277281
host_url="test-host.com",
282+
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE
278283
)
279284

280285
# Factory should be initialized

tests/unit/test_telemetry_retry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def get_client(self, session_id, num_retries=3):
4747
session_id_hex=session_id,
4848
auth_provider=None,
4949
host_url="test.databricks.com",
50+
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE
5051
)
5152
client = TelemetryClientFactory.get_telemetry_client(session_id)
5253

0 commit comments

Comments
 (0)