diff --git a/build.rs b/build.rs index 8bd4f53..bc91a48 100644 --- a/build.rs +++ b/build.rs @@ -364,6 +364,11 @@ fn build_ffi( println!("cargo:rerun-if-changed=include/lbug_rs.h"); println!("cargo:rerun-if-changed=src/lbug_rs.cpp"); + println!("cargo:rerun-if-changed={bridge_file}"); + println!("cargo:rerun-if-changed={source_file}"); + if cfg!(feature = "arrow") { + println!("cargo:rerun-if-changed=include/lbug_arrow.h"); + } if bundled { // Note that this should match the lbug-src/* entries in the package.include list in Cargo.toml // Unfortunately they appear to need to be specified individually since the symlink is diff --git a/include/lbug_arrow.h b/include/lbug_arrow.h index a3587a7..eb2ec23 100644 --- a/include/lbug_arrow.h +++ b/include/lbug_arrow.h @@ -10,6 +10,11 @@ namespace lbug_arrow { ArrowSchema query_result_get_arrow_schema(const lbug::main::QueryResult& result); +bool query_result_has_next_arrow_chunk(lbug::main::QueryResult& result); ArrowArray query_result_get_next_arrow_chunk(lbug::main::QueryResult& result, uint64_t chunkSize); +ArrowArray query_result_get_csr_indptr(const lbug::main::QueryResult& result); +ArrowArray query_result_get_csr_indices(const lbug::main::QueryResult& result); +ArrowArray query_result_get_csr_edge_ids(const lbug::main::QueryResult& result); +bool query_result_has_csr_edge_ids(const lbug::main::QueryResult& result); } // namespace lbug_arrow diff --git a/include/lbug_rs.h b/include/lbug_rs.h index ad7fa76..7c4ae65 100644 --- a/include/lbug_rs.h +++ b/include/lbug_rs.h @@ -19,6 +19,18 @@ namespace lbug_rs { +using ArrowArray = ::ArrowArray; +using ArrowSchema = ::ArrowSchema; + +struct ArrowArrayList { + std::vector arrays; +}; + +std::unique_ptr new_arrow_array_list(); +inline void arrow_array_list_push(ArrowArrayList& list, ArrowArray array) { + list.arrays.push_back(array); +} + struct TypeListBuilder { std::vector types; @@ -117,6 +129,23 @@ inline std::unique_ptr connection_query(lbug::main::Con std::string_view query) { return connection.query(query); } +inline std::unique_ptr connection_query_as_arrow( + lbug::main::Connection& connection, std::string_view query, int64_t chunkSize) { + return connection.queryAsArrow(query, chunkSize); +} +std::unique_ptr connection_create_arrow_table( + lbug::main::Connection& connection, std::string_view tableName, ArrowSchema schema, + std::unique_ptr arrays); +std::unique_ptr connection_create_arrow_rel_table( + lbug::main::Connection& connection, std::string_view tableName, std::string_view srcTableName, + std::string_view dstTableName, ArrowSchema schema, std::unique_ptr arrays); +std::unique_ptr connection_create_arrow_rel_table_csr( + lbug::main::Connection& connection, std::string_view tableName, std::string_view srcTableName, + std::string_view dstTableName, ArrowSchema indicesSchema, + std::unique_ptr indicesArrays, ArrowSchema indptrSchema, + std::unique_ptr indptrArrays); +std::unique_ptr connection_drop_arrow_table( + lbug::main::Connection& connection, std::string_view tableName); inline std::unique_ptr connection_prepare( lbug::main::Connection& connection, std::string_view query) { return connection.prepare(query); diff --git a/src/connection.rs b/src/connection.rs index 4034d31..0bebd46 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -7,6 +7,44 @@ use cxx::UniquePtr; use std::cell::UnsafeCell; use std::convert::TryInto; +#[cfg(feature = "arrow")] +fn export_arrow_batches( + batches: &[arrow::record_batch::RecordBatch], +) -> Result< + ( + crate::ffi::arrow::ArrowSchema, + UniquePtr, + ), + Error, +> { + use arrow::array::{Array, StructArray}; + + let Some(first) = batches.first() else { + return Err(Error::ArrowError( + arrow::error::ArrowError::InvalidArgumentError( + "at least one Arrow record batch is required".to_string(), + ), + )); + }; + let schema = first.schema(); + let schema_ffi = arrow::ffi::FFI_ArrowSchema::try_from(schema.as_ref())?; + let mut arrays = crate::ffi::arrow::ffi_arrow::new_arrow_array_list(); + for batch in batches { + if batch.schema() != schema { + return Err(Error::ArrowError(arrow::error::ArrowError::SchemaError( + "all Arrow record batches must have the same schema".to_string(), + ))); + } + let struct_array = StructArray::from(batch.clone()); + let array_ffi = arrow::ffi::FFI_ArrowArray::new(&struct_array.into_data()); + crate::ffi::arrow::ffi_arrow::arrow_array_list_push( + arrays.pin_mut(), + crate::ffi::arrow::ArrowArray(array_ffi), + ); + } + Ok((crate::ffi::arrow::ArrowSchema(schema_ffi), arrays)) +} + /// A prepared stattement is a parameterized query which can avoid planning the same query for /// repeated execution pub struct PreparedStatement { @@ -142,6 +180,12 @@ impl<'a> Connection<'a> { pub fn query(&self, query: &str) -> Result, Error> { let conn = unsafe { (*self.conn.get()).pin_mut() }; let result = ffi::connection_query(conn, ffi::StringView::new(query))?; + Self::query_result_from_ffi(result) + } + + fn query_result_from_ffi( + result: UniquePtr>, + ) -> Result, Error> { if ffi::query_result_is_success(&result) { Ok(QueryResult { result }) } else { @@ -151,6 +195,115 @@ impl<'a> Connection<'a> { } } + #[cfg(feature = "arrow")] + /// Executes the given query with the native Arrow result collector. + /// + /// The returned [`QueryResult`] can be consumed with [`QueryResult::iter_arrow`] and can expose + /// CSR metadata through [`QueryResult::csr`] for relationship-shaped row-id projections. + /// + /// *Requires the `arrow` feature* + pub fn query_as_arrow(&self, query: &str, chunk_size: usize) -> Result, Error> { + let conn = unsafe { (*self.conn.get()).pin_mut() }; + let result = crate::ffi::arrow::ffi_arrow::connection_query_as_arrow( + conn, + ffi::StringView::new(query), + chunk_size as i64, + )?; + Self::query_result_from_ffi(result) + } + + #[cfg(feature = "arrow")] + /// Registers Arrow memory as a node table. + /// + /// The first column is used as the table primary key. + /// + /// *Requires the `arrow` feature* + pub fn create_arrow_table( + &self, + table_name: &str, + batches: &[arrow::record_batch::RecordBatch], + ) -> Result, Error> { + let (schema, arrays) = export_arrow_batches(batches)?; + let conn = unsafe { (*self.conn.get()).pin_mut() }; + let result = crate::ffi::arrow::ffi_arrow::connection_create_arrow_table( + conn, + ffi::StringView::new(table_name), + schema, + arrays, + )?; + Self::query_result_from_ffi(result) + } + + #[cfg(feature = "arrow")] + /// Registers Arrow memory as a relationship table. + /// + /// The Arrow schema must include endpoint columns named `from` and `to`. + /// + /// *Requires the `arrow` feature* + pub fn create_arrow_rel_table( + &self, + table_name: &str, + batches: &[arrow::record_batch::RecordBatch], + src_table_name: &str, + dst_table_name: &str, + ) -> Result, Error> { + let (schema, arrays) = export_arrow_batches(batches)?; + let conn = unsafe { (*self.conn.get()).pin_mut() }; + let result = crate::ffi::arrow::ffi_arrow::connection_create_arrow_rel_table( + conn, + ffi::StringView::new(table_name), + ffi::StringView::new(src_table_name), + ffi::StringView::new(dst_table_name), + schema, + arrays, + )?; + Self::query_result_from_ffi(result) + } + + #[cfg(feature = "arrow")] + /// Registers Arrow memory in CSR form as a relationship table. + /// + /// The `indices_batches` schema must include a destination column named `to`. The + /// `indptr_batches` schema must contain at least one offset column. + /// + /// *Requires the `arrow` feature* + pub fn create_arrow_rel_table_csr( + &self, + table_name: &str, + indices_batches: &[arrow::record_batch::RecordBatch], + indptr_batches: &[arrow::record_batch::RecordBatch], + src_table_name: &str, + dst_table_name: &str, + ) -> Result, Error> { + let (indices_schema, indices_arrays) = export_arrow_batches(indices_batches)?; + let (indptr_schema, indptr_arrays) = export_arrow_batches(indptr_batches)?; + let conn = unsafe { (*self.conn.get()).pin_mut() }; + let result = crate::ffi::arrow::ffi_arrow::connection_create_arrow_rel_table_csr( + conn, + ffi::StringView::new(table_name), + ffi::StringView::new(src_table_name), + ffi::StringView::new(dst_table_name), + indices_schema, + indices_arrays, + indptr_schema, + indptr_arrays, + )?; + Self::query_result_from_ffi(result) + } + + #[cfg(feature = "arrow")] + /// Drops an Arrow memory-backed table. + /// + /// *Requires the `arrow` feature* + pub fn drop_arrow_table(&self, table_name: &str) -> Result, Error> { + let conn = unsafe { (*self.conn.get()).pin_mut() }; + let result = crate::ffi::arrow::ffi_arrow::connection_drop_arrow_table( + conn, + ffi::StringView::new(table_name), + )?; + Self::query_result_from_ffi(result) + } + /// Executes the given prepared statement with args and returns the result. /// /// # Arguments @@ -190,13 +343,7 @@ impl<'a> Connection<'a> { let conn = unsafe { (*self.conn.get()).pin_mut() }; let result = ffi::connection_execute(conn, prepared_statement.statement.pin_mut(), cxx_params)?; - if ffi::query_result_is_success(&result) { - Ok(QueryResult { result }) - } else { - Err(Error::FailedQuery(ffi::query_result_get_error_message( - &result, - ))) - } + Self::query_result_from_ffi(result) } /// Interrupts all queries currently executing within this connection @@ -267,6 +414,209 @@ Invalid input : expected rule oC_SingleQuery (line: 1, o Ok(()) } + #[test] + #[cfg(feature = "arrow")] + fn test_create_arrow_table() -> Result<()> { + use arrow::array::{Int64Array, StringArray}; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + use std::sync::Arc; + + let temp_dir = tempfile::tempdir()?; + let db = Database::new(temp_dir.path().join("test"), SYSTEM_CONFIG_FOR_TESTS)?; + let conn = Connection::new(&db)?; + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("name", DataType::Utf8, false), + ])); + let batch = RecordBatch::try_new( + schema, + vec![ + Arc::new(Int64Array::from(vec![1, 2])), + Arc::new(StringArray::from(vec!["Alice", "Bob"])), + ], + )?; + + conn.create_arrow_table("Person", &[batch])?; + let result = conn.query("MATCH (p:Person) RETURN p.name ORDER BY p.id;")?; + + assert_eq!(result.to_string(), "p.name\nAlice\nBob\n"); + temp_dir.close()?; + Ok(()) + } + + #[test] + #[cfg(feature = "arrow")] + fn test_create_arrow_table_multiple_batches_and_drop() -> Result<()> { + use arrow::array::{Int64Array, StringArray}; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + use std::sync::Arc; + + let temp_dir = tempfile::tempdir()?; + let db = Database::new(temp_dir.path().join("test"), SYSTEM_CONFIG_FOR_TESTS)?; + let conn = Connection::new(&db)?; + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("name", DataType::Utf8, false), + ])); + let batch1 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int64Array::from(vec![1, 2])), + Arc::new(StringArray::from(vec!["Alice", "Bob"])), + ], + )?; + let batch2 = RecordBatch::try_new( + schema, + vec![ + Arc::new(Int64Array::from(vec![3])), + Arc::new(StringArray::from(vec!["Carol"])), + ], + )?; + + conn.create_arrow_table("Person", &[batch1, batch2])?; + let result = conn.query("MATCH (p:Person) RETURN p.name ORDER BY p.id;")?; + assert_eq!(result.to_string(), "p.name\nAlice\nBob\nCarol\n"); + + conn.drop_arrow_table("Person")?; + let err = conn + .query("MATCH (p:Person) RETURN p.name;") + .expect_err("dropped Arrow table should no longer be queryable"); + assert!(err.to_string().contains("Table Person does not exist")); + temp_dir.close()?; + Ok(()) + } + + #[test] + #[cfg(feature = "arrow")] + fn test_create_arrow_table_validates_batches() -> Result<()> { + use arrow::array::{Int64Array, StringArray}; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + use std::sync::Arc; + + let temp_dir = tempfile::tempdir()?; + let db = Database::new(temp_dir.path().join("test"), SYSTEM_CONFIG_FOR_TESTS)?; + let conn = Connection::new(&db)?; + let err = conn + .create_arrow_table("Person", &[]) + .expect_err("empty Arrow table registration should fail"); + assert!(err + .to_string() + .contains("at least one Arrow record batch is required")); + + let schema1 = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)])); + let schema2 = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("name", DataType::Utf8, false), + ])); + let batch1 = RecordBatch::try_new(schema1, vec![Arc::new(Int64Array::from(vec![1]))])?; + let batch2 = RecordBatch::try_new( + schema2, + vec![ + Arc::new(Int64Array::from(vec![2])), + Arc::new(StringArray::from(vec!["Bob"])), + ], + )?; + + let err = conn + .create_arrow_table("Person", &[batch1, batch2]) + .expect_err("mixed Arrow schemas should fail"); + assert!(err + .to_string() + .contains("all Arrow record batches must have the same schema")); + temp_dir.close()?; + Ok(()) + } + + #[test] + #[cfg(feature = "arrow")] + fn test_create_arrow_rel_table() -> Result<()> { + use arrow::array::Int64Array; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + use std::sync::Arc; + + let temp_dir = tempfile::tempdir()?; + let db = Database::new(temp_dir.path().join("test"), SYSTEM_CONFIG_FOR_TESTS)?; + let conn = Connection::new(&db)?; + let node_schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)])); + let nodes = + RecordBatch::try_new(node_schema, vec![Arc::new(Int64Array::from(vec![0, 1]))])?; + conn.create_arrow_table("Person", &[nodes])?; + + let rel_schema = Arc::new(Schema::new(vec![ + Field::new("from", DataType::Int64, false), + Field::new("to", DataType::Int64, false), + Field::new("weight", DataType::Int64, false), + ])); + let rels = RecordBatch::try_new( + rel_schema, + vec![ + Arc::new(Int64Array::from(vec![0, 1])), + Arc::new(Int64Array::from(vec![1, 0])), + Arc::new(Int64Array::from(vec![7, 9])), + ], + )?; + conn.create_arrow_rel_table("Knows", &[rels], "Person", "Person")?; + + let result = conn.query( + "MATCH (a:Person)-[r:Knows]->(b:Person) \ + RETURN a.id, r.weight, b.id ORDER BY a.id, b.id;", + )?; + assert_eq!(result.to_string(), "a.id|r.weight|b.id\n0|7|1\n1|9|0\n"); + temp_dir.close()?; + Ok(()) + } + + #[test] + #[cfg(feature = "arrow")] + fn test_create_arrow_rel_table_csr() -> Result<()> { + use arrow::array::{Int64Array, UInt64Array}; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + use std::sync::Arc; + + let temp_dir = tempfile::tempdir()?; + let db = Database::new(temp_dir.path().join("test"), SYSTEM_CONFIG_FOR_TESTS)?; + let conn = Connection::new(&db)?; + let node_schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)])); + let nodes = + RecordBatch::try_new(node_schema, vec![Arc::new(Int64Array::from(vec![0, 1]))])?; + conn.create_arrow_table("Person", &[nodes])?; + + let indices_schema = Arc::new(Schema::new(vec![ + Field::new("to", DataType::UInt64, false), + Field::new("weight", DataType::Int64, false), + ])); + let indices = RecordBatch::try_new( + indices_schema, + vec![ + Arc::new(UInt64Array::from(vec![1, 0])), + Arc::new(Int64Array::from(vec![7, 9])), + ], + )?; + let indptr_schema = Arc::new(Schema::new(vec![Field::new( + "indptr", + DataType::UInt64, + false, + )])); + let indptr = RecordBatch::try_new( + indptr_schema, + vec![Arc::new(UInt64Array::from(vec![0, 1, 2]))], + )?; + conn.create_arrow_rel_table_csr("Knows", &[indices], &[indptr], "Person", "Person")?; + + let result = conn.query( + "MATCH (a:Person)-[r:Knows]->(b:Person) \ + RETURN a.id, r.weight, b.id ORDER BY a.id, b.id;", + )?; + assert_eq!(result.to_string(), "a.id|r.weight|b.id\n0|7|1\n1|9|0\n"); + temp_dir.close()?; + Ok(()) + } + #[test] fn test_query_result() -> Result<()> { let temp_dir = tempfile::tempdir()?; diff --git a/src/ffi/arrow.rs b/src/ffi/arrow.rs index 786af15..10107f8 100644 --- a/src/ffi/arrow.rs +++ b/src/ffi/arrow.rs @@ -18,6 +18,14 @@ unsafe impl cxx::ExternType for ArrowSchema { pub(crate) mod ffi_arrow { unsafe extern "C++" { include!("lbug/include/lbug_arrow.h"); + include!("lbug/include/lbug_rs.h"); + + #[namespace = "std"] + #[cxx_name = "string_view"] + type StringView<'a> = crate::ffi::StringView<'a>; + + #[namespace = "lbug::main"] + type Connection<'db> = crate::ffi::ffi::Connection<'db>; #[namespace = "lbug::main"] type QueryResult<'db> = crate::ffi::ffi::QueryResult<'db>; @@ -26,11 +34,26 @@ pub(crate) mod ffi_arrow { unsafe extern "C++" { type ArrowArray = crate::ffi::arrow::ArrowArray; + #[namespace = "lbug_arrow"] + fn query_result_has_next_arrow_chunk<'db>(result: Pin<&mut QueryResult<'db>>) -> bool; + #[namespace = "lbug_arrow"] fn query_result_get_next_arrow_chunk<'db>( result: Pin<&mut QueryResult<'db>>, chunk_size: u64, ) -> Result; + + #[namespace = "lbug_arrow"] + fn query_result_get_csr_indptr<'db>(result: &QueryResult<'db>) -> Result; + + #[namespace = "lbug_arrow"] + fn query_result_get_csr_indices<'db>(result: &QueryResult<'db>) -> Result; + + #[namespace = "lbug_arrow"] + fn query_result_get_csr_edge_ids<'db>(result: &QueryResult<'db>) -> Result; + + #[namespace = "lbug_arrow"] + fn query_result_has_csr_edge_ids<'db>(result: &QueryResult<'db>) -> Result; } unsafe extern "C++" { @@ -39,4 +62,51 @@ pub(crate) mod ffi_arrow { #[namespace = "lbug_arrow"] fn query_result_get_arrow_schema<'db>(result: &QueryResult<'db>) -> Result; } + + #[namespace = "lbug_rs"] + unsafe extern "C++" { + type ArrowArrayList; + + fn new_arrow_array_list() -> UniquePtr; + + fn arrow_array_list_push(list: Pin<&mut ArrowArrayList>, array: ArrowArray); + + fn connection_query_as_arrow<'a, 'db>( + connection: Pin<&mut Connection<'db>>, + query: StringView<'a>, + chunk_size: i64, + ) -> Result>>; + + fn connection_create_arrow_table<'a, 'db>( + connection: Pin<&mut Connection<'db>>, + table_name: StringView<'a>, + schema: ArrowSchema, + arrays: UniquePtr, + ) -> Result>>; + + fn connection_create_arrow_rel_table<'a, 'b, 'c, 'db>( + connection: Pin<&mut Connection<'db>>, + table_name: StringView<'a>, + src_table_name: StringView<'b>, + dst_table_name: StringView<'c>, + schema: ArrowSchema, + arrays: UniquePtr, + ) -> Result>>; + + fn connection_create_arrow_rel_table_csr<'a, 'b, 'c, 'db>( + connection: Pin<&mut Connection<'db>>, + table_name: StringView<'a>, + src_table_name: StringView<'b>, + dst_table_name: StringView<'c>, + indices_schema: ArrowSchema, + indices_arrays: UniquePtr, + indptr_schema: ArrowSchema, + indptr_arrays: UniquePtr, + ) -> Result>>; + + fn connection_drop_arrow_table<'a, 'db>( + connection: Pin<&mut Connection<'db>>, + table_name: StringView<'a>, + ) -> Result>>; + } } diff --git a/src/lbug_arrow.cpp b/src/lbug_arrow.cpp index 7d8067e..a5ac1c1 100644 --- a/src/lbug_arrow.cpp +++ b/src/lbug_arrow.cpp @@ -1,5 +1,66 @@ #include "lbug_arrow.h" +#include +#include +#include + +namespace lbug { +namespace main { + +class ArrowQueryResult : public QueryResult { +public: + struct CSRMetadata { + std::vector indptr; + std::vector indices; + std::vector edgeIDs; + bool hasEdgeIDs = false; + }; + + struct CSRArrowArray { + ArrowArray array{}; + ArrowSchema schema{}; + + CSRArrowArray() = default; + ~CSRArrowArray() { release(); } + CSRArrowArray(CSRArrowArray&& other) noexcept : array{other.array}, schema{other.schema} { + other.array.release = nullptr; + other.schema.release = nullptr; + } + CSRArrowArray& operator=(CSRArrowArray&& other) noexcept { + if (this != &other) { + release(); + array = other.array; + schema = other.schema; + other.array.release = nullptr; + other.schema.release = nullptr; + } + return *this; + } + CSRArrowArray(const CSRArrowArray&) = delete; + CSRArrowArray& operator=(const CSRArrowArray&) = delete; + + void release() { + if (schema.release) { + schema.release(&schema); + } + if (array.release) { + array.release(&array); + } + } + }; + + struct CSRArrowArrays { + CSRArrowArray indptr; + CSRArrowArray indices; + std::optional edgeIDs; + }; + + CSRArrowArrays getCSRArrowArrays() const; +}; + +} // namespace main +} // namespace lbug + namespace lbug_arrow { ArrowSchema query_result_get_arrow_schema(const lbug::main::QueryResult& result) { @@ -8,8 +69,50 @@ ArrowSchema query_result_get_arrow_schema(const lbug::main::QueryResult& result) return *result.getArrowSchema(); } +bool query_result_has_next_arrow_chunk(lbug::main::QueryResult& result) { + return result.hasNextArrowChunk(); +} + ArrowArray query_result_get_next_arrow_chunk(lbug::main::QueryResult& result, uint64_t chunkSize) { return *result.getNextArrowChunk(chunkSize); } +static const lbug::main::ArrowQueryResult& get_arrow_query_result( + const lbug::main::QueryResult& result) { + auto arrowResult = dynamic_cast(&result); + if (arrowResult == nullptr) { + throw std::runtime_error( + "CSR export is only supported for Arrow query results with native CSR metadata."); + } + return *arrowResult; +} + +static ArrowArray detach(lbug::main::ArrowQueryResult::CSRArrowArray& array) { + auto result = array.array; + array.array.release = nullptr; + return result; +} + +ArrowArray query_result_get_csr_indptr(const lbug::main::QueryResult& result) { + auto arrays = get_arrow_query_result(result).getCSRArrowArrays(); + return detach(arrays.indptr); +} + +ArrowArray query_result_get_csr_indices(const lbug::main::QueryResult& result) { + auto arrays = get_arrow_query_result(result).getCSRArrowArrays(); + return detach(arrays.indices); +} + +ArrowArray query_result_get_csr_edge_ids(const lbug::main::QueryResult& result) { + auto arrays = get_arrow_query_result(result).getCSRArrowArrays(); + if (!arrays.edgeIDs.has_value()) { + throw std::runtime_error("Arrow query result does not have CSR edge ids."); + } + return detach(*arrays.edgeIDs); +} + +bool query_result_has_csr_edge_ids(const lbug::main::QueryResult& result) { + return get_arrow_query_result(result).getCSRArrowArrays().edgeIDs.has_value(); +} + } // namespace lbug_arrow diff --git a/src/lbug_rs.cpp b/src/lbug_rs.cpp index ac6574f..014eca3 100644 --- a/src/lbug_rs.cpp +++ b/src/lbug_rs.cpp @@ -1,5 +1,7 @@ #include "lbug_rs.h" +#include + using lbug::common::ArrayTypeInfo; using lbug::common::Interval; using lbug::common::ListTypeInfo; @@ -19,6 +21,10 @@ std::unique_ptr new_params() { return std::make_unique(); } +std::unique_ptr new_arrow_array_list() { + return std::make_unique(); +} + std::unique_ptr create_logical_type(lbug::common::LogicalTypeID id) { return std::make_unique(id); } @@ -100,6 +106,93 @@ std::unique_ptr connection_execute(lbug::main::Connecti return connection.executeWithParams(&query, std::move(params->inputParams)); } +std::unique_ptr connection_create_arrow_table( + lbug::main::Connection& connection, std::string_view tableName, ArrowSchema schema, + std::unique_ptr arrays) { + lbug_connection capiConnection{&connection}; + lbug_query_result capiResult{}; + (void)lbug_connection_create_arrow_table(&capiConnection, std::string(tableName).c_str(), + &schema, arrays->arrays.data(), arrays->arrays.size(), &capiResult); + if (capiResult._query_result == nullptr) { + auto error = lbug_get_last_error(); + std::string message = error == nullptr ? "Failed to create Arrow table" : error; + if (error != nullptr) { + lbug_destroy_string(error); + } + throw std::runtime_error(message); + } + return std::unique_ptr( + static_cast(capiResult._query_result)); +} + +std::unique_ptr connection_create_arrow_rel_table( + lbug::main::Connection& connection, std::string_view tableName, std::string_view srcTableName, + std::string_view dstTableName, ArrowSchema schema, std::unique_ptr arrays) { + lbug_connection capiConnection{&connection}; + lbug_query_result capiResult{}; + auto tableNameString = std::string(tableName); + auto srcTableNameString = std::string(srcTableName); + auto dstTableNameString = std::string(dstTableName); + (void)lbug_connection_create_arrow_rel_table(&capiConnection, tableNameString.c_str(), + srcTableNameString.c_str(), dstTableNameString.c_str(), &schema, arrays->arrays.data(), + arrays->arrays.size(), &capiResult); + if (capiResult._query_result == nullptr) { + auto error = lbug_get_last_error(); + std::string message = error == nullptr ? "Failed to create Arrow relationship table" : error; + if (error != nullptr) { + lbug_destroy_string(error); + } + throw std::runtime_error(message); + } + return std::unique_ptr( + static_cast(capiResult._query_result)); +} + +std::unique_ptr connection_create_arrow_rel_table_csr( + lbug::main::Connection& connection, std::string_view tableName, std::string_view srcTableName, + std::string_view dstTableName, ArrowSchema indicesSchema, + std::unique_ptr indicesArrays, ArrowSchema indptrSchema, + std::unique_ptr indptrArrays) { + lbug_connection capiConnection{&connection}; + lbug_query_result capiResult{}; + auto tableNameString = std::string(tableName); + auto srcTableNameString = std::string(srcTableName); + auto dstTableNameString = std::string(dstTableName); + (void)lbug_connection_create_arrow_rel_table_csr(&capiConnection, + tableNameString.c_str(), srcTableNameString.c_str(), dstTableNameString.c_str(), + &indicesSchema, indicesArrays->arrays.data(), indicesArrays->arrays.size(), &indptrSchema, + indptrArrays->arrays.data(), indptrArrays->arrays.size(), &capiResult); + if (capiResult._query_result == nullptr) { + auto error = lbug_get_last_error(); + std::string message = + error == nullptr ? "Failed to create Arrow CSR relationship table" : error; + if (error != nullptr) { + lbug_destroy_string(error); + } + throw std::runtime_error(message); + } + return std::unique_ptr( + static_cast(capiResult._query_result)); +} + +std::unique_ptr connection_drop_arrow_table( + lbug::main::Connection& connection, std::string_view tableName) { + lbug_connection capiConnection{&connection}; + lbug_query_result capiResult{}; + (void)lbug_connection_drop_arrow_table( + &capiConnection, std::string(tableName).c_str(), &capiResult); + if (capiResult._query_result == nullptr) { + auto error = lbug_get_last_error(); + std::string message = error == nullptr ? "Failed to drop Arrow table" : error; + if (error != nullptr) { + lbug_destroy_string(error); + } + throw std::runtime_error(message); + } + return std::unique_ptr( + static_cast(capiResult._query_result)); +} + rust::String prepared_statement_error_message(const lbug::main::PreparedStatement& statement) { return rust::String(statement.getErrorMessage()); } diff --git a/src/lib.rs b/src/lib.rs index 45e02d2..1fdb1d0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -73,7 +73,7 @@ pub use database::{Database, SystemConfig}; pub use error::Error; pub use logical_type::LogicalType; #[cfg(feature = "arrow")] -pub use query_result::ArrowIterator; +pub use query_result::{ArrowIterator, CsrResult}; pub use query_result::{CSVOptions, QueryResult}; pub use value::{InternalID, NodeVal, RelVal, Value}; diff --git a/src/query_result.rs b/src/query_result.rs index 2815050..4bb32d8 100644 --- a/src/query_result.rs +++ b/src/query_result.rs @@ -113,6 +113,46 @@ impl<'db> QueryResult<'db> { schema, }) } + + #[cfg(feature = "arrow")] + /// Returns native CSR arrays from an Arrow query result. + /// + /// This is available for Arrow-native results with CSR metadata, typically results produced by + /// [`Connection::query_as_arrow`](crate::Connection::query_as_arrow) for relationship-shaped + /// row-id projections. + /// + /// *Requires the `arrow` feature* + pub fn csr(&self) -> Result { + use arrow::array::Int64Array; + use arrow::datatypes::DataType; + + fn import_i64_array( + array: crate::ffi::arrow::ArrowArray, + ) -> Result { + let data = unsafe { arrow::ffi::from_ffi_and_data_type(array.0, DataType::Int64) }?; + Ok(Int64Array::from(data)) + } + + let result = self.result.as_ref().unwrap(); + let indptr = import_i64_array(crate::ffi::arrow::ffi_arrow::query_result_get_csr_indptr( + result, + )?)?; + let indices = import_i64_array( + crate::ffi::arrow::ffi_arrow::query_result_get_csr_indices(result)?, + )?; + let edge_ids = if crate::ffi::arrow::ffi_arrow::query_result_has_csr_edge_ids(result)? { + Some(import_i64_array( + crate::ffi::arrow::ffi_arrow::query_result_get_csr_edge_ids(result)?, + )?) + } else { + None + }; + Ok(CsrResult { + indptr, + indices, + edge_ids, + }) + } } // the underlying C++ type is both data and an iterator (sort-of) @@ -139,6 +179,16 @@ impl Iterator for QueryResult<'_> { } } +#[cfg(feature = "arrow")] +/// Native CSR arrays exported from an Arrow query result. +/// +/// *Requires the `arrow` feature* +pub struct CsrResult { + pub indptr: arrow::array::Int64Array, + pub indices: arrow::array::Int64Array, + pub edge_ids: Option, +} + #[cfg(feature = "arrow")] /// Produces an iterator over a `QueryResult` as [`RecordBatch`](arrow::record_batch::RecordBatch)es /// @@ -156,7 +206,7 @@ impl Iterator for ArrowIterator<'_, '_> { type Item = arrow::record_batch::RecordBatch; fn next(&mut self) -> Option { - if ffi::query_result_has_next(self.result.as_ref().unwrap()) { + if crate::ffi::arrow::ffi_arrow::query_result_has_next_arrow_chunk(self.result.pin_mut()) { use crate::ffi::arrow::ffi_arrow; // Generally this panic should be unreachable, since the only exceptions produced by // arrow_converter are for unsupported types, but those would produce an error when we @@ -287,4 +337,39 @@ mod tests { temp_dir.close()?; Ok(()) } + + #[test] + #[cfg(feature = "arrow")] + fn test_query_as_arrow_csr() -> anyhow::Result<()> { + let temp_dir = tempfile::tempdir()?; + let path = temp_dir.path().join("test"); + let db = Database::new(path, SYSTEM_CONFIG_FOR_TESTS)?; + let conn = Connection::new(&db)?; + conn.query("CREATE NODE TABLE Person(id INT64, PRIMARY KEY(id));")?; + conn.query("CREATE (:Person {id: 0});")?; + conn.query("CREATE (:Person {id: 1});")?; + conn.query("CREATE REL TABLE Knows(FROM Person TO Person, weight INT64);")?; + conn.query("MATCH (a:Person), (b:Person) WHERE a.id = 0 AND b.id = 1 CREATE (a)-[:Knows {weight: 7}]->(b);")?; + conn.query("MATCH (a:Person), (b:Person) WHERE a.id = 1 AND b.id = 0 CREATE (a)-[:Knows {weight: 9}]->(b);")?; + + let result = conn.query_as_arrow( + "MATCH (a:Person)-[r:Knows]->(b:Person) RETURN a.rowid, r.rowid, b.rowid", + 8, + )?; + let csr = result.csr()?; + + assert_eq!(csr.indptr.len(), 3); + assert_eq!(csr.indptr.value(0), 0); + assert_eq!(csr.indptr.value(1), 1); + assert_eq!(csr.indptr.value(2), 2); + assert_eq!(csr.indices.len(), 2); + assert_eq!(csr.indices.value(0), 1); + assert_eq!(csr.indices.value(1), 0); + let edge_ids = csr.edge_ids.expect("CSR result should include edge ids"); + assert_eq!(edge_ids.len(), 2); + assert_eq!(edge_ids.value(0), 0); + assert_eq!(edge_ids.value(1), 1); + temp_dir.close()?; + Ok(()) + } }