Skip to content

Commit 2f96e37

Browse files
remove deepcopy
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 81bf7db commit 2f96e37

File tree

4 files changed

+47
-34
lines changed

4 files changed

+47
-34
lines changed

src/databricks/sql/backend/sea/utils/filters.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import io
1010
import logging
11-
from copy import deepcopy
1211
from typing import (
1312
List,
1413
Optional,
@@ -62,7 +61,7 @@ def _create_execute_response(result_set: SeaResultSet) -> ExecuteResponse:
6261
)
6362

6463
@staticmethod
65-
def _create_filtered_manifest(result_set: SeaResultSet, new_row_count: int):
64+
def _update_manifest(result_set: SeaResultSet, new_row_count: int):
6665
"""
6766
Create a copy of the manifest with updated row count.
6867
@@ -73,7 +72,7 @@ def _create_filtered_manifest(result_set: SeaResultSet, new_row_count: int):
7372
Returns:
7473
Updated manifest copy
7574
"""
76-
filtered_manifest = deepcopy(result_set.manifest)
75+
filtered_manifest = result_set.manifest
7776
filtered_manifest.total_row_count = new_row_count
7877
return filtered_manifest
7978

@@ -97,9 +96,7 @@ def _create_filtered_result_set(
9796
from databricks.sql.backend.sea.result_set import SeaResultSet
9897

9998
execute_response = ResultSetFilter._create_execute_response(result_set)
100-
filtered_manifest = ResultSetFilter._create_filtered_manifest(
101-
result_set, row_count
102-
)
99+
filtered_manifest = ResultSetFilter._update_manifest(result_set, row_count)
103100

104101
return SeaResultSet(
105102
connection=result_set.connection,

tests/e2e/test_concurrent_telemetry.py

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,13 @@
66
import pytest
77

88
from databricks.sql.telemetry.models.enums import StatementType
9-
from databricks.sql.telemetry.telemetry_client import TelemetryClient, TelemetryClientFactory
9+
from databricks.sql.telemetry.telemetry_client import (
10+
TelemetryClient,
11+
TelemetryClientFactory,
12+
)
1013
from tests.e2e.test_driver import PySQLPytestTestCase
1114

15+
1216
def run_in_threads(target, num_threads, pass_index=False):
1317
"""Helper to run target function in multiple threads."""
1418
threads = [
@@ -22,7 +26,6 @@ def run_in_threads(target, num_threads, pass_index=False):
2226

2327

2428
class TestE2ETelemetry(PySQLPytestTestCase):
25-
2629
@pytest.fixture(autouse=True)
2730
def telemetry_setup_teardown(self):
2831
"""
@@ -31,7 +34,7 @@ def telemetry_setup_teardown(self):
3134
this robust and automatic.
3235
"""
3336
try:
34-
yield
37+
yield
3538
finally:
3639
if TelemetryClientFactory._executor:
3740
TelemetryClientFactory._executor.shutdown(wait=True)
@@ -68,20 +71,25 @@ def callback_wrapper(self_client, future, sent_count):
6871
captured_futures.append(future)
6972
original_callback(self_client, future, sent_count)
7073

71-
with patch.object(TelemetryClient, "_send_telemetry", send_telemetry_wrapper), \
72-
patch.object(TelemetryClient, "_telemetry_request_callback", callback_wrapper):
74+
with patch.object(
75+
TelemetryClient, "_send_telemetry", send_telemetry_wrapper
76+
), patch.object(
77+
TelemetryClient, "_telemetry_request_callback", callback_wrapper
78+
):
7379

7480
def execute_query_worker(thread_id):
7581
"""Each thread creates a connection and executes a query."""
7682

7783
time.sleep(random.uniform(0, 0.05))
78-
79-
with self.connection(extra_params={"force_enable_telemetry": True}) as conn:
84+
85+
with self.connection(
86+
extra_params={"force_enable_telemetry": True}
87+
) as conn:
8088
# Capture the session ID from the connection before executing the query
8189
session_id_hex = conn.get_session_id_hex()
8290
with capture_lock:
8391
captured_session_ids.append(session_id_hex)
84-
92+
8593
with conn.cursor() as cursor:
8694
cursor.execute(f"SELECT {thread_id}")
8795
# Capture the statement ID after executing the query
@@ -97,7 +105,10 @@ def execute_query_worker(thread_id):
97105
start_time = time.time()
98106
expected_event_count = num_threads
99107

100-
while len(captured_futures) < expected_event_count and time.time() - start_time < timeout_seconds:
108+
while (
109+
len(captured_futures) < expected_event_count
110+
and time.time() - start_time < timeout_seconds
111+
):
101112
time.sleep(0.1)
102113

103114
done, not_done = wait(captured_futures, timeout=timeout_seconds)
@@ -115,30 +126,37 @@ def execute_query_worker(thread_id):
115126

116127
assert not captured_exceptions
117128
assert len(captured_responses) > 0
118-
129+
119130
total_successful_events = 0
120131
for response in captured_responses:
121132
assert "errors" not in response or not response["errors"]
122133
if "numProtoSuccess" in response:
123134
total_successful_events += response["numProtoSuccess"]
124135
assert total_successful_events == num_threads * 2
125136

126-
assert len(captured_telemetry) == num_threads * 2 # 2 events per thread (initial_telemetry_log, latency_log (execute))
137+
assert (
138+
len(captured_telemetry) == num_threads * 2
139+
) # 2 events per thread (initial_telemetry_log, latency_log (execute))
127140
assert len(captured_session_ids) == num_threads # One session ID per thread
128-
assert len(captured_statement_ids) == num_threads # One statement ID per thread (per query)
141+
assert (
142+
len(captured_statement_ids) == num_threads
143+
) # One statement ID per thread (per query)
129144

130145
# Separate initial logs from latency logs
131146
initial_logs = [
132-
e for e in captured_telemetry
147+
e
148+
for e in captured_telemetry
133149
if e.entry.sql_driver_log.operation_latency_ms is None
134150
and e.entry.sql_driver_log.driver_connection_params is not None
135151
and e.entry.sql_driver_log.system_configuration is not None
136152
]
137153
latency_logs = [
138-
e for e in captured_telemetry
139-
if e.entry.sql_driver_log.operation_latency_ms is not None
140-
and e.entry.sql_driver_log.sql_statement_id is not None
141-
and e.entry.sql_driver_log.sql_operation.statement_type == StatementType.QUERY
154+
e
155+
for e in captured_telemetry
156+
if e.entry.sql_driver_log.operation_latency_ms is not None
157+
and e.entry.sql_driver_log.sql_statement_id is not None
158+
and e.entry.sql_driver_log.sql_operation.statement_type
159+
== StatementType.QUERY
142160
]
143161

144162
# Verify counts
@@ -171,4 +189,4 @@ def execute_query_worker(thread_id):
171189
for event in latency_logs:
172190
log = event.entry.sql_driver_log
173191
assert log.sql_statement_id in captured_statement_ids
174-
assert log.session_id in captured_session_ids
192+
assert log.session_id in captured_session_ids

tests/unit/test_telemetry.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def mock_telemetry_client():
3030
auth_provider=auth_provider,
3131
host_url="test-host.com",
3232
executor=executor,
33-
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE
33+
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE,
3434
)
3535

3636

@@ -215,7 +215,7 @@ def test_client_lifecycle_flow(self):
215215
session_id_hex=session_id_hex,
216216
auth_provider=auth_provider,
217217
host_url="test-host.com",
218-
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE
218+
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE,
219219
)
220220

221221
client = TelemetryClientFactory.get_telemetry_client(session_id_hex)
@@ -240,7 +240,7 @@ def test_disabled_telemetry_flow(self):
240240
session_id_hex=session_id_hex,
241241
auth_provider=None,
242242
host_url="test-host.com",
243-
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE
243+
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE,
244244
)
245245

