Skip to content

Commit 2f0444f

Browse files
committed
Add dedicated Arrow CSR result type
1 parent abcec65 commit 2f0444f

9 files changed

Lines changed: 216 additions & 2 deletions

File tree

src_cpp/include/py_connection.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ class PyConnection {
2929
const py::dict& params);
3030

3131
std::unique_ptr<PyQueryResult> query(const std::string& statement);
32+
std::unique_ptr<PyQueryResult> queryAsArrow(const std::string& statement,
33+
int64_t chunkSize);
3234

3335
void setMaxNumThreadForExec(uint64_t numThreads);
3436

src_cpp/include/py_query_result.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class PyQueryResult {
3535
py::object getAsDF();
3636

3737
lbug::pyarrow::Table getAsArrow(std::int64_t chunkSize, bool fallbackExtensionTypes);
38+
py::dict getCSR();
3839

3940
py::list getColumnDataTypes();
4041

src_cpp/py_connection.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ void PyConnection::initialize(py::handle& m) {
3131
.def("execute", &PyConnection::execute, py::arg("prepared_statement"),
3232
py::arg("parameters") = py::dict())
3333
.def("query", &PyConnection::query, py::arg("statement"))
34+
.def("query_as_arrow", &PyConnection::queryAsArrow, py::arg("statement"),
35+
py::arg("chunk_size"))
3436
.def("set_max_threads_for_exec", &PyConnection::setMaxNumThreadForExec,
3537
py::arg("num_threads"))
3638
.def("prepare", &PyConnection::prepare, py::arg("query"),
@@ -175,6 +177,14 @@ std::unique_ptr<PyQueryResult> PyConnection::query(const std::string& statement)
175177
return checkAndWrapQueryResult(queryResult);
176178
}
177179

180+
std::unique_ptr<PyQueryResult> PyConnection::queryAsArrow(const std::string& statement,
181+
int64_t chunkSize) {
182+
py::gil_scoped_release release;
183+
auto queryResult = conn->queryAsArrow(statement, chunkSize);
184+
py::gil_scoped_acquire acquire;
185+
return checkAndWrapQueryResult(queryResult);
186+
}
187+
178188
void PyConnection::setMaxNumThreadForExec(uint64_t numThreads) {
179189
conn->setMaxNumThreadForExec(numThreads);
180190
}

src_cpp/py_query_result.cpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@
77
#include "common/arrow/arrow_row_batch.h"
88
#include "common/constants.h"
99
#include "common/exception/not_implemented.h"
10+
#include "common/exception/runtime.h"
1011
#include "common/types/uuid.h"
1112
#include "common/types/value/nested.h"
1213
#include "common/types/value/node.h"
1314
#include "common/types/value/rel.h"
1415
#include "datetime.h" // python lib
1516
#include "include/py_query_result_converter.h"
17+
#include "main/query_result/arrow_query_result.h"
1618

1719
using namespace lbug::common;
1820
using lbug::importCache;
@@ -30,6 +32,7 @@ void PyQueryResult::initialize(py::handle& m) {
3032
.def("close", &PyQueryResult::close)
3133
.def("getAsDF", &PyQueryResult::getAsDF)
3234
.def("getAsArrow", &PyQueryResult::getAsArrow)
35+
.def("getCSR", &PyQueryResult::getCSR)
3336
.def("getColumnNames", &PyQueryResult::getColumnNames)
3437
.def("getColumnDataTypes", &PyQueryResult::getColumnDataTypes)
3538
.def("resetIterator", &PyQueryResult::resetIterator)
@@ -85,6 +88,30 @@ void PyQueryResult::close() {
8588
}
8689
}
8790

91+
namespace {
92+
93+
py::array_t<int64_t> copyToNumpyArray(const std::vector<int64_t>& values) {
94+
auto result = py::array_t<int64_t>(values.size());
95+
auto* data = static_cast<int64_t*>(result.request().ptr);
96+
std::copy(values.begin(), values.end(), data);
97+
return result;
98+
}
99+
100+
py::dict buildCSRResult(std::vector<int64_t> indptr, std::vector<int64_t> indices,
101+
std::vector<int64_t> edgeIDs, bool includeEdgeIDs) {
102+
py::dict result;
103+
result["indptr"] = copyToNumpyArray(indptr);
104+
result["indices"] = copyToNumpyArray(indices);
105+
if (includeEdgeIDs) {
106+
result["edge_ids"] = copyToNumpyArray(edgeIDs);
107+
} else {
108+
result["edge_ids"] = py::none();
109+
}
110+
return result;
111+
}
112+
113+
} // namespace
114+
88115
static py::object converTimestampToPyObject(timestamp_t& timestamp) {
89116
int32_t year = 0, month = 0, day = 0, hour = 0, min = 0, sec = 0, micros = 0;
90117
date_t date;
@@ -320,6 +347,23 @@ py::object PyQueryResult::getArrowChunks(const std::vector<LogicalType>& types,
320347

321348
lbug::pyarrow::Table PyQueryResult::getAsArrow(std::int64_t chunkSize,
322349
bool fallbackExtensionTypes) {
350+
if (queryResult->getType() == QueryResultType::ARROW) {
351+
auto types = queryResult->getColumnDataTypes();
352+
auto names = queryResult->getColumnNames();
353+
py::list batches;
354+
auto batchImportFunc = importCache->pyarrow.lib.RecordBatch._import_from_c();
355+
while (queryResult->hasNextArrowChunk()) {
356+
auto data = queryResult->getNextArrowChunk(chunkSize);
357+
auto schema = ArrowConverter::toArrowSchema(types, names, fallbackExtensionTypes);
358+
batches.append(
359+
batchImportFunc((std::uint64_t)data.get(), (std::uint64_t)schema.get()));
360+
}
361+
auto schema = ArrowConverter::toArrowSchema(types, names, fallbackExtensionTypes);
362+
auto fromBatchesFunc = importCache->pyarrow.lib.Table.from_batches();
363+
auto schemaImportFunc = importCache->pyarrow.lib.Schema._import_from_c();
364+
auto schemaObj = schemaImportFunc((std::uint64_t)schema.get());
365+
return py::cast<lbug::pyarrow::Table>(fromBatchesFunc(batches, schemaObj));
366+
}
323367
auto types = queryResult->getColumnDataTypes();
324368
auto names = queryResult->getColumnNames();
325369
py::list batches = getArrowChunks(types, names, chunkSize, fallbackExtensionTypes);
@@ -330,6 +374,17 @@ lbug::pyarrow::Table PyQueryResult::getAsArrow(std::int64_t chunkSize,
330374
return py::cast<lbug::pyarrow::Table>(fromBatchesFunc(batches, schemaObj));
331375
}
332376

377+
py::dict PyQueryResult::getCSR() {
378+
if (auto* arrowQueryResult = dynamic_cast<lbug::main::ArrowQueryResult*>(queryResult);
379+
arrowQueryResult != nullptr && arrowQueryResult->hasCSRMetadata()) {
380+
const auto& metadata = arrowQueryResult->getCSRMetadata();
381+
return buildCSRResult(metadata.indptr, metadata.indices, metadata.edgeIDs,
382+
metadata.hasEdgeIDs);
383+
}
384+
throw RuntimeException(
385+
"CSR export is only supported for Arrow query results with native CSR metadata.");
386+
}
387+
333388
py::list PyQueryResult::getColumnDataTypes() {
334389
auto columnDataTypes = queryResult->getColumnDataTypes();
335390
py::tuple result(columnDataTypes.size());

src_py/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
from .connection import Connection # noqa: E402
5757
from .database import Database # noqa: E402
5858
from .prepared_statement import PreparedStatement # noqa: E402
59-
from .query_result import QueryResult # noqa: E402
59+
from .query_result import ArrowQueryResult, CSRResult, QueryResult # noqa: E402
6060
from .types import Type # noqa: E402
6161

6262
_VERSION_INFO: tuple[str, int] | None = None
@@ -80,7 +80,9 @@ def __getattr__(name: str) -> str | int:
8080

8181
__all__ = [
8282
"AsyncConnection",
83+
"ArrowQueryResult",
8384
"Connection",
85+
"CSRResult",
8486
"Database",
8587
"PreparedStatement",
8688
"QueryResult",

src_py/_lbug_capi.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1229,6 +1229,11 @@ def getAsArrow(self, *_args: Any, **_kwargs: Any) -> Any:
12291229
"Arrow export is not yet implemented in C-API backend"
12301230
)
12311231

1232+
def getCSR(self, *_args: Any, **_kwargs: Any) -> Any:
1233+
raise NotImplementedError(
1234+
"CSR export is not yet implemented in C-API backend"
1235+
)
1236+
12321237
def getAsDF(self) -> Any:
12331238
raise NotImplementedError(
12341239
"DataFrame export is not yet implemented in C-API backend"

src_py/connection.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from ._backend import get_capi_module, get_pybind_module
1010
from .prepared_statement import PreparedStatement
11-
from .query_result import QueryResult
11+
from .query_result import ArrowQueryResult, QueryResult
1212

1313
if TYPE_CHECKING:
1414
import sys
@@ -369,6 +369,27 @@ def execute(
369369
all_query_results.append(next_query_result)
370370
return all_query_results
371371

372+
def query_as_arrow(self, query: str, chunk_size: int) -> ArrowQueryResult:
373+
"""
374+
Execute a query with the native Arrow collector path.
375+
376+
This is the efficient path for CSR-aware Arrow export.
377+
"""
378+
self.init_connection()
379+
if not self._using_pybind_backend():
380+
msg = "query_as_arrow requires the pybind backend"
381+
raise NotImplementedError(msg)
382+
query_result_internal = self._get_pybind_connection().query_as_arrow(
383+
query, chunk_size
384+
)
385+
if not query_result_internal.isSuccess():
386+
raise RuntimeError(query_result_internal.getErrorMessage())
387+
current_query_result = ArrowQueryResult(
388+
self, query_result_internal, native_chunk_size=chunk_size
389+
)
390+
self._register_query_result(current_query_result)
391+
return current_query_result
392+
372393
def _prepare(
373394
self,
374395
query: str,

src_py/query_result.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from dataclasses import dataclass
34
from typing import TYPE_CHECKING
45

56
from .constants import DST, ID, LABEL, NODES, RELS, SRC
@@ -525,6 +526,59 @@ def rows_as_dict(self, state=True) -> Self:
525526
return self
526527

527528

529+
class ArrowQueryResult(QueryResult):
530+
"""QueryResult backed by the native Arrow collector path."""
531+
532+
def __init__(
533+
self, connection: Any, query_result: Any, native_chunk_size: int
534+
) -> None:
535+
super().__init__(connection, query_result)
536+
self._native_chunk_size = native_chunk_size
537+
538+
def get_as_arrow(
539+
self, chunk_size: int | None = None, *, fallbackExtensionTypes: bool = False
540+
) -> pa.Table:
541+
"""
542+
Get the query result as a PyArrow Table.
543+
544+
Arrow-native results preserve the execution-time chunking chosen by
545+
`Connection.query_as_arrow(...)`. Requesting `None`, `0`, or `-1`
546+
reuses that native chunk size instead of rechunking the result.
547+
"""
548+
if chunk_size is None or chunk_size <= 0:
549+
chunk_size = self._native_chunk_size
550+
return super().get_as_arrow(
551+
chunk_size, fallbackExtensionTypes=fallbackExtensionTypes
552+
)
553+
554+
def csr(self) -> CSRResult:
555+
"""
556+
Get native CSR arrays from an Arrow query result.
557+
558+
This is available only for Arrow results with CSR metadata, typically
559+
from `Connection.query_as_arrow(...)` on relationship-shaped projections.
560+
"""
561+
self.check_for_query_result_close()
562+
563+
import pyarrow as pa
564+
565+
csr = self._query_result.getCSR()
566+
return CSRResult(
567+
indptr=pa.array(csr["indptr"]),
568+
indices=pa.array(csr["indices"]),
569+
edge_ids=(
570+
None if csr["edge_ids"] is None else pa.array(csr["edge_ids"])
571+
),
572+
)
573+
574+
575+
@dataclass(frozen=True)
576+
class CSRResult:
577+
indptr: pa.Array
578+
indices: pa.Array
579+
edge_ids: pa.Array | None = None
580+
581+
528582
def _row_to_dict(columns: list[str], row: list[Any]) -> dict[str, Any]:
529583
if len(columns) != len(row):
530584
msg = "Number of columns in output row does not match number of columns"

test/test_arrow.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -772,3 +772,67 @@ def test_to_arrow1(conn: lb.Connection) -> None:
772772
-1
773773
) # what is a chunk size of -1 even supposed to mean?
774774
assert arrow_tbl == []
775+
776+
777+
def test_query_as_arrow_csr_with_rel_ids(conn_db_readonly: ConnDB) -> None:
778+
conn, _ = conn_db_readonly
779+
query = """
780+
MATCH (a:person)-[b:knows]->(c:person)
781+
RETURN a.rowid, b.rowid, c.rowid
782+
"""
783+
rows = conn.execute(query).get_all()
784+
csr = conn.query_as_arrow(query, 8).csr()
785+
786+
assert csr.edge_ids is not None
787+
788+
reconstructed = []
789+
indptr = csr.indptr.to_pylist()
790+
indices = csr.indices.to_pylist()
791+
edge_ids = csr.edge_ids.to_pylist()
792+
for src_rowid in range(len(indptr) - 1):
793+
for idx in range(indptr[src_rowid], indptr[src_rowid + 1]):
794+
reconstructed.append([src_rowid, edge_ids[idx], indices[idx]])
795+
796+
assert reconstructed == rows
797+
798+
799+
def test_query_as_arrow_csr_with_extra_columns(conn_db_readonly: ConnDB) -> None:
800+
conn, _ = conn_db_readonly
801+
query = """
802+
MATCH (a:person)-[b:knows]->(c:person)
803+
RETURN a.rowid, b.rowid, c.rowid, b.date, c.fName
804+
"""
805+
result = conn.query_as_arrow(query, 8)
806+
csr = result.csr()
807+
arrow_tbl = result.get_as_arrow(0)
808+
809+
assert csr.edge_ids is not None
810+
assert arrow_tbl.column_names == ["a.rowid", "b.rowid", "c.rowid", "b.date", "c.fName"]
811+
assert len(csr.indptr) >= 2
812+
813+
814+
def test_query_as_arrow_csr_without_rel_ids(conn_db_readonly: ConnDB) -> None:
815+
conn, _ = conn_db_readonly
816+
query = """
817+
MATCH (a:person)-[:knows]->(c:person)
818+
RETURN a.rowid, c.rowid
819+
"""
820+
rows = conn.execute(query).get_all()
821+
csr = conn.query_as_arrow(query, 8).csr()
822+
823+
assert csr.edge_ids is None
824+
825+
reconstructed = []
826+
indptr = csr.indptr.to_pylist()
827+
indices = csr.indices.to_pylist()
828+
for src_rowid in range(len(indptr) - 1):
829+
for idx in range(indptr[src_rowid], indptr[src_rowid + 1]):
830+
reconstructed.append([src_rowid, indices[idx]])
831+
832+
assert reconstructed == rows
833+
834+
835+
def test_query_as_arrow_csr_rejects_non_csr_shape(conn_db_readonly: ConnDB) -> None:
836+
conn, _ = conn_db_readonly
837+
with pytest.raises(RuntimeError, match="CSR export is only supported"):
838+
conn.query_as_arrow("MATCH (a:person) RETURN a.fName", 8).csr()

0 commit comments

Comments
 (0)