Skip to content

Commit 0a4a672

Browse files
committed
Add Arrow CSR relationship table Python API
1 parent d5fdcc4 commit 0a4a672

7 files changed

Lines changed: 246 additions & 48 deletions

File tree

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
"""
2+
Register icebug-disk Parquet files as Arrow memory-backed tables.
3+
4+
The example keeps the data in PyArrow tables and exposes it to Ladybug as
5+
ice-mem/Arrow tables. Relationship tables can be either FLAT or CSR.
6+
"""
7+
8+
from __future__ import annotations
9+
10+
import argparse
11+
from pathlib import Path
12+
13+
import ladybug as lb
14+
import pyarrow.parquet as pq
15+
16+
17+
def register_flat(
18+
conn: lb.Connection,
19+
data_dir: Path,
20+
node_table: str,
21+
rel_table: str,
22+
src_table: str,
23+
dst_table: str,
24+
) -> None:
25+
"""Register FLAT icebug-disk Parquet files as Arrow memory-backed tables."""
26+
nodes = pq.read_table(data_dir / f"nodes_{node_table}.parquet")
27+
rels = pq.read_table(data_dir / f"rels_{rel_table}.parquet")
28+
29+
conn.create_arrow_table(node_table, nodes)
30+
conn.create_arrow_rel_table(
31+
rel_table,
32+
rels,
33+
src_table,
34+
dst_table,
35+
layout=lb.ArrowRelTableLayout.FLAT,
36+
)
37+
38+
39+
def register_csr(
40+
conn: lb.Connection,
41+
data_dir: Path,
42+
node_table: str,
43+
rel_table: str,
44+
src_table: str,
45+
dst_table: str,
46+
) -> None:
47+
"""Register CSR icebug-disk Parquet files as Arrow memory-backed tables."""
48+
nodes = pq.read_table(data_dir / f"nodes_{node_table}.parquet")
49+
indices = pq.read_table(data_dir / f"indices_{rel_table}.parquet")
50+
indptr = pq.read_table(data_dir / f"indptr_{rel_table}.parquet")
51+
52+
conn.create_arrow_table(node_table, nodes)
53+
conn.create_arrow_rel_table(
54+
rel_table,
55+
indices,
56+
src_table,
57+
dst_table,
58+
layout=lb.ArrowRelTableLayout.CSR,
59+
indptr_dataframe=indptr,
60+
)
61+
62+
63+
def main() -> None:
64+
"""Run the example."""
65+
parser = argparse.ArgumentParser()
66+
parser.add_argument("data_dir", type=Path)
67+
parser.add_argument("--layout", choices=["flat", "csr"], default="csr")
68+
parser.add_argument("--node-table", required=True)
69+
parser.add_argument("--rel-table", required=True)
70+
parser.add_argument("--src-table")
71+
parser.add_argument("--dst-table")
72+
args = parser.parse_args()
73+
74+
src_table = args.src_table or args.node_table
75+
dst_table = args.dst_table or args.node_table
76+
77+
db = lb.Database(":memory:")
78+
conn = db.connect()
79+
if args.layout == "flat":
80+
register_flat(conn, args.data_dir, args.node_table, args.rel_table, src_table, dst_table)
81+
else:
82+
register_csr(conn, args.data_dir, args.node_table, args.rel_table, src_table, dst_table)
83+
84+
result = conn.execute(f"MATCH (a:{src_table})-[r:{args.rel_table}]->(b:{dst_table}) RETURN COUNT(*)")
85+
print(result.get_next()[0])
86+
87+
88+
if __name__ == "__main__":
89+
main()

src_cpp/include/py_connection.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ class PyConnection {
5656
std::unique_ptr<PyQueryResult> createArrowTable(const std::string& tableName,
5757
py::object arrowTable);
5858
std::unique_ptr<PyQueryResult> createArrowRelTable(const std::string& tableName,
59-
py::object arrowTable, const std::string& srcTableName, const std::string& dstTableName);
59+
py::object arrowTable, const std::string& srcTableName, const std::string& dstTableName,
60+
const std::string& layout, py::object indptrTable);
6061
std::unique_ptr<PyQueryResult> dropArrowTable(const std::string& tableName);
6162

