Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions python/laminardb/_laminardb.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,18 @@ class AsyncStreamSubscription:
"""The subscription schema as a PyArrow Schema."""
...

def next(self) -> QueryResult | None:
"""Blocking wait for the next batch."""
...

def next_timeout(self, timeout_ms: int) -> QueryResult | None:
"""Blocking wait for the next batch with a timeout in milliseconds."""
...

def try_next(self) -> QueryResult | None:
"""Non-blocking poll for the next batch."""
...

def cancel(self) -> None:
"""Cancel the subscription."""
...
Expand Down
57 changes: 57 additions & 0 deletions src/stream_subscription.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,63 @@ impl AsyncStreamSubscription {
}
}

/// Blocking wait for the next batch.
fn next(&self, py: Python<'_>) -> PyResult<Option<QueryResult>> {
let has_sub = self.inner.lock().is_some();
if !has_sub {
return Ok(None);
}

py.allow_threads(|| {
let _rt = runtime().enter();
let mut guard = self.inner.lock();
let sub = guard.as_mut().unwrap();
match sub.next().into_pyresult()? {
Some(batch) => Ok(Some(QueryResult::from_batch(batch))),
None => Ok(None),
}
})
}

/// Blocking wait for the next batch with a timeout in milliseconds.
///
/// Returns None if the timeout expires without data.
fn next_timeout(&self, py: Python<'_>, timeout_ms: u64) -> PyResult<Option<QueryResult>> {
let has_sub = self.inner.lock().is_some();
if !has_sub {
return Ok(None);
}

let timeout = Duration::from_millis(timeout_ms);
py.allow_threads(|| {
let _rt = runtime().enter();
let mut guard = self.inner.lock();
let sub = guard.as_mut().unwrap();
match sub.next_timeout(timeout).into_pyresult()? {
Some(batch) => Ok(Some(QueryResult::from_batch(batch))),
None => Ok(None),
}
})
}

/// Non-blocking poll for the next batch.
fn try_next(&self, py: Python<'_>) -> PyResult<Option<QueryResult>> {
let has_sub = self.inner.lock().is_some();
if !has_sub {
return Ok(None);
}

py.allow_threads(|| {
let _rt = runtime().enter();
let mut guard = self.inner.lock();
let sub = guard.as_mut().unwrap();
match sub.try_next().into_pyresult()? {
Some(batch) => Ok(Some(QueryResult::from_batch(batch))),
None => Ok(None),
}
})
}

/// Cancel the subscription.
fn cancel(&self, py: Python<'_>) -> PyResult<()> {
py.allow_threads(|| {
Expand Down
4 changes: 2 additions & 2 deletions tests/test_shutdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ def test_shutdown_after_close_raises(self, tmp_path):

class TestCancelQuery:
def test_cancel_query_invalid_id(self, db):
# Cancelling a non-existent query is a no-op (no error raised)
db.cancel_query(999999)
with pytest.raises(laminardb.SchemaError, match="not found"):
db.cancel_query(999999)

def test_cancel_query_after_close_raises(self, tmp_path):
conn = laminardb.open(str(tmp_path / "test.db"))
Expand Down
35 changes: 35 additions & 0 deletions tests/test_stream_subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,38 @@ async def test_async_schema_after_cancel_raises(self, conn):
sub.cancel()
with pytest.raises(RuntimeError, match="cancelled"):
_ = sub.schema

@pytest.mark.asyncio
async def test_async_try_next_no_data(self, conn):
sub = await conn.subscribe_stream_async("filtered")
# No data inserted yet, try_next should return None
result = sub.try_next()
assert result is None
sub.cancel()

@pytest.mark.asyncio
async def test_async_try_next_after_cancel(self, conn):
sub = await conn.subscribe_stream_async("filtered")
sub.cancel()
result = sub.try_next()
assert result is None

@pytest.mark.asyncio
async def test_async_next_timeout_no_data(self, conn):
sub = await conn.subscribe_stream_async("filtered")
# No data inserted, timeout raises SubscriptionError
with pytest.raises(laminardb.SubscriptionError, match="timeout"):
sub.next_timeout(100)
sub.cancel()

@pytest.mark.asyncio
async def test_async_subscribe_stream_with_data(self, conn):
sub = await conn.subscribe_stream_async("filtered")
conn.insert("events", {"id": 1, "msg": "hello"})
# Use next_timeout so we don't block forever
result = sub.next_timeout(2000)
if result is not None:
assert result.num_rows > 0
else:
pytest.skip("data did not arrive within timeout")
sub.cancel()