Skip to content

Commit c90875f

Browse files
some comparator stuff
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 08717be commit c90875f

File tree

3 files changed

+268
-37
lines changed

3 files changed

+268
-37
lines changed

src/databricks/sql/backend/sea/backend.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -736,12 +736,19 @@ def get_schemas(
736736
)
737737
assert result is not None, "execute_command returned None in synchronous mode"
738738

739-
# Normalize column names to match JDBC/thrift backend
740-
from .metadata_constants import SCHEMA_COLUMNS, SCHEMA_TYPE_CODES, normalize_metadata_description
741-
739+
# Normalize column names and transform data to match JDBC/thrift backend
740+
from .metadata_constants import SCHEMA_COLUMNS, SCHEMA_TYPE_CODES, normalize_metadata_description, transform_schemas_data_rows
741+
742+
# Store original description before normalization for data transformation
743+
original_description = result.description[:]
744+
745+
# Normalize the description (column names and types)
742746
result.description = normalize_metadata_description(
743747
result.description, SCHEMA_COLUMNS, SCHEMA_TYPE_CODES
744748
)
749+
750+
# Transform the actual data rows to match the new column order and format
751+
transform_schemas_data_rows(result, catalog_name, original_description)
745752

746753
return result
747754

@@ -785,12 +792,19 @@ def get_tables(
785792
)
786793
assert result is not None, "execute_command returned None in synchronous mode"
787794

788-
# Normalize column names to match JDBC/thrift backend
789-
from .metadata_constants import TABLE_COLUMNS, TABLE_TYPE_CODES, normalize_metadata_description
790-
795+
# Normalize column names and transform data to match JDBC/thrift backend
796+
from .metadata_constants import TABLE_COLUMNS, TABLE_TYPE_CODES, normalize_metadata_description, transform_tables_data_rows
797+
798+
# Store original description before normalization for data transformation
799+
original_description = result.description[:]
800+
801+
# Normalize the description (column names and types)
791802
result.description = normalize_metadata_description(
792803
result.description, TABLE_COLUMNS, TABLE_TYPE_CODES
793804
)
805+
806+
# Transform the actual data rows to match the new column order and format
807+
transform_tables_data_rows(result, catalog_name, original_description)
794808

795809
# Apply client-side filtering by table_types
796810
from databricks.sql.backend.sea.utils.filters import ResultSetFilter
@@ -839,9 +853,16 @@ def get_columns(
839853
)
840854
assert result is not None, "execute_command returned None in synchronous mode"
841855

842-
# Normalize column names to match JDBC/thrift backend
843-
from .metadata_constants import normalize_columns_metadata_description
844-
856+
# Normalize column names and transform data to match JDBC/thrift backend
857+
from .metadata_constants import normalize_columns_metadata_description, transform_columns_data_rows
858+
859+
# Store original description before normalization for data transformation
860+
original_description = result.description[:]
861+
862+
# Normalize the description (column names and types)
845863
result.description = normalize_columns_metadata_description(result.description)
864+
865+
# Transform the actual data rows to match the new column order and format
866+
transform_columns_data_rows(result, original_description)
846867

847868
return result

src/databricks/sql/backend/sea/metadata_constants.py

Lines changed: 235 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -39,32 +39,34 @@
3939
), # REF_GENERATION_COLUMN (likely None in data)
4040
]
4141

