Skip to content

Commit b1bc517

Browse files
committed
Enable C API Arrow table scans
1 parent c3e30c5 commit b1bc517

3 files changed

Lines changed: 250 additions & 31 deletions

File tree

src_py/_lbug_capi.py

Lines changed: 126 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
from __future__ import annotations
22

33
import ast
4+
import atexit
45
import ctypes
56
import ctypes.util
67
import datetime as dt
78
import os
89
import sys
910
import threading
1011
import uuid
12+
import weakref
1113
from dataclasses import dataclass
1214
from decimal import Decimal
1315
from pathlib import Path
@@ -186,6 +188,24 @@ def _resolve_library_path() -> str:
186188

187189
_dlopen_mode = getattr(ctypes, "RTLD_GLOBAL", 0) | getattr(ctypes, "RTLD_NOW", 0)
188190
_LIB = ctypes.CDLL(_resolve_library_path(), mode=_dlopen_mode)
191+
_CAPI_DATABASES: weakref.WeakSet[Any] = weakref.WeakSet()
192+
_CAPI_CONNECTIONS: weakref.WeakSet[Any] = weakref.WeakSet()
193+
_ARROW_ATEXIT_REGISTERED = False
194+
195+
196+
def _close_capi_connections() -> None:
197+
for connection in list(_CAPI_CONNECTIONS):
198+
connection.close()
199+
for database in list(_CAPI_DATABASES):
200+
database.close()
201+
202+
203+
def _ensure_arrow_atexit_cleanup() -> None:
204+
global _ARROW_ATEXIT_REGISTERED
205+
if not _ARROW_ATEXIT_REGISTERED:
206+
atexit.register(_close_capi_connections)
207+
_ARROW_ATEXIT_REGISTERED = True
208+
189209

190210
_LBUG_SUCCESS = 0
191211

@@ -288,6 +308,35 @@ def _setup_signatures() -> None:
288308
]
289309
_LIB.lbug_connection_execute.restype = ctypes.c_int
290310

311+
_LIB.lbug_connection_create_arrow_table.argtypes = [
312+
ctypes.POINTER(_LbugConnection),
313+
ctypes.c_char_p,
314+
ctypes.POINTER(_ArrowSchema),
315+
ctypes.POINTER(_ArrowArray),
316+
ctypes.c_uint64,
317+
ctypes.POINTER(_LbugQueryResult),
318+
]
319+
_LIB.lbug_connection_create_arrow_table.restype = ctypes.c_int
320+
321+
_LIB.lbug_connection_create_arrow_rel_table.argtypes = [
322+
ctypes.POINTER(_LbugConnection),
323+
ctypes.c_char_p,
324+
ctypes.c_char_p,
325+
ctypes.c_char_p,
326+
ctypes.POINTER(_ArrowSchema),
327+
ctypes.POINTER(_ArrowArray),
328+
ctypes.c_uint64,
329+
ctypes.POINTER(_LbugQueryResult),
330+
]
331+
_LIB.lbug_connection_create_arrow_rel_table.restype = ctypes.c_int
332+
333+
_LIB.lbug_connection_drop_arrow_table.argtypes = [
334+
ctypes.POINTER(_LbugConnection),
335+
ctypes.c_char_p,
336+
ctypes.POINTER(_LbugQueryResult),
337+
]
338+
_LIB.lbug_connection_drop_arrow_table.restype = ctypes.c_int
339+
291340
_LIB.lbug_prepared_statement_destroy.argtypes = [
292341
ctypes.POINTER(_LbugPreparedStatement)
293342
]
@@ -1065,13 +1114,15 @@ def __init__(
10651114
database_path.encode("utf-8"), config, ctypes.byref(self._database)
10661115
)
10671116
_check_state(state, "Failed to initialize database")
1117+
_CAPI_DATABASES.add(self)
10681118

10691119
def close(self) -> None:
10701120
lib = _LIB
10711121
if self._database._database:
10721122
if lib is not None:
10731123
lib.lbug_database_destroy(ctypes.byref(self._database))
10741124
self._database._database = None
1125+
_CAPI_DATABASES.discard(self)
10751126

10761127
@staticmethod
10771128
def get_version() -> str:
@@ -2014,6 +2065,7 @@ def __init__(self, database: Database, num_threads: int = 0):
20142065
),
20152066
"Failed to initialize connection",
20162067
)
2068+
_CAPI_CONNECTIONS.add(self)
20172069
if num_threads > 0:
20182070
self.set_max_threads_for_exec(num_threads)
20192071

@@ -2023,6 +2075,7 @@ def close(self) -> None:
20232075
if lib is not None:
20242076
lib.lbug_connection_destroy(ctypes.byref(self._connection))
20252077
self._connection._connection = None
2078+
_CAPI_CONNECTIONS.discard(self)
20262079

