Skip to content

Commit a90bfeb

Browse files
case sensitive support for arrow table
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 20c9fbd commit a90bfeb

File tree

1 file changed

+24
-4
lines changed

1 file changed

+24
-4
lines changed

src/databricks/sql/backend/sea/utils/filters.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def _filter_arrow_table(
135135
table: Any, # pyarrow.Table
136136
column_name: str,
137137
allowed_values: List[str],
138+
case_sensitive: bool = True,
138139
) -> Any: # returns pyarrow.Table
139140
"""
140141
Filter a PyArrow table by column values.
@@ -143,6 +144,7 @@ def _filter_arrow_table(
143144
table: The PyArrow table to filter
144145
column_name: The name of the column to filter on
145146
allowed_values: List of allowed values for the column
147+
case_sensitive: Whether to perform case-sensitive comparison
146148
147149
Returns:
148150
A filtered PyArrow table
@@ -153,18 +155,29 @@ def _filter_arrow_table(
153155
if table.num_rows == 0:
154156
return table
155157

158+
# Handle case-insensitive filtering by normalizing both column and allowed values
159+
if not case_sensitive:
160+
# Convert allowed values to uppercase
161+
allowed_values = [v.upper() for v in allowed_values]
162+
# Get column values as uppercase
163+
column = pc.utf8_upper(table[column_name])
164+
else:
165+
# Use column as-is
166+
column = table[column_name]
167+
156168
# Convert allowed_values to PyArrow Array for better performance
157169
allowed_array = pyarrow.array(allowed_values)
158170

159171
# Construct a boolean mask: True where column is in allowed_list
160-
mask = pc.is_in(table[column_name], value_set=allowed_array)
172+
mask = pc.is_in(column, value_set=allowed_array)
161173
return table.filter(mask)
162174

163175
@staticmethod
164176
def _filter_arrow_result_set(
165177
result_set: SeaResultSet,
166178
column_index: int,
167179
allowed_values: List[str],
180+
case_sensitive: bool = True,
168181
) -> SeaResultSet:
169182
"""
170183
Filter a SEA result set that contains Arrow tables.
@@ -173,6 +186,7 @@ def _filter_arrow_result_set(
173186
result_set: The SEA result set to filter (containing Arrow data)
174187
column_index: The index of the column to filter on
175188
allowed_values: List of allowed values for the column
189+
case_sensitive: Whether to perform case-sensitive comparison
176190
177191
Returns:
178192
A filtered SEA result set
@@ -183,7 +197,7 @@ def _filter_arrow_result_set(
183197
# Get all remaining rows as Arrow table and filter it
184198
arrow_table = result_set.results.remaining_rows()
185199
filtered_table = ResultSetFilter._filter_arrow_table(
186-
arrow_table, column_name, allowed_values
200+
arrow_table, column_name, allowed_values, case_sensitive
187201
)
188202

189203
# Convert the filtered table to Arrow stream format for ResultData
@@ -281,10 +295,16 @@ def filter_tables_by_type(
281295
if isinstance(result_set.results, (CloudFetchQueue, ArrowQueue)):
282296
# For Arrow tables, we need to handle filtering differently
283297
return ResultSetFilter._filter_arrow_result_set(
284-
result_set, column_index=5, allowed_values=valid_types
298+
result_set,
299+
column_index=5,
300+
allowed_values=valid_types,
301+
case_sensitive=True,
285302
)
286303
else:
287304
# For JSON data, use the existing filter method
288305
return ResultSetFilter._filter_json_result_set(
289-
result_set, 5, valid_types, case_sensitive=True
306+
result_set,
307+
column_index=5,
308+
allowed_values=valid_types,
309+
case_sensitive=True,
290310
)

0 commit comments

Comments
 (0)