11from __future__ import annotations
22
3- from typing import Any , List , Optional , TYPE_CHECKING
3+ from typing import Any , List , Optional , TYPE_CHECKING , Dict
44
55import logging
66
77from databricks .sql .backend .sea .models .base import ResultData , ResultManifest
88from databricks .sql .backend .sea .utils .conversion import SqlTypeConverter
9+ from databricks .sql .backend .sea .utils .result_column import ResultColumn
910
1011try :
1112 import pyarrow
@@ -82,6 +83,10 @@ def __init__(
8283 arrow_schema_bytes = execute_response .arrow_schema_bytes ,
8384 )
8485
86+ # Initialize metadata columns for post-fetch transformation
87+ self ._metadata_columns = None
88+ self ._column_index_mapping = None
89+
8590 def _convert_json_types (self , row : List [str ]) -> List [Any ]:
8691 """
8792 Convert string values in the row to appropriate Python types based on column metadata.
@@ -160,6 +165,7 @@ def fetchmany_json(self, size: int) -> List[List[str]]:
160165 raise ValueError (f"size argument for fetchmany is { size } but must be >= 0" )
161166
162167 results = self .results .next_n_rows (size )
168+ results = self ._transform_json_rows (results )
163169 self ._next_row_index += len (results )
164170
165171 return results
@@ -173,6 +179,7 @@ def fetchall_json(self) -> List[List[str]]:
173179 """
174180
175181 results = self .results .remaining_rows ()
182+ results = self ._transform_json_rows (results )
176183 self ._next_row_index += len (results )
177184
178185 return results
@@ -197,7 +204,12 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
197204
198205 results = self .results .next_n_rows (size )
199206 if isinstance (self .results , JsonQueue ):
200- results = self ._convert_json_to_arrow_table (results )
207+ # Transform JSON first, then convert to Arrow
208+ transformed_json = self ._transform_json_rows (results )
209+ results = self ._convert_json_to_arrow_table (transformed_json )
210+ else :
211+ # Transform Arrow table directly
212+ results = self ._transform_arrow_table (results )
201213
202214 self ._next_row_index += results .num_rows
203215
@@ -210,7 +222,12 @@ def fetchall_arrow(self) -> "pyarrow.Table":
210222
211223 results = self .results .remaining_rows ()
212224 if isinstance (self .results , JsonQueue ):
213- results = self ._convert_json_to_arrow_table (results )
225+ # Transform JSON first, then convert to Arrow
226+ transformed_json = self ._transform_json_rows (results )
227+ results = self ._convert_json_to_arrow_table (transformed_json )
228+ else :
229+ # Transform Arrow table directly
230+ results = self ._transform_arrow_table (results )
214231
215232 self ._next_row_index += results .num_rows
216233
@@ -263,3 +280,108 @@ def fetchall(self) -> List[Row]:
263280 return self ._create_json_table (self .fetchall_json ())
264281 else :
265282 return self ._convert_arrow_table (self .fetchall_arrow ())
283+
284+ def prepare_metadata_columns (self , metadata_columns : List [ResultColumn ]) -> None :
285+ """
286+ Prepare result set for metadata column normalization.
287+
288+ Args:
289+ metadata_columns: List of ResultColumn objects defining the expected columns
290+ and their mappings from SEA column names
291+ """
292+ self ._metadata_columns = metadata_columns
293+ self ._prepare_column_mapping ()
294+
295+ def _prepare_column_mapping (self ) -> None :
296+ """
297+ Prepare column index mapping for metadata queries.
298+ Updates description to use JDBC column names.
299+ """
300+ # Ensure description is available
301+ if not self .description :
302+ raise ValueError ("Cannot prepare column mapping without result description" )
303+
304+ # Build mapping from SEA column names to their indices
305+ sea_column_indices = {}
306+ for idx , col in enumerate (self .description ):
307+ sea_column_indices [col [0 ]] = idx
308+
309+ # Create new description and index mapping
310+ new_description = []
311+ self ._column_index_mapping = {} # Maps new index -> old index
312+
313+ for new_idx , result_column in enumerate (self ._metadata_columns ):
314+ # Find the corresponding SEA column
315+ if (
316+ result_column .result_set_column_name
317+ and result_column .result_set_column_name in sea_column_indices
318+ ):
319+ old_idx = sea_column_indices [result_column .result_set_column_name ]
320+ self ._column_index_mapping [new_idx ] = old_idx
321+ # Use the original column metadata but with JDBC name
322+ old_col = self .description [old_idx ]
323+ new_description .append (
324+ (
325+ result_column .column_name , # JDBC name
326+ result_column .column_type , # Expected type
327+ old_col [2 ], # display_size
328+ old_col [3 ], # internal_size
329+ old_col [4 ], # precision
330+ old_col [5 ], # scale
331+ old_col [6 ], # null_ok
332+ )
333+ )
334+ else :
335+ # Column doesn't exist in SEA - add with None values
336+ new_description .append (
337+ (
338+ result_column .column_name ,
339+ result_column .column_type ,
340+ None ,
341+ None ,
342+ None ,
343+ None ,
344+ True ,
345+ )
346+ )
347+ self ._column_index_mapping [new_idx ] = None
348+
349+ self .description = new_description
350+
351+ def _transform_arrow_table (self , table : "pyarrow.Table" ) -> "pyarrow.Table" :
352+ """Transform arrow table columns for metadata normalization."""
353+ if not self ._metadata_columns :
354+ return table
355+
356+ # Reorder columns and add missing ones
357+ new_columns = []
358+ column_names = []
359+
360+ for new_idx , result_column in enumerate (self ._metadata_columns ):
361+ old_idx = self ._column_index_mapping .get (new_idx )
362+ if old_idx is not None :
363+ new_columns .append (table .column (old_idx ))
364+ else :
365+ # Create null column for missing data
366+ null_array = pyarrow .nulls (table .num_rows )
367+ new_columns .append (null_array )
368+ column_names .append (result_column .column_name )
369+
370+ return pyarrow .Table .from_arrays (new_columns , names = column_names )
371+
372+ def _transform_json_rows (self , rows : List [List [str ]]) -> List [List [Any ]]:
373+ """Transform JSON rows for metadata normalization."""
374+ if not self ._metadata_columns :
375+ return rows
376+
377+ transformed_rows = []
378+ for row in rows :
379+ new_row = []
380+ for new_idx in range (len (self ._metadata_columns )):
381+ old_idx = self ._column_index_mapping .get (new_idx )
382+ if old_idx is not None :
383+ new_row .append (row [old_idx ])
384+ else :
385+ new_row .append (None )
386+ transformed_rows .append (new_row )
387+ return transformed_rows
0 commit comments