@@ -135,6 +135,7 @@ def _filter_arrow_table(
135135 table : Any , # pyarrow.Table
136136 column_name : str ,
137137 allowed_values : List [str ],
138+ case_sensitive : bool = True ,
138139 ) -> Any : # returns pyarrow.Table
139140 """
140141 Filter a PyArrow table by column values.
@@ -143,6 +144,7 @@ def _filter_arrow_table(
143144 table: The PyArrow table to filter
144145 column_name: The name of the column to filter on
145146 allowed_values: List of allowed values for the column
147+ case_sensitive: Whether to perform case-sensitive comparison
146148
147149 Returns:
148150 A filtered PyArrow table
@@ -153,18 +155,29 @@ def _filter_arrow_table(
153155 if table .num_rows == 0 :
154156 return table
155157
158+ # Handle case-insensitive filtering by normalizing both column and allowed values
159+ if not case_sensitive :
160+ # Convert allowed values to uppercase
161+ allowed_values = [v .upper () for v in allowed_values ]
162+ # Get column values as uppercase
163+ column = pc .utf8_upper (table [column_name ])
164+ else :
165+ # Use column as-is
166+ column = table [column_name ]
167+
156168 # Convert allowed_values to PyArrow Array for better performance
157169 allowed_array = pyarrow .array (allowed_values )
158170
159171 # Construct a boolean mask: True where column is in allowed_list
160- mask = pc .is_in (table [ column_name ] , value_set = allowed_array )
172+ mask = pc .is_in (column , value_set = allowed_array )
161173 return table .filter (mask )
162174
163175 @staticmethod
164176 def _filter_arrow_result_set (
165177 result_set : SeaResultSet ,
166178 column_index : int ,
167179 allowed_values : List [str ],
180+ case_sensitive : bool = True ,
168181 ) -> SeaResultSet :
169182 """
170183 Filter a SEA result set that contains Arrow tables.
@@ -173,6 +186,7 @@ def _filter_arrow_result_set(
173186 result_set: The SEA result set to filter (containing Arrow data)
174187 column_index: The index of the column to filter on
175188 allowed_values: List of allowed values for the column
189+ case_sensitive: Whether to perform case-sensitive comparison
176190
177191 Returns:
178192 A filtered SEA result set
@@ -183,7 +197,7 @@ def _filter_arrow_result_set(
183197 # Get all remaining rows as Arrow table and filter it
184198 arrow_table = result_set .results .remaining_rows ()
185199 filtered_table = ResultSetFilter ._filter_arrow_table (
186- arrow_table , column_name , allowed_values
200+ arrow_table , column_name , allowed_values , case_sensitive
187201 )
188202
189203 # Convert the filtered table to Arrow stream format for ResultData
@@ -281,10 +295,16 @@ def filter_tables_by_type(
281295 if isinstance (result_set .results , (CloudFetchQueue , ArrowQueue )):
282296 # For Arrow tables, we need to handle filtering differently
283297 return ResultSetFilter ._filter_arrow_result_set (
284- result_set , column_index = 5 , allowed_values = valid_types
298+ result_set ,
299+ column_index = 5 ,
300+ allowed_values = valid_types ,
301+ case_sensitive = True ,
285302 )
286303 else :
287304 # For JSON data, use the existing filter method
288305 return ResultSetFilter ._filter_json_result_set (
289- result_set , 5 , valid_types , case_sensitive = True
306+ result_set ,
307+ column_index = 5 ,
308+ allowed_values = valid_types ,
309+ case_sensitive = True ,
290310 )
0 commit comments