From a856c00fe962af26a6bcf62acff47de74b06f644 Mon Sep 17 00:00:00 2001 From: ebembi-crdb Date: Wed, 8 Apr 2026 15:59:10 +0530 Subject: [PATCH] feat: implement DFExtensionType for remaining canonical Arrow extension types Closes #21144 Implements DFExtensionType for all remaining canonical Arrow extension types so they are recognized and pretty-printed by the extension type registry: - Bool8: displays Int8 values as 'true'/'false' instead of raw integers - Json: uses default string formatter (values are already valid JSON) - Opaque: uses default formatter - FixedShapeTensor: uses default formatter, storage_type computed from value_type and list_size - VariableShapeTensor: uses default formatter, storage_type computed from value_type and dimensions - TimestampWithOffset: uses default formatter All six types are registered in MemoryExtensionTypeRegistry::new_with_canonical_extension_types() alongside the existing UUID registration. --- .../src/types/canonical_extensions/bool8.rs | 131 +++++++ .../fixed_shape_tensor.rs | 42 +++ .../src/types/canonical_extensions/json.rs | 37 ++ .../src/types/canonical_extensions/mod.rs | 24 ++ .../src/types/canonical_extensions/opaque.rs | 37 ++ .../timestamp_with_offset.rs | 40 +++ .../variable_shape_tensor.rs | 49 +++ .../tests/extension_types/pretty_printing.rs | 167 +++++++++ datafusion/expr/src/registry.rs | 338 +++++++++++++++++- 9 files changed, 862 insertions(+), 3 deletions(-) create mode 100644 datafusion/common/src/types/canonical_extensions/bool8.rs create mode 100644 datafusion/common/src/types/canonical_extensions/fixed_shape_tensor.rs create mode 100644 datafusion/common/src/types/canonical_extensions/json.rs create mode 100644 datafusion/common/src/types/canonical_extensions/mod.rs create mode 100644 datafusion/common/src/types/canonical_extensions/opaque.rs create mode 100644 datafusion/common/src/types/canonical_extensions/timestamp_with_offset.rs create mode 100644 datafusion/common/src/types/canonical_extensions/variable_shape_tensor.rs create mode 100644 datafusion/core/tests/extension_types/pretty_printing.rs diff --git a/datafusion/common/src/types/canonical_extensions/bool8.rs b/datafusion/common/src/types/canonical_extensions/bool8.rs new file mode 100644 index 000000000000..7c8264ae5a44 --- /dev/null +++ b/datafusion/common/src/types/canonical_extensions/bool8.rs @@ -0,0 +1,131 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::error::_internal_err; +use crate::types::extension::DFExtensionType; +use arrow::array::{Array, Int8Array}; +use arrow::datatypes::DataType; +use arrow::util::display::{ArrayFormatter, DisplayIndex, FormatOptions, FormatResult}; +use std::fmt::Write; + +/// Defines the extension type logic for the canonical `arrow.bool8` extension type. +/// +/// Bool8 values are displayed as `true` or `false`, where `0` maps to `false` and +/// any non-zero value maps to `true`. +/// +/// See [`DFExtensionType`] for information on DataFusion's extension type mechanism. +impl DFExtensionType for arrow_schema::extension::Bool8 { + fn storage_type(&self) -> DataType { + DataType::Int8 + } + + fn serialize_metadata(&self) -> Option { + // Bool8 metadata is an empty string per the Arrow spec. + Some(String::new()) + } + + fn create_array_formatter<'fmt>( + &self, + array: &'fmt dyn Array, + options: &FormatOptions<'fmt>, + ) -> crate::Result>> { + if array.data_type() != &DataType::Int8 { + return _internal_err!("Wrong array type for Bool8"); + } + + let display_index = Bool8ValueDisplayIndex { + array: array.as_any().downcast_ref().unwrap(), + null_str: options.null(), + }; + Ok(Some(ArrayFormatter::new( + Box::new(display_index), + options.safe(), + ))) + } +} + +/// Pretty printer for 8-bit Boolean values. +/// +/// Displays `false` for zero values and `true` for any non-zero value. +#[derive(Debug, Clone, Copy)] +struct Bool8ValueDisplayIndex<'a> { + array: &'a Int8Array, + null_str: &'a str, +} + +impl DisplayIndex for Bool8ValueDisplayIndex<'_> { + fn write(&self, idx: usize, f: &mut dyn Write) -> FormatResult { + if self.array.is_null(idx) { + write!(f, "{}", self.null_str)?; + return Ok(()); + } + + let value = self.array.value(idx); + if value == 0 { + write!(f, "false")?; + } else { + write!(f, "true")?; + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ScalarValue; + + #[test] + pub fn test_pretty_print_bool8_false() { + let bool8 = ScalarValue::Int8(Some(0)).to_array_of_size(1).unwrap(); + + let extension_type = arrow_schema::extension::Bool8 {}; + let formatter = extension_type + .create_array_formatter(bool8.as_ref(), &FormatOptions::default()) + .unwrap() + .unwrap(); + + assert_eq!(formatter.value(0).to_string(), "false"); + } + + #[test] + pub fn test_pretty_print_bool8_true() { + let bool8 = ScalarValue::Int8(Some(1)).to_array_of_size(1).unwrap(); + + let extension_type = arrow_schema::extension::Bool8 {}; + let formatter = extension_type + .create_array_formatter(bool8.as_ref(), &FormatOptions::default()) + .unwrap() + .unwrap(); + + assert_eq!(formatter.value(0).to_string(), "true"); + } + + #[test] + pub fn test_pretty_print_bool8_nonzero_is_true() { + // Any non-zero value should display as "true" + let bool8 = ScalarValue::Int8(Some(42)).to_array_of_size(1).unwrap(); + + let extension_type = arrow_schema::extension::Bool8 {}; + let formatter = extension_type + .create_array_formatter(bool8.as_ref(), &FormatOptions::default()) + .unwrap() + .unwrap(); + + assert_eq!(formatter.value(0).to_string(), "true"); + } +} diff --git a/datafusion/common/src/types/canonical_extensions/fixed_shape_tensor.rs b/datafusion/common/src/types/canonical_extensions/fixed_shape_tensor.rs new file mode 100644 index 000000000000..ac35fd6afc81 --- /dev/null +++ b/datafusion/common/src/types/canonical_extensions/fixed_shape_tensor.rs @@ -0,0 +1,42 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::types::extension::DFExtensionType; +use arrow::datatypes::DataType; +use arrow_schema::extension::ExtensionType; + +/// Defines the extension type logic for the canonical `arrow.fixed_shape_tensor` extension type. +/// +/// Fixed shape tensors are stored as `FixedSizeList` arrays; the default Arrow formatter +/// is used for display. +/// +/// See [`DFExtensionType`] for information on DataFusion's extension type mechanism. +impl DFExtensionType for arrow_schema::extension::FixedShapeTensor { + fn storage_type(&self) -> DataType { + DataType::new_fixed_size_list( + self.value_type().clone(), + i32::try_from(self.list_size()).expect("list size overflow"), + false, + ) + } + + fn serialize_metadata(&self) -> Option { + ::serialize_metadata( + self, + ) + } +} diff --git a/datafusion/common/src/types/canonical_extensions/json.rs b/datafusion/common/src/types/canonical_extensions/json.rs new file mode 100644 index 000000000000..88d1ed0a73bd --- /dev/null +++ b/datafusion/common/src/types/canonical_extensions/json.rs @@ -0,0 +1,37 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::types::extension::DFExtensionType; +use arrow::datatypes::DataType; +use arrow_schema::extension::ExtensionType; + +/// Defines the extension type logic for the canonical `arrow.json` extension type. +/// +/// JSON values are already stored as UTF-8 strings, so the default Arrow string +/// formatter is used for display. +/// +/// See [`DFExtensionType`] for information on DataFusion's extension type mechanism. +impl DFExtensionType for arrow_schema::extension::Json { + fn storage_type(&self) -> DataType { + // JSON can be stored as Utf8, LargeUtf8, or Utf8View; Utf8 is the most common default. + DataType::Utf8 + } + + fn serialize_metadata(&self) -> Option { + ::serialize_metadata(self) + } +} diff --git a/datafusion/common/src/types/canonical_extensions/mod.rs b/datafusion/common/src/types/canonical_extensions/mod.rs new file mode 100644 index 000000000000..d4359c73497a --- /dev/null +++ b/datafusion/common/src/types/canonical_extensions/mod.rs @@ -0,0 +1,24 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod bool8; +mod fixed_shape_tensor; +mod json; +mod opaque; +mod timestamp_with_offset; +mod uuid; +mod variable_shape_tensor; diff --git a/datafusion/common/src/types/canonical_extensions/opaque.rs b/datafusion/common/src/types/canonical_extensions/opaque.rs new file mode 100644 index 000000000000..2b58b8c4a82b --- /dev/null +++ b/datafusion/common/src/types/canonical_extensions/opaque.rs @@ -0,0 +1,37 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::types::extension::DFExtensionType; +use arrow::datatypes::DataType; +use arrow_schema::extension::ExtensionType; + +/// Defines the extension type logic for the canonical `arrow.opaque` extension type. +/// +/// Opaque represents a type received from an external system that cannot be interpreted. +/// The default Arrow formatter is used for display. +/// +/// See [`DFExtensionType`] for information on DataFusion's extension type mechanism. +impl DFExtensionType for arrow_schema::extension::Opaque { + fn storage_type(&self) -> DataType { + // Opaque supports any storage type; Null is recommended when there is no underlying data. + DataType::Null + } + + fn serialize_metadata(&self) -> Option { + ::serialize_metadata(self) + } +} diff --git a/datafusion/common/src/types/canonical_extensions/timestamp_with_offset.rs b/datafusion/common/src/types/canonical_extensions/timestamp_with_offset.rs new file mode 100644 index 000000000000..a252aceb9772 --- /dev/null +++ b/datafusion/common/src/types/canonical_extensions/timestamp_with_offset.rs @@ -0,0 +1,40 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::types::extension::DFExtensionType; +use arrow::datatypes::DataType; + +/// Defines the extension type logic for the canonical `arrow.timestamp_with_offset` extension type. +/// +/// Timestamp with offset values are stored as `Struct` arrays containing a UTC timestamp +/// and an offset in minutes. The default Arrow formatter is used for display. +/// +/// See [`DFExtensionType`] for information on DataFusion's extension type mechanism. +impl DFExtensionType for arrow_schema::extension::TimestampWithOffset { + fn storage_type(&self) -> DataType { + // TimestampWithOffset stores no internal state to determine the timestamp precision. + // The actual storage type depends on the time unit chosen by the producer. + // Returning Null here is a placeholder; the actual DataType is validated at registration + // time via ExtensionType::supports_data_type. + DataType::Null + } + + fn serialize_metadata(&self) -> Option { + // TimestampWithOffset has no metadata. + None + } +} diff --git a/datafusion/common/src/types/canonical_extensions/variable_shape_tensor.rs b/datafusion/common/src/types/canonical_extensions/variable_shape_tensor.rs new file mode 100644 index 000000000000..53dab987658b --- /dev/null +++ b/datafusion/common/src/types/canonical_extensions/variable_shape_tensor.rs @@ -0,0 +1,49 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::types::extension::DFExtensionType; +use arrow::datatypes::{DataType, Field, Fields}; +use arrow_schema::extension::ExtensionType; + +/// Defines the extension type logic for the canonical `arrow.variable_shape_tensor` extension type. +/// +/// Variable shape tensors are stored as `Struct` arrays containing `data` (a list of elements) +/// and `shape` (a fixed-size list of int32 dimensions). The default Arrow formatter is used +/// for display. +/// +/// See [`DFExtensionType`] for information on DataFusion's extension type mechanism. +impl DFExtensionType for arrow_schema::extension::VariableShapeTensor { + fn storage_type(&self) -> DataType { + let dims = i32::try_from(self.dimensions()).expect("dimensions overflow"); + DataType::Struct(Fields::from_iter([ + Field::new_list( + "data", + Field::new_list_field(self.value_type().clone(), false), + false, + ), + Field::new( + "shape", + DataType::new_fixed_size_list(DataType::Int32, dims, false), + false, + ), + ])) + } + + fn serialize_metadata(&self) -> Option { + ::serialize_metadata(self) + } +} diff --git a/datafusion/core/tests/extension_types/pretty_printing.rs b/datafusion/core/tests/extension_types/pretty_printing.rs new file mode 100644 index 000000000000..8d5476368512 --- /dev/null +++ b/datafusion/core/tests/extension_types/pretty_printing.rs @@ -0,0 +1,167 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{FixedSizeBinaryArray, Int8Array, RecordBatch, StringArray}; +use arrow_schema::extension::{Bool8, Json, Uuid}; +use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use datafusion::dataframe::DataFrame; +use datafusion::error::Result; +use datafusion::execution::SessionStateBuilder; +use datafusion::prelude::SessionContext; +use datafusion_expr::registry::MemoryExtensionTypeRegistry; +use insta::assert_snapshot; +use std::sync::Arc; + +fn test_schema() -> SchemaRef { + Arc::new(Schema::new(vec![uuid_field()])) +} + +fn uuid_field() -> Field { + Field::new("my_uuids", DataType::FixedSizeBinary(16), false).with_extension_type(Uuid) +} + +async fn create_test_table() -> Result { + let schema = test_schema(); + + // define data. + let batch = RecordBatch::try_new( + schema, + vec![Arc::new(FixedSizeBinaryArray::from(vec![ + &[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 5, 6], + ]))], + )?; + + let state = SessionStateBuilder::default() + .with_extension_type_registry(Arc::new( + MemoryExtensionTypeRegistry::new_with_canonical_extension_types(), + )) + .build(); + let ctx = SessionContext::new_with_state(state); + + ctx.register_batch("test", batch)?; + + ctx.table("test").await +} + +#[tokio::test] +async fn test_pretty_print_extension_type_formatter() -> Result<()> { + let result = create_test_table().await?.to_string().await?; + + assert_snapshot!( + result, + @r" + +--------------------------------------+ + | my_uuids | + +--------------------------------------+ + | 00000000-0000-0000-0000-000000000000 | + | 00010203-0405-0607-0809-000102030506 | + +--------------------------------------+ + " + ); + + Ok(()) +} + +fn bool8_test_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("my_bools", DataType::Int8, false).with_extension_type(Bool8), + ])) +} + +async fn create_bool8_test_table() -> Result { + let schema = bool8_test_schema(); + let batch = RecordBatch::try_new( + schema, + vec![Arc::new(Int8Array::from(vec![0, 1, 42, -1]))], + )?; + + let state = SessionStateBuilder::default() + .with_extension_type_registry(Arc::new( + MemoryExtensionTypeRegistry::new_with_canonical_extension_types(), + )) + .build(); + let ctx = SessionContext::new_with_state(state); + ctx.register_batch("test", batch)?; + ctx.table("test").await +} + +#[tokio::test] +async fn test_pretty_print_bool8() -> Result<()> { + let result = create_bool8_test_table().await?.to_string().await?; + + assert_snapshot!( + result, + @r" + +----------+ + | my_bools | + +----------+ + | false | + | true | + | true | + | true | + +----------+ + " + ); + + Ok(()) +} + +fn json_test_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("my_json", DataType::Utf8, false).with_extension_type(Json::default()), + ])) +} + +async fn create_json_test_table() -> Result { + let schema = json_test_schema(); + let batch = RecordBatch::try_new( + schema, + vec![Arc::new(StringArray::from(vec![ + r#"{"key": "value"}"#, + r#"[1, 2, 3]"#, + ]))], + )?; + + let state = SessionStateBuilder::default() + .with_extension_type_registry(Arc::new( + MemoryExtensionTypeRegistry::new_with_canonical_extension_types(), + )) + .build(); + let ctx = SessionContext::new_with_state(state); + ctx.register_batch("test", batch)?; + ctx.table("test").await +} + +#[tokio::test] +async fn test_pretty_print_json() -> Result<()> { + let result = create_json_test_table().await?.to_string().await?; + + assert_snapshot!( + result, + @r#" + +------------------+ + | my_json | + +------------------+ + | {"key": "value"} | + | [1, 2, 3] | + +------------------+ + "# + ); + + Ok(()) +} diff --git a/datafusion/expr/src/registry.rs b/datafusion/expr/src/registry.rs index 472e065211aa..4e7e9bb4fc27 100644 --- a/datafusion/expr/src/registry.rs +++ b/datafusion/expr/src/registry.rs @@ -20,10 +20,20 @@ use crate::expr_rewriter::FunctionRewrite; use crate::planner::ExprPlanner; use crate::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode, WindowUDF}; -use datafusion_common::{HashMap, Result, not_impl_err, plan_datafusion_err}; +use arrow::datatypes::Field; +use arrow_schema::DataType; +use arrow_schema::extension::{ + Bool8, ExtensionType, FixedShapeTensor, FixedShapeTensorMetadata, Json, JsonMetadata, + Opaque, OpaqueMetadata, TimestampWithOffset, VariableShapeTensor, + VariableShapeTensorMetadata, +}; +use datafusion_common::types::{DFExtensionType, DFExtensionTypeRef}; +use datafusion_common::{ + DataFusionError, HashMap, Result, not_impl_err, plan_datafusion_err, +}; use std::collections::HashSet; -use std::fmt::Debug; -use std::sync::Arc; +use std::fmt::{Debug, Formatter}; +use std::sync::{Arc, RwLock}; /// A registry knows how to build logical expressions out of user-defined function' names pub trait FunctionRegistry { @@ -215,3 +225,325 @@ impl FunctionRegistry for MemoryFunctionRegistry { self.udwfs.keys().cloned().collect() } } + +/// A cheaply cloneable pointer to an [ExtensionTypeRegistration]. +pub type ExtensionTypeRegistrationRef = Arc; + +/// The registration of an extension type. Implementations of this trait are responsible for +/// *creating* instances of [`DFExtensionType`] that represent the entire semantics of an extension +/// type. +/// +/// # Why do we need a Registration? +/// +/// A good question is why this trait is even necessary. Why not directly register the +/// [`DFExtensionType`] in a registration? +/// +/// While this works for extension types requiring no additional metadata (e.g., `arrow.uuid`), it +/// does not work for more complex extension types with metadata. For example, consider an extension +/// type `custom.shortened(n)` that aims to short the pretty-printing string to `n` characters. +/// Here, `n` is a parameter of the extension type and should be a field in the struct that +/// implements the [`DFExtensionType`]. The job of the registration is to read the metadata from the +/// field and create the corresponding [`DFExtensionType`] instance with the correct `n` set. +/// +/// The [`DefaultExtensionTypeRegistration`] provides a convenient way of creating registrations. +pub trait ExtensionTypeRegistration: Debug + Send + Sync { + /// The name of the extension type. + /// + /// This name will be used to find the correct [ExtensionTypeRegistration] when an extension + /// type is encountered. + fn type_name(&self) -> &str; + + /// Creates an extension type instance from the optional metadata. The name of the extension + /// type is not a parameter as it's already defined by the registration itself. + fn create_df_extension_type( + &self, + storage_type: &DataType, + metadata: Option<&str>, + ) -> Result; +} + +/// A cheaply cloneable pointer to an [ExtensionTypeRegistry]. +pub type ExtensionTypeRegistryRef = Arc; + +/// Manages [`ExtensionTypeRegistration`]s, which allow users to register custom behavior for +/// extension types. +/// +/// Each registration is connected to the extension type name, which can also be looked up to get +/// the registration. +pub trait ExtensionTypeRegistry: Debug + Send + Sync { + /// Returns a reference to registration of an extension type named `name`. + /// + /// Returns an error if there is no extension type with that name. + fn extension_type_registration( + &self, + name: &str, + ) -> Result; + + /// Creates a [`DFExtensionTypeRef`] from the type information in the `field`. + /// + /// The result `Ok(None)` indicates that there is no extension type metadata. Returns an error + /// if the extension type in the metadata is not found. + fn create_extension_type_for_field( + &self, + field: &Field, + ) -> Result> { + let Some(extension_type_name) = field.extension_type_name() else { + return Ok(None); + }; + + let registration = self.extension_type_registration(extension_type_name)?; + registration + .create_df_extension_type(field.data_type(), field.extension_type_metadata()) + .map(Some) + } + + /// Returns all registered [ExtensionTypeRegistration]. + fn extension_type_registrations(&self) -> Vec>; + + /// Registers a new [ExtensionTypeRegistrationRef], returning any previously registered + /// implementation. + /// + /// Returns an error if the type cannot be registered, for example, if the registry is + /// read-only. + fn add_extension_type_registration( + &self, + extension_type: ExtensionTypeRegistrationRef, + ) -> Result>; + + /// Extends the registry with the provided extension types. + /// + /// Returns an error if the type cannot be registered, for example, if the registry is + /// read-only. + fn extend(&self, extension_types: &[ExtensionTypeRegistrationRef]) -> Result<()> { + for extension_type in extension_types.iter().cloned() { + self.add_extension_type_registration(extension_type)?; + } + Ok(()) + } + + /// Deregisters an extension type registration with the name `name`, returning the + /// implementation that was deregistered. + /// + /// Returns an error if the type cannot be deregistered, for example, if the registry is + /// read-only. + fn remove_extension_type_registration( + &self, + name: &str, + ) -> Result>; +} + +/// A factory that creates instances of extension types from a storage [`DataType`] and the +/// metadata. +pub type ExtensionTypeFactory = dyn Fn(&DataType, ::Metadata) -> Result + + Send + + Sync; + +/// A default implementation of [ExtensionTypeRegistration] that parses the metadata from the +/// given extension type and passes it to a constructor function. +pub struct DefaultExtensionTypeRegistration< + TExtensionType: ExtensionType + DFExtensionType + 'static, +> { + /// A function that creates an instance of [`DFExtensionTypeRef`] from the storage type and the + /// metadata. + factory: Box>, +} + +impl + DefaultExtensionTypeRegistration +{ + /// Creates a new registration for an extension type. + /// + /// The factory is not required to validate the storage [`DataType`], as the compatibility will + /// be checked by the registration using [`ExtensionType::supports_data_type`]. However, the + /// factory may still choose to do so. + pub fn new_arc( + factory: impl Fn(&DataType, TExtensionType::Metadata) -> Result + + Send + + Sync + + 'static, + ) -> ExtensionTypeRegistrationRef { + Arc::new(Self { + factory: Box::new(factory), + }) + } +} + +impl ExtensionTypeRegistration + for DefaultExtensionTypeRegistration +{ + fn type_name(&self) -> &str { + TExtensionType::NAME + } + + fn create_df_extension_type( + &self, + storage_type: &DataType, + metadata: Option<&str>, + ) -> Result { + let metadata = TExtensionType::deserialize_metadata(metadata)?; + let type_instance = self.factory.as_ref()(storage_type, metadata)?; + type_instance + .supports_data_type(storage_type) + .map_err(|_| { + plan_datafusion_err!( + "Extension type {} obtained from registration does not support the storage data type {}", + TExtensionType::NAME, + storage_type + ) + })?; + + Ok(Arc::new(type_instance) as DFExtensionTypeRef) + } +} + +impl Debug + for DefaultExtensionTypeRegistration +{ + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("DefaultExtensionTypeRegistration") + .field("type_name", &TExtensionType::NAME) + .finish() + } +} + +/// An [`ExtensionTypeRegistry`] that uses in memory [`HashMap`]s. +#[derive(Clone, Debug)] +pub struct MemoryExtensionTypeRegistry { + /// Holds a mapping between the name of an extension type and its logical type. + extension_types: Arc>>, +} + +impl Default for MemoryExtensionTypeRegistry { + fn default() -> Self { + Self::new_empty() + } +} + +impl MemoryExtensionTypeRegistry { + /// Creates an empty [MemoryExtensionTypeRegistry]. + pub fn new_empty() -> Self { + Self { + extension_types: Arc::new(RwLock::new(HashMap::new())), + } + } + + /// Pre-registers the [canonical extension types](https://arrow.apache.org/docs/format/CanonicalExtensions.html) + /// in the extension type registry. + pub fn new_with_canonical_extension_types() -> Self { + let mapping: [ExtensionTypeRegistrationRef; 7] = [ + DefaultExtensionTypeRegistration::new_arc(|_, _| { + Ok(arrow_schema::extension::Uuid {}) + }), + DefaultExtensionTypeRegistration::new_arc(|_, _: &'static str| Ok(Bool8 {})), + DefaultExtensionTypeRegistration::new_arc(|dt, metadata: JsonMetadata| { + ::try_new(dt, metadata) + .map_err(DataFusionError::from) + }), + DefaultExtensionTypeRegistration::new_arc(|_, metadata: OpaqueMetadata| { + Ok(Opaque::from(metadata)) + }), + DefaultExtensionTypeRegistration::new_arc( + |dt, metadata: FixedShapeTensorMetadata| { + ::try_new(dt, metadata) + .map_err(DataFusionError::from) + }, + ), + DefaultExtensionTypeRegistration::new_arc( + |dt, metadata: VariableShapeTensorMetadata| { + ::try_new(dt, metadata) + .map_err(DataFusionError::from) + }, + ), + DefaultExtensionTypeRegistration::new_arc(|_, _| Ok(TimestampWithOffset {})), + ]; + + let mut extension_types = HashMap::new(); + for registration in mapping.into_iter() { + extension_types.insert(registration.type_name().to_owned(), registration); + } + + Self { + extension_types: Arc::new(RwLock::new(HashMap::from(extension_types))), + } + } + + /// Creates a new [MemoryExtensionTypeRegistry] with the provided `types`. + /// + /// # Errors + /// + /// Returns an error if one of the `types` is a native type. + pub fn new_with_types( + types: impl IntoIterator, + ) -> Result { + let extension_types = types + .into_iter() + .map(|t| (t.type_name().to_owned(), t)) + .collect::>(); + Ok(Self { + extension_types: Arc::new(RwLock::new(extension_types)), + }) + } + + /// Returns a list of all registered types. + pub fn all_extension_types(&self) -> Vec { + self.extension_types + .read() + .expect("Extension type registry lock poisoned") + .values() + .cloned() + .collect() + } +} + +impl ExtensionTypeRegistry for MemoryExtensionTypeRegistry { + fn extension_type_registration( + &self, + name: &str, + ) -> Result { + self.extension_types + .write() + .expect("Extension type registry lock poisoned") + .get(name) + .ok_or_else(|| plan_datafusion_err!("Logical type not found.")) + .cloned() + } + + fn extension_type_registrations(&self) -> Vec> { + self.extension_types + .read() + .expect("Extension type registry lock poisoned") + .values() + .cloned() + .collect() + } + + fn add_extension_type_registration( + &self, + extension_type: ExtensionTypeRegistrationRef, + ) -> Result> { + Ok(self + .extension_types + .write() + .expect("Extension type registry lock poisoned") + .insert(extension_type.type_name().to_owned(), extension_type)) + } + + fn remove_extension_type_registration( + &self, + name: &str, + ) -> Result> { + Ok(self + .extension_types + .write() + .expect("Extension type registry lock poisoned") + .remove(name)) + } +} + +impl From> for MemoryExtensionTypeRegistry { + fn from(value: HashMap) -> Self { + Self { + extension_types: Arc::new(RwLock::new(value)), + } + } +}