diff --git a/python/laminardb/_laminardb.pyi b/python/laminardb/_laminardb.pyi index 839a3b0..778b31e 100644 --- a/python/laminardb/_laminardb.pyi +++ b/python/laminardb/_laminardb.pyi @@ -619,6 +619,18 @@ class AsyncStreamSubscription: """The subscription schema as a PyArrow Schema.""" ... + def next(self) -> QueryResult | None: + """Blocking wait for the next batch.""" + ... + + def next_timeout(self, timeout_ms: int) -> QueryResult | None: + """Blocking wait for the next batch with a timeout in milliseconds.""" + ... + + def try_next(self) -> QueryResult | None: + """Non-blocking poll for the next batch.""" + ... + def cancel(self) -> None: """Cancel the subscription.""" ... diff --git a/src/stream_subscription.rs b/src/stream_subscription.rs index 0121857..656c717 100644 --- a/src/stream_subscription.rs +++ b/src/stream_subscription.rs @@ -214,6 +214,63 @@ impl AsyncStreamSubscription { } } + /// Blocking wait for the next batch. + fn next(&self, py: Python<'_>) -> PyResult> { + let has_sub = self.inner.lock().is_some(); + if !has_sub { + return Ok(None); + } + + py.allow_threads(|| { + let _rt = runtime().enter(); + let mut guard = self.inner.lock(); + let sub = guard.as_mut().unwrap(); + match sub.next().into_pyresult()? { + Some(batch) => Ok(Some(QueryResult::from_batch(batch))), + None => Ok(None), + } + }) + } + + /// Blocking wait for the next batch with a timeout in milliseconds. + /// + /// Returns None if the timeout expires without data. + fn next_timeout(&self, py: Python<'_>, timeout_ms: u64) -> PyResult> { + let has_sub = self.inner.lock().is_some(); + if !has_sub { + return Ok(None); + } + + let timeout = Duration::from_millis(timeout_ms); + py.allow_threads(|| { + let _rt = runtime().enter(); + let mut guard = self.inner.lock(); + let sub = guard.as_mut().unwrap(); + match sub.next_timeout(timeout).into_pyresult()? { + Some(batch) => Ok(Some(QueryResult::from_batch(batch))), + None => Ok(None), + } + }) + } + + /// Non-blocking poll for the next batch. + fn try_next(&self, py: Python<'_>) -> PyResult> { + let has_sub = self.inner.lock().is_some(); + if !has_sub { + return Ok(None); + } + + py.allow_threads(|| { + let _rt = runtime().enter(); + let mut guard = self.inner.lock(); + let sub = guard.as_mut().unwrap(); + match sub.try_next().into_pyresult()? { + Some(batch) => Ok(Some(QueryResult::from_batch(batch))), + None => Ok(None), + } + }) + } + /// Cancel the subscription. fn cancel(&self, py: Python<'_>) -> PyResult<()> { py.allow_threads(|| { diff --git a/tests/test_shutdown.py b/tests/test_shutdown.py index 7a1717d..a504140 100644 --- a/tests/test_shutdown.py +++ b/tests/test_shutdown.py @@ -20,8 +20,8 @@ def test_shutdown_after_close_raises(self, tmp_path): class TestCancelQuery: def test_cancel_query_invalid_id(self, db): - # Cancelling a non-existent query is a no-op (no error raised) - db.cancel_query(999999) + with pytest.raises(laminardb.SchemaError, match="not found"): + db.cancel_query(999999) def test_cancel_query_after_close_raises(self, tmp_path): conn = laminardb.open(str(tmp_path / "test.db")) diff --git a/tests/test_stream_subscription.py b/tests/test_stream_subscription.py index bb4ebb3..e5a63d2 100644 --- a/tests/test_stream_subscription.py +++ b/tests/test_stream_subscription.py @@ -136,3 +136,38 @@ async def test_async_schema_after_cancel_raises(self, conn): sub.cancel() with pytest.raises(RuntimeError, match="cancelled"): _ = sub.schema + + @pytest.mark.asyncio + async def test_async_try_next_no_data(self, conn): + sub = await conn.subscribe_stream_async("filtered") + # No data inserted yet, try_next should return None + result = sub.try_next() + assert result is None + sub.cancel() + + @pytest.mark.asyncio + async def test_async_try_next_after_cancel(self, conn): + sub = await conn.subscribe_stream_async("filtered") + sub.cancel() + result = sub.try_next() + assert result is None + + @pytest.mark.asyncio + async def test_async_next_timeout_no_data(self, conn): + sub = await conn.subscribe_stream_async("filtered") + # No data inserted, timeout raises SubscriptionError + with pytest.raises(laminardb.SubscriptionError, match="timeout"): + sub.next_timeout(100) + sub.cancel() + + @pytest.mark.asyncio + async def test_async_subscribe_stream_with_data(self, conn): + sub = await conn.subscribe_stream_async("filtered") + conn.insert("events", {"id": 1, "msg": "hello"}) + # Use next_timeout so we don't block forever + result = sub.next_timeout(2000) + if result is not None: + assert result.num_rows > 0 + else: + pytest.skip("data did not arrive within timeout") + sub.cancel()