Skip to content

Commit 877b3b5

Browse files
fix typing, errors
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent e43c07b commit 877b3b5

File tree

3 files changed

+30
-115
lines changed

3 files changed

+30
-115
lines changed

src/databricks/sql/backend/sea/backend.py

Lines changed: 18 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
if TYPE_CHECKING:
2020
from databricks.sql.client import Cursor
2121

22-
from databricks.sql.backend.sea.result_set import SeaResultSet
22+
from databricks.sql.result_set import SeaResultSet
2323

2424
from databricks.sql.backend.databricks_client import DatabricksClient
2525
from databricks.sql.backend.types import (
@@ -29,7 +29,7 @@
2929
BackendType,
3030
ExecuteResponse,
3131
)
32-
from databricks.sql.exc import DatabaseError, ProgrammingError, ServerOperationError
32+
from databricks.sql.exc import DatabaseError, ServerOperationError
3333
from databricks.sql.backend.sea.utils.http_client import SeaHttpClient
3434
from databricks.sql.types import SSLOptions
3535

@@ -135,7 +135,7 @@ def __init__(
135135
self.warehouse_id = self._extract_warehouse_id(http_path)
136136

137137
# Initialize HTTP client
138-
self.http_client = SeaHttpClient(
138+
self._http_client = SeaHttpClient(
139139
server_hostname=server_hostname,
140140
port=port,
141141
http_path=http_path,
@@ -180,7 +180,7 @@ def _extract_warehouse_id(self, http_path: str) -> str:
180180
f"Note: SEA only works for warehouses."
181181
)
182182
logger.error(error_message)
183-
raise ProgrammingError(error_message)
183+
raise ValueError(error_message)
184184

185185
@property
186186
def max_download_threads(self) -> int:
@@ -227,7 +227,7 @@ def open_session(
227227
schema=schema,
228228
)
229229

230-
response = self.http_client._make_request(
230+
response = self._http_client._make_request(
231231
method="POST", path=self.SESSION_PATH, data=request_data.to_dict()
232232
)
233233

@@ -252,22 +252,22 @@ def close_session(self, session_id: SessionId) -> None:
252252
session_id: The session identifier returned by open_session()
253253
254254
Raises:
255-
ProgrammingError: If the session ID is invalid
255+
ValueError: If the session ID is invalid
256256
OperationalError: If there's an error closing the session
257257
"""
258258

259259
logger.debug("SeaDatabricksClient.close_session(session_id=%s)", session_id)
260260

261261
if session_id.backend_type != BackendType.SEA:
262-
raise ProgrammingError("Not a valid SEA session ID")
262+
raise ValueError("Not a valid SEA session ID")
263263
sea_session_id = session_id.to_sea_session_id()
264264

265265
request_data = DeleteSessionRequest(
266266
warehouse_id=self.warehouse_id,
267267
session_id=sea_session_id,
268268
)
269269

270-
self.http_client._make_request(
270+
self._http_client._make_request(
271271
method="DELETE",
272272
path=self.SESSION_PATH_WITH_ID.format(sea_session_id),
273273
data=request_data.to_dict(),
@@ -462,7 +462,7 @@ def execute_command(
462462
"""
463463

464464
if session_id.backend_type != BackendType.SEA:
465-
raise ProgrammingError("Not a valid SEA session ID")
465+
raise ValueError("Not a valid SEA session ID")
466466

467467
sea_session_id = session_id.to_sea_session_id()
468468

@@ -509,7 +509,7 @@ def execute_command(
509509
result_compression=result_compression,
510510
)
511511

512-
response_data = self.http_client._make_request(
512+
response_data = self._http_client._make_request(
513513
method="POST", path=self.STATEMENT_PATH, data=request.to_dict()
514514
)
515515
response = ExecuteStatementResponse.from_dict(response_data)
@@ -546,16 +546,16 @@ def cancel_command(self, command_id: CommandId) -> None:
546546
command_id: Command identifier to cancel
547547
548548
Raises:
549-
ProgrammingError: If the command ID is invalid
549+
ValueError: If the command ID is invalid
550550
"""
551551

552552
if command_id.backend_type != BackendType.SEA:
553-
raise ProgrammingError("Not a valid SEA command ID")
553+
raise ValueError("Not a valid SEA command ID")
554554

555555
sea_statement_id = command_id.to_sea_statement_id()
556556

557557
request = CancelStatementRequest(statement_id=sea_statement_id)
558-
self.http_client._make_request(
558+
self._http_client._make_request(
559559
method="POST",
560560
path=self.CANCEL_STATEMENT_PATH_WITH_ID.format(sea_statement_id),
561561
data=request.to_dict(),
@@ -569,16 +569,16 @@ def close_command(self, command_id: CommandId) -> None:
569569
command_id: Command identifier to close
570570
571571
Raises:
572-
ProgrammingError: If the command ID is invalid
572+
ValueError: If the command ID is invalid
573573
"""
574574

575575
if command_id.backend_type != BackendType.SEA:
576-
raise ProgrammingError("Not a valid SEA command ID")
576+
raise ValueError("Not a valid SEA command ID")
577577

578578
sea_statement_id = command_id.to_sea_statement_id()
579579

580580
request = CloseStatementRequest(statement_id=sea_statement_id)
581-
self.http_client._make_request(
581+
self._http_client._make_request(
582582
method="DELETE",
583583
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id),
584584
data=request.to_dict(),
@@ -595,7 +595,7 @@ def _poll_query(self, command_id: CommandId) -> GetStatementResponse:
595595
sea_statement_id = command_id.to_sea_statement_id()
596596

597597
request = GetStatementRequest(statement_id=sea_statement_id)
598-
response_data = self.http_client._make_request(
598+
response_data = self._http_client._make_request(
599599
method="GET",
600600
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id),
601601
data=request.to_dict(),
@@ -615,7 +615,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
615615
CommandState: The current state of the command
616616
617617
Raises:
618-
ProgrammingError: If the command ID is invalid
618+
ValueError: If the command ID is invalid
619619
"""
620620

621621
response = self._poll_query(command_id)
@@ -643,27 +643,6 @@ def get_execution_result(
643643
response = self._poll_query(command_id)
644644
return self._response_to_result_set(response, cursor)
645645

646-
def get_chunk_links(
647-
self, statement_id: str, chunk_index: int
648-
) -> List[ExternalLink]:
649-
"""
650-
Get links for chunks starting from the specified index.
651-
Args:
652-
statement_id: The statement ID
653-
chunk_index: The starting chunk index
654-
Returns:
655-
ExternalLink: External link for the chunk
656-
"""
657-
658-
response_data = self._http_client._make_request(
659-
method="GET",
660-
path=self.CHUNK_PATH_WITH_ID_AND_INDEX.format(statement_id, chunk_index),
661-
)
662-
response = GetChunksResponse.from_dict(response_data)
663-
664-
links = response.external_links or []
665-
return links
666-
667646
# == Metadata Operations ==
668647

669648
def get_catalogs(

src/databricks/sql/result_set.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
from abc import ABC, abstractmethod
2-
from typing import List, Optional, Any, Union, Tuple, TYPE_CHECKING
2+
from typing import List, Optional, Any, TYPE_CHECKING
33

44
import logging
5-
import time
65
import pandas
76

8-
from databricks.sql.backend.sea.backend import SeaDatabricksClient
97

108
try:
119
import pyarrow
@@ -14,11 +12,12 @@
1412

1513
if TYPE_CHECKING:
1614
from databricks.sql.backend.thrift_backend import ThriftDatabricksClient
15+
from databricks.sql.backend.sea.backend import SeaDatabricksClient
1716
from databricks.sql.client import Connection
17+
1818
from databricks.sql.backend.databricks_client import DatabricksClient
19-
from databricks.sql.thrift_api.TCLIService import ttypes
2019
from databricks.sql.types import Row
21-
from databricks.sql.exc import Error, RequestError, CursorAlreadyClosedError
20+
from databricks.sql.exc import RequestError, CursorAlreadyClosedError
2221
from databricks.sql.utils import ColumnTable, ColumnQueue
2322
from databricks.sql.backend.types import CommandId, CommandState, ExecuteResponse
2423

@@ -136,7 +135,8 @@ def close(self) -> None:
136135
been closed on the server for some other reason, issue a request to the server to close it.
137136
"""
138137
try:
139-
self.results.close()
138+
if self.results:
139+
self.results.close()
140140
if (
141141
self.status != CommandState.CLOSED
142142
and not self.has_been_closed_server_side

tests/unit/test_sea_backend.py

Lines changed: 6 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def test_initialization(self, mock_http_client):
132132
assert client3.max_download_threads == 5
133133

134134
# Test with invalid HTTP path
135-
with pytest.raises(ProgrammingError) as excinfo:
135+
with pytest.raises(ValueError) as excinfo:
136136
SeaDatabricksClient(
137137
server_hostname="test-server.databricks.com",
138138
port=443,
@@ -198,7 +198,7 @@ def test_session_management(self, sea_client, mock_http_client, thrift_session_i
198198
)
199199

200200
# Test close_session with invalid ID type
201-
with pytest.raises(ProgrammingError) as excinfo:
201+
with pytest.raises(ValueError) as excinfo:
202202
sea_client.close_session(thrift_session_id)
203203
assert "Not a valid SEA session ID" in str(excinfo.value)
204204

@@ -244,7 +244,7 @@ def test_command_execution_sync(
244244
assert result == "mock_result_set"
245245

246246
# Test with invalid session ID
247-
with pytest.raises(ProgrammingError) as excinfo:
247+
with pytest.raises(ValueError) as excinfo:
248248
mock_thrift_handle = MagicMock()
249249
mock_thrift_handle.sessionId.guid = b"guid"
250250
mock_thrift_handle.sessionId.secret = b"secret"
@@ -449,7 +449,7 @@ def test_command_management(
449449
)
450450

451451
# Test cancel_command with invalid ID
452-
with pytest.raises(ProgrammingError) as excinfo:
452+
with pytest.raises(ValueError) as excinfo:
453453
sea_client.cancel_command(thrift_command_id)
454454
assert "Not a valid SEA command ID" in str(excinfo.value)
455455

@@ -463,7 +463,7 @@ def test_command_management(
463463
)
464464

465465
# Test close_command with invalid ID
466-
with pytest.raises(ProgrammingError) as excinfo:
466+
with pytest.raises(ValueError) as excinfo:
467467
sea_client.close_command(thrift_command_id)
468468
assert "Not a valid SEA command ID" in str(excinfo.value)
469469

@@ -522,7 +522,7 @@ def test_command_management(
522522
assert result.status == CommandState.SUCCEEDED
523523

524524
# Test get_execution_result with invalid ID
525-
with pytest.raises(ProgrammingError) as excinfo:
525+
with pytest.raises(ValueError) as excinfo:
526526
sea_client.get_execution_result(thrift_command_id, mock_cursor)
527527
assert "Not a valid SEA command ID" in str(excinfo.value)
528528

@@ -955,67 +955,3 @@ def test_get_columns(self, sea_client, sea_session_id, mock_cursor):
955955
cursor=mock_cursor,
956956
)
957957
assert "Catalog name is required for get_columns" in str(excinfo.value)
958-
959-
def test_get_chunk_links(self, sea_client, mock_http_client, sea_command_id):
960-
"""Test get_chunk_links method when links are available."""
961-
# Setup mock response
962-
mock_response = {
963-
"external_links": [
964-
{
965-
"external_link": "https://example.com/data/chunk0",
966-
"expiration": "2025-07-03T05:51:18.118009",
967-
"row_count": 100,
968-
"byte_count": 1024,
969-
"row_offset": 0,
970-
"chunk_index": 0,
971-
"next_chunk_index": 1,
972-
"http_headers": {"Authorization": "Bearer token123"},
973-
}
974-
]
975-
}
976-
mock_http_client._make_request.return_value = mock_response
977-
978-
# Call the method
979-
results = sea_client.get_chunk_links("test-statement-123", 0)
980-
981-
# Verify the HTTP client was called correctly
982-
mock_http_client._make_request.assert_called_once_with(
983-
method="GET",
984-
path=sea_client.CHUNK_PATH_WITH_ID_AND_INDEX.format(
985-
"test-statement-123", 0
986-
),
987-
)
988-
989-
# Verify the results
990-
assert isinstance(results, list)
991-
assert len(results) == 1
992-
result = results[0]
993-
assert result.external_link == "https://example.com/data/chunk0"
994-
assert result.expiration == "2025-07-03T05:51:18.118009"
995-
assert result.row_count == 100
996-
assert result.byte_count == 1024
997-
assert result.row_offset == 0
998-
assert result.chunk_index == 0
999-
assert result.next_chunk_index == 1
1000-
assert result.http_headers == {"Authorization": "Bearer token123"}
1001-
1002-
def test_get_chunk_links_empty(self, sea_client, mock_http_client):
1003-
"""Test get_chunk_links when no links are returned (empty list)."""
1004-
# Setup mock response with no matching chunk
1005-
mock_response = {"external_links": []}
1006-
mock_http_client._make_request.return_value = mock_response
1007-
1008-
# Call the method
1009-
results = sea_client.get_chunk_links("test-statement-123", 0)
1010-
1011-
# Verify the HTTP client was called correctly
1012-
mock_http_client._make_request.assert_called_once_with(
1013-
method="GET",
1014-
path=sea_client.CHUNK_PATH_WITH_ID_AND_INDEX.format(
1015-
"test-statement-123", 0
1016-
),
1017-
)
1018-
1019-
# Verify the results are empty
1020-
assert isinstance(results, list)
1021-
assert results == []

0 commit comments

Comments
 (0)