Skip to content

Commit a655e4d

Browse files
committed
Return CSR metadata as PyArrow arrays
1 parent 5dbfff7 commit a655e4d

3 files changed

Lines changed: 23 additions & 20 deletions

File tree

src_cpp/include/cached_import/py_cached_modules.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,14 @@ class PolarsCachedItem : public PythonCachedItem {
103103
};
104104

105105
class PyarrowCachedItem : public PythonCachedItem {
106+
class ArrayCachedItem : public PythonCachedItem {
107+
public:
108+
explicit ArrayCachedItem(PythonCachedItem* parent)
109+
: PythonCachedItem("Array", parent), _import_from_c("_import_from_c", this) {}
110+
111+
PythonCachedItem _import_from_c;
112+
};
113+
106114
class RecordBatchCachedItem : public PythonCachedItem {
107115
public:
108116
explicit RecordBatchCachedItem(PythonCachedItem* parent)
@@ -132,8 +140,10 @@ class PyarrowCachedItem : public PythonCachedItem {
132140
class LibCachedItem : public PythonCachedItem {
133141
public:
134142
explicit LibCachedItem(PythonCachedItem* parent)
135-
: PythonCachedItem("lib", parent), RecordBatch(this), Schema(this), Table(this) {}
143+
: PythonCachedItem("lib", parent), Array(this), RecordBatch(this), Schema(this),
144+
Table(this) {}
136145

146+
ArrayCachedItem Array;
137147
RecordBatchCachedItem RecordBatch;
138148
SchemaCachedItem Schema;
139149
TableCachedItem Table;

src_cpp/py_query_result.cpp

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -90,20 +90,17 @@ void PyQueryResult::close() {
9090

9191
namespace {
9292

93-
py::array_t<int64_t> copyToNumpyArray(const std::vector<int64_t>& values) {
94-
auto result = py::array_t<int64_t>(values.size());
95-
auto* data = static_cast<int64_t*>(result.request().ptr);
96-
std::copy(values.begin(), values.end(), data);
97-
return result;
93+
py::object importCSRArrowArray(lbug::main::ArrowQueryResult::CSRArrowArray& array) {
94+
auto arrayImportFunc = importCache->pyarrow.lib.Array._import_from_c();
95+
return arrayImportFunc((std::uint64_t)&array.array, (std::uint64_t)&array.schema);
9896
}
9997

100-
py::dict buildCSRResult(std::vector<int64_t> indptr, std::vector<int64_t> indices,
101-
std::vector<int64_t> edgeIDs, bool includeEdgeIDs) {
98+
py::dict buildCSRResult(lbug::main::ArrowQueryResult::CSRArrowArrays arrays) {
10299
py::dict result;
103-
result["indptr"] = copyToNumpyArray(indptr);
104-
result["indices"] = copyToNumpyArray(indices);
105-
if (includeEdgeIDs) {
106-
result["edge_ids"] = copyToNumpyArray(edgeIDs);
100+
result["indptr"] = importCSRArrowArray(arrays.indptr);
101+
result["indices"] = importCSRArrowArray(arrays.indices);
102+
if (arrays.edgeIDs.has_value()) {
103+
result["edge_ids"] = importCSRArrowArray(*arrays.edgeIDs);
107104
} else {
108105
result["edge_ids"] = py::none();
109106
}
@@ -377,9 +374,7 @@ lbug::pyarrow::Table PyQueryResult::getAsArrow(std::int64_t chunkSize,
377374
py::dict PyQueryResult::getCSR() {
378375
if (auto* arrowQueryResult = dynamic_cast<lbug::main::ArrowQueryResult*>(queryResult);
379376
arrowQueryResult != nullptr && arrowQueryResult->hasCSRMetadata()) {
380-
const auto& metadata = arrowQueryResult->getCSRMetadata();
381-
return buildCSRResult(metadata.indptr, metadata.indices, metadata.edgeIDs,
382-
metadata.hasEdgeIDs);
377+
return buildCSRResult(arrowQueryResult->getCSRArrowArrays());
383378
}
384379
throw RuntimeException(
385380
"CSR export is only supported for Arrow query results with native CSR metadata.");

src_py/query_result.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -560,13 +560,11 @@ def csr(self) -> CSRResult:
560560
"""
561561
self.check_for_query_result_close()
562562

563-
import pyarrow as pa
564-
565563
csr = self._query_result.getCSR()
566564
return CSRResult(
567-
indptr=pa.array(csr["indptr"]),
568-
indices=pa.array(csr["indices"]),
569-
edge_ids=(None if csr["edge_ids"] is None else pa.array(csr["edge_ids"])),
565+
indptr=csr["indptr"],
566+
indices=csr["indices"],
567+
edge_ids=csr["edge_ids"],
570568
)
571569

572570

0 commit comments

Comments
 (0)