Skip to content

Commit 081e886

Browse files
committed
[substrait] Add support for ExtensionTable
1 parent 8b6daaf commit 081e886

3 files changed

Lines changed: 177 additions & 25 deletions

File tree

datafusion/expr/src/registry.rs

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
2020
use crate::expr_rewriter::FunctionRewrite;
2121
use crate::planner::ExprPlanner;
22-
use crate::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode, WindowUDF};
23-
use datafusion_common::{not_impl_err, plan_datafusion_err, HashMap, Result};
22+
use crate::{AggregateUDF, ScalarUDF, TableSource, UserDefinedLogicalNode, WindowUDF};
23+
use datafusion_common::{not_impl_err, plan_datafusion_err, DataFusionError, HashMap, Result};
2424
use std::collections::HashSet;
2525
use std::fmt::Debug;
2626
use std::sync::Arc;
@@ -123,22 +123,57 @@ pub trait FunctionRegistry {
123123
}
124124
}
125125

126-
/// Serializer and deserializer registry for extensions like [UserDefinedLogicalNode].
126+
/// Serializer and deserializer registry for extensions like [UserDefinedLogicalNode]
127+
/// and custom table providers for which the name alone is meaningless in the target
128+
/// execution context, e.g. UDTFs, manually registered tables etc.
127129
pub trait SerializerRegistry: Debug + Send + Sync {
128130
/// Serialize this node to a byte array. This serialization should not include
129131
/// input plans.
130132
fn serialize_logical_plan(
131133
&self,
132-
node: &dyn UserDefinedLogicalNode,
133-
) -> Result<Vec<u8>>;
134+
_node: &dyn UserDefinedLogicalNode,
135+
) -> Result<Vec<u8>> {
136+
Err(DataFusionError::Plan("UserDefinedLogicalNode serialization not supported".into()))
137+
}
134138

135139
/// Deserialize user defined logical plan node ([UserDefinedLogicalNode]) from
136140
/// bytes.
137141
fn deserialize_logical_plan(
138142
&self,
139-
name: &str,
140-
bytes: &[u8],
141-
) -> Result<Arc<dyn UserDefinedLogicalNode>>;
143+
_name: &str,
144+
_bytes: &[u8],
145+
) -> Result<Arc<dyn UserDefinedLogicalNode>> {
146+
Err(DataFusionError::Plan("UserDefinedLogicalNode deserialization not supported".into()))
147+
}
148+
149+
/// Binary representation for custom tables, to be converted to substrait extension tables.
150+
/// Should only return success for table implementations that cannot be found by name
151+
/// in the destination execution context, such as UDTFs or manually registered table providers.
152+
fn serialize_custom_table(
153+
&self,
154+
_table: &dyn TableSource,
155+
) -> Result<Vec<u8>> {
156+
Err(DataFusionError::Plan("Custom table serialization not supported".into()))
157+
}
158+
159+
/// Deserialize the custom table with the given name.
160+
/// The name may not be useful as a discriminator if multiple UDTF/TableProvider
161+
/// implementations are expected. This is particularly true for UDTFs in DataFusion,
162+
/// which are always registered under the same name: `tmp_table`, so one should
163+
/// use the binary payload to distinguish between multiple table types.
164+
/// A potential future improvement would be to return a (name, bytes) tuple from
165+
/// [SerializerRegistry::serialize_custom_table] to allow the implementors to assign
166+
/// different names to different table provider implementations (e.g. in the case of proto,
167+
/// by using the actual protobuf `type_url`).
168+
/// But this would mean the table names in the restored plan may no longer match
169+
/// the original ones.
170+
fn deserialize_custom_table(
171+
&self,
172+
_name: &str,
173+
_bytes: &[u8],
174+
) -> Result<Arc<dyn TableSource>> {
175+
Err(DataFusionError::Plan("Custom table deserialization not supported".into()))
176+
}
142177
}
143178

144179
/// A [`FunctionRegistry`] that uses in memory [`HashMap`]s

