Skip to content

Commit 5dedb96

Browse files
committed
[HSTACK] - add spark syntax for projecting nested structs from arrays
1 parent 11f90d6 commit 5dedb96

1 file changed

Lines changed: 51 additions & 5 deletions

File tree

datafusion/functions/src/core/getfield.rs

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,11 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use arrow::array::{
19-
make_array, make_comparator, Array, BooleanArray, Capacities, MutableArrayData,
20-
Scalar,
21-
};
18+
use arrow::array::{make_array, make_comparator, Array, BooleanArray, Capacities, ListArray, MutableArrayData, Scalar, StructArray};
2219
use arrow::compute::SortOptions;
2320
use arrow::datatypes::DataType;
2421
use arrow_buffer::NullBuffer;
25-
use datafusion_common::cast::{as_map_array, as_struct_array};
22+
use datafusion_common::cast::{as_list_array, as_map_array, as_struct_array};
2623
use datafusion_common::{
2724
exec_err, internal_err, plan_datafusion_err, utils::take_function_args, Result,
2825
ScalarValue,
@@ -138,6 +135,15 @@ impl ScalarUDFImpl for GetFieldFunc {
138135
debug_assert_eq!(args.scalar_arguments.len(), 2);
139136

140137
match (&args.arg_types[0], args.scalar_arguments[1].as_ref()) {
138+
(DataType::List(field), Some(ScalarValue::Utf8(Some(field_name)))) => {
139+
if let DataType::Struct(fields) = field.data_type() {
140+
fields.iter().find(|f| f.name() == field_name)
141+
.ok_or(plan_datafusion_err!("Field {field_name} not found in struct"))
142+
.map(|f| ReturnInfo::new_nullable(DataType::List(f.clone())))
143+
} else {
144+
exec_err!("Expected a List of Structs")
145+
}
146+
}
141147
(DataType::Map(fields, _), _) => {
142148
match fields.data_type() {
143149
DataType::Struct(fields) if fields.len() == 2 => {
@@ -185,6 +191,39 @@ impl ScalarUDFImpl for GetFieldFunc {
185191
}
186192
};
187193

194+
pub fn get_field_from_list(
195+
array: Arc<dyn Array>,
196+
field_name: &str,
197+
) -> Result<ColumnarValue> {
198+
let list_array = as_list_array(array.as_ref())?;
199+
match list_array.value_type() {
200+
DataType::Struct(fields) => {
201+
let struct_array = as_struct_array(list_array.values()).or_else(|_| {
202+
exec_err!("Expected a StructArray inside the ListArray")
203+
})?;
204+
let Some(field_index) = fields
205+
.iter()
206+
.position(|f| f.name() == field_name)
207+
else {
208+
return exec_err!("Field {field_name} not found in struct")
209+
};
210+
let projection_array = struct_array.column(field_index);
211+
212+
let (_, offsets, _, nulls) = list_array.clone().into_parts();
213+
214+
let new_list = ListArray::new(
215+
fields[field_index].clone(),
216+
offsets,
217+
projection_array.to_owned(),
218+
nulls,
219+
);
220+
221+
Ok(ColumnarValue::Array(Arc::new(new_list)))
222+
}
223+
_ => exec_err!("Expected a ListArray of Structs"),
224+
}
225+
}
226+
188227
fn process_map_array(
189228
array: Arc<dyn Array>,
190229
key_array: Arc<dyn Array>,
@@ -235,6 +274,13 @@ impl ScalarUDFImpl for GetFieldFunc {
235274
}
236275

237276
match (array.data_type(), name) {
277+
(DataType::List(field), ScalarValue::Utf8(Some(k))) => {
278+
if let DataType::Struct(_) = field.data_type() {
279+
get_field_from_list(array, &k)
280+
} else {
281+
exec_err!("Expected a List of Structs")
282+
}
283+
}
238284
(DataType::Map(_, _), ScalarValue::List(arr)) => {
239285
let key_array: Arc<dyn Array> = arr;
240286
process_map_array(array, key_array)

0 commit comments

Comments
 (0)