Skip to content

Commit ddd1947

Browse files
committed
Fix Arrow relationship C API adapter
1 parent 1a07b1b commit ddd1947

1 file changed

Lines changed: 60 additions & 10 deletions

File tree

src_py/_lbug_capi.py

Lines changed: 60 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,21 @@ def _setup_signatures() -> None:
332332
]
333333
_LIB.lbug_connection_create_arrow_rel_table.restype = ctypes.c_int
334334

335+
_LIB.lbug_connection_create_arrow_rel_table_csr.argtypes = [
336+
ctypes.POINTER(_LbugConnection),
337+
ctypes.c_char_p,
338+
ctypes.c_char_p,
339+
ctypes.c_char_p,
340+
ctypes.POINTER(_ArrowSchema),
341+
ctypes.POINTER(_ArrowArray),
342+
ctypes.c_uint64,
343+
ctypes.POINTER(_ArrowSchema),
344+
ctypes.POINTER(_ArrowArray),
345+
ctypes.c_uint64,
346+
ctypes.POINTER(_LbugQueryResult),
347+
]
348+
_LIB.lbug_connection_create_arrow_rel_table_csr.restype = ctypes.c_int
349+
335350
_LIB.lbug_connection_drop_arrow_table.argtypes = [
336351
ctypes.POINTER(_LbugConnection),
337352
ctypes.c_char_p,
@@ -2324,19 +2339,54 @@ def create_arrow_rel_table(
23242339
dataframe: Any,
23252340
src_table_name: str,
23262341
dst_table_name: str,
2342+
layout: Any = "FLAT",
2343+
indptr_dataframe: Any | None = None,
23272344
) -> QueryResult:
2345+
layout_value = getattr(layout, "value", layout)
2346+
layout_value = str(layout_value).upper()
2347+
if layout_value not in {"FLAT", "CSR"}:
2348+
msg = "Arrow relationship table layout must be FLAT or CSR"
2349+
raise RuntimeError(msg)
2350+
if layout_value == "FLAT" and indptr_dataframe is not None:
2351+
msg = "indptr_dataframe is only valid for CSR Arrow relationship tables"
2352+
raise RuntimeError(msg)
2353+
if layout_value == "CSR" and indptr_dataframe is None:
2354+
msg = "indptr_dataframe is required for CSR Arrow relationship tables"
2355+
raise RuntimeError(msg)
2356+
23282357
_table, schema, arrays, _batches = self._export_arrow_table(dataframe)
23292358
result = _LbugQueryResult()
2330-
state = _LIB.lbug_connection_create_arrow_rel_table(
2331-
ctypes.byref(self._connection),
2332-
table_name.encode("utf-8"),
2333-
src_table_name.encode("utf-8"),
2334-
dst_table_name.encode("utf-8"),
2335-
ctypes.byref(schema),
2336-
arrays,
2337-
len(arrays),
2338-
ctypes.byref(result),
2339-
)
2359+
if layout_value == "FLAT":
2360+
state = _LIB.lbug_connection_create_arrow_rel_table(
2361+
ctypes.byref(self._connection),
2362+
table_name.encode("utf-8"),
2363+
src_table_name.encode("utf-8"),
2364+
dst_table_name.encode("utf-8"),
2365+
ctypes.byref(schema),
2366+
arrays,
2367+
len(arrays),
2368+
ctypes.byref(result),
2369+
)
2370+
else:
2371+
(
2372+
_indptr_table,
2373+
indptr_schema,
2374+
indptr_arrays,
2375+
_indptr_batches,
2376+
) = self._export_arrow_table(indptr_dataframe)
2377+
state = _LIB.lbug_connection_create_arrow_rel_table_csr(
2378+
ctypes.byref(self._connection),
2379+
table_name.encode("utf-8"),
2380+
src_table_name.encode("utf-8"),
2381+
dst_table_name.encode("utf-8"),
2382+
ctypes.byref(schema),
2383+
arrays,
2384+
len(arrays),
2385+
ctypes.byref(indptr_schema),
2386+
indptr_arrays,
2387+
len(indptr_arrays),
2388+
ctypes.byref(result),
2389+
)
23402390
if state != _LBUG_SUCCESS and not result._query_result:
23412391
_check_state(state, "Failed to create Arrow relationship table")
23422392
return QueryResult(result)

0 commit comments

Comments
 (0)