@@ -64,7 +64,7 @@ use substrait::proto::expression::literal::{
6464} ;
6565use substrait:: proto:: expression:: subquery:: InPredicate ;
6666use substrait:: proto:: expression:: window_function:: BoundsType ;
67- use substrait:: proto:: read_rel:: VirtualTable ;
67+ use substrait:: proto:: read_rel:: { ExtensionTable , VirtualTable } ;
6868use substrait:: proto:: rel_common:: EmitKind ;
6969use substrait:: proto:: rel_common:: EmitKind :: Emit ;
7070use substrait:: proto:: {
@@ -211,6 +211,22 @@ pub fn to_substrait_rel(
211211 let table_schema = scan. source . schema ( ) . to_dfschema_ref ( ) ?;
212212 let base_schema = to_substrait_named_struct ( & table_schema) ?;
213213
214+ let table = if let Ok ( bytes) = state
215+ . serializer_registry ( )
216+ . serialize_custom_table ( scan. source . as_ref ( ) ) {
217+ ReadType :: ExtensionTable ( ExtensionTable {
218+ detail : Some ( ProtoAny {
219+ type_url : scan. table_name . to_string ( ) ,
220+ value : bytes. into ( ) ,
221+ } )
222+ } )
223+ } else {
224+ ReadType :: NamedTable ( NamedTable {
225+ names : scan. table_name . to_vec ( ) ,
226+ advanced_extension : None ,
227+ } )
228+ } ;
229+
214230 Ok ( Box :: new ( Rel {
215231 rel_type : Some ( RelType :: Read ( Box :: new ( ReadRel {
216232 common : None ,
@@ -219,10 +235,7 @@ pub fn to_substrait_rel(
219235 best_effort_filter : None ,
220236 projection,
221237 advanced_extension : None ,
222- read_type : Some ( ReadType :: NamedTable ( NamedTable {
223- names : scan. table_name . to_vec ( ) ,
224- advanced_extension : None ,
225- } ) ) ,
238+ read_type : Some ( table) ,
226239 } ) ) ) ,
227240 } ) )
228241 }
@@ -2200,18 +2213,20 @@ fn substrait_field_ref(index: usize) -> Result<Expression> {
22002213#[ cfg( test) ]
22012214mod test {
22022215 use super :: * ;
2203- use crate :: logical_plan:: consumer:: {
2204- from_substrait_extended_expr, from_substrait_literal_without_names,
2205- from_substrait_named_struct, from_substrait_type_without_names,
2206- } ;
2216+ use crate :: logical_plan:: consumer:: { from_substrait_extended_expr, from_substrait_literal_without_names, from_substrait_named_struct, from_substrait_plan, from_substrait_type_without_names} ;
22072217 use arrow_buffer:: { IntervalDayTime , IntervalMonthDayNano } ;
22082218 use datafusion:: arrow:: array:: {
22092219 GenericListArray , Int64Builder , MapBuilder , StringBuilder ,
22102220 } ;
22112221 use datafusion:: arrow:: datatypes:: { Field , Fields , Schema } ;
22122222 use datafusion:: common:: scalar:: ScalarStructBuilder ;
2213- use datafusion:: common:: DFSchema ;
2214- use datafusion:: execution:: SessionStateBuilder ;
2223+ use datafusion:: common:: { assert_contains, DFSchema } ;
2224+ use datafusion:: datasource:: { DefaultTableSource , TableProvider } ;
2225+ use datafusion:: datasource:: empty:: EmptyTable ;
2226+ use datafusion:: execution:: registry:: SerializerRegistry ;
2227+ use datafusion:: execution:: { SessionState , SessionStateBuilder } ;
2228+ use datafusion:: logical_expr:: TableSource ;
2229+ use datafusion:: prelude:: SessionContext ;
22152230
22162231 #[ test]
22172232 fn round_trip_literals ( ) -> Result < ( ) > {
@@ -2518,4 +2533,90 @@ mod test {
25182533
25192534 assert ! ( matches!( err, Err ( DataFusionError :: SchemaError ( _, _) ) ) ) ;
25202535 }
2536+
2537+ #[ tokio:: test]
2538+ async fn round_trip_extension_table ( ) -> Result < ( ) > {
2539+ #[ derive( Debug ) ]
2540+ struct Registry {
2541+ table : Arc < dyn TableProvider > ,
2542+ }
2543+ impl SerializerRegistry for Registry {
2544+ fn serialize_custom_table ( & self , _table : & dyn TableSource ) -> Result < Vec < u8 > > {
2545+ Ok ( "expected payload" . as_bytes ( ) . to_vec ( ) )
2546+ }
2547+ fn deserialize_custom_table ( & self , _name : & str , _bytes : & [ u8 ] ) -> Result < Arc < dyn TableSource > > {
2548+ Ok ( Arc :: new ( DefaultTableSource :: new ( self . table . clone ( ) ) ) )
2549+ }
2550+ }
2551+
2552+ async fn round_trip_logical_plans (
2553+ local : & SessionContext ,
2554+ remote : & SessionContext ,
2555+ table : Arc < dyn TableProvider >
2556+ ) -> Result < ( LogicalPlan , LogicalPlan ) > {
2557+ local. register_table ( "custom_table" , table) ?;
2558+ let initial_plan = local. sql ( "select id from custom_table" )
2559+ . await ?
2560+ . logical_plan ( )
2561+ . clone ( ) ;
2562+
2563+ // write substrait locally
2564+ let substrait = to_substrait_plan ( & initial_plan, & local. state ( ) ) ?;
2565+
2566+ // read substrait remotely
2567+ // this will only succeed if our custom_table was encoded as an ExtensionTable,
2568+ // since there's no `custom_table` registered in the remote context.
2569+ let restored = from_substrait_plan ( & remote. state ( ) , & substrait) . await ?;
2570+ assert_contains ! (
2571+ serde_json:: to_string( substrait. as_ref( ) ) . unwrap( ) ,
2572+ // value == base64("expected payload")
2573+ r#""extensionTable":{"detail":{"typeUrl":"custom_table","value":"ZXhwZWN0ZWQgcGF5bG9hZA=="}}"#
2574+ ) ;
2575+ Ok ( ( initial_plan, restored) )
2576+ }
2577+
2578+ let empty = Arc :: new ( EmptyTable :: new ( Arc :: new (
2579+ Schema :: new ( [
2580+ Arc :: new ( Field :: new ( "id" , DataType :: Int32 , false ) ) ,
2581+ Arc :: new ( Field :: new ( "name" , DataType :: Utf8 , false ) ) ,
2582+ ] )
2583+ ) ) ) ;
2584+
2585+ let first_attempt = round_trip_logical_plans (
2586+ & SessionContext :: new ( ) ,
2587+ & SessionContext :: new ( ) ,
2588+ empty. clone ( )
2589+ ) . await ;
2590+ assert_eq ! (
2591+ first_attempt. unwrap_err( ) . to_string( ) ,
2592+ "Error during planning: No table named 'custom_table'"
2593+ ) ;
2594+ fn proper_state ( table : Arc < dyn TableProvider > ) -> SessionState {
2595+ SessionStateBuilder :: new ( )
2596+ . with_default_features ( )
2597+ . with_serializer_registry ( Arc :: new ( Registry { table } ) )
2598+ . build ( )
2599+ }
2600+ let local = SessionContext :: new_with_state ( proper_state ( empty. clone ( ) ) ) ;
2601+ let remote = SessionContext :: new_with_state ( proper_state ( empty. clone ( ) ) ) ;
2602+
2603+ let ( initial_plan, restored) = round_trip_logical_plans ( & local, & remote, empty)
2604+ . await
2605+ . expect ( "Should restore the substrait plan as datafusion logical plan" ) ;
2606+
2607+ assert_eq ! (
2608+ initial_plan. to_string( ) ,
2609+ restored. to_string( )
2610+ // substrait will add an explicit projection with the full schema
2611+ . replace(
2612+ "TableScan: custom_table projection=[id, name]" ,
2613+ "TableScan: custom_table"
2614+ )
2615+ ) ;
2616+ assert_eq ! (
2617+ local. execute_logical_plan( initial_plan) . await ?. collect( ) . await ?,
2618+ remote. execute_logical_plan( restored) . await ?. collect( ) . await ?,
2619+ ) ;
2620+ Ok ( ( ) )
2621+ }
25212622}
0 commit comments