datafusion/substrait/src/logical_plan/consumer.rs

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,7 @@ use datafusion::common::{
2929
use datafusion::datasource::provider_as_source;
3030
use datafusion::logical_expr::expr::{Exists, InSubquery, Sort};
3131

32-
use datafusion::logical_expr::{
33-
Aggregate, BinaryExpr, Case, EmptyRelation, Expr, ExprSchemable, LogicalPlan,
34-
Operator, Projection, SortExpr, TryCast, Values,
35-
};
32+
use datafusion::logical_expr::{Aggregate, BinaryExpr, Case, EmptyRelation, Expr, ExprSchemable, LogicalPlan, Operator, Projection, SortExpr, TableScan, TryCast, Values};
3633
use substrait::proto::aggregate_rel::Grouping;
3734
use substrait::proto::expression::subquery::set_predicate::PredicateOp;
3835
use substrait::proto::expression_reference::ExprType;
@@ -994,8 +991,27 @@ pub async fn from_substrait_rel(
994991
)
995992
.await
996993
}
997-
_ => {
998-
not_impl_err!("Unsupported ReadType: {:?}", &read.as_ref().read_type)
994+
Some(ReadType::ExtensionTable(ext)) => {
995+
if let Some(ext_detail) = &ext.detail {
996+
let source = state.serializer_registry()
997+
.deserialize_custom_table(&ext_detail.type_url, &ext_detail.value)?;
998+
let table_name = if let Some((_, name)) = ext_detail.type_url.rsplit_once('/') {
999+
name
1000+
} else {
1001+
&ext_detail.type_url
1002+
};
1003+
let plan = LogicalPlan::TableScan(
1004+
TableScan::try_new(table_name, source, None, vec![], None)?
1005+
);
1006+
let schema = apply_masking(substrait_schema, &read.projection)?;
1007+
ensure_schema_compatability(plan.schema(), schema.clone())?;
1008+
apply_projection(plan, schema)
1009+
} else {
1010+
substrait_err!("Unexpected empty detail in ExtensionTable")
1011+
}
1012+
},
1013+
None => {
1014+
substrait_err!("Unexpected empty read_type")
9991015
}
10001016
}
10011017
}

datafusion/substrait/src/logical_plan/producer.rs

