Skip to content

Commit f9d0cae

Browse files
committed
Allow MJPEGStream to stop cleanly
It's now possible to tell MJPEGStream that the stream has stopped. This terminates the streaming response, and gets around the lack of streaming support in starlette's test client. With a bit of thought, this could potentially fix the longstanding issue that MJPEG Streams prevent a labthings server from restarting.
1 parent 3964d74 commit f9d0cae

File tree

2 files changed

+32
-17
lines changed

2 files changed

+32
-17
lines changed

src/labthings_fastapi/outputs/mjpeg_stream.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,11 @@ def reset(self, ringbuffer_size: Optional[int] = None):
107107
]
108108
self.last_frame_i = -1
109109

110-
def stop(self):
110+
def stop(self, portal: BlockingPortal):
111111
"""Stop the stream"""
112112
with self._lock:
113113
self._streaming = False
114+
portal.start_task_soon(self.notify_stream_stopped)
114115

115116
async def ringbuffer_entry(self, i: int) -> RingbufferEntry:
116117
"""Return the ith frame acquired by the camera
@@ -139,9 +140,13 @@ async def buffer_for_reading(self, i: int) -> AsyncIterator[bytes]:
139140
yield entry.frame
140141

141142
async def next_frame(self) -> int:
142-
"""Wait for the next frame, and return its index"""
143+
"""Wait for the next frame, and return its index
144+
145+
:raises StopAsyncIteration: if the stream has stopped."""
143146
async with self.condition:
144147
await self.condition.wait()
148+
if not self._streaming:
149+
raise StopAsyncIteration()
145150
return self.last_frame_i
146151

147152
async def grab_frame(self) -> bytes:
@@ -170,6 +175,8 @@ async def frame_async_generator(self) -> AsyncGenerator[bytes, None]:
170175
i = await self.next_frame()
171176
async with self.buffer_for_reading(i) as frame:
172177
yield frame
178+
except StopAsyncIteration:
179+
break
173180
except Exception as e:
174181
logging.error(f"Error in stream: {e}, stream stopped")
175182
return
@@ -178,7 +185,7 @@ async def mjpeg_stream_response(self) -> MJPEGStreamResponse:
178185
"""Return a StreamingResponse that streams an MJPEG stream"""
179186
return MJPEGStreamResponse(self.frame_async_generator())
180187

181-
def add_frame(self, frame: bytes, portal: BlockingPortal):
188+
def add_frame(self, frame: bytes, portal: BlockingPortal) -> None:
182189
"""Return the next buffer in the ringbuffer to write to
183190
184191
:param frame: The frame to add
@@ -196,12 +203,18 @@ def add_frame(self, frame: bytes, portal: BlockingPortal):
196203
entry.index = self.last_frame_i + 1
197204
portal.start_task_soon(self.notify_new_frame, entry.index)
198205

199-
async def notify_new_frame(self, i):
206+
async def notify_new_frame(self, i: int) -> None:
200207
"""Notify any waiting tasks that a new frame is available"""
201208
async with self.condition:
202209
self.last_frame_i = i
203210
self.condition.notify_all()
204211

212+
async def notify_stream_stopped(self) -> None:
213+
"""Raise an exception in any waiting tasks to signal the stream has stopped."""
214+
assert self._streaming is False
215+
async with self.condition:
216+
self.condition.notify_all()
217+
205218

206219
class MJPEGStreamDescriptor:
207220
"""A descriptor that returns a MJPEGStream object when accessed

tests/test_mjpeg_stream.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,34 +32,36 @@ def _make_images(self):
3232
image.save(dest, "jpeg")
3333
jpegs.append(dest.getvalue())
3434

35-
i = -1
36-
start_time = time.time()
37-
while self._streaming:
38-
i = (i + 1) % len(jpegs)
39-
print(f"sending frame {i}")
35+
i = 0
36+
while self._streaming and i < len(jpegs):
4037
self.stream.add_frame(jpegs[i], self._labthings_blocking_portal)
4138
time.sleep(1 / self.framerate)
42-
43-
if time.time() - start_time > 10:
44-
break
45-
print("stopped sending frames")
39+
i = i + 1
40+
self.stream.stop(self._labthings_blocking_portal)
4641
self._streaming = False
4742

4843

4944
def test_mjpeg_stream():
45+
"""Verify the MJPEG stream contains at least one frame marker.
46+
47+
A limitation of the TestClient is that it can't actually stream.
48+
This means that all of the frames sent by our test Thing will
49+
arrive in a single packet.
50+
51+
For now, we just check it starts with the frame separator,
52+
but it might be possible in the future to check there are three
53+
images there.
54+
"""
5055
server = lt.ThingServer()
5156
telly = Telly()
5257
server.add_thing(telly, "telly")
5358
with TestClient(server.app) as client:
54-
with client.stream("GET", "/telly/stream", timeout=0.1) as stream:
59+
with client.stream("GET", "/telly/stream") as stream:
5560
stream.raise_for_status()
5661
received = 0
5762
for b in stream.iter_bytes():
5863
received += 1
59-
print(f"Got packet {received}")
6064
assert b.startswith(b"--frame")
61-
if received > 5:
62-
break
6365

6466

6567
if __name__ == "__main__":

0 commit comments

Comments
 (0)