42-
# Columns for columns() - matching JDBC COLUMN_COLUMNS exactly
42+
# Columns for columns() - mapping JDBC columns to actual SEA SHOW COLUMNS output
43+
# Based on actual SEA output: col_name, catalogName, namespace, tableName, columnType,
44+
# columnSize, decimalDigits, radix, isNullable, remarks, ordinalPosition, isAutoIncrement, isGenerated
4345
COLUMN_COLUMNS: List[Tuple[str, str]] = [
44-
("TABLE_CAT", "catalogName"), # CATALOG_COLUMN
45-
("TABLE_SCHEM", "namespace"), # SCHEMA_COLUMN
46-
("TABLE_NAME", "tableName"), # TABLE_NAME_COLUMN
47-
("COLUMN_NAME", "col_name"), # COL_NAME_COLUMN
48-
("DATA_TYPE", "dataType"), # DATA_TYPE_COLUMN
49-
("TYPE_NAME", "columnType"), # COLUMN_TYPE_COLUMN
50-
("COLUMN_SIZE", "columnSize"), # COLUMN_SIZE_COLUMN
51-
("BUFFER_LENGTH", "bufferLength"), # BUFFER_LENGTH_COLUMN
52-
("DECIMAL_DIGITS", "decimalDigits"), # DECIMAL_DIGITS_COLUMN
53-
("NUM_PREC_RADIX", "radix"), # NUM_PREC_RADIX_COLUMN
54-
("NULLABLE", "Nullable"), # NULLABLE_COLUMN
55-
("REMARKS", "remarks"), # REMARKS_COLUMN
56-
("COLUMN_DEF", "columnType"), # COLUMN_DEF_COLUMN (same source as TYPE_NAME)
57-
("SQL_DATA_TYPE", "SQLDataType"), # SQL_DATA_TYPE_COLUMN
58-
("SQL_DATETIME_SUB", "SQLDateTimeSub"), # SQL_DATETIME_SUB_COLUMN
59-
("CHAR_OCTET_LENGTH", "CharOctetLength"), # CHAR_OCTET_LENGTH_COLUMN
60-
("ORDINAL_POSITION", "ordinalPosition"), # ORDINAL_POSITION_COLUMN
61-
("IS_NULLABLE", "isNullable"), # IS_NULLABLE_COLUMN
62-
("SCOPE_CATALOG", "ScopeCatalog"), # SCOPE_CATALOG_COLUMN
63-
("SCOPE_SCHEMA", "ScopeSchema"), # SCOPE_SCHEMA_COLUMN
64-
("SCOPE_TABLE", "ScopeTable"), # SCOPE_TABLE_COLUMN
65-
("SOURCE_DATA_TYPE", "SourceDataType"), # SOURCE_DATA_TYPE_COLUMN
66-
("IS_AUTOINCREMENT", "isAutoIncrement"), # IS_AUTO_INCREMENT_COLUMN
67-
("IS_GENERATEDCOLUMN", "isGenerated"), # IS_GENERATED_COLUMN
46+
("TABLE_CAT", "catalogName"), # Maps to existing SEA column
47+
("TABLE_SCHEM", "namespace"), # Maps to existing SEA column
48+
("TABLE_NAME", "tableName"), # Maps to existing SEA column
49+
("COLUMN_NAME", "col_name"), # Maps to existing SEA column
50+
("DATA_TYPE", None), # Calculated from columnType
51+
("TYPE_NAME", "columnType"), # Maps to existing SEA column
52+
("COLUMN_SIZE", "columnSize"), # Maps to existing SEA column
53+
("BUFFER_LENGTH", None), # Not available in SEA - default to None
54+
("DECIMAL_DIGITS", "decimalDigits"), # Maps to existing SEA column
55+
("NUM_PREC_RADIX", "radix"), # Maps to existing SEA column
56+
("NULLABLE", None), # Calculated from isNullable
57+
("REMARKS", "remarks"), # Maps to existing SEA column
58+
("COLUMN_DEF", None), # Not available in SEA - default to None
59+
("SQL_DATA_TYPE", None), # Not available in SEA - default to None
60+
("SQL_DATETIME_SUB", None), # Not available in SEA - default to None
61+
("CHAR_OCTET_LENGTH", None), # Not available in SEA - default to None
62+
("ORDINAL_POSITION", "ordinalPosition"), # Maps to existing SEA column
63+
("IS_NULLABLE", "isNullable"), # Maps to existing SEA column
64+
("SCOPE_CATALOG", None), # Not available in SEA - default to None
65+
("SCOPE_SCHEMA", None), # Not available in SEA - default to None
66+
("SCOPE_TABLE", None), # Not available in SEA - default to None
67+
("SOURCE_DATA_TYPE", None), # Not available in SEA - default to None
68+
("IS_AUTO_INCREMENT", "isAutoIncrement"), # Maps to existing SEA column (renamed from IS_AUTOINCREMENT)
69+
# Note: Removing IS_GENERATEDCOLUMN to match Thrift's 23 columns exactly
6870
]
6971

