@@ -332,6 +332,21 @@ def _setup_signatures() -> None:
332332 ]
333333 _LIB .lbug_connection_create_arrow_rel_table .restype = ctypes .c_int
334334
335+ _LIB .lbug_connection_create_arrow_rel_table_csr .argtypes = [
336+ ctypes .POINTER (_LbugConnection ),
337+ ctypes .c_char_p ,
338+ ctypes .c_char_p ,
339+ ctypes .c_char_p ,
340+ ctypes .POINTER (_ArrowSchema ),
341+ ctypes .POINTER (_ArrowArray ),
342+ ctypes .c_uint64 ,
343+ ctypes .POINTER (_ArrowSchema ),
344+ ctypes .POINTER (_ArrowArray ),
345+ ctypes .c_uint64 ,
346+ ctypes .POINTER (_LbugQueryResult ),
347+ ]
348+ _LIB .lbug_connection_create_arrow_rel_table_csr .restype = ctypes .c_int
349+
335350 _LIB .lbug_connection_drop_arrow_table .argtypes = [
336351 ctypes .POINTER (_LbugConnection ),
337352 ctypes .c_char_p ,
@@ -2324,19 +2339,54 @@ def create_arrow_rel_table(
23242339 dataframe : Any ,
23252340 src_table_name : str ,
23262341 dst_table_name : str ,
2342+ layout : Any = "FLAT" ,
2343+ indptr_dataframe : Any | None = None ,
23272344 ) -> QueryResult :
2345+ layout_value = getattr (layout , "value" , layout )
2346+ layout_value = str (layout_value ).upper ()
2347+ if layout_value not in {"FLAT" , "CSR" }:
2348+ msg = "Arrow relationship table layout must be FLAT or CSR"
2349+ raise RuntimeError (msg )
2350+ if layout_value == "FLAT" and indptr_dataframe is not None :
2351+ msg = "indptr_dataframe is only valid for CSR Arrow relationship tables"
2352+ raise RuntimeError (msg )
2353+ if layout_value == "CSR" and indptr_dataframe is None :
2354+ msg = "indptr_dataframe is required for CSR Arrow relationship tables"
2355+ raise RuntimeError (msg )
2356+
23282357 _table , schema , arrays , _batches = self ._export_arrow_table (dataframe )
23292358 result = _LbugQueryResult ()
2330- state = _LIB .lbug_connection_create_arrow_rel_table (
2331- ctypes .byref (self ._connection ),
2332- table_name .encode ("utf-8" ),
2333- src_table_name .encode ("utf-8" ),
2334- dst_table_name .encode ("utf-8" ),
2335- ctypes .byref (schema ),
2336- arrays ,
2337- len (arrays ),
2338- ctypes .byref (result ),
2339- )
2359+ if layout_value == "FLAT" :
2360+ state = _LIB .lbug_connection_create_arrow_rel_table (
2361+ ctypes .byref (self ._connection ),
2362+ table_name .encode ("utf-8" ),
2363+ src_table_name .encode ("utf-8" ),
2364+ dst_table_name .encode ("utf-8" ),
2365+ ctypes .byref (schema ),
2366+ arrays ,
2367+ len (arrays ),
2368+ ctypes .byref (result ),
2369+ )
2370+ else :
2371+ (
2372+ _indptr_table ,
2373+ indptr_schema ,
2374+ indptr_arrays ,
2375+ _indptr_batches ,
2376+ ) = self ._export_arrow_table (indptr_dataframe )
2377+ state = _LIB .lbug_connection_create_arrow_rel_table_csr (
2378+ ctypes .byref (self ._connection ),
2379+ table_name .encode ("utf-8" ),
2380+ src_table_name .encode ("utf-8" ),
2381+ dst_table_name .encode ("utf-8" ),
2382+ ctypes .byref (schema ),
2383+ arrays ,
2384+ len (arrays ),
2385+ ctypes .byref (indptr_schema ),
2386+ indptr_arrays ,
2387+ len (indptr_arrays ),
2388+ ctypes .byref (result ),
2389+ )
23402390 if state != _LBUG_SUCCESS and not result ._query_result :
23412391 _check_state (state , "Failed to create Arrow relationship table" )
23422392 return QueryResult (result )
0 commit comments