20272080
def set_max_threads_for_exec(self, num_threads: int) -> None:
20282081
_check_state(
@@ -2119,17 +2172,81 @@ def create_function(self, *_args: Any, **_kwargs: Any) -> None:
21192172
def remove_function(self, *_args: Any, **_kwargs: Any) -> None:
21202173
raise NotImplementedError("UDF removal is not yet implemented in C-API backend")
21212174

2122-
def create_arrow_table(self, *_args: Any, **_kwargs: Any) -> Any:
2123-
raise NotImplementedError(
2124-
"Arrow memory table APIs are not yet implemented in C-API backend"
2175+
@staticmethod
2176+
def _as_arrow_table(dataframe: Any) -> Any:
2177+
import pyarrow as pa
2178+
2179+
_ensure_arrow_atexit_cleanup()
2180+
module_name = type(dataframe).__module__
2181+
if module_name.startswith("pandas"):
2182+
return pa.Table.from_pandas(dataframe)
2183+
if module_name.startswith("polars"):
2184+
return dataframe.to_arrow()
2185+
if (
2186+
module_name.startswith("pyarrow")
2187+
and dataframe.__class__.__name__ == "Table"
2188+
):
2189+
return dataframe
2190+
msg = "Expected a pyarrow Table, polars DataFrame, or pandas DataFrame"
2191+
raise RuntimeError(msg)
2192+
2193+
@staticmethod
2194+
def _export_arrow_table(dataframe: Any) -> tuple[Any, _ArrowSchema, Any, Any]:
2195+
table = Connection._as_arrow_table(dataframe)
2196+
schema = _ArrowSchema()
2197+
table.schema._export_to_c(ctypes.addressof(schema))
2198+
batches = table.to_batches()
2199+
array_type = _ArrowArray * len(batches)
2200+
arrays = array_type()
2201+
for idx, batch in enumerate(batches):
2202+
batch._export_to_c(ctypes.addressof(arrays[idx]))
2203+
return table, schema, arrays, batches
2204+
2205+
def create_arrow_table(self, table_name: str, dataframe: Any) -> QueryResult:
2206+
_table, schema, arrays, _batches = self._export_arrow_table(dataframe)
2207+
result = _LbugQueryResult()
2208+
state = _LIB.lbug_connection_create_arrow_table(
2209+
ctypes.byref(self._connection),
2210+
table_name.encode("utf-8"),
2211+
ctypes.byref(schema),
2212+
arrays,
2213+
len(arrays),
2214+
ctypes.byref(result),
21252215
)
2216+
if state != _LBUG_SUCCESS and not result._query_result:
2217+
_check_state(state, "Failed to create Arrow table")
2218+
return QueryResult(result)
21262219

2127-
def drop_arrow_table(self, *_args: Any, **_kwargs: Any) -> Any:
2128-
raise NotImplementedError(
2129-
"Arrow memory table APIs are not yet implemented in C-API backend"
2220+
def drop_arrow_table(self, table_name: str) -> QueryResult:
2221+
result = _LbugQueryResult()
2222+
state = _LIB.lbug_connection_drop_arrow_table(
2223+
ctypes.byref(self._connection),
2224+
table_name.encode("utf-8"),
2225+
ctypes.byref(result),
21302226
)
2227+
if state != _LBUG_SUCCESS and not result._query_result:
2228+
_check_state(state, "Failed to drop Arrow table")
2229+
return QueryResult(result)
21312230

2132-
def create_arrow_rel_table(self, *_args: Any, **_kwargs: Any) -> Any:
2133-
raise NotImplementedError(
2134-
"Arrow memory table APIs are not yet implemented in C-API backend"
2231+
def create_arrow_rel_table(
2232+
self,
2233+
table_name: str,
2234+
dataframe: Any,
2235+
src_table_name: str,
2236+
dst_table_name: str,
2237+
) -> QueryResult:
2238+
_table, schema, arrays, _batches = self._export_arrow_table(dataframe)
2239+
result = _LbugQueryResult()
2240+
state = _LIB.lbug_connection_create_arrow_rel_table(
2241+
ctypes.byref(self._connection),
2242+
table_name.encode("utf-8"),
2243+
src_table_name.encode("utf-8"),
2244+
dst_table_name.encode("utf-8"),
2245+
ctypes.byref(schema),
2246+
arrays,
2247+
len(arrays),
2248+
ctypes.byref(result),
21352249
)
2250+
if state != _LBUG_SUCCESS and not result._query_result:
2251+
_check_state(state, "Failed to create Arrow relationship table")
2252+
return QueryResult(result)

src_py/connection.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import inspect
44
import json
55
import re
6+
import uuid
67
import warnings
78
from typing import TYPE_CHECKING, Any
89
from weakref import WeakSet
@@ -49,6 +50,7 @@ def __init__(self, database: Database, num_threads: int = 0):
4950
self._prefer_pybind = False
5051
self._query_timeout_ms = 0
5152
self._query_results: WeakSet[QueryResult] = WeakSet()
53+
self._capi_scan_tables: set[str] = set()
5254
self.database._register_connection(self)
5355
self.init_connection()
5456

@@ -174,6 +176,113 @@ def _has_scan_pattern(self, query: str) -> bool:
174176
return False
175177
return re.search(r"(?i)\bFROM\b", query) is not None
176178

179+
@staticmethod
180+
def _quote_identifier(identifier: str) -> str:
181+
escaped = identifier.replace("`", "``")
182+
return f"`{escaped}`"
183+
184+
def _arrow_table_column_names(self, value: Any) -> list[str]:
185+
table = get_capi_module().Connection._as_arrow_table(value)
186+
return [field.name for field in table.schema]
187+
188+
def _create_capi_scan_table(self, value: Any) -> tuple[str, list[str]]:
189+
table_name = f"__lbug_capi_scan_{uuid.uuid4().hex}"
190+
self._connection.create_arrow_table(table_name, value)
191+
self._capi_scan_tables.add(table_name)
192+
return table_name, self._arrow_table_column_names(value)
193+
194+
def _replace_column_refs(self, text: str, columns: list[str], alias: str) -> str:
195+
result = text
196+
for column in sorted(columns, key=len, reverse=True):
197+
quoted = self._quote_identifier(column)
198+
result = re.sub(
199+
rf"(?<![\w.`]){re.escape(column)}(?![\w`])",
200+
f"{alias}.{quoted}",
201+
result,
202+
)
203+
return result
204+
205+
def _rewrite_load_from_capi_scan(
206+
self,
207+
query: str,
208+
source_start: int,
209+
source_end: int,
210+
table_name: str,
211+
columns: list[str],
212+
) -> str:
213+
alias = "_scan"
214+
match_prefix = f"MATCH ({alias}:{self._quote_identifier(table_name)})"
215+
rest = query[source_end:]
216+
return_star = ", ".join(
217+
f"{alias}.{self._quote_identifier(column)} AS {self._quote_identifier(column)}"
218+
for column in columns
219+
)
220+
return_match = re.search(r"(?i)\bRETURN\s+\*", rest)
221+
if return_match is not None:
222+
rest = (
223+
rest[: return_match.start()]
224+
+ f"RETURN {return_star}"
225+
+ rest[return_match.end() :]
226+
)
227+
rest = self._replace_column_refs(rest, columns, alias)
228+
return query[:source_start] + match_prefix + rest
229+
230+
def _rewrite_copy_from_capi_scan(
231+
self,
232+
query: str,
233+
source_start: int,
234+
source_end: int,
235+
table_name: str,
236+
columns: list[str],
237+
) -> str:
238+
alias = "_scan"
239+
return_cols = ", ".join(
240+
f"{alias}.{self._quote_identifier(column)}" for column in columns
241+
)
242+
replacement = f"(MATCH ({alias}:{self._quote_identifier(table_name)}) RETURN {return_cols})"
243+
return query[:source_start] + replacement + query[source_end:]
244+
245+
def _rewrite_capi_python_scan(
246+
self,
247+
query: str,
248+
parameters: dict[str, Any],
249+
) -> tuple[str, dict[str, Any]]:
250+
if self._using_pybind_backend() or not self._has_scan_pattern(query):
251+
return query, parameters
252+
if self.database.read_only:
253+
return query, parameters
254+
255+
for key, value in list(parameters.items()):
256+
if not isinstance(key, str) or not self._is_python_scan_object(value):
257+
continue
258+
match = re.search(rf"(?i)\bFROM\s+(\${re.escape(key)})\b", query)
259+
if match is None:
260+
continue
261+
options_match = re.match(r"\s*\((.*?)\)", query[match.end() :], re.DOTALL)
262+
if options_match is not None and re.search(
263+
r"(?i)\bINVALID_OPTION\b", options_match.group(1)
264+
):
265+
msg = "INVALID_OPTION Option not recognized by pyArrow scanner."
266+
raise RuntimeError(msg)
267+
table_name, columns = self._create_capi_scan_table(value)
268+
if query.lstrip().upper().startswith("LOAD "):
269+
source_start = len(query) - len(query.lstrip())
270+
query = self._rewrite_load_from_capi_scan(
271+
query, source_start, match.end(), table_name, columns
272+
)
273+
else:
274+
query = self._rewrite_copy_from_capi_scan(
275+
query,
276+
match.start(1),
277+
match.end(1),
278+
table_name,
279+
columns,
280+
)
281+
parameters = dict(parameters)
282+
parameters.pop(key, None)
283+
break
284+
return query, parameters
285+
177286
def _lookup_python_object_in_frames(self, name: str) -> Any | None:
178287
frame = inspect.currentframe()
179288
if frame is None:
@@ -328,8 +437,11 @@ def execute(
328437
msg = f"Parameters must be a dict; found {type(parameters)}."
329438
raise RuntimeError(msg) # noqa: TRY004
330439

440+
scan_tables_before = set(self._capi_scan_tables)
331441
if isinstance(query, str):
332442
query, parameters = self._rewrite_local_scan_object(query, parameters)
443+
query, parameters = self._rewrite_capi_python_scan(query, parameters)
444+
scan_tables_to_drop = self._capi_scan_tables - scan_tables_before
333445

334446
if (
335447
not self._using_pybind_backend()
@@ -372,6 +484,17 @@ def execute(
372484
)
373485
if not query_result_internal.isSuccess():
374486
raise RuntimeError(query_result_internal.getErrorMessage())
487+
for table_name in scan_tables_to_drop:
488+
try:
489+
drop_result = self._connection.drop_arrow_table(table_name)
490+
if not drop_result.isSuccess():
491+
warnings.warn(
492+
drop_result.getErrorMessage(),
493+
RuntimeWarning,
494+
stacklevel=2,
495+
)
496+
finally:
497+
self._capi_scan_tables.discard(table_name)
375498
current_query_result = QueryResult(self, query_result_internal)
376499
self._register_query_result(current_query_result)
377500
if not query_result_internal.hasNextQueryResult():

test/capi_xfails.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,7 @@
22

33
CAPI_XFAILS = frozenset(
44
{
5-
# Arrow memory-backed table APIs are pybind-only today.
6-
"test/test_arrow_memory_backed_table.py::test_arrow_memory_backed_table_basic",
7-
"test/test_arrow_memory_backed_table.py::test_arrow_memory_backed_table_filtering",
8-
"test/test_arrow_memory_backed_table.py::test_arrow_memory_backed_table_with_pandas",
9-
"test/test_arrow_memory_backed_table.py::test_arrow_memory_backed_table_with_pyarrow",
10-
"test/test_arrow_memory_backed_table.py::test_arrow_memory_backed_table_empty_result",
11-
"test/test_arrow_memory_backed_table.py::test_arrow_memory_backed_table_count",
12-
"test/test_arrow_memory_backed_table.py::test_arrow_memory_backed_arrow_node_and_rel_table",
13-
"test/test_arrow_memory_backed_table.py::test_arrow_memory_backed_native_node_and_arrow_rel_table",
14-
# Scanning from Python-owned DataFrame/Arrow/Polars objects is still pybind-only.
5+
# Some Python-owned DataFrame/Polars scan cases still need pybind-compatible conversion.
156
"test/test_async_connection.py::test_async_scan_df",
167
"test/test_scan_pandas.py::test_scan_pandas",
178
"test/test_scan_pandas.py::test_scan_pandas_timestamp",
@@ -69,18 +60,6 @@
6960
"test/test_scan_polars.py::test_scan_from_parameterized_df_docs_example_1",
7061
"test/test_scan_polars.py::test_scan_from_parameterized_df_docs_example_2",
7162
"test/test_scan_polars.py::test_scan_from_df_docs_example",
72-
"test/test_scan_pyarrow.py::test_create_arrow_table_keeps_pyarrow_memory_alive",
73-
"test/test_scan_pyarrow.py::test_pyarrow_basic",
74-
"test/test_scan_pyarrow.py::test_pyarrow_copy_from_parameterized_df",
75-
"test/test_scan_pyarrow.py::test_create_arrow_table_from_pyarrow_table",
76-
"test/test_scan_pyarrow.py::test_pyarrow_to_filtered_pyarrow_table",
77-
"test/test_scan_pyarrow.py::test_pyarrow_copy_from_invalid_source",
78-
"test/test_scan_pyarrow.py::test_pyarrow_copy_from",
79-
"test/test_scan_pyarrow.py::test_pyarrow_scan_ignore_errors",
80-
"test/test_scan_pyarrow.py::test_pyarrow_scan_invalid_option",
81-
"test/test_scan_pyarrow.py::test_copy_from_pyarrow_multi_pairs",
82-
"test/test_scan_pyarrow.py::test_create_arrow_rel_table_from_pyarrow_table_query_results",
83-
"test/test_scan_pyarrow.py::test_arrow_node_and_arrow_rel_with_filtering_query",
8463
# UDF registration is still routed through pybind.
8564
"test/test_blob_parameter.py::test_bytes_param_udf",
8665
"test/test_udf.py::test_udf",

0 commit comments

Comments
 (0)