77#include " common/arrow/arrow_row_batch.h"
88#include " common/constants.h"
99#include " common/exception/not_implemented.h"
10+ #include " common/exception/runtime.h"
1011#include " common/types/uuid.h"
1112#include " common/types/value/nested.h"
1213#include " common/types/value/node.h"
1314#include " common/types/value/rel.h"
1415#include " datetime.h" // python lib
1516#include " include/py_query_result_converter.h"
17+ #include " main/query_result/arrow_query_result.h"
1618
1719using namespace lbug ::common;
1820using lbug::importCache;
@@ -30,6 +32,7 @@ void PyQueryResult::initialize(py::handle& m) {
3032 .def (" close" , &PyQueryResult::close)
3133 .def (" getAsDF" , &PyQueryResult::getAsDF)
3234 .def (" getAsArrow" , &PyQueryResult::getAsArrow)
35+ .def (" getCSR" , &PyQueryResult::getCSR)
3336 .def (" getColumnNames" , &PyQueryResult::getColumnNames)
3437 .def (" getColumnDataTypes" , &PyQueryResult::getColumnDataTypes)
3538 .def (" resetIterator" , &PyQueryResult::resetIterator)
@@ -85,6 +88,30 @@ void PyQueryResult::close() {
8588 }
8689}
8790
91+ namespace {
92+
93+ py::array_t <int64_t > copyToNumpyArray (const std::vector<int64_t >& values) {
94+ auto result = py::array_t <int64_t >(values.size ());
95+ auto * data = static_cast <int64_t *>(result.request ().ptr );
96+ std::copy (values.begin (), values.end (), data);
97+ return result;
98+ }
99+
100+ py::dict buildCSRResult (std::vector<int64_t > indptr, std::vector<int64_t > indices,
101+ std::vector<int64_t > edgeIDs, bool includeEdgeIDs) {
102+ py::dict result;
103+ result[" indptr" ] = copyToNumpyArray (indptr);
104+ result[" indices" ] = copyToNumpyArray (indices);
105+ if (includeEdgeIDs) {
106+ result[" edge_ids" ] = copyToNumpyArray (edgeIDs);
107+ } else {
108+ result[" edge_ids" ] = py::none ();
109+ }
110+ return result;
111+ }
112+
113+ } // namespace
114+
88115static py::object converTimestampToPyObject (timestamp_t & timestamp) {
89116 int32_t year = 0 , month = 0 , day = 0 , hour = 0 , min = 0 , sec = 0 , micros = 0 ;
90117 date_t date;
@@ -320,6 +347,23 @@ py::object PyQueryResult::getArrowChunks(const std::vector<LogicalType>& types,
320347
321348lbug::pyarrow::Table PyQueryResult::getAsArrow (std::int64_t chunkSize,
322349 bool fallbackExtensionTypes) {
350+ if (queryResult->getType () == QueryResultType::ARROW) {
351+ auto types = queryResult->getColumnDataTypes ();
352+ auto names = queryResult->getColumnNames ();
353+ py::list batches;
354+ auto batchImportFunc = importCache->pyarrow .lib .RecordBatch ._import_from_c ();
355+ while (queryResult->hasNextArrowChunk ()) {
356+ auto data = queryResult->getNextArrowChunk (chunkSize);
357+ auto schema = ArrowConverter::toArrowSchema (types, names, fallbackExtensionTypes);
358+ batches.append (
359+ batchImportFunc ((std::uint64_t )data.get (), (std::uint64_t )schema.get ()));
360+ }
361+ auto schema = ArrowConverter::toArrowSchema (types, names, fallbackExtensionTypes);
362+ auto fromBatchesFunc = importCache->pyarrow .lib .Table .from_batches ();
363+ auto schemaImportFunc = importCache->pyarrow .lib .Schema ._import_from_c ();
364+ auto schemaObj = schemaImportFunc ((std::uint64_t )schema.get ());
365+ return py::cast<lbug::pyarrow::Table>(fromBatchesFunc (batches, schemaObj));
366+ }
323367 auto types = queryResult->getColumnDataTypes ();
324368 auto names = queryResult->getColumnNames ();
325369 py::list batches = getArrowChunks (types, names, chunkSize, fallbackExtensionTypes);
@@ -330,6 +374,17 @@ lbug::pyarrow::Table PyQueryResult::getAsArrow(std::int64_t chunkSize,
330374 return py::cast<lbug::pyarrow::Table>(fromBatchesFunc (batches, schemaObj));
331375}
332376
377+ py::dict PyQueryResult::getCSR () {
378+ if (auto * arrowQueryResult = dynamic_cast <lbug::main::ArrowQueryResult*>(queryResult);
379+ arrowQueryResult != nullptr && arrowQueryResult->hasCSRMetadata ()) {
380+ const auto & metadata = arrowQueryResult->getCSRMetadata ();
381+ return buildCSRResult (metadata.indptr , metadata.indices , metadata.edgeIDs ,
382+ metadata.hasEdgeIDs );
383+ }
384+ throw RuntimeException (
385+ " CSR export is only supported for Arrow query results with native CSR metadata." );
386+ }
387+
333388py::list PyQueryResult::getColumnDataTypes () {
334389 auto columnDataTypes = queryResult->getColumnDataTypes ();
335390 py::tuple result (columnDataTypes.size ());
0 commit comments