6263
static Value transformPythonValue(const py::handle& val);

src_cpp/py_connection.cpp

Lines changed: 61 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include "include/py_connection.h"
22

3+
#include <algorithm>
4+
#include <cctype>
35
#include <utility>
46

57
#include "cached_import/py_cached_import.h"
@@ -52,7 +54,8 @@ void PyConnection::initialize(py::handle& m) {
5254
.def("create_arrow_table", &PyConnection::createArrowTable, py::arg("table_name"),
5355
py::arg("arrow_table"))
5456
.def("create_arrow_rel_table", &PyConnection::createArrowRelTable, py::arg("table_name"),
55-
py::arg("arrow_table"), py::arg("src_table_name"), py::arg("dst_table_name"))
57+
py::arg("arrow_table"), py::arg("src_table_name"), py::arg("dst_table_name"),
58+
py::arg("layout") = "FLAT", py::arg("indptr_table") = py::none())
5659
.def("drop_arrow_table", &PyConnection::dropArrowTable, py::arg("table_name"));
5760
PyDateTime_IMPORT;
5861
}
@@ -1013,79 +1016,91 @@ void PyConnection::removeScalarFunction(const std::string& name) {
10131016
refState().ref().removeUDFFunction(name);
10141017
}
10151018

1016-
std::unique_ptr<PyQueryResult> PyConnection::createArrowTable(const std::string& tableName,
1017-
py::object arrowTable) {
1018-
auto& stateRef = refState();
1019-
py::gil_scoped_acquire acquire;
1019+
struct ExportedArrowTable {
1020+
ArrowSchemaWrapper schema;
1021+
std::vector<ArrowArrayWrapper> arrays;
1022+
py::list keepAlive;
1023+
};
10201024

1021-
// Convert pandas/polars to pyarrow if needed
1025+
static py::object normalizeArrowTable(py::object arrowTable) {
10221026
if (PyConnection::isPandasDataframe(arrowTable)) {
1023-
arrowTable = importCache->pyarrow.lib.Table.from_pandas()(arrowTable);
1024-
} else if (PyConnection::isPolarsDataframe(arrowTable)) {
1025-
arrowTable = arrowTable.attr("to_arrow")();
1027+
return importCache->pyarrow.lib.Table.from_pandas()(arrowTable);
1028+
}
1029+
if (PyConnection::isPolarsDataframe(arrowTable)) {
1030+
return arrowTable.attr("to_arrow")();
10261031
}
1027-
1028-
// Ensure we have a pyarrow table
10291032
if (!PyConnection::isPyArrowTable(arrowTable)) {
10301033
throw RuntimeException("Expected a pyarrow Table, polars DataFrame, or pandas DataFrame");
10311034
}
1035+
return arrowTable;
1036+
}
10321037

1033-
// Export Arrow table to C Data Interface
1034-
// First, get the schema
1035-
ArrowSchemaWrapper schema;
1036-
arrowTable.attr("schema").attr("_export_to_c")(reinterpret_cast<uint64_t>(&schema));
1038+
static ExportedArrowTable exportArrowTable(py::object arrowTable) {
1039+
arrowTable = normalizeArrowTable(std::move(arrowTable));
1040+
1041+
ExportedArrowTable exported;
1042+
arrowTable.attr("schema").attr("_export_to_c")(reinterpret_cast<uint64_t>(&exported.schema));
10371043

1038-
// Get the batches (arrays)
1039-
std::vector<ArrowArrayWrapper> arrays;
10401044
py::list batches = arrowTable.attr("to_batches")();
10411045
for (auto& batch : batches) {
1042-
arrays.emplace_back();
1043-
batch.attr("_export_to_c")(reinterpret_cast<uint64_t>(&arrays.back()));
1046+
exported.arrays.emplace_back();
1047+
batch.attr("_export_to_c")(reinterpret_cast<uint64_t>(&exported.arrays.back()));
10441048
}
10451049

10461050
// Keep pyarrow producers alive while C++ accesses exported Arrow memory.
1047-
py::list keepAlive;
1048-
keepAlive.append(arrowTable);
1049-
keepAlive.append(batches);
1051+
exported.keepAlive.append(arrowTable);
1052+
exported.keepAlive.append(batches);
1053+
return exported;
1054+
}
1055+
1056+
std::unique_ptr<PyQueryResult> PyConnection::createArrowTable(const std::string& tableName,
1057+
py::object arrowTable) {
1058+
auto& stateRef = refState();
1059+
py::gil_scoped_acquire acquire;
10501060

1061+
auto exported = exportArrowTable(std::move(arrowTable));
10511062
auto result = ArrowTableSupport::createViewFromArrowTable(stateRef.ref(), tableName,
1052-
std::move(schema), std::move(arrays));
1063+
std::move(exported.schema), std::move(exported.arrays));
10531064
if (result.queryResult && result.queryResult->isSuccess()) {
1054-
stateRef.arrowTableRefs[tableName] = std::move(keepAlive);
1065+
stateRef.arrowTableRefs[tableName] = std::move(exported.keepAlive);
10551066
}
10561067

10571068
return checkAndWrapQueryResult(result.queryResult, state);
10581069
}
10591070

