Skip to content

Commit 9de6c8b

Browse files
init col norm
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent fe8cd57 commit 9de6c8b

File tree

7 files changed

+287
-31
lines changed

7 files changed

+287
-31
lines changed

src/databricks/sql/backend/sea/backend.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
WaitTimeout,
2020
MetadataCommands,
2121
)
22+
from databricks.sql.backend.sea.utils.metadata_mappings import MetadataColumnMappings
2223
from databricks.sql.backend.sea.utils.normalize import normalize_sea_type_to_thrift
2324
from databricks.sql.thrift_api.TCLIService import ttypes
2425

@@ -699,7 +700,10 @@ def get_catalogs(
699700
async_op=False,
700701
enforce_embedded_schema_correctness=False,
701702
)
702-
assert result is not None, "execute_command returned None in synchronous mode"
703+
assert isinstance(
704+
result, SeaResultSet
705+
), "Expected SeaResultSet from SEA backend"
706+
result.prepare_metadata_columns(MetadataColumnMappings.CATALOG_COLUMNS)
703707
return result
704708

705709
def get_schemas(
@@ -732,7 +736,10 @@ def get_schemas(
732736
async_op=False,
733737
enforce_embedded_schema_correctness=False,
734738
)
735-
assert result is not None, "execute_command returned None in synchronous mode"
739+
assert isinstance(
740+
result, SeaResultSet
741+
), "Expected SeaResultSet from SEA backend"
742+
result.prepare_metadata_columns(MetadataColumnMappings.SCHEMA_COLUMNS)
736743
return result
737744

738745
def get_tables(
@@ -773,7 +780,10 @@ def get_tables(
773780
async_op=False,
774781
enforce_embedded_schema_correctness=False,
775782
)
776-
assert result is not None, "execute_command returned None in synchronous mode"
783+
assert isinstance(
784+
result, SeaResultSet
785+
), "Expected SeaResultSet from SEA backend"
786+
result.prepare_metadata_columns(MetadataColumnMappings.TABLE_COLUMNS)
777787

778788
# Apply client-side filtering by table_types
779789
from databricks.sql.backend.sea.utils.filters import ResultSetFilter
@@ -820,5 +830,8 @@ def get_columns(
820830
async_op=False,
821831
enforce_embedded_schema_correctness=False,
822832
)
823-
assert result is not None, "execute_command returned None in synchronous mode"
833+
assert isinstance(
834+
result, SeaResultSet
835+
), "Expected SeaResultSet from SEA backend"
836+
result.prepare_metadata_columns(MetadataColumnMappings.COLUMN_COLUMNS)
824837
return result

src/databricks/sql/backend/sea/result_set.py

Lines changed: 125 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from __future__ import annotations
22

3-
from typing import Any, List, Optional, TYPE_CHECKING
3+
from typing import Any, List, Optional, TYPE_CHECKING, Dict
44

55
import logging
66

77
from databricks.sql.backend.sea.models.base import ResultData, ResultManifest
88
from databricks.sql.backend.sea.utils.conversion import SqlTypeConverter
9+
from databricks.sql.backend.sea.utils.result_column import ResultColumn
910

1011
try:
1112
import pyarrow
@@ -82,6 +83,10 @@ def __init__(
8283
arrow_schema_bytes=execute_response.arrow_schema_bytes,
8384
)
8485

86+
# Initialize metadata columns for post-fetch transformation
87+
self._metadata_columns = None
88+
self._column_index_mapping = None
89+
8590
def _convert_json_types(self, row: List[str]) -> List[Any]:
8691
"""
8792
Convert string values in the row to appropriate Python types based on column metadata.
@@ -160,6 +165,7 @@ def fetchmany_json(self, size: int) -> List[List[str]]:
160165
raise ValueError(f"size argument for fetchmany is {size} but must be >= 0")
161166

162167
results = self.results.next_n_rows(size)
168+
results = self._transform_json_rows(results)
163169
self._next_row_index += len(results)
164170

165171
return results
@@ -173,6 +179,7 @@ def fetchall_json(self) -> List[List[str]]:
173179
"""
174180

175181
results = self.results.remaining_rows()
182+
results = self._transform_json_rows(results)
176183
self._next_row_index += len(results)
177184

178185
return results
@@ -197,7 +204,12 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
197204

198205
results = self.results.next_n_rows(size)
199206
if isinstance(self.results, JsonQueue):
200-
results = self._convert_json_to_arrow_table(results)
207+
# Transform JSON first, then convert to Arrow
208+
transformed_json = self._transform_json_rows(results)
209+
results = self._convert_json_to_arrow_table(transformed_json)
210+
else:
211+
# Transform Arrow table directly
212+
results = self._transform_arrow_table(results)
201213

202214
self._next_row_index += results.num_rows
203215

@@ -210,7 +222,12 @@ def fetchall_arrow(self) -> "pyarrow.Table":
210222

211223
results = self.results.remaining_rows()
212224
if isinstance(self.results, JsonQueue):
213-
results = self._convert_json_to_arrow_table(results)
225+
# Transform JSON first, then convert to Arrow
226+
transformed_json = self._transform_json_rows(results)
227+
results = self._convert_json_to_arrow_table(transformed_json)
228+
else:
229+
# Transform Arrow table directly
230+
results = self._transform_arrow_table(results)
214231

215232
self._next_row_index += results.num_rows
216233

@@ -263,3 +280,108 @@ def fetchall(self) -> List[Row]:
263280
return self._create_json_table(self.fetchall_json())
264281
else:
265282
return self._convert_arrow_table(self.fetchall_arrow())
283+
284+
def prepare_metadata_columns(self, metadata_columns: List[ResultColumn]) -> None:
285+
"""
286+
Prepare result set for metadata column normalization.
287+
288+
Args:
289+
metadata_columns: List of ResultColumn objects defining the expected columns
290+
and their mappings from SEA column names
291+
"""
292+
self._metadata_columns = metadata_columns
293+
self._prepare_column_mapping()
294+
295+
def _prepare_column_mapping(self) -> None:
296+
"""
297+
Prepare column index mapping for metadata queries.
298+
Updates description to use JDBC column names.
299+
"""
300+
# Ensure description is available
301+
if not self.description:
302+
raise ValueError("Cannot prepare column mapping without result description")
303+
304+
# Build mapping from SEA column names to their indices
305+
sea_column_indices = {}
306+
for idx, col in enumerate(self.description):
307+
sea_column_indices[col[0]] = idx
308+
309+
# Create new description and index mapping
310+
new_description = []
311+
self._column_index_mapping = {} # Maps new index -> old index
312+
313+
for new_idx, result_column in enumerate(self._metadata_columns):
314+
# Find the corresponding SEA column
315+
if (
316+
result_column.result_set_column_name
317+
and result_column.result_set_column_name in sea_column_indices
318+
):
319+
old_idx = sea_column_indices[result_column.result_set_column_name]
320+
self._column_index_mapping[new_idx] = old_idx
321+
# Use the original column metadata but with JDBC name
322+
old_col = self.description[old_idx]
323+
new_description.append(
324+
(
325+
result_column.column_name, # JDBC name
326+
result_column.column_type, # Expected type
327+
old_col[2], # display_size
328+
old_col[3], # internal_size
329+
old_col[4], # precision
330+
old_col[5], # scale
331+
old_col[6], # null_ok
332+
)
333+
)
334+
else:
335+
# Column doesn't exist in SEA - add with None values
336+
new_description.append(
337+
(
338+
result_column.column_name,
339+
result_column.column_type,
340+
None,
341+
None,
342+
None,
343+
None,
344+
True,
345+
)
346+
)
347+
self._column_index_mapping[new_idx] = None
348+
349+
self.description = new_description
350+
351+
def _transform_arrow_table(self, table: "pyarrow.Table") -> "pyarrow.Table":
352+
"""Transform arrow table columns for metadata normalization."""
353+
if not self._metadata_columns:
354+
return table
355+
356+
# Reorder columns and add missing ones
357+
new_columns = []
358+
column_names = []
359+
360+
for new_idx, result_column in enumerate(self._metadata_columns):
361+
old_idx = self._column_index_mapping.get(new_idx)
362+
if old_idx is not None:
363+
new_columns.append(table.column(old_idx))
364+
else:
365+
# Create null column for missing data
366+
null_array = pyarrow.nulls(table.num_rows)
367+
new_columns.append(null_array)
368+
column_names.append(result_column.column_name)
369+
370+
return pyarrow.Table.from_arrays(new_columns, names=column_names)
371+
372+
def _transform_json_rows(self, rows: List[List[str]]) -> List[List[Any]]:
373+
"""Transform JSON rows for metadata normalization."""
374+
if not self._metadata_columns:
375+
return rows
376+
377+
transformed_rows = []
378+
for row in rows:
379+
new_row = []
380+
for new_idx in range(len(self._metadata_columns)):
381+
old_idx = self._column_index_mapping.get(new_idx)
382+
if old_idx is not None:
383+
new_row.append(row[old_idx])
384+
else:
385+
new_row.append(None)
386+
transformed_rows.append(new_row)
387+
return transformed_rows
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
from databricks.sql.backend.sea.utils.result_column import ResultColumn
2+
3+
4+
class MetadataColumnMappings:
5+
"""Column mappings for metadata queries following JDBC specification."""
6+
7+
# Common columns used across multiple metadata queries
8+
CATALOG_COLUMN = ResultColumn("TABLE_CAT", "catalog", "string")
9+
CATALOG_COLUMN_FOR_TABLES = ResultColumn("TABLE_CAT", "catalogName", "string")
10+
SCHEMA_COLUMN = ResultColumn("TABLE_SCHEM", "namespace", "string")
11+
SCHEMA_COLUMN_FOR_GET_SCHEMA = ResultColumn("TABLE_SCHEM", "databaseName", "string")
12+
TABLE_NAME_COLUMN = ResultColumn("TABLE_NAME", "tableName", "string")
13+
TABLE_TYPE_COLUMN = ResultColumn("TABLE_TYPE", "tableType", "string")
14+
REMARKS_COLUMN = ResultColumn("REMARKS", "remarks", "string")
15+
16+
# Columns specific to getColumns()
17+
COLUMN_NAME_COLUMN = ResultColumn("COLUMN_NAME", "col_name", "string")
18+
DATA_TYPE_COLUMN = ResultColumn(
19+
"DATA_TYPE", None, "int"
20+
) # SEA doesn't provide this
21+
TYPE_NAME_COLUMN = ResultColumn("TYPE_NAME", "columnType", "string")
22+
COLUMN_SIZE_COLUMN = ResultColumn("COLUMN_SIZE", None, "int")
23+
DECIMAL_DIGITS_COLUMN = ResultColumn("DECIMAL_DIGITS", None, "int")
24+
NUM_PREC_RADIX_COLUMN = ResultColumn("NUM_PREC_RADIX", None, "int")
25+
NULLABLE_COLUMN = ResultColumn("NULLABLE", None, "int")
26+
COLUMN_DEF_COLUMN = ResultColumn(
27+
"COLUMN_DEF", "columnType", "string"
28+
) # Note: duplicate mapping
29+
SQL_DATA_TYPE_COLUMN = ResultColumn("SQL_DATA_TYPE", None, "int")
30+
SQL_DATETIME_SUB_COLUMN = ResultColumn("SQL_DATETIME_SUB", None, "int")
31+
CHAR_OCTET_LENGTH_COLUMN = ResultColumn("CHAR_OCTET_LENGTH", None, "int")
32+
ORDINAL_POSITION_COLUMN = ResultColumn("ORDINAL_POSITION", None, "int")
33+
IS_NULLABLE_COLUMN = ResultColumn("IS_NULLABLE", "isNullable", "string")
34+
35+
# Columns for getTables() that don't exist in SEA
36+
TYPE_CAT_COLUMN = ResultColumn("TYPE_CAT", None, "string")
37+
TYPE_SCHEM_COLUMN = ResultColumn("TYPE_SCHEM", None, "string")
38+
TYPE_NAME_COLUMN = ResultColumn("TYPE_NAME", None, "string")
39+
SELF_REFERENCING_COL_NAME_COLUMN = ResultColumn(
40+
"SELF_REFERENCING_COL_NAME", None, "string"
41+
)
42+
REF_GENERATION_COLUMN = ResultColumn("REF_GENERATION", None, "string")
43+
44+
# Column lists for each metadata operation
45+
CATALOG_COLUMNS = [CATALOG_COLUMN]
46+
47+
SCHEMA_COLUMNS = [
48+
SCHEMA_COLUMN_FOR_GET_SCHEMA,
49+
ResultColumn("TABLE_CATALOG", None, "string"), # SEA doesn't return this
50+
]
51+
52+
TABLE_COLUMNS = [
53+
CATALOG_COLUMN_FOR_TABLES,
54+
SCHEMA_COLUMN,
55+
TABLE_NAME_COLUMN,
56+
TABLE_TYPE_COLUMN,
57+
REMARKS_COLUMN,
58+
TYPE_CAT_COLUMN,
59+
TYPE_SCHEM_COLUMN,
60+
TYPE_NAME_COLUMN,
61+
SELF_REFERENCING_COL_NAME_COLUMN,
62+
REF_GENERATION_COLUMN,
63+
]
64+
65+
COLUMN_COLUMNS = [
66+
CATALOG_COLUMN_FOR_TABLES,
67+
SCHEMA_COLUMN,
68+
TABLE_NAME_COLUMN,
69+
COLUMN_NAME_COLUMN,
70+
DATA_TYPE_COLUMN,
71+
TYPE_NAME_COLUMN,
72+
COLUMN_SIZE_COLUMN,
73+
ResultColumn("BUFFER_LENGTH", None, "int"),
74+
DECIMAL_DIGITS_COLUMN,
75+
NUM_PREC_RADIX_COLUMN,
76+
NULLABLE_COLUMN,
77+
REMARKS_COLUMN,
78+
COLUMN_DEF_COLUMN,
79+
SQL_DATA_TYPE_COLUMN,
80+
SQL_DATETIME_SUB_COLUMN,
81+
CHAR_OCTET_LENGTH_COLUMN,
82+
ORDINAL_POSITION_COLUMN,
83+
IS_NULLABLE_COLUMN,
84+
ResultColumn("SCOPE_CATALOG", None, "string"),
85+
ResultColumn("SCOPE_SCHEMA", None, "string"),
86+
ResultColumn("SCOPE_TABLE", None, "string"),
87+
ResultColumn("SOURCE_DATA_TYPE", None, "smallint"),
88+
ResultColumn("IS_AUTO_INCREMENT", None, "string"),
89+
ResultColumn("IS_GENERATEDCOLUMN", None, "string"),
90+
]
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from dataclasses import dataclass
2+
from typing import Optional
3+
4+
5+
@dataclass(frozen=True)
6+
class ResultColumn:
7+
"""
8+
Represents a mapping between JDBC specification column names and actual result set column names.
9+
10+
Attributes:
11+
column_name: JDBC specification column name (e.g., "TABLE_CAT")
12+
result_set_column_name: Server result column name from SEA (e.g., "catalog")
13+
column_type: SQL type code from databricks.sql.types
14+
"""
15+
16+
column_name: str
17+
result_set_column_name: Optional[str] # None if SEA doesn't return this column
18+
column_type: str

0 commit comments

Comments
 (0)