246246
client = TelemetryClientFactory.get_telemetry_client(session_id_hex)
@@ -260,7 +260,7 @@ def test_factory_error_handling(self):
260260
session_id_hex=session_id,
261261
auth_provider=AccessTokenAuthProvider("token"),
262262
host_url="test-host.com",
263-
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE
263+
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE,
264264
)
265265

266266
# Should fall back to NoopTelemetryClient
@@ -279,7 +279,7 @@ def test_factory_shutdown_flow(self):
279279
session_id_hex=session,
280280
auth_provider=AccessTokenAuthProvider("token"),
281281
host_url="test-host.com",
282-
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE
282+
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE,
283283
)
284284

285285
# Factory should be initialized
@@ -342,9 +342,7 @@ def _mock_ff_response(self, mock_requests_get, enabled: bool):
342342
mock_requests_get.return_value = mock_response
343343

344344
@patch("databricks.sql.common.feature_flag.requests.get")
345-
def test_telemetry_enabled_when_flag_is_true(
346-
self, mock_requests_get, MockSession
347-
):
345+
def test_telemetry_enabled_when_flag_is_true(self, mock_requests_get, MockSession):
348346
"""Telemetry should be ON when enable_telemetry=True and server flag is 'true'."""
349347
self._mock_ff_response(mock_requests_get, enabled=True)
350348
mock_session_instance = MockSession.return_value
@@ -405,4 +403,4 @@ def test_telemetry_disabled_when_flag_request_fails(
405403
assert conn.telemetry_enabled is False
406404
mock_requests_get.assert_called_once()
407405
client = TelemetryClientFactory.get_telemetry_client("test-session-ff-fail")
408-
assert isinstance(client, NoopTelemetryClient)
406+
assert isinstance(client, NoopTelemetryClient)

tests/unit/test_telemetry_retry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def get_client(self, session_id, num_retries=3):
5151
session_id_hex=session_id,
5252
auth_provider=None,
5353
host_url="test.databricks.com",
54-
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE
54+
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE,
5555
)
5656
client = TelemetryClientFactory.get_telemetry_client(session_id)
5757

0 commit comments

Comments
 (0)