Lines changed: 112 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ use substrait::proto::expression::literal::{
6464
};
6565
use substrait::proto::expression::subquery::InPredicate;
6666
use substrait::proto::expression::window_function::BoundsType;
67-
use substrait::proto::read_rel::VirtualTable;
67+
use substrait::proto::read_rel::{ExtensionTable, VirtualTable};
6868
use substrait::proto::rel_common::EmitKind;
6969
use substrait::proto::rel_common::EmitKind::Emit;
7070
use substrait::proto::{
@@ -211,6 +211,22 @@ pub fn to_substrait_rel(
211211
let table_schema = scan.source.schema().to_dfschema_ref()?;
212212
let base_schema = to_substrait_named_struct(&table_schema)?;
213213

214+
let table = if let Ok(bytes) = state
215+
.serializer_registry()
216+
.serialize_custom_table(scan.source.as_ref()) {
217+
ReadType::ExtensionTable(ExtensionTable {
218+
detail: Some(ProtoAny {
219+
type_url: scan.table_name.to_string(),
220+
value: bytes.into(),
221+
})
222+
})
223+
} else {
224+
ReadType::NamedTable(NamedTable {
225+
names: scan.table_name.to_vec(),
226+
advanced_extension: None,
227+
})
228+
};
229+
214230
Ok(Box::new(Rel {
215231
rel_type: Some(RelType::Read(Box::new(ReadRel {
216232
common: None,
@@ -219,10 +235,7 @@ pub fn to_substrait_rel(
219235
best_effort_filter: None,
220236
projection,
221237
advanced_extension: None,
222-
read_type: Some(ReadType::NamedTable(NamedTable {
223-
names: scan.table_name.to_vec(),
224-
advanced_extension: None,
225-
})),
238+
read_type: Some(table),
226239
}))),
227240
}))
228241
}
@@ -2200,18 +2213,20 @@ fn substrait_field_ref(index: usize) -> Result<Expression> {
22002213
#[cfg(test)]
22012214
mod test {
22022215
use super::*;
2203-
use crate::logical_plan::consumer::{
2204-
from_substrait_extended_expr, from_substrait_literal_without_names,
2205-
from_substrait_named_struct, from_substrait_type_without_names,
2206-
};
2216+
use crate::logical_plan::consumer::{from_substrait_extended_expr, from_substrait_literal_without_names, from_substrait_named_struct, from_substrait_plan, from_substrait_type_without_names};
22072217
use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
22082218
use datafusion::arrow::array::{
22092219
GenericListArray, Int64Builder, MapBuilder, StringBuilder,
22102220
};
22112221
use datafusion::arrow::datatypes::{Field, Fields, Schema};
22122222
use datafusion::common::scalar::ScalarStructBuilder;
2213-
use datafusion::common::DFSchema;
2214-
use datafusion::execution::SessionStateBuilder;
2223+
use datafusion::common::{assert_contains, DFSchema};
2224+
use datafusion::datasource::{DefaultTableSource, TableProvider};
2225+
use datafusion::datasource::empty::EmptyTable;
2226+
use datafusion::execution::registry::SerializerRegistry;
2227+
use datafusion::execution::{SessionState, SessionStateBuilder};
2228+
use datafusion::logical_expr::TableSource;
2229+
use datafusion::prelude::SessionContext;
22152230

22162231
#[test]
22172232
fn round_trip_literals() -> Result<()> {
@@ -2518,4 +2533,90 @@ mod test {
25182533

25192534
assert!(matches!(err, Err(DataFusionError::SchemaError(_, _))));
25202535
}
2536+
2537+
#[tokio::test]
2538+
async fn round_trip_extension_table() -> Result<()> {
2539+
#[derive(Debug)]
2540+
struct Registry {
2541+
table: Arc<dyn TableProvider>,
2542+
}
2543+
impl SerializerRegistry for Registry {
2544+
fn serialize_custom_table(&self, _table: &dyn TableSource) -> Result<Vec<u8>> {
2545+
Ok("expected payload".as_bytes().to_vec())
2546+
}
2547+
fn deserialize_custom_table(&self, _name: &str, _bytes: &[u8]) -> Result<Arc<dyn TableSource>> {
2548+
Ok(Arc::new(DefaultTableSource::new(self.table.clone())))
2549+
}
2550+
}
2551+
2552+
async fn round_trip_logical_plans(
2553+
local: &SessionContext,
2554+
remote: &SessionContext,
2555+
table: Arc<dyn TableProvider>
2556+
) -> Result<(LogicalPlan, LogicalPlan)> {
2557+
local.register_table("custom_table", table)?;
2558+
let initial_plan = local.sql("select id from custom_table")
2559+
.await?
2560+
.logical_plan()
2561+
.clone();
2562+
2563+
// write substrait locally
2564+
let substrait = to_substrait_plan(&initial_plan, &local.state())?;
2565+
2566+
// read substrait remotely
2567+
// this will only succeed if our custom_table was encoded as an ExtensionTable,
2568+
// since there's no `custom_table` registered in the remote context.
2569+
let restored = from_substrait_plan(&remote.state(), &substrait).await?;
2570+
assert_contains!(
2571+
serde_json::to_string(substrait.as_ref()).unwrap(),
2572+
// value == base64("expected payload")
2573+
r#""extensionTable":{"detail":{"typeUrl":"custom_table","value":"ZXhwZWN0ZWQgcGF5bG9hZA=="}}"#
2574+
);
2575+
Ok((initial_plan, restored))
2576+
}
2577+
2578+
let empty = Arc::new(EmptyTable::new(Arc::new(
2579+
Schema::new([
2580+
Arc::new(Field::new("id", DataType::Int32, false)),
2581+
Arc::new(Field::new("name", DataType::Utf8, false)),
2582+
])
2583+
)));
2584+
2585+
let first_attempt = round_trip_logical_plans(
2586+
&SessionContext::new(),
2587+
&SessionContext::new(),
2588+
empty.clone()
2589+
).await;
2590+
assert_eq!(
2591+
first_attempt.unwrap_err().to_string(),
2592+
"Error during planning: No table named 'custom_table'"
2593+
);
2594+
fn proper_state(table: Arc<dyn TableProvider>) -> SessionState {
2595+
SessionStateBuilder::new()
2596+
.with_default_features()
2597+
.with_serializer_registry(Arc::new(Registry { table }))
2598+
.build()
2599+
}
2600+
let local = SessionContext::new_with_state(proper_state(empty.clone()));
2601+
let remote = SessionContext::new_with_state(proper_state(empty.clone()));
2602+
2603+
let (initial_plan, restored) = round_trip_logical_plans(&local, &remote, empty)
2604+
.await
2605+
.expect("Should restore the substrait plan as datafusion logical plan");
2606+
2607+
assert_eq!(
2608+
initial_plan.to_string(),
2609+
restored.to_string()
2610+
// substrait will add an explicit projection with the full schema
2611+
.replace(
2612+
"TableScan: custom_table projection=[id, name]",
2613+
"TableScan: custom_table"
2614+
)
2615+
);
2616+
assert_eq!(
2617+
local.execute_logical_plan(initial_plan).await?.collect().await?,
2618+
remote.execute_logical_plan(restored).await?.collect().await?,
2619+
);
2620+
Ok(())
2621+
}
25212622
}

0 commit comments

Comments
 (0)