|
15 | 15 | // specific language governing permissions and limitations |
16 | 16 | // under the License. |
17 | 17 |
|
18 | | -use arrow::array::{ |
19 | | - make_array, make_comparator, Array, BooleanArray, Capacities, MutableArrayData, |
20 | | - Scalar, |
21 | | -}; |
| 18 | +use arrow::array::{make_array, make_comparator, Array, BooleanArray, Capacities, ListArray, MutableArrayData, Scalar, StructArray}; |
22 | 19 | use arrow::compute::SortOptions; |
23 | 20 | use arrow::datatypes::DataType; |
24 | 21 | use arrow_buffer::NullBuffer; |
25 | | -use datafusion_common::cast::{as_map_array, as_struct_array}; |
| 22 | +use datafusion_common::cast::{as_list_array, as_map_array, as_struct_array}; |
26 | 23 | use datafusion_common::{ |
27 | 24 | exec_err, internal_err, plan_datafusion_err, utils::take_function_args, Result, |
28 | 25 | ScalarValue, |
@@ -138,6 +135,15 @@ impl ScalarUDFImpl for GetFieldFunc { |
138 | 135 | debug_assert_eq!(args.scalar_arguments.len(), 2); |
139 | 136 |
|
140 | 137 | match (&args.arg_types[0], args.scalar_arguments[1].as_ref()) { |
| 138 | + (DataType::List(field), Some(ScalarValue::Utf8(Some(field_name)))) => { |
| 139 | + if let DataType::Struct(fields) = field.data_type() { |
| 140 | + fields.iter().find(|f| f.name() == field_name) |
| 141 | + .ok_or(plan_datafusion_err!("Field {field_name} not found in struct")) |
| 142 | + .map(|f| ReturnInfo::new_nullable(DataType::List(f.clone()))) |
| 143 | + } else { |
| 144 | + exec_err!("Expected a List of Structs") |
| 145 | + } |
| 146 | + } |
141 | 147 | (DataType::Map(fields, _), _) => { |
142 | 148 | match fields.data_type() { |
143 | 149 | DataType::Struct(fields) if fields.len() == 2 => { |
@@ -185,6 +191,39 @@ impl ScalarUDFImpl for GetFieldFunc { |
185 | 191 | } |
186 | 192 | }; |
187 | 193 |
|
| 194 | + pub fn get_field_from_list( |
| 195 | + array: Arc<dyn Array>, |
| 196 | + field_name: &str, |
| 197 | + ) -> Result<ColumnarValue> { |
| 198 | + let list_array = as_list_array(array.as_ref())?; |
| 199 | + match list_array.value_type() { |
| 200 | + DataType::Struct(fields) => { |
| 201 | + let struct_array = as_struct_array(list_array.values()).or_else(|_| { |
| 202 | + exec_err!("Expected a StructArray inside the ListArray") |
| 203 | + })?; |
| 204 | + let Some(field_index) = fields |
| 205 | + .iter() |
| 206 | + .position(|f| f.name() == field_name) |
| 207 | + else { |
| 208 | + return exec_err!("Field {field_name} not found in struct") |
| 209 | + }; |
| 210 | + let projection_array = struct_array.column(field_index); |
| 211 | + |
| 212 | + let (_, offsets, _, nulls) = list_array.clone().into_parts(); |
| 213 | + |
| 214 | + let new_list = ListArray::new( |
| 215 | + fields[field_index].clone(), |
| 216 | + offsets, |
| 217 | + projection_array.to_owned(), |
| 218 | + nulls, |
| 219 | + ); |
| 220 | + |
| 221 | + Ok(ColumnarValue::Array(Arc::new(new_list))) |
| 222 | + } |
| 223 | + _ => exec_err!("Expected a ListArray of Structs"), |
| 224 | + } |
| 225 | + } |
| 226 | + |
188 | 227 | fn process_map_array( |
189 | 228 | array: Arc<dyn Array>, |
190 | 229 | key_array: Arc<dyn Array>, |
@@ -235,6 +274,13 @@ impl ScalarUDFImpl for GetFieldFunc { |
235 | 274 | } |
236 | 275 |
|
237 | 276 | match (array.data_type(), name) { |
| 277 | + (DataType::List(field), ScalarValue::Utf8(Some(k))) => { |
| 278 | + if let DataType::Struct(_) = field.data_type() { |
| 279 | + get_field_from_list(array, &k) |
| 280 | + } else { |
| 281 | + exec_err!("Expected a List of Structs") |
| 282 | + } |
| 283 | + } |
238 | 284 | (DataType::Map(_, _), ScalarValue::List(arr)) => { |
239 | 285 | let key_array: Arc<dyn Array> = arr; |
240 | 286 | process_map_array(array, key_array) |
|
0 commit comments