Skip to content

Commit 3ca9678

Browse files
Merge branch 'main' into robust-metadata-sea
2 parents 2b42ea0 + e732e96 commit 3ca9678

File tree

11 files changed

+277
-129
lines changed

11 files changed

+277
-129
lines changed

src/databricks/sql/backend/sea/backend.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
WaitTimeout,
2020
MetadataCommands,
2121
)
22+
from databricks.sql.backend.sea.utils.normalize import normalize_sea_type_to_thrift
2223
from databricks.sql.thrift_api.TCLIService import ttypes
2324

2425
if TYPE_CHECKING:
@@ -323,6 +324,11 @@ def _extract_description_from_manifest(
323324
# Format: (name, type_code, display_size, internal_size, precision, scale, null_ok)
324325
name = col_data.get("name", "")
325326
type_name = col_data.get("type_name", "")
327+
328+
# Normalize SEA type to Thrift conventions before any processing
329+
type_name = normalize_sea_type_to_thrift(type_name, col_data)
330+
331+
# Now strip _TYPE suffix and convert to lowercase
326332
type_name = (
327333
type_name[:-5] if type_name.endswith("_TYPE") else type_name
328334
).lower()

src/databricks/sql/backend/sea/result_set.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -92,20 +92,19 @@ def _convert_json_types(self, row: List[str]) -> List[Any]:
9292
converted_row = []
9393

9494
for i, value in enumerate(row):
95+
column_name = self.description[i][0]
9596
column_type = self.description[i][1]
9697
precision = self.description[i][4]
9798
scale = self.description[i][5]
9899

99-
try:
100-
converted_value = SqlTypeConverter.convert_value(
101-
value, column_type, precision=precision, scale=scale
102-
)
103-
converted_row.append(converted_value)
104-
except Exception as e:
105-
logger.warning(
106-
f"Error converting value '{value}' to {column_type}: {e}"
107-
)
108-
converted_row.append(value)
100+
converted_value = SqlTypeConverter.convert_value(
101+
value,
102+
column_type,
103+
column_name=column_name,
104+
precision=precision,
105+
scale=scale,
106+
)
107+
converted_row.append(converted_value)
109108

110109
return converted_row
111110

src/databricks/sql/backend/sea/utils/conversion.py

Lines changed: 44 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -50,60 +50,65 @@ def _convert_decimal(
5050

5151
class SqlType:
5252
"""
53-
SQL type constants
53+
SQL type constants based on Thrift TTypeId values.
5454
55-
The list of types can be found in the SEA REST API Reference:
56-
https://docs.databricks.com/api/workspace/statementexecution/executestatement
55+
These correspond to the normalized type names that come from the SEA backend
56+
after normalize_sea_type_to_thrift processing (lowercase, without _TYPE suffix).
5757
"""
5858

5959
# Numeric types
60-
BYTE = "byte"
61-
SHORT = "short"
62-
INT = "int"
63-
LONG = "long"
64-
FLOAT = "float"
65-
DOUBLE = "double"
66-
DECIMAL = "decimal"
60+
TINYINT = "tinyint" # Maps to TTypeId.TINYINT_TYPE
61+
SMALLINT = "smallint" # Maps to TTypeId.SMALLINT_TYPE
62+
INT = "int" # Maps to TTypeId.INT_TYPE
63+
BIGINT = "bigint" # Maps to TTypeId.BIGINT_TYPE
64+
FLOAT = "float" # Maps to TTypeId.FLOAT_TYPE
65+
DOUBLE = "double" # Maps to TTypeId.DOUBLE_TYPE
66+
DECIMAL = "decimal" # Maps to TTypeId.DECIMAL_TYPE
6767

6868
# Boolean type
69-
BOOLEAN = "boolean"
69+
BOOLEAN = "boolean" # Maps to TTypeId.BOOLEAN_TYPE
7070

7171
# Date/Time types
72-
DATE = "date"
73-
TIMESTAMP = "timestamp"
74-
INTERVAL = "interval"
72+
DATE = "date" # Maps to TTypeId.DATE_TYPE
73+
TIMESTAMP = "timestamp" # Maps to TTypeId.TIMESTAMP_TYPE
74+
INTERVAL_YEAR_MONTH = (
75+
"interval_year_month" # Maps to TTypeId.INTERVAL_YEAR_MONTH_TYPE
76+
)
77+
INTERVAL_DAY_TIME = "interval_day_time" # Maps to TTypeId.INTERVAL_DAY_TIME_TYPE
7578

7679
# String types
77-
CHAR = "char"
78-
STRING = "string"
80+
CHAR = "char" # Maps to TTypeId.CHAR_TYPE
81+
VARCHAR = "varchar" # Maps to TTypeId.VARCHAR_TYPE
82+
STRING = "string" # Maps to TTypeId.STRING_TYPE
7983

8084
# Binary type
81-
BINARY = "binary"
85+
BINARY = "binary" # Maps to TTypeId.BINARY_TYPE
8286

8387
# Complex types
84-
ARRAY = "array"
85-
MAP = "map"
86-
STRUCT = "struct"
88+
ARRAY = "array" # Maps to TTypeId.ARRAY_TYPE
89+
MAP = "map" # Maps to TTypeId.MAP_TYPE
90+
STRUCT = "struct" # Maps to TTypeId.STRUCT_TYPE
8791

8892
# Other types
89-
NULL = "null"
90-
USER_DEFINED_TYPE = "user_defined_type"
93+
NULL = "null" # Maps to TTypeId.NULL_TYPE
94+
UNION = "union" # Maps to TTypeId.UNION_TYPE
95+
USER_DEFINED = "user_defined" # Maps to TTypeId.USER_DEFINED_TYPE
9196

9297

9398
class SqlTypeConverter:
9499
"""
95100
Utility class for converting SQL types to Python types.
96-
Based on the types supported by the Databricks SDK.
101+
Based on the Thrift TTypeId types after normalization.
97102
"""
98103

99104
# SQL type to conversion function mapping
100105
# TODO: complex types
101106
TYPE_MAPPING: Dict[str, Callable] = {
102107
# Numeric types
103-
SqlType.BYTE: lambda v: int(v),
104-
SqlType.SHORT: lambda v: int(v),
108+
SqlType.TINYINT: lambda v: int(v),
109+
SqlType.SMALLINT: lambda v: int(v),
105110
SqlType.INT: lambda v: int(v),
106-
SqlType.LONG: lambda v: int(v),
111+
SqlType.BIGINT: lambda v: int(v),
107112
SqlType.FLOAT: lambda v: float(v),
108113
SqlType.DOUBLE: lambda v: float(v),
109114
SqlType.DECIMAL: _convert_decimal,
@@ -112,30 +117,34 @@ class SqlTypeConverter:
112117
# Date/Time types
113118
SqlType.DATE: lambda v: datetime.date.fromisoformat(v),
114119
SqlType.TIMESTAMP: lambda v: parser.parse(v),
115-
SqlType.INTERVAL: lambda v: v, # Keep as string for now
120+
SqlType.INTERVAL_YEAR_MONTH: lambda v: v, # Keep as string for now
121+
SqlType.INTERVAL_DAY_TIME: lambda v: v, # Keep as string for now
116122
# String types - no conversion needed
117123
SqlType.CHAR: lambda v: v,
124+
SqlType.VARCHAR: lambda v: v,
118125
SqlType.STRING: lambda v: v,
119126
# Binary type
120127
SqlType.BINARY: lambda v: bytes.fromhex(v),
121128
# Other types
122129
SqlType.NULL: lambda v: None,
123130
# Complex types and user-defined types return as-is
124-
SqlType.USER_DEFINED_TYPE: lambda v: v,
131+
SqlType.USER_DEFINED: lambda v: v,
125132
}
126133

127134
@staticmethod
128135
def convert_value(
129136
value: str,
130137
sql_type: str,
138+
column_name: Optional[str],
131139
**kwargs,
132140
) -> object:
133141
"""
134142
Convert a string value to the appropriate Python type based on SQL type.
135143
136144
Args:
137145
value: The string value to convert
138-
sql_type: The SQL type (e.g., 'int', 'decimal')
146+
sql_type: The SQL type (e.g., 'tinyint', 'decimal')
147+
column_name: The name of the column being converted
139148
**kwargs: Additional keyword arguments for the conversion function
140149
141150
Returns:
@@ -155,6 +164,10 @@ def convert_value(
155164
return converter_func(value, precision, scale)
156165
else:
157166
return converter_func(value)
158-
except (ValueError, TypeError, decimal.InvalidOperation) as e:
159-
logger.warning(f"Error converting value '{value}' to {sql_type}: {e}")
167+
except Exception as e:
168+
warning_message = f"Error converting value '{value}' to {sql_type}"
169+
if column_name:
170+
warning_message += f" in column {column_name}"
171+
warning_message += f": {e}"
172+
logger.warning(warning_message)
160173
return value
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
"""
2+
Type normalization utilities for SEA backend.
3+
4+
This module provides functionality to normalize SEA type names to match
5+
Thrift type naming conventions.
6+
"""
7+
8+
from typing import Dict, Any
9+
10+
# SEA types that need to be translated to Thrift types
11+
# The list of all SEA types is available in the REST reference at:
12+
# https://docs.databricks.com/api/workspace/statementexecution/executestatement
13+
# The list of all Thrift types can be found in the ttypes.TTypeId definition
14+
# The SEA types that do not align with Thrift are explicitly mapped below
15+
SEA_TO_THRIFT_TYPE_MAP = {
16+
"BYTE": "TINYINT",
17+
"SHORT": "SMALLINT",
18+
"LONG": "BIGINT",
19+
"INTERVAL": "INTERVAL", # Default mapping, will be overridden if type_interval_type is present
20+
}
21+
22+
23+
def normalize_sea_type_to_thrift(type_name: str, col_data: Dict[str, Any]) -> str:
24+
"""
25+
Normalize SEA type names to match Thrift type naming conventions.
26+
27+
Args:
28+
type_name: The type name from SEA (e.g., "BYTE", "LONG", "INTERVAL")
29+
col_data: The full column data dictionary from manifest (for accessing type_interval_type)
30+
31+
Returns:
32+
Normalized type name matching Thrift conventions
33+
"""
34+
# Early return if type doesn't need mapping
35+
if type_name not in SEA_TO_THRIFT_TYPE_MAP:
36+
return type_name
37+
38+
normalized_type = SEA_TO_THRIFT_TYPE_MAP[type_name]
39+
40+
# Special handling for interval types
41+
if type_name == "INTERVAL":
42+
type_interval_type = col_data.get("type_interval_type")
43+
if type_interval_type:
44+
return (
45+
"INTERVAL_YEAR_MONTH"
46+
if any(t in type_interval_type.upper() for t in ["YEAR", "MONTH"])
47+
else "INTERVAL_DAY_TIME"
48+
)
49+
50+
return normalized_type

src/databricks/sql/client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,9 @@ def read(self) -> Optional[OAuthToken]:
254254
self.telemetry_enabled = (
255255
self.client_telemetry_enabled and self.server_telemetry_enabled
256256
)
257+
self.telemetry_batch_size = kwargs.get(
258+
"telemetry_batch_size", TelemetryClientFactory.DEFAULT_BATCH_SIZE
259+
)
257260

258261
try:
259262
self.session = Session(
@@ -290,6 +293,7 @@ def read(self) -> Optional[OAuthToken]:
290293
session_id_hex=self.get_session_id_hex(),
291294
auth_provider=self.session.auth_provider,
292295
host_url=self.session.host,
296+
batch_size=self.telemetry_batch_size,
293297
)
294298

295299
self._telemetry_client = TelemetryClientFactory.get_telemetry_client(

src/databricks/sql/telemetry/telemetry_client.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,19 +138,18 @@ class TelemetryClient(BaseTelemetryClient):
138138
TELEMETRY_AUTHENTICATED_PATH = "/telemetry-ext"
139139
TELEMETRY_UNAUTHENTICATED_PATH = "/telemetry-unauth"
140140

141-
DEFAULT_BATCH_SIZE = 100
142-
143141
def __init__(
144142
self,
145143
telemetry_enabled,
146144
session_id_hex,
147145
auth_provider,
148146
host_url,
149147
executor,
148+
batch_size,
150149
):
151150
logger.debug("Initializing TelemetryClient for connection: %s", session_id_hex)
152151
self._telemetry_enabled = telemetry_enabled
153-
self._batch_size = self.DEFAULT_BATCH_SIZE
152+
self._batch_size = batch_size
154153
self._session_id_hex = session_id_hex
155154
self._auth_provider = auth_provider
156155
self._user_agent = None
@@ -318,7 +317,7 @@ def close(self):
318317
class TelemetryClientFactory:
319318
"""
320319
Static factory class for creating and managing telemetry clients.
321-
It uses a thread pool to handle asynchronous operations.
320+
It uses a thread pool to handle asynchronous operations and a single flush thread for all clients.
322321
"""
323322

324323
_clients: Dict[
@@ -331,6 +330,13 @@ class TelemetryClientFactory:
331330
_original_excepthook = None
332331
_excepthook_installed = False
333332

333+
# Shared flush thread for all clients
334+
_flush_thread = None
335+
_flush_event = threading.Event()
336+
_flush_interval_seconds = 90
337+
338+
DEFAULT_BATCH_SIZE = 100
339+
334340
@classmethod
335341
def _initialize(cls):
336342
"""Initialize the factory if not already initialized"""
@@ -341,11 +347,39 @@ def _initialize(cls):
341347
max_workers=10
342348
) # Thread pool for async operations
343349
cls._install_exception_hook()
350+
cls._start_flush_thread()
344351
cls._initialized = True
345352
logger.debug(
346353
"TelemetryClientFactory initialized with thread pool (max_workers=10)"
347354
)
348355

356+
@classmethod
357+
def _start_flush_thread(cls):
358+
"""Start the shared background thread for periodic flushing of all clients"""
359+
cls._flush_event.clear()
360+
cls._flush_thread = threading.Thread(target=cls._flush_worker, daemon=True)
361+
cls._flush_thread.start()
362+
363+
@classmethod
364+
def _flush_worker(cls):
365+
"""Background worker thread for periodic flushing of all clients"""
366+
while not cls._flush_event.wait(cls._flush_interval_seconds):
367+
logger.debug("Performing periodic flush for all telemetry clients")
368+
369+
with cls._lock:
370+
clients_to_flush = list(cls._clients.values())
371+
372+
for client in clients_to_flush:
373+
client._flush()
374+
375+
@classmethod
376+
def _stop_flush_thread(cls):
377+
"""Stop the shared background flush thread"""
378+
if cls._flush_thread is not None:
379+
cls._flush_event.set()
380+
cls._flush_thread.join(timeout=1.0)
381+
cls._flush_thread = None
382+
349383
@classmethod
350384
def _install_exception_hook(cls):
351385
"""Install global exception handler for unhandled exceptions"""
@@ -374,6 +408,7 @@ def initialize_telemetry_client(
374408
session_id_hex,
375409
auth_provider,
376410
host_url,
411+
batch_size,
377412
):
378413
"""Initialize a telemetry client for a specific connection if telemetry is enabled"""
379414
try:
@@ -395,6 +430,7 @@ def initialize_telemetry_client(
395430
auth_provider=auth_provider,
396431
host_url=host_url,
397432
executor=TelemetryClientFactory._executor,
433+
batch_size=batch_size,
398434
)
399435
else:
400436
TelemetryClientFactory._clients[
@@ -433,6 +469,7 @@ def close(session_id_hex):
433469
"No more telemetry clients, shutting down thread pool executor"
434470
)
435471
try:
472+
TelemetryClientFactory._stop_flush_thread()
436473
TelemetryClientFactory._executor.shutdown(wait=True)
437474
TelemetryHttpClient.close()
438475
except Exception as e:
@@ -458,6 +495,7 @@ def connection_failure_log(
458495
session_id_hex=UNAUTH_DUMMY_SESSION_ID,
459496
auth_provider=None,
460497
host_url=host_url,
498+
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE,
461499
)
462500

463501
telemetry_client = TelemetryClientFactory.get_telemetry_client(

0 commit comments

Comments
 (0)