10601071
std::unique_ptr<PyQueryResult> PyConnection::createArrowRelTable(const std::string& tableName,
1061-
py::object arrowTable, const std::string& srcTableName, const std::string& dstTableName) {
1072+
py::object arrowTable, const std::string& srcTableName, const std::string& dstTableName,
1073+
const std::string& layout, py::object indptrTable) {
10621074
auto& stateRef = refState();
10631075
py::gil_scoped_acquire acquire;
10641076

1065-
if (PyConnection::isPandasDataframe(arrowTable)) {
1066-
arrowTable = importCache->pyarrow.lib.Table.from_pandas()(arrowTable);
1067-
} else if (PyConnection::isPolarsDataframe(arrowTable)) {
1068-
arrowTable = arrowTable.attr("to_arrow")();
1069-
}
1070-
if (!PyConnection::isPyArrowTable(arrowTable)) {
1071-
throw RuntimeException("Expected a pyarrow Table, polars DataFrame, or pandas DataFrame");
1072-
}
1073-
1074-
ArrowSchemaWrapper schema;
1075-
arrowTable.attr("schema").attr("_export_to_c")(reinterpret_cast<uint64_t>(&schema));
1076-
std::vector<ArrowArrayWrapper> arrays;
1077-
py::list batches = arrowTable.attr("to_batches")();
1078-
for (auto& batch : batches) {
1079-
arrays.emplace_back();
1080-
batch.attr("_export_to_c")(reinterpret_cast<uint64_t>(&arrays.back()));
1081-
}
1077+
auto layoutUpper = layout;
1078+
std::transform(layoutUpper.begin(), layoutUpper.end(), layoutUpper.begin(),
1079+
[](unsigned char c) { return static_cast<char>(std::toupper(c)); });
10821080

1081+
auto exported = exportArrowTable(std::move(arrowTable));
1082+
ArrowTableCreationResult result;
10831083
py::list keepAlive;
1084-
keepAlive.append(arrowTable);
1085-
keepAlive.append(batches);
1084+
keepAlive.append(exported.keepAlive);
10861085

1087-
auto result = ArrowTableSupport::createRelTableFromArrowTable(stateRef.ref(), tableName,
1088-
srcTableName, dstTableName, std::move(schema), std::move(arrays));
1086+
if (layoutUpper == "FLAT") {
1087+
if (!py::none().is(indptrTable)) {
1088+
throw RuntimeException("indptr_table is only valid for CSR Arrow relationship tables");
1089+
}
1090+
result = ArrowTableSupport::createRelTableFromArrowTable(stateRef.ref(), tableName,
1091+
srcTableName, dstTableName, std::move(exported.schema), std::move(exported.arrays));
1092+
} else if (layoutUpper == "CSR") {
1093+
if (py::none().is(indptrTable)) {
1094+
throw RuntimeException("indptr_table is required for CSR Arrow relationship tables");
1095+
}
1096+
auto exportedIndptr = exportArrowTable(std::move(indptrTable));
1097+
keepAlive.append(exportedIndptr.keepAlive);
1098+
result = ArrowTableSupport::createRelTableFromArrowCSR(stateRef.ref(), tableName,
1099+
srcTableName, dstTableName, std::move(exported.schema), std::move(exported.arrays),
1100+
std::move(exportedIndptr.schema), std::move(exportedIndptr.arrays));
1101+
} else {
1102+
throw RuntimeException("Arrow relationship table layout must be FLAT or CSR");
1103+
}
10891104
if (result.queryResult && result.queryResult->isSuccess()) {
10901105
stateRef.arrowTableRefs[tableName] = std::move(keepAlive);
10911106
}

src_py/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
from .database import Database # noqa: E402
5858
from .prepared_statement import PreparedStatement # noqa: E402
5959
from .query_result import ArrowQueryResult, CSRResult, QueryResult # noqa: E402
60-
from .types import Type # noqa: E402
60+
from .types import ArrowRelTableLayout, Type # noqa: E402
6161

6262
_VERSION_INFO: tuple[str, int] | None = None
6363

@@ -81,6 +81,7 @@ def __getattr__(name: str) -> str | int:
8181
__all__ = [
8282
"AsyncConnection",
8383
"ArrowQueryResult",
84+
"ArrowRelTableLayout",
8485
"Connection",
8586
"CSRResult",
8687
"Database",

src_py/connection.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from ._backend import get_capi_module, get_pybind_module
1212
from .prepared_statement import PreparedStatement
1313
from .query_result import ArrowQueryResult, QueryResult
14+
from .types import ArrowRelTableLayout
1415

1516
if TYPE_CHECKING:
1617
import sys
@@ -811,6 +812,8 @@ def create_arrow_rel_table(
811812
dataframe: Any,
812813
src_table_name: str,
813814
dst_table_name: str,
815+
layout: ArrowRelTableLayout | str = ArrowRelTableLayout.FLAT,
816+
indptr_dataframe: Any | None = None,
814817
) -> QueryResult:
815818
"""
816819
Create an Arrow memory-backed relationship table from a DataFrame.
@@ -829,19 +832,37 @@ def create_arrow_rel_table(
829832
dst_table_name : str
830833
Destination node table name in the FROM/TO pair.
831834
835+
layout : ArrowRelTableLayout | str
836+
Relationship layout. FLAT expects ``dataframe`` to contain ``from``
837+
and ``to`` endpoint columns. CSR expects ``dataframe`` to contain a
838+
``to`` destination offset column plus properties, and
839+
``indptr_dataframe`` to contain source offsets.
840+
841+
indptr_dataframe : Any | None
842+
A pandas DataFrame, polars DataFrame, or PyArrow table containing
843+
CSR source offsets. Required when ``layout`` is CSR.
844+
832845
Returns
833846
-------
834847
QueryResult
835848
Result of the table creation query.
836849
837850
"""
838851
self.init_connection()
852+
layout_value = (
853+
layout.value if isinstance(layout, ArrowRelTableLayout) else str(layout)
854+
).upper()
855+
if layout_value == ArrowRelTableLayout.CSR.value and indptr_dataframe is None:
856+
msg = "indptr_dataframe is required when layout is CSR"
857+
raise ValueError(msg)
839858
try:
840859
query_result_internal = self._connection.create_arrow_rel_table(
841860
table_name,
842861
dataframe,
843862
src_table_name,
844863
dst_table_name,
864+
layout_value,
865+
indptr_dataframe,
845866
)
846867
except NotImplementedError:
847868
py_connection = self._get_pybind_connection()
@@ -853,6 +874,8 @@ def create_arrow_rel_table(
853874
dataframe,
854875
src_table_name,
855876
dst_table_name,
877+
layout_value,
878+
indptr_dataframe,
856879
)
857880
if not query_result_internal.isSuccess():
858881
raise RuntimeError(query_result_internal.getErrorMessage())

src_py/types.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,10 @@ class Type(Enum):
3737
STRUCT = "STRUCT"
3838
MAP = "MAP"
3939
UNION = "UNION"
40+
41+
42+
class ArrowRelTableLayout(Enum):
43+
"""Arrow-backed relationship table layout."""
44+
45+
FLAT = "FLAT"
46+
CSR = "CSR"

0 commit comments

Comments
 (0)