diff --git a/Cargo.lock b/Cargo.lock index 4bc5ce1..fbf5362 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1206,6 +1206,16 @@ dependencies = [ "unicode-segmentation", ] +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation" version = "0.10.1" @@ -2776,6 +2786,21 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb" +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + [[package]] name = "form_urlencoded" version = "1.2.2" @@ -3544,7 +3569,7 @@ dependencies = [ [[package]] name = "laminar-connectors" -version = "0.15.0" +version = "0.15.1" dependencies = [ "arrow-array", "arrow-avro", @@ -3587,7 +3612,7 @@ dependencies = [ [[package]] name = "laminar-core" -version = "0.15.0" +version = "0.15.1" dependencies = [ "ahash", "anyhow", @@ -3634,7 +3659,7 @@ dependencies = [ [[package]] name = "laminar-db" -version = "0.15.0" +version = "0.15.1" dependencies = [ "anyhow", "arrow", @@ -3664,7 +3689,7 @@ dependencies = [ [[package]] name = "laminar-sql" -version = "0.15.0" +version = "0.15.1" dependencies = [ "anyhow", "arrow", @@ -3689,7 +3714,7 @@ dependencies = [ [[package]] name = "laminar-storage" -version = "0.15.0" +version = "0.15.1" dependencies = [ "anyhow", "async-trait", @@ -3714,7 +3739,7 @@ dependencies = [ [[package]] name = "laminardb" -version = "0.15.0" +version = "0.15.1" dependencies = [ "anyhow", "arrow", @@ -4125,6 +4150,23 @@ dependencies = [ "zstd", ] +[[package]] +name = "native-tls" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87de3442987e9dbec73158d5c715e7ad9072fda936bb03d19d7fa10e00520f0e" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe 0.1.6", + "openssl-sys", + "schannel", + "security-framework 2.11.1", + "security-framework-sys", + "tempfile", +] + [[package]] name = "ndarray" version = "0.17.2" @@ -4310,12 +4352,56 @@ version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" +[[package]] +name = "openssl" +version = "0.10.75" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08838db121398ad17ab8531ce9de97b244589089e290a384c900cb9ff7434328" +dependencies = [ + "bitflags 2.11.0", + "cfg-if", + "foreign-types", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "openssl-probe" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" + [[package]] name = "openssl-probe" version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe" +[[package]] +name = "openssl-sys" +version = "0.9.111" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82cab2d520aa75e3c58898289429321eb788c3106963d0dc886ec7a5f4adc321" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + [[package]] name = "option-ext" version = "0.2.0" @@ -5385,10 +5471,10 @@ version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "612460d5f7bea540c490b2b6395d8e34a953e52b491accd6c86c8164c5932a63" dependencies = [ - "openssl-probe", + "openssl-probe 0.2.1", "rustls-pki-types", "schannel", - "security-framework", + "security-framework 3.7.0", ] [[package]] @@ -5416,7 +5502,7 @@ version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1d99feebc72bae7ab76ba994bb5e121b8d83d910ca40b36e0921f53becc41784" dependencies = [ - "core-foundation", + "core-foundation 0.10.1", "core-foundation-sys", "jni", "log", @@ -5425,7 +5511,7 @@ dependencies = [ "rustls-native-certs", "rustls-platform-verifier-android", "rustls-webpki", - "security-framework", + "security-framework 3.7.0", "security-framework-sys", "webpki-root-certs", "windows-sys 0.61.2", @@ -5491,6 +5577,19 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "security-framework" +version = "2.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" +dependencies = [ + "bitflags 2.11.0", + "core-foundation 0.9.4", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + [[package]] name = "security-framework" version = "3.7.0" @@ -5498,7 +5597,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b7f4bc775c73d9a02cde8bf7b2ec4c9d12743edf609006c7facc23998404cd1d" dependencies = [ "bitflags 2.11.0", - "core-foundation", + "core-foundation 0.10.1", "core-foundation-sys", "libc", "security-framework-sys", @@ -5981,6 +6080,16 @@ dependencies = [ "syn", ] +[[package]] +name = "tokio-native-tls" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" +dependencies = [ + "native-tls", + "tokio", +] + [[package]] name = "tokio-postgres" version = "0.7.16" @@ -6037,12 +6146,10 @@ checksum = "7a9daff607c6d2bf6c16fd681ccb7eecc83e4e2cdc1ca067ffaadfca5de7f084" dependencies = [ "futures-util", "log", - "rustls", - "rustls-pki-types", + "native-tls", "tokio", - "tokio-rustls", + "tokio-native-tls", "tungstenite", - "webpki-roots 0.26.11", ] [[package]] @@ -6183,9 +6290,8 @@ dependencies = [ "http 1.4.0", "httparse", "log", + "native-tls", "rand 0.9.2", - "rustls", - "rustls-pki-types", "sha1", "thiserror 2.0.18", "utf-8", diff --git a/Cargo.toml b/Cargo.toml index 6b02f92..66733a5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "laminardb" -version = "0.15.0" +version = "0.15.1" edition = "2024" rust-version = "1.85" license = "MIT" diff --git a/pyproject.toml b/pyproject.toml index b33f7c7..d71fc98 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "maturin" [project] name = "laminardb" -version = "0.15.0" +version = "0.15.1" description = "Python bindings for LaminarDB streaming SQL database" readme = "README.md" requires-python = ">=3.11" diff --git a/python/laminardb/__init__.py b/python/laminardb/__init__.py index 78f32b3..2bb2e0a 100644 --- a/python/laminardb/__init__.py +++ b/python/laminardb/__init__.py @@ -11,6 +11,7 @@ AsyncSubscription Asynchronous continuous query subscription. StreamSubscription Synchronous named-stream subscription. AsyncStreamSubscription Asynchronous named-stream subscription. + CallbackSubscription Push-based subscription with callbacks. MaterializedView High-level wrapper for named streams. Schema Convenience wrapper around PyArrow Schema. ChangeEvent A batch of change rows from a subscription. @@ -39,7 +40,9 @@ from laminardb._laminardb import ( AsyncStreamSubscription, AsyncSubscription, + CallbackSubscription, CheckpointError, + CheckpointResult, Connection, ConnectionError, ConnectorError, @@ -155,6 +158,7 @@ def mv(conn: Connection, name: str, sql_def: str | None = None) -> MaterializedV "execute", "mv", # Core classes + "CheckpointResult", "Connection", "ExecuteResult", "LaminarConfig", @@ -164,6 +168,7 @@ def mv(conn: Connection, name: str, sql_def: str | None = None) -> MaterializedV "AsyncSubscription", "StreamSubscription", "AsyncStreamSubscription", + "CallbackSubscription", # Aliases "Config", "BatchWriter", diff --git a/python/laminardb/_laminardb.pyi b/python/laminardb/_laminardb.pyi index 778b31e..aefc2ef 100644 --- a/python/laminardb/_laminardb.pyi +++ b/python/laminardb/_laminardb.pyi @@ -117,6 +117,22 @@ class LaminarConfig: def __repr__(self) -> str: ... +# --------------------------------------------------------------------------- +# CheckpointResult +# --------------------------------------------------------------------------- + +class CheckpointResult: + """The result of a checkpoint operation.""" + + @property + def checkpoint_id(self) -> int: + """The checkpoint ID assigned by the database.""" + ... + + def __bool__(self) -> bool: ... + def __int__(self) -> int: ... + def __repr__(self) -> str: ... + # --------------------------------------------------------------------------- # ExecuteResult # --------------------------------------------------------------------------- @@ -250,6 +266,34 @@ class Connection: """ ... + def subscribe_callback( + self, + sql: str, + on_data: Callable[[QueryResult], None], + on_error: Callable[[str], None] | None = None, + ) -> CallbackSubscription: + """Subscribe to a continuous query with a callback. + + ``on_data`` is called with a ``QueryResult`` for each batch. + ``on_error`` is called with an error message string on failure; + if absent or if it raises, the subscription stops. + """ + ... + + def subscribe_stream_callback( + self, + name: str, + on_data: Callable[[QueryResult], None], + on_error: Callable[[str], None] | None = None, + ) -> CallbackSubscription: + """Subscribe to a named stream with a callback. + + ``on_data`` is called with a ``QueryResult`` for each batch. + ``on_error`` is called with an error message string on failure; + if absent or if it raises, the subscription stops. + """ + ... + def query_stream(self, name: str, filter: str | None = None) -> QueryResult: """Query a named stream's current data. @@ -282,8 +326,8 @@ class Connection: """Start the streaming pipeline.""" ... - def checkpoint(self) -> int | None: - """Trigger a checkpoint. Returns the checkpoint ID or None.""" + def checkpoint(self) -> CheckpointResult: + """Trigger a checkpoint. Returns a CheckpointResult.""" ... def execute(self, sql: str) -> ExecuteResult: @@ -604,6 +648,28 @@ class StreamSubscription: def __iter__(self) -> Iterator[QueryResult]: ... def __next__(self) -> QueryResult: ... +class CallbackSubscription: + """A push-based subscription that calls a Python function for each batch. + + Created via ``Connection.subscribe_callback()`` or + ``Connection.subscribe_stream_callback()``. + """ + + @property + def is_active(self) -> bool: + """Whether the background thread is still running.""" + ... + + def cancel(self) -> None: + """Stop the subscription. Safe to call multiple times.""" + ... + + def wait(self) -> None: + """Block until the background thread exits. Releases the GIL.""" + ... + + def __repr__(self) -> str: ... + class AsyncStreamSubscription: """An asynchronous subscription to a named stream. diff --git a/python/laminardb/types.py b/python/laminardb/types.py index 21cc7f4..81b560c 100644 --- a/python/laminardb/types.py +++ b/python/laminardb/types.py @@ -87,7 +87,12 @@ class Watermark: @dataclass(frozen=True) class CheckpointStatus: - """Status of the checkpoint system.""" + """Status of the checkpoint system. + + .. deprecated:: + Use :class:`laminardb.CheckpointResult` returned by + ``Connection.checkpoint()`` instead. + """ checkpoint_id: int | None enabled: bool diff --git a/src/callback.rs b/src/callback.rs new file mode 100644 index 0000000..6eef824 --- /dev/null +++ b/src/callback.rs @@ -0,0 +1,239 @@ +//! Callback-based push subscription API. +//! +//! `CallbackSubscription` runs a background thread that polls for new batches +//! and invokes a user-provided Python callable for each one. This provides a +//! push-based alternative to the pull-based iterator subscriptions. + +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::thread::{self, JoinHandle}; +use std::time::Duration; + +use parking_lot::Mutex; +use pyo3::prelude::*; + +use crate::async_support::runtime; +use crate::query::QueryResult; + +// --------------------------------------------------------------------------- +// SubscriptionKind — unifies QueryStream and ArrowSubscription +// --------------------------------------------------------------------------- + +enum SubscriptionKind { + Query(laminar_db::api::QueryStream), + Stream(laminar_db::api::ArrowSubscription), +} + +impl SubscriptionKind { + fn try_next(&mut self) -> Result, laminar_db::api::ApiError> { + match self { + Self::Query(s) => s.try_next(), + Self::Stream(s) => s.try_next(), + } + } + + fn is_active(&self) -> bool { + match self { + Self::Query(s) => s.is_active(), + Self::Stream(s) => s.is_active(), + } + } + + fn cancel(&mut self) { + match self { + Self::Query(s) => s.cancel(), + Self::Stream(s) => s.cancel(), + } + } +} + +// Safety: The inner types are Send (protected by being owned by a single thread). +unsafe impl Send for SubscriptionKind {} + +// --------------------------------------------------------------------------- +// CallbackSubscription pyclass +// --------------------------------------------------------------------------- + +/// A push-based subscription that calls a Python function for each batch. +/// +/// Created via `Connection.subscribe_callback(sql, on_data)` or +/// `Connection.subscribe_stream_callback(name, on_data)`. +/// +/// The subscription runs on a background thread and invokes `on_data` +/// with a `QueryResult` for each batch received. Call `cancel()` to +/// stop the subscription and `wait()` to block until the background +/// thread exits. +#[pyclass(name = "CallbackSubscription")] +pub struct CallbackSubscription { + cancelled: Arc, + active: Arc, + thread: Mutex>>, + kind: String, +} + +unsafe impl Send for CallbackSubscription {} +unsafe impl Sync for CallbackSubscription {} + +#[pymethods] +impl CallbackSubscription { + /// Whether the background thread is still running. + #[getter] + fn is_active(&self) -> bool { + self.active.load(Ordering::Relaxed) + } + + /// Stop the subscription. Safe to call multiple times. + fn cancel(&self) { + self.cancelled.store(true, Ordering::Relaxed); + } + + /// Block until the background thread exits. Releases the GIL. + fn wait(&self, py: Python<'_>) -> PyResult<()> { + let handle = self.thread.lock().take(); + if let Some(h) = handle { + py.allow_threads(|| { + let _ = h.join(); + }); + } + Ok(()) + } + + fn __repr__(&self) -> String { + let state = if self.active.load(Ordering::Relaxed) { + "active" + } else { + "finished" + }; + format!("CallbackSubscription({}, {})", self.kind, state) + } + + fn __del__(&self) { + // Set the cancel flag so the background thread will exit. + // Do NOT join here — the background thread may be inside + // Python::with_gil() while the GC holds the GIL, which + // would deadlock. + self.cancelled.store(true, Ordering::Relaxed); + } +} + +// --------------------------------------------------------------------------- +// Factory methods (pub(crate), not exposed to Python) +// --------------------------------------------------------------------------- + +impl CallbackSubscription { + pub(crate) fn from_query_stream( + stream: laminar_db::api::QueryStream, + on_data: PyObject, + on_error: Option, + ) -> Self { + Self::spawn(SubscriptionKind::Query(stream), on_data, on_error, "query") + } + + pub(crate) fn from_arrow_subscription( + sub: laminar_db::api::ArrowSubscription, + on_data: PyObject, + on_error: Option, + ) -> Self { + Self::spawn(SubscriptionKind::Stream(sub), on_data, on_error, "stream") + } + + fn spawn( + mut sub: SubscriptionKind, + on_data: PyObject, + on_error: Option, + kind: &str, + ) -> Self { + let cancelled = Arc::new(AtomicBool::new(false)); + let active = Arc::new(AtomicBool::new(true)); + let cancelled2 = cancelled.clone(); + let active2 = active.clone(); + let kind_str = kind.to_owned(); + + let handle = thread::Builder::new() + .name(format!("laminardb-callback-{}", kind)) + .spawn(move || { + let _rt = runtime().enter(); + callback_thread_loop(&mut sub, &cancelled2, &on_data, &on_error); + active2.store(false, Ordering::Relaxed); + }) + .expect("failed to spawn callback thread"); + + Self { + cancelled, + active, + thread: Mutex::new(Some(handle)), + kind: kind_str, + } + } +} + +// --------------------------------------------------------------------------- +// Background thread loop +// --------------------------------------------------------------------------- + +fn callback_thread_loop( + sub: &mut SubscriptionKind, + cancelled: &AtomicBool, + on_data: &PyObject, + on_error: &Option, +) { + loop { + if cancelled.load(Ordering::Relaxed) { + sub.cancel(); + break; + } + + if !sub.is_active() { + break; + } + + match sub.try_next() { + Ok(Some(batch)) => { + let result = QueryResult::from_batch(batch); + let should_stop = Python::with_gil(|py| match on_data.call1(py, (result,)) { + Ok(_) => false, + Err(e) => handle_callback_error(py, e, on_error), + }); + if should_stop { + sub.cancel(); + break; + } + } + Ok(None) => { + // No data ready yet — sleep briefly to avoid busy-spin + thread::sleep(Duration::from_millis(50)); + } + Err(e) => { + let msg = format!("{}", e); + let should_stop = Python::with_gil(|py| { + let py_err = pyo3::exceptions::PyRuntimeError::new_err(msg); + handle_callback_error(py, py_err, on_error) + }); + if should_stop { + sub.cancel(); + break; + } + } + } + } +} + +/// Route an error to on_error if provided; returns `true` if the loop should stop. +fn handle_callback_error(py: Python<'_>, error: PyErr, on_error: &Option) -> bool { + match on_error { + Some(handler) => { + let msg = error.to_string(); + match handler.call1(py, (msg,)) { + Ok(_) => false, // on_error handled it, continue + Err(e) => { + eprintln!("laminardb: on_error callback raised: {}", e); + true + } + } + } + None => { + eprintln!("laminardb: callback error (no on_error handler): {}", error); + true + } + } +} diff --git a/src/checkpoint.rs b/src/checkpoint.rs new file mode 100644 index 0000000..6000cd5 --- /dev/null +++ b/src/checkpoint.rs @@ -0,0 +1,39 @@ +//! Checkpoint result class. + +use pyo3::prelude::*; + +/// The result of a checkpoint operation. +#[pyclass(name = "CheckpointResult", frozen)] +pub struct PyCheckpointResult { + checkpoint_id: u64, +} + +unsafe impl Send for PyCheckpointResult {} +unsafe impl Sync for PyCheckpointResult {} + +#[pymethods] +impl PyCheckpointResult { + /// The checkpoint ID assigned by the database. + #[getter] + fn checkpoint_id(&self) -> u64 { + self.checkpoint_id + } + + fn __bool__(&self) -> bool { + true + } + + fn __int__(&self) -> u64 { + self.checkpoint_id + } + + fn __repr__(&self) -> String { + format!("CheckpointResult(checkpoint_id={})", self.checkpoint_id) + } +} + +impl PyCheckpointResult { + pub fn from_id(checkpoint_id: u64) -> Self { + Self { checkpoint_id } + } +} diff --git a/src/connection.rs b/src/connection.rs index 551eacd..9759217 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -12,7 +12,9 @@ use pyo3::prelude::*; use pyo3_arrow::PySchema; use crate::async_support::{AsyncSubscription, runtime}; +use crate::callback::CallbackSubscription; use crate::catalog::{PyQueryInfo, PySinkInfo, PySourceInfo, PyStreamInfo}; +use crate::checkpoint::PyCheckpointResult; use crate::conversion; use crate::error::{ConnectionError, IntoPyResult, QueryError}; use crate::execute::ExecuteResult; @@ -222,6 +224,60 @@ impl PyConnection { }) } + /// Subscribe to a continuous query with a callback. + /// + /// The `on_data` callable receives a `QueryResult` for each batch. + /// If `on_error` is provided, it is called with an error message string + /// when the callback or subscription raises; if it also raises (or is + /// absent), the subscription stops. + #[pyo3(signature = (sql, on_data, on_error = None))] + fn subscribe_callback( + &self, + py: Python<'_>, + sql: &str, + on_data: PyObject, + on_error: Option, + ) -> PyResult { + self.check_closed()?; + let inner = self.inner.clone(); + let sql = sql.to_owned(); + let stream = py.allow_threads(|| { + let _rt = runtime().enter(); + let conn = inner.lock(); + conn.query_stream(&sql).into_pyresult() + })?; + Ok(CallbackSubscription::from_query_stream( + stream, on_data, on_error, + )) + } + + /// Subscribe to a named stream with a callback. + /// + /// The `on_data` callable receives a `QueryResult` for each batch. + /// If `on_error` is provided, it is called with an error message string + /// when the callback or subscription raises; if it also raises (or is + /// absent), the subscription stops. + #[pyo3(signature = (name, on_data, on_error = None))] + fn subscribe_stream_callback( + &self, + py: Python<'_>, + name: &str, + on_data: PyObject, + on_error: Option, + ) -> PyResult { + self.check_closed()?; + let inner = self.inner.clone(); + let name = name.to_owned(); + let sub = py.allow_threads(|| { + let _rt = runtime().enter(); + let conn = inner.lock(); + conn.subscribe(&name).into_pyresult() + })?; + Ok(CallbackSubscription::from_arrow_subscription( + sub, on_data, on_error, + )) + } + /// Get the schema of a source or stream as a PyArrow Schema. fn schema(&self, py: Python<'_>, table: &str) -> PyResult> { self.check_closed()?; @@ -336,15 +392,16 @@ impl PyConnection { self.closed } - /// Trigger a checkpoint. Returns the checkpoint ID on success, or None. - fn checkpoint(&self, py: Python<'_>) -> PyResult> { + /// Trigger a checkpoint. Returns a CheckpointResult on success. + fn checkpoint(&self, py: Python<'_>) -> PyResult { self.check_closed()?; let inner = self.inner.clone(); - py.allow_threads(|| { + let id = py.allow_threads(|| { let _rt = runtime().enter(); let conn = inner.lock(); - conn.checkpoint().into_pyresult().map(Some) - }) + conn.checkpoint().into_pyresult() + })?; + Ok(PyCheckpointResult::from_id(id)) } /// Whether checkpointing is enabled for this connection. diff --git a/src/lib.rs b/src/lib.rs index 7d5f5f6..5b06f86 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,7 +10,9 @@ #![allow(deprecated)] mod async_support; +mod callback; mod catalog; +mod checkpoint; mod config; mod connection; mod conversion; @@ -42,9 +44,13 @@ fn _laminardb(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; + // Checkpoint + m.add_class::()?; + // Catalog info classes m.add_class::()?; m.add_class::()?; diff --git a/tests/test_callback_subscription.py b/tests/test_callback_subscription.py new file mode 100644 index 0000000..abc18cb --- /dev/null +++ b/tests/test_callback_subscription.py @@ -0,0 +1,251 @@ +"""Tests for callback-based push subscriptions (CallbackSubscription).""" + +import threading +import time + +import pytest + +import laminardb + + +@pytest.fixture +def conn(tmp_path): + """A connection with a source and a named stream.""" + c = laminardb.open(str(tmp_path / "callback_test.db")) + c.create_table( + "events", + {"id": "int64", "msg": "string"}, + ) + c.execute("CREATE STREAM filtered AS SELECT * FROM events WHERE id > 0") + c.start() + yield c + c.close() + + +class TestSubscribeCallback: + """Tests for subscribe_callback (SQL query → callback).""" + + def test_creates_active_handle(self, conn): + handle = conn.subscribe_callback( + "SELECT * FROM events", lambda batch: None + ) + assert handle.is_active + handle.cancel() + handle.wait() + + def test_cancel_sets_inactive(self, conn): + handle = conn.subscribe_callback( + "SELECT * FROM events", lambda batch: None + ) + handle.cancel() + handle.wait() + assert not handle.is_active + + def test_double_cancel_is_safe(self, conn): + handle = conn.subscribe_callback( + "SELECT * FROM events", lambda batch: None + ) + handle.cancel() + handle.cancel() # should not raise + handle.wait() + + def test_wait_returns_after_cancel(self, conn): + handle = conn.subscribe_callback( + "SELECT * FROM events", lambda batch: None + ) + handle.cancel() + handle.wait() # should return promptly + + def test_double_wait_is_noop(self, conn): + handle = conn.subscribe_callback( + "SELECT * FROM events", lambda batch: None + ) + handle.cancel() + handle.wait() + handle.wait() # second wait should be a no-op + + def test_repr_shows_query_active(self, conn): + handle = conn.subscribe_callback( + "SELECT * FROM events", lambda batch: None + ) + r = repr(handle) + assert "query" in r + assert "active" in r + handle.cancel() + handle.wait() + + def test_repr_shows_finished_after_cancel(self, conn): + handle = conn.subscribe_callback( + "SELECT * FROM events", lambda batch: None + ) + handle.cancel() + handle.wait() + r = repr(handle) + assert "finished" in r + + def test_callback_receives_data(self, conn): + received = threading.Event() + results = [] + + def on_data(batch): + results.append(batch) + received.set() + + handle = conn.subscribe_callback("SELECT * FROM events", on_data) + conn.insert("events", {"id": 1, "msg": "hello"}) + got_data = received.wait(timeout=5.0) + handle.cancel() + handle.wait() + if got_data: + assert len(results) > 0 + assert results[0].num_rows > 0 + else: + pytest.skip("data did not arrive within timeout") + + def test_on_error_called_when_callback_raises(self, conn): + errors = [] + error_event = threading.Event() + + def bad_callback(batch): + raise ValueError("test error") + + def on_error(msg): + errors.append(msg) + error_event.set() + + handle = conn.subscribe_callback( + "SELECT * FROM events", bad_callback, on_error + ) + conn.insert("events", {"id": 1, "msg": "hello"}) + got_error = error_event.wait(timeout=5.0) + handle.cancel() + handle.wait() + if got_error: + assert len(errors) > 0 + assert "test error" in errors[0] + + def test_stops_without_on_error_when_callback_raises(self, conn): + called = threading.Event() + + def bad_callback(batch): + called.set() + raise ValueError("boom") + + handle = conn.subscribe_callback( + "SELECT * FROM events", bad_callback + ) + conn.insert("events", {"id": 1, "msg": "hello"}) + got_data = called.wait(timeout=5.0) + # Cancel to ensure thread exits even if data never arrived + handle.cancel() + handle.wait() + if got_data: + assert not handle.is_active + else: + pytest.skip("data did not arrive within timeout") + + def test_wait_releases_gil(self, conn): + """Verify wait() doesn't deadlock when another thread needs the GIL.""" + handle = conn.subscribe_callback( + "SELECT * FROM events", lambda batch: None + ) + + result = [None] + + def other_thread(): + # This runs while wait() is blocking — if wait() held the GIL, + # this thread would deadlock trying to acquire it. + result[0] = 1 + 1 + + handle.cancel() + t = threading.Thread(target=other_thread) + t.start() + handle.wait() + t.join(timeout=2.0) + assert result[0] == 2 + + +class TestSubscribeStreamCallback: + """Tests for subscribe_stream_callback (named stream → callback).""" + + def test_creates_active_handle(self, conn): + handle = conn.subscribe_stream_callback( + "filtered", lambda batch: None + ) + assert handle.is_active + handle.cancel() + handle.wait() + + def test_cancel_sets_inactive(self, conn): + handle = conn.subscribe_stream_callback( + "filtered", lambda batch: None + ) + handle.cancel() + handle.wait() + assert not handle.is_active + + def test_double_cancel_is_safe(self, conn): + handle = conn.subscribe_stream_callback( + "filtered", lambda batch: None + ) + handle.cancel() + handle.cancel() + handle.wait() + + def test_repr_shows_stream(self, conn): + handle = conn.subscribe_stream_callback( + "filtered", lambda batch: None + ) + r = repr(handle) + assert "stream" in r + handle.cancel() + handle.wait() + + def test_callback_receives_stream_data(self, conn): + received = threading.Event() + results = [] + + def on_data(batch): + results.append(batch) + received.set() + + handle = conn.subscribe_stream_callback("filtered", on_data) + conn.insert("events", {"id": 1, "msg": "hello"}) + got_data = received.wait(timeout=5.0) + handle.cancel() + handle.wait() + if got_data: + assert len(results) > 0 + assert results[0].num_rows > 0 + else: + pytest.skip("data did not arrive within timeout") + + def test_on_error_with_stream_callback(self, conn): + errors = [] + error_event = threading.Event() + + def bad_callback(batch): + raise RuntimeError("stream error") + + def on_error(msg): + errors.append(msg) + error_event.set() + + handle = conn.subscribe_stream_callback( + "filtered", bad_callback, on_error + ) + conn.insert("events", {"id": 1, "msg": "hello"}) + got_error = error_event.wait(timeout=5.0) + handle.cancel() + handle.wait() + if got_error: + assert len(errors) > 0 + assert "stream error" in errors[0] + + def test_wait_after_cancel(self, conn): + handle = conn.subscribe_stream_callback( + "filtered", lambda batch: None + ) + handle.cancel() + handle.wait() # should return promptly + assert not handle.is_active diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py new file mode 100644 index 0000000..f7f2790 --- /dev/null +++ b/tests/test_checkpoint.py @@ -0,0 +1,79 @@ +"""Tests for checkpoint APIs.""" + +import pytest + +import laminardb + + +class TestCheckpointDisabled: + """Tests for checkpoint operations when checkpointing is disabled (default).""" + + def test_checkpoint_raises_when_disabled(self, db): + with pytest.raises(laminardb.LaminarError): + db.checkpoint() + + def test_is_checkpoint_enabled_false(self, db): + assert db.is_checkpoint_enabled is False + + +class TestCheckpointEnabled: + """Tests for checkpoint operations when checkpointing is enabled.""" + + @pytest.fixture + def ckpt_db(self, tmp_path): + storage = tmp_path / "storage" + storage.mkdir() + config = laminardb.LaminarConfig( + storage_dir=str(storage), + checkpoint_interval_ms=60_000, + ) + conn = laminardb.open("ckpt_test", config=config) + conn.create_table( + "events", + {"ts": "int64", "value": "float64"}, + ) + conn.start() + yield conn + conn.close() + + def test_is_checkpoint_enabled_true(self, ckpt_db): + assert ckpt_db.is_checkpoint_enabled is True + + def test_checkpoint_returns_result(self, ckpt_db): + result = ckpt_db.checkpoint() + assert isinstance(result, laminardb.CheckpointResult) + assert result.checkpoint_id >= 0 + + def test_checkpoint_result_bool(self, ckpt_db): + result = ckpt_db.checkpoint() + assert bool(result) is True + + def test_checkpoint_result_int(self, ckpt_db): + result = ckpt_db.checkpoint() + assert int(result) == result.checkpoint_id + + def test_checkpoint_result_repr(self, ckpt_db): + result = ckpt_db.checkpoint() + assert "CheckpointResult" in repr(result) + assert str(result.checkpoint_id) in repr(result) + + def test_multiple_checkpoints_increasing_ids(self, ckpt_db): + r1 = ckpt_db.checkpoint() + r2 = ckpt_db.checkpoint() + assert r2.checkpoint_id > r1.checkpoint_id + + +class TestCheckpointAfterClose: + """Tests that checkpoint methods raise after the connection is closed.""" + + def test_checkpoint_after_close_raises(self, tmp_path): + conn = laminardb.open(str(tmp_path / "test.db")) + conn.close() + with pytest.raises(laminardb.ConnectionError, match="closed"): + conn.checkpoint() + + def test_is_checkpoint_enabled_after_close_raises(self, tmp_path): + conn = laminardb.open(str(tmp_path / "test.db")) + conn.close() + with pytest.raises(laminardb.ConnectionError, match="closed"): + _ = conn.is_checkpoint_enabled