Skip to content

Commit f277b07

Browse files
stop passing session_id_hex
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 2201765 commit f277b07

File tree

3 files changed

+37
-28
lines changed

3 files changed

+37
-28
lines changed

tests/unit/test_client.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,6 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class):
127127
connection=connection,
128128
execute_response=mock_execute_response,
129129
thrift_client=mock_backend,
130-
session_id_hex=Mock(),
131130
)
132131

133132
# Mock execute_command to return our real result set
@@ -187,7 +186,6 @@ def test_closing_result_set_with_closed_connection_soft_closes_commands(self):
187186
connection=mock_connection,
188187
execute_response=Mock(),
189188
thrift_client=mock_backend,
190-
session_id_hex=Mock(),
191189
)
192190
result_set.results = mock_results
193191

@@ -217,7 +215,6 @@ def test_closing_result_set_hard_closes_commands(self):
217215
mock_connection,
218216
mock_results_response,
219217
mock_thrift_backend,
220-
session_id_hex=Mock(),
221218
)
222219
result_set.results = mock_results
223220

@@ -265,9 +262,7 @@ def test_negative_fetch_throws_exception(self):
265262
mock_backend = Mock()
266263
mock_backend.fetch_results.return_value = (Mock(), False, 0)
267264

268-
result_set = ThriftResultSet(
269-
Mock(), Mock(), mock_backend
270-
)
265+
result_set = ThriftResultSet(Mock(), Mock(), mock_backend)
271266

272267
with self.assertRaises(ValueError) as e:
273268
result_set.fetchmany(-1)

tests/unit/test_fetches.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ def make_dummy_result_set_from_initial_results(initial_results):
6363
),
6464
thrift_client=mock_thrift_backend,
6565
t_row_set=None,
66-
session_id_hex=Mock(),
6766
)
6867
return rs
6968

@@ -108,7 +107,6 @@ def fetch_results(
108107
is_staging_operation=False,
109108
),
110109
thrift_client=mock_thrift_backend,
111-
session_id_hex=Mock(),
112110
)
113111
return rs
114112

tests/unit/test_telemetry_retry.py

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory
77
from databricks.sql.auth.retry import DatabricksRetryPolicy
88

9-
PATCH_TARGET = 'urllib3.connectionpool.HTTPSConnectionPool._get_conn'
9+
PATCH_TARGET = "urllib3.connectionpool.HTTPSConnectionPool._get_conn"
10+
1011

1112
def create_mock_conn(responses):
1213
"""Creates a mock connection object whose getresponse() method yields a series of responses."""
@@ -16,15 +17,18 @@ def create_mock_conn(responses):
1617
mock_http_response = MagicMock()
1718
mock_http_response.status = resp.get("status")
1819
mock_http_response.headers = resp.get("headers", {})
19-
body = resp.get("body", b'{}')
20+
body = resp.get("body", b"{}")
2021
mock_http_response.fp = io.BytesIO(body)
22+
2123
def release():
2224
mock_http_response.fp.close()
25+
2326
mock_http_response.release_conn = release
2427
mock_http_responses.append(mock_http_response)
2528
mock_conn.getresponse.side_effect = mock_http_responses
2629
return mock_conn
2730

31+
2832
class TestTelemetryClientRetries:
2933
@pytest.fixture(autouse=True)
3034
def setup_and_teardown(self):
@@ -49,28 +53,28 @@ def get_client(self, session_id, num_retries=3):
4953
host_url="test.databricks.com",
5054
)
5155
client = TelemetryClientFactory.get_telemetry_client(session_id)
52-
56+
5357
retry_policy = DatabricksRetryPolicy(
5458
delay_min=0.01,
5559
delay_max=0.02,
5660
stop_after_attempts_duration=2.0,
57-
stop_after_attempts_count=num_retries,
61+
stop_after_attempts_count=num_retries,
5862
delay_default=0.1,
5963
force_dangerous_codes=[],
60-
urllib3_kwargs={'total': num_retries}
64+
urllib3_kwargs={"total": num_retries},
6165
)
6266
adapter = client._http_client.session.adapters.get("https://")
6367
adapter.max_retries = retry_policy
6468
return client
6569

6670
@pytest.mark.parametrize(
67-
"status_code, description",
68-
[
69-
(401, "Unauthorized"),
70-
(403, "Forbidden"),
71-
(501, "Not Implemented"),
72-
(200, "Success"),
73-
],
71+
"status_code, description",
72+
[
73+
(401, "Unauthorized"),
74+
(403, "Forbidden"),
75+
(501, "Not Implemented"),
76+
(200, "Success"),
77+
],
7478
)
7579
def test_non_retryable_status_codes_are_not_retried(self, status_code, description):
7680
"""
@@ -80,7 +84,9 @@ def test_non_retryable_status_codes_are_not_retried(self, status_code, descripti
8084
client = self.get_client(f"session-{status_code}")
8185
mock_responses = [{"status": status_code}]
8286

83-
with patch(PATCH_TARGET, return_value=create_mock_conn(mock_responses)) as mock_get_conn:
87+
with patch(
88+
PATCH_TARGET, return_value=create_mock_conn(mock_responses)
89+
) as mock_get_conn:
8490
client.export_failure_log("TestError", "Test message")
8591
TelemetryClientFactory.close(client._session_id_hex)
8692

@@ -92,16 +98,26 @@ def test_exceeds_retry_count_limit(self):
9298
Verifies that the client respects the Retry-After header and retries on 429, 502, 503.
9399
"""
94100
num_retries = 3
95-
expected_total_calls = num_retries + 1
101+
expected_total_calls = num_retries + 1
96102
retry_after = 1
97103
client = self.get_client("session-exceed-limit", num_retries=num_retries)
98-
mock_responses = [{"status": 503, "headers": {"Retry-After": str(retry_after)}}, {"status": 429}, {"status": 502}, {"status": 503}]
99-
100-
with patch(PATCH_TARGET, return_value=create_mock_conn(mock_responses)) as mock_get_conn:
104+
mock_responses = [
105+
{"status": 503, "headers": {"Retry-After": str(retry_after)}},
106+
{"status": 429},
107+
{"status": 502},
108+
{"status": 503},
109+
]
110+
111+
with patch(
112+
PATCH_TARGET, return_value=create_mock_conn(mock_responses)
113+
) as mock_get_conn:
101114
start_time = time.time()
102115
client.export_failure_log("TestError", "Test message")
103116
TelemetryClientFactory.close(client._session_id_hex)
104117
end_time = time.time()
105-
106-
assert mock_get_conn.return_value.getresponse.call_count == expected_total_calls
107-
assert end_time - start_time > retry_after
118+
119+
assert (
120+
mock_get_conn.return_value.getresponse.call_count
121+
== expected_total_calls
122+
)
123+
assert end_time - start_time > retry_after

0 commit comments

Comments
 (0)