11from __future__ import annotations
22
33import ast
4+ import atexit
45import ctypes
56import ctypes .util
67import datetime as dt
78import os
89import sys
910import threading
1011import uuid
12+ import weakref
1113from dataclasses import dataclass
1214from decimal import Decimal
1315from pathlib import Path
@@ -186,6 +188,24 @@ def _resolve_library_path() -> str:
186188
187189_dlopen_mode = getattr (ctypes , "RTLD_GLOBAL" , 0 ) | getattr (ctypes , "RTLD_NOW" , 0 )
188190_LIB = ctypes .CDLL (_resolve_library_path (), mode = _dlopen_mode )
191+ _CAPI_DATABASES : weakref .WeakSet [Any ] = weakref .WeakSet ()
192+ _CAPI_CONNECTIONS : weakref .WeakSet [Any ] = weakref .WeakSet ()
193+ _ARROW_ATEXIT_REGISTERED = False
194+
195+
196+ def _close_capi_connections () -> None :
197+ for connection in list (_CAPI_CONNECTIONS ):
198+ connection .close ()
199+ for database in list (_CAPI_DATABASES ):
200+ database .close ()
201+
202+
203+ def _ensure_arrow_atexit_cleanup () -> None :
204+ global _ARROW_ATEXIT_REGISTERED
205+ if not _ARROW_ATEXIT_REGISTERED :
206+ atexit .register (_close_capi_connections )
207+ _ARROW_ATEXIT_REGISTERED = True
208+
189209
190210_LBUG_SUCCESS = 0
191211
@@ -288,6 +308,35 @@ def _setup_signatures() -> None:
288308 ]
289309 _LIB .lbug_connection_execute .restype = ctypes .c_int
290310
311+ _LIB .lbug_connection_create_arrow_table .argtypes = [
312+ ctypes .POINTER (_LbugConnection ),
313+ ctypes .c_char_p ,
314+ ctypes .POINTER (_ArrowSchema ),
315+ ctypes .POINTER (_ArrowArray ),
316+ ctypes .c_uint64 ,
317+ ctypes .POINTER (_LbugQueryResult ),
318+ ]
319+ _LIB .lbug_connection_create_arrow_table .restype = ctypes .c_int
320+
321+ _LIB .lbug_connection_create_arrow_rel_table .argtypes = [
322+ ctypes .POINTER (_LbugConnection ),
323+ ctypes .c_char_p ,
324+ ctypes .c_char_p ,
325+ ctypes .c_char_p ,
326+ ctypes .POINTER (_ArrowSchema ),
327+ ctypes .POINTER (_ArrowArray ),
328+ ctypes .c_uint64 ,
329+ ctypes .POINTER (_LbugQueryResult ),
330+ ]
331+ _LIB .lbug_connection_create_arrow_rel_table .restype = ctypes .c_int
332+
333+ _LIB .lbug_connection_drop_arrow_table .argtypes = [
334+ ctypes .POINTER (_LbugConnection ),
335+ ctypes .c_char_p ,
336+ ctypes .POINTER (_LbugQueryResult ),
337+ ]
338+ _LIB .lbug_connection_drop_arrow_table .restype = ctypes .c_int
339+
291340 _LIB .lbug_prepared_statement_destroy .argtypes = [
292341 ctypes .POINTER (_LbugPreparedStatement )
293342 ]
@@ -1065,13 +1114,15 @@ def __init__(
10651114 database_path .encode ("utf-8" ), config , ctypes .byref (self ._database )
10661115 )
10671116 _check_state (state , "Failed to initialize database" )
1117+ _CAPI_DATABASES .add (self )
10681118
10691119 def close (self ) -> None :
10701120 lib = _LIB
10711121 if self ._database ._database :
10721122 if lib is not None :
10731123 lib .lbug_database_destroy (ctypes .byref (self ._database ))
10741124 self ._database ._database = None
1125+ _CAPI_DATABASES .discard (self )
10751126
10761127 @staticmethod
10771128 def get_version () -> str :
@@ -2014,6 +2065,7 @@ def __init__(self, database: Database, num_threads: int = 0):
20142065 ),
20152066 "Failed to initialize connection" ,
20162067 )
2068+ _CAPI_CONNECTIONS .add (self )
20172069 if num_threads > 0 :
20182070 self .set_max_threads_for_exec (num_threads )
20192071
@@ -2023,6 +2075,7 @@ def close(self) -> None:
20232075 if lib is not None :
20242076 lib .lbug_connection_destroy (ctypes .byref (self ._connection ))
20252077 self ._connection ._connection = None
2078+ _CAPI_CONNECTIONS .discard (self )
20262079
20272080 def set_max_threads_for_exec (self , num_threads : int ) -> None :
20282081 _check_state (
@@ -2119,17 +2172,81 @@ def create_function(self, *_args: Any, **_kwargs: Any) -> None:
21192172 def remove_function (self , * _args : Any , ** _kwargs : Any ) -> None :
21202173 raise NotImplementedError ("UDF removal is not yet implemented in C-API backend" )
21212174
2122- def create_arrow_table (self , * _args : Any , ** _kwargs : Any ) -> Any :
2123- raise NotImplementedError (
2124- "Arrow memory table APIs are not yet implemented in C-API backend"
2175+ @staticmethod
2176+ def _as_arrow_table (dataframe : Any ) -> Any :
2177+ import pyarrow as pa
2178+
2179+ _ensure_arrow_atexit_cleanup ()
2180+ module_name = type (dataframe ).__module__
2181+ if module_name .startswith ("pandas" ):
2182+ return pa .Table .from_pandas (dataframe )
2183+ if module_name .startswith ("polars" ):
2184+ return dataframe .to_arrow ()
2185+ if (
2186+ module_name .startswith ("pyarrow" )
2187+ and dataframe .__class__ .__name__ == "Table"
2188+ ):
2189+ return dataframe
2190+ msg = "Expected a pyarrow Table, polars DataFrame, or pandas DataFrame"
2191+ raise RuntimeError (msg )
2192+
2193+ @staticmethod
2194+ def _export_arrow_table (dataframe : Any ) -> tuple [Any , _ArrowSchema , Any , Any ]:
2195+ table = Connection ._as_arrow_table (dataframe )
2196+ schema = _ArrowSchema ()
2197+ table .schema ._export_to_c (ctypes .addressof (schema ))
2198+ batches = table .to_batches ()
2199+ array_type = _ArrowArray * len (batches )
2200+ arrays = array_type ()
2201+ for idx , batch in enumerate (batches ):
2202+ batch ._export_to_c (ctypes .addressof (arrays [idx ]))
2203+ return table , schema , arrays , batches
2204+
2205+ def create_arrow_table (self , table_name : str , dataframe : Any ) -> QueryResult :
2206+ _table , schema , arrays , _batches = self ._export_arrow_table (dataframe )
2207+ result = _LbugQueryResult ()
2208+ state = _LIB .lbug_connection_create_arrow_table (
2209+ ctypes .byref (self ._connection ),
2210+ table_name .encode ("utf-8" ),
2211+ ctypes .byref (schema ),
2212+ arrays ,
2213+ len (arrays ),
2214+ ctypes .byref (result ),
21252215 )
2216+ if state != _LBUG_SUCCESS and not result ._query_result :
2217+ _check_state (state , "Failed to create Arrow table" )
2218+ return QueryResult (result )
21262219
2127- def drop_arrow_table (self , * _args : Any , ** _kwargs : Any ) -> Any :
2128- raise NotImplementedError (
2129- "Arrow memory table APIs are not yet implemented in C-API backend"
2220+ def drop_arrow_table (self , table_name : str ) -> QueryResult :
2221+ result = _LbugQueryResult ()
2222+ state = _LIB .lbug_connection_drop_arrow_table (
2223+ ctypes .byref (self ._connection ),
2224+ table_name .encode ("utf-8" ),
2225+ ctypes .byref (result ),
21302226 )
2227+ if state != _LBUG_SUCCESS and not result ._query_result :
2228+ _check_state (state , "Failed to drop Arrow table" )
2229+ return QueryResult (result )
21312230
2132- def create_arrow_rel_table (self , * _args : Any , ** _kwargs : Any ) -> Any :
2133- raise NotImplementedError (
2134- "Arrow memory table APIs are not yet implemented in C-API backend"
2231+ def create_arrow_rel_table (
2232+ self ,
2233+ table_name : str ,
2234+ dataframe : Any ,
2235+ src_table_name : str ,
2236+ dst_table_name : str ,
2237+ ) -> QueryResult :
2238+ _table , schema , arrays , _batches = self ._export_arrow_table (dataframe )
2239+ result = _LbugQueryResult ()
2240+ state = _LIB .lbug_connection_create_arrow_rel_table (
2241+ ctypes .byref (self ._connection ),
2242+ table_name .encode ("utf-8" ),
2243+ src_table_name .encode ("utf-8" ),
2244+ dst_table_name .encode ("utf-8" ),
2245+ ctypes .byref (schema ),
2246+ arrays ,
2247+ len (arrays ),
2248+ ctypes .byref (result ),
21352249 )
2250+ if state != _LBUG_SUCCESS and not result ._query_result :
2251+ _check_state (state , "Failed to create Arrow relationship table" )
2252+ return QueryResult (result )
0 commit comments