7072
# Note: COLUMN_DEF and TYPE_NAME both map to "columnType" - no special handling needed
@@ -111,8 +113,7 @@
111113
"SCOPE_CATALOG": "string",
112114
"SCOPE_SCHEMA": "string",
113115
"SCOPE_TABLE": "string",
114-
"IS_AUTOINCREMENT": "string",
115-
"IS_GENERATEDCOLUMN": "string",
116+
"IS_AUTO_INCREMENT": "string",
116117
}
117118

118119

@@ -203,7 +204,213 @@ def normalize_columns_metadata_description(
203204
Returns:
204205
Normalized description matching JDBC COLUMN_COLUMNS with correct type codes
205206
"""
206-
# COLUMN_DEF and TYPE_NAME both map to "columnType" so no special handling needed
207207
return normalize_metadata_description(
208208
original_description, COLUMN_COLUMNS, COLUMN_TYPE_CODES
209209
)
210+
211+
212+
def transform_schemas_data_rows(result_set, catalog_name: str, original_description: List[Tuple]) -> None:
213+
"""
214+
Transform SEA schemas() data rows to match JDBC format.
215+
216+
Args:
217+
result_set: The SEA result set to modify
218+
catalog_name: The catalog name to add as TABLE_CATALOG
219+
original_description: Original column descriptions before normalization
220+
"""
221+
if not hasattr(result_set, 'rows') or not result_set.rows:
222+
return
223+
224+
# Build mapping from original column names to their indices
225+
original_col_to_idx = {}
226+
for idx, col_desc in enumerate(original_description):
227+
original_col_to_idx[col_desc[0]] = idx
228+
229+
# Transform each row to JDBC format: (TABLE_SCHEM, TABLE_CATALOG)
230+
new_rows = []
231+
for row in result_set.rows:
232+
# Convert row to list for easier manipulation
233+
if hasattr(row, '_asdict'):
234+
row_dict = row._asdict()
235+
row_data = [row_dict.get(col_desc[0]) for col_desc in original_description]
236+
else:
237+
row_data = list(row)
238+
239+
# Extract schema name from databaseName field
240+
schema_name = None
241+
if 'databaseName' in original_col_to_idx:
242+
idx = original_col_to_idx['databaseName']
243+
schema_name = row_data[idx] if idx < len(row_data) else None
244+
# Remove quotes if present
245+
if schema_name and schema_name.startswith("'") and schema_name.endswith("'"):
246+
schema_name = schema_name[1:-1]
247+
248+
# Create new row: (TABLE_SCHEM, TABLE_CATALOG)
249+
new_row_data = (schema_name, catalog_name)
250+
new_rows.append(new_row_data)
251+
252+
# Replace the rows in the result set
253+
result_set.rows = new_rows
254+
255+
256+
def transform_tables_data_rows(result_set, catalog_name: str, original_description: List[Tuple]) -> None:
257+
"""
258+
Transform SEA tables() data rows to match JDBC format.
259+
260+
Args:
261+
result_set: The SEA result set to modify
262+
catalog_name: The catalog name to add as TABLE_CAT
263+
original_description: Original column descriptions before normalization
264+
"""
265+
if not hasattr(result_set, 'rows') or not result_set.rows:
266+
return
267+
268+
# Build mapping from original column names to their indices
269+
original_col_to_idx = {}
270+
for idx, col_desc in enumerate(original_description):
271+
original_col_to_idx[col_desc[0]] = idx
272+
273+
# Transform each row to JDBC format
274+
new_rows = []
275+
for row in result_set.rows:
276+
# Convert row to list for easier manipulation
277+
if hasattr(row, '_asdict'):
278+
row_dict = row._asdict()
279+
row_data = [row_dict.get(col_desc[0]) for col_desc in original_description]
280+
else:
281+
row_data = list(row)
282+
283+
# Extract values from original SHOW TABLES output
284+
table_schema = None
285+
table_name = None
286+
is_temporary = None
287+
288+
if 'database' in original_col_to_idx:
289+
idx = original_col_to_idx['database']
290+
table_schema = row_data[idx] if idx < len(row_data) else None
291+
292+
if 'tableName' in original_col_to_idx:
293+
idx = original_col_to_idx['tableName']
294+
table_name = row_data[idx] if idx < len(row_data) else None
295+
296+
if 'isTemporary' in original_col_to_idx:
297+
idx = original_col_to_idx['isTemporary']
298+
is_temporary = row_data[idx] if idx < len(row_data) else None
299+
300+
# Determine table type based on isTemporary flag
301+
table_type = "TEMPORARY TABLE" if is_temporary else "TABLE"
302+
303+
# Create new row with JDBC format:
304+
# (TABLE_CAT, TABLE_SCHEM, TABLE_NAME, TABLE_TYPE, REMARKS, TYPE_CAT, TYPE_SCHEM, TYPE_NAME, SELF_REFERENCING_COL_NAME, REF_GENERATION)
305+
new_row_data = (
306+
catalog_name, # TABLE_CAT
307+
table_schema, # TABLE_SCHEM
308+
table_name, # TABLE_NAME
309+
table_type, # TABLE_TYPE
310+
"", # REMARKS (empty string)
311+
None, # TYPE_CAT
312+
None, # TYPE_SCHEM
313+
None, # TYPE_NAME
314+
None, # SELF_REFERENCING_COL_NAME
315+
None, # REF_GENERATION
316+
)
317+
new_rows.append(new_row_data)
318+
319+
# Replace the rows in the result set
320+
result_set.rows = new_rows
321+
322+
323+
def transform_columns_data_rows(result_set, original_description: List[Tuple]) -> None:
324+
"""
325+
Transform SEA columns() data rows to match JDBC format and column order.
326+
327+
This function modifies the result_set.rows in place to:
328+
1. Reorder columns to match JDBC standard
329+
2. Transform data types (e.g., string to int for DATA_TYPE)
330+
3. Add missing columns with appropriate defaults
331+
4. Remove extra columns not in JDBC standard
332+
333+
Args:
334+
result_set: The SEA result set to modify
335+
original_description: Original column descriptions before normalization
336+
"""
337+
if not hasattr(result_set, 'rows') or not result_set.rows:
338+
return
339+
340+
# Build mapping from original column names to their indices
341+
original_col_to_idx = {}
342+
for idx, col_desc in enumerate(original_description):
343+
original_col_to_idx[col_desc[0]] = idx
344+
345+
# SQL type code mapping for DATA_TYPE field
346+
TYPE_CODE_MAP = {
347+
'INT': 4, 'INTEGER': 4,
348+
'BIGINT': -5,
349+
'SMALLINT': 5,
350+
'TINYINT': -6,
351+
'FLOAT': 6,
352+
'DOUBLE': 8,
353+
'DECIMAL': 3, 'NUMERIC': 3,
354+
'STRING': 12, 'VARCHAR': 12,
355+
'BOOLEAN': 16,
356+
'DATE': 91,
357+
'TIMESTAMP': 93,
358+
'BINARY': -2,
359+
'ARRAY': 2003,
360+
'STRUCT': 2002,
361+
'MAP': 2003,
362+
}
363+
364+
# Special handling for DECIMAL types with precision/scale
365+
def parse_decimal_type(type_str):
366+
"""Parse DECIMAL(precision,scale) to extract base type."""
367+
if type_str and type_str.upper().startswith('DECIMAL'):
368+
return 'DECIMAL'
369+
return type_str
370+
371+
# Transform each row
372+
new_rows = []
373+
for row in result_set.rows:
374+
# Convert row to list for easier manipulation
375+
if hasattr(row, '_asdict'):
376+
row_dict = row._asdict()
377+
row_data = [row_dict.get(col_desc[0]) for col_desc in original_description]
378+
else:
379+
row_data = list(row)
380+
381+
# Build new row according to JDBC column order
382+
new_row_data = []
383+
384+
for jdbc_col, sea_col in COLUMN_COLUMNS:
385+
if sea_col and sea_col in original_col_to_idx:
386+
# Column exists in original data
387+
original_idx = original_col_to_idx[sea_col]
388+
value = row_data[original_idx] if original_idx < len(row_data) else None
389+
390+
# Special transformations
391+
if jdbc_col == "DATA_TYPE" and value:
392+
# Convert type name to SQL type code
393+
base_type = parse_decimal_type(str(value))
394+
value = TYPE_CODE_MAP.get(str(base_type).upper(), 12) # Default to VARCHAR
395+
elif jdbc_col == "NULLABLE" and sea_col == "isNullable":
396+
# Convert boolean string to int (1=nullable, 0=not nullable)
397+
value = 1 if str(value).lower() == 'true' else 0
398+
399+
new_row_data.append(value)
400+
else:
401+
# Column doesn't exist in SEA, use appropriate default
402+
if jdbc_col == "DATA_TYPE":
403+
new_row_data.append(12) # Default to VARCHAR
404+
elif jdbc_col == "NULLABLE":
405+
new_row_data.append(1) # Default to nullable
406+
elif jdbc_col in ["BUFFER_LENGTH", "SQL_DATA_TYPE", "SQL_DATETIME_SUB",
407+
"CHAR_OCTET_LENGTH", "COLUMN_DEF", "SCOPE_CATALOG",
408+
"SCOPE_SCHEMA", "SCOPE_TABLE", "SOURCE_DATA_TYPE"]:
409+
new_row_data.append(None)
410+
else:
411+
new_row_data.append(None)
412+
413+
new_rows.append(tuple(new_row_data))
414+
415+
# Replace the rows in the result set
416+
result_set.rows = new_rows

tests/unit/test_sea_backend.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -683,6 +683,7 @@ def test_get_schemas(self, sea_client, sea_session_id, mock_cursor):
683683
("databaseName", "string", None, None, None, None, None),
684684
("catalogName", "string", None, None, None, None, None),
685685
]
686+
mock_result_set.rows = [] # Add empty rows for the transformation function
686687
with patch.object(
687688
sea_client, "execute_command", return_value=mock_result_set
688689
) as mock_execute:
@@ -754,6 +755,7 @@ def test_get_tables(self, sea_client, sea_session_id, mock_cursor):
754755
("tableType", "string", None, None, None, None, None),
755756
("remarks", "string", None, None, None, None, None),
756757
]
758+
mock_result_set.rows = [] # Add empty rows for the transformation function
757759

758760
with patch.object(
759761
sea_client, "execute_command", return_value=mock_result_set
@@ -847,6 +849,7 @@ def test_get_columns(self, sea_client, sea_session_id, mock_cursor):
847849
("dataType", "int", None, None, None, None, None),
848850
("columnType", "string", None, None, None, None, None),
849851
]
852+
mock_result_set.rows = [] # Add empty rows for the transformation function
850853
with patch.object(
851854
sea_client, "execute_command", return_value=mock_result_set
852855
) as mock_execute:

0 commit comments

Comments
 (0)