@@ -65,7 +65,7 @@ use substrait::proto::expression::literal::{
6565} ;
6666use substrait:: proto:: expression:: subquery:: InPredicate ;
6767use substrait:: proto:: expression:: window_function:: BoundsType ;
68- use substrait:: proto:: read_rel:: VirtualTable ;
68+ use substrait:: proto:: read_rel:: { ExtensionTable , VirtualTable } ;
6969use substrait:: proto:: rel_common:: EmitKind ;
7070use substrait:: proto:: rel_common:: EmitKind :: Emit ;
7171use substrait:: proto:: {
@@ -212,6 +212,23 @@ pub fn to_substrait_rel(
212212 let table_schema = scan. source . schema ( ) . to_dfschema_ref ( ) ?;
213213 let base_schema = to_substrait_named_struct ( & table_schema) ?;
214214
215+ let table = if let Ok ( bytes) = state
216+ . serializer_registry ( )
217+ . serialize_custom_table ( scan. source . as_ref ( ) )
218+ {
219+ ReadType :: ExtensionTable ( ExtensionTable {
220+ detail : Some ( ProtoAny {
221+ type_url : scan. table_name . to_string ( ) ,
222+ value : bytes. into ( ) ,
223+ } ) ,
224+ } )
225+ } else {
226+ ReadType :: NamedTable ( NamedTable {
227+ names : scan. table_name . to_vec ( ) ,
228+ advanced_extension : None ,
229+ } )
230+ } ;
231+
215232 Ok ( Box :: new ( Rel {
216233 rel_type : Some ( RelType :: Read ( Box :: new ( ReadRel {
217234 common : None ,
@@ -220,10 +237,7 @@ pub fn to_substrait_rel(
220237 best_effort_filter : None ,
221238 projection,
222239 advanced_extension : None ,
223- read_type : Some ( ReadType :: NamedTable ( NamedTable {
224- names : scan. table_name . to_vec ( ) ,
225- advanced_extension : None ,
226- } ) ) ,
240+ read_type : Some ( table) ,
227241 } ) ) ) ,
228242 } ) )
229243 }
@@ -2204,16 +2218,22 @@ mod test {
22042218 use super :: * ;
22052219 use crate :: logical_plan:: consumer:: {
22062220 from_substrait_extended_expr, from_substrait_literal_without_names,
2207- from_substrait_named_struct, from_substrait_type_without_names,
2221+ from_substrait_named_struct, from_substrait_plan,
2222+ from_substrait_type_without_names,
22082223 } ;
22092224 use arrow_buffer:: { IntervalDayTime , IntervalMonthDayNano } ;
22102225 use datafusion:: arrow:: array:: {
22112226 GenericListArray , Int64Builder , MapBuilder , StringBuilder ,
22122227 } ;
22132228 use datafusion:: arrow:: datatypes:: { Field , Fields , Schema } ;
22142229 use datafusion:: common:: scalar:: ScalarStructBuilder ;
2215- use datafusion:: common:: DFSchema ;
2230+ use datafusion:: common:: { assert_contains, DFSchema } ;
2231+ use datafusion:: datasource:: empty:: EmptyTable ;
2232+ use datafusion:: datasource:: { DefaultTableSource , TableProvider } ;
2233+ use datafusion:: execution:: registry:: SerializerRegistry ;
22162234 use datafusion:: execution:: SessionStateBuilder ;
2235+ use datafusion:: logical_expr:: TableSource ;
2236+ use datafusion:: prelude:: SessionContext ;
22172237
22182238 #[ test]
22192239 fn round_trip_literals ( ) -> Result < ( ) > {
@@ -2540,4 +2560,110 @@ mod test {
25402560
25412561 assert ! ( matches!( err, Err ( DataFusionError :: SchemaError ( _, _) ) ) ) ;
25422562 }
2563+
2564+ #[ tokio:: test]
2565+ async fn round_trip_extension_table ( ) {
2566+ const TABLE_NAME : & str = "custom_table" ;
2567+ const SERIALIZED : & [ u8 ] = "table definition" . as_bytes ( ) ;
2568+
2569+ fn custom_table ( ) -> Arc < dyn TableProvider > {
2570+ Arc :: new ( EmptyTable :: new ( Arc :: new ( Schema :: new ( [
2571+ Arc :: new ( Field :: new ( "id" , DataType :: Int32 , false ) ) ,
2572+ Arc :: new ( Field :: new ( "name" , DataType :: Utf8 , false ) ) ,
2573+ ] ) ) ) )
2574+ }
2575+
2576+ #[ derive( Debug ) ]
2577+ struct Registry ;
2578+ impl SerializerRegistry for Registry {
2579+ fn serialize_custom_table ( & self , table : & dyn TableSource ) -> Result < Vec < u8 > > {
2580+ if table. schema ( ) == custom_table ( ) . schema ( ) {
2581+ Ok ( SERIALIZED . to_vec ( ) )
2582+ } else {
2583+ Err ( DataFusionError :: Internal ( "Not our table" . into ( ) ) )
2584+ }
2585+ }
2586+ fn deserialize_custom_table (
2587+ & self ,
2588+ name : & str ,
2589+ bytes : & [ u8 ] ,
2590+ ) -> Result < Arc < dyn TableSource > > {
2591+ if name == TABLE_NAME && bytes == SERIALIZED {
2592+ Ok ( Arc :: new ( DefaultTableSource :: new ( custom_table ( ) ) ) )
2593+ } else {
2594+ panic ! ( "Unexpected extension table: {name}" ) ;
2595+ }
2596+ }
2597+ }
2598+
2599+ async fn round_trip_logical_plans (
2600+ local : & SessionContext ,
2601+ remote : & SessionContext ,
2602+ ) -> Result < ( ) > {
2603+ local. register_table ( TABLE_NAME , custom_table ( ) ) ?;
2604+ remote. table_provider ( TABLE_NAME ) . await . expect_err (
2605+ "The remote context is not supposed to know about custom_table" ,
2606+ ) ;
2607+ let initial_plan = local
2608+ . sql ( & format ! ( "select id from {TABLE_NAME}" ) )
2609+ . await ?
2610+ . logical_plan ( )
2611+ . clone ( ) ;
2612+
2613+ // write substrait locally
2614+ let substrait = to_substrait_plan ( & initial_plan, & local. state ( ) ) ?;
2615+
2616+ // read substrait remotely
2617+ // since we know there's no `custom_table` registered in the remote context, this will only succeed
2618+ // if our table got encoded as an ExtensionTable and is now decoded back to a table source.
2619+ let restored = from_substrait_plan ( & remote. state ( ) , & substrait) . await ?;
2620+ assert_contains ! (
2621+ // confirm that the Substrait plan contains our custom_table as an ExtensionTable
2622+ serde_json:: to_string( substrait. as_ref( ) ) . unwrap( ) ,
2623+ format!( r#""extensionTable":{{"detail":{{"typeUrl":"{TABLE_NAME}","# )
2624+ ) ;
2625+ remote // make sure the restored plan is fully working in the remote context
2626+ . execute_logical_plan ( restored. clone ( ) )
2627+ . await ?
2628+ . collect ( )
2629+ . await
2630+ . expect ( "Restored plan cannot be executed remotely" ) ;
2631+ assert_eq ! (
2632+ // check that the restored plan is functionally equivalent (and almost identical) to the initial one
2633+ initial_plan. to_string( ) ,
2634+ restored. to_string( ) . replace(
2635+ // substrait will add an explicit full-schema projection if the original table had none
2636+ & format!( "TableScan: {TABLE_NAME} projection=[id, name]" ) ,
2637+ & format!( "TableScan: {TABLE_NAME}" ) ,
2638+ )
2639+ ) ;
2640+ Ok ( ( ) )
2641+ }
2642+
2643+ // take 1
2644+ let failed_attempt =
2645+ round_trip_logical_plans ( & SessionContext :: new ( ) , & SessionContext :: new ( ) )
2646+ . await
2647+ . expect_err (
2648+ "The round trip should fail in the absence of a SerializerRegistry" ,
2649+ ) ;
2650+ assert_contains ! (
2651+ failed_attempt. message( ) ,
2652+ format!( "No table named '{TABLE_NAME}'" )
2653+ ) ;
2654+
2655+ // take 2
2656+ fn proper_context ( ) -> SessionContext {
2657+ SessionContext :: new_with_state (
2658+ SessionStateBuilder :: new ( )
2659+ // This will transport our custom_table as a Substrait ExtensionTable
2660+ . with_serializer_registry ( Arc :: new ( Registry ) )
2661+ . build ( ) ,
2662+ )
2663+ }
2664+
2665+ round_trip_logical_plans ( & proper_context ( ) , & proper_context ( ) )
2666+ . await
2667+ . expect ( "Local plan could not be restored remotely" ) ;
2668+ }
25432669}
0 commit comments