diff --git a/plugins/deepgram/tests/test_deepgram_stt_close.py b/plugins/deepgram/tests/test_deepgram_stt_close.py new file mode 100644 index 000000000..1269f4abe --- /dev/null +++ b/plugins/deepgram/tests/test_deepgram_stt_close.py @@ -0,0 +1,11 @@ +from vision_agents.plugins import deepgram + + +class TestDeepgramSTTClose: + async def test_close_closes_http_client(self): + stt = deepgram.STT(api_key="fake") + httpx_client = stt.client._client_wrapper.httpx_client.httpx_client + + assert httpx_client.is_closed is False + await stt.close() + assert httpx_client.is_closed is True diff --git a/plugins/deepgram/vision_agents/plugins/deepgram/deepgram_stt.py b/plugins/deepgram/vision_agents/plugins/deepgram/deepgram_stt.py index ec7a68b21..c31284439 100644 --- a/plugins/deepgram/vision_agents/plugins/deepgram/deepgram_stt.py +++ b/plugins/deepgram/vision_agents/plugins/deepgram/deepgram_stt.py @@ -312,3 +312,10 @@ async def close(self): self.connection = None self._connection_context = None self._connection_ready.clear() + + # SDK doesn't expose a public aclose() - workaround using internals + wrapper = getattr(self.client, "_client_wrapper", None) + http_client = getattr(wrapper, "httpx_client", None) + httpx_client = getattr(http_client, "httpx_client", None) + if httpx_client is not None: + await httpx_client.aclose() diff --git a/plugins/elevenlabs/tests/test_tts_close.py b/plugins/elevenlabs/tests/test_tts_close.py new file mode 100644 index 000000000..a88ad87ec --- /dev/null +++ b/plugins/elevenlabs/tests/test_tts_close.py @@ -0,0 +1,11 @@ +from vision_agents.plugins import elevenlabs + + +class TestElevenLabsTTSClose: + async def test_close_closes_http_client(self): + tts = elevenlabs.TTS(api_key="fake") + httpx_client = tts.client._client_wrapper.httpx_client.httpx_client + + assert httpx_client.is_closed is False + await tts.close() + assert httpx_client.is_closed is True diff --git a/plugins/elevenlabs/vision_agents/plugins/elevenlabs/tts.py b/plugins/elevenlabs/vision_agents/plugins/elevenlabs/tts.py index 95a8abd69..230329c2a 100644 --- a/plugins/elevenlabs/vision_agents/plugins/elevenlabs/tts.py +++ b/plugins/elevenlabs/vision_agents/plugins/elevenlabs/tts.py @@ -63,6 +63,17 @@ async def stream_audio( audio_stream, sample_rate=16000, channels=1, format=AudioFormat.S16 ) + async def close(self) -> None: + # SDK doesn't expose a public aclose() - workaround using internals + try: + wrapper = getattr(self.client, "_client_wrapper", None) + http_client = getattr(wrapper, "httpx_client", None) + httpx_client = getattr(http_client, "httpx_client", None) + if httpx_client is not None: + await httpx_client.aclose() + finally: + await super().close() + async def stop_audio(self) -> None: """ Clears the queue and stops playing audio. diff --git a/plugins/getstream/tests/test_stream_edge_transport.py b/plugins/getstream/tests/test_stream_edge_transport.py index 62cf397e6..ae7a01c56 100644 --- a/plugins/getstream/tests/test_stream_edge_transport.py +++ b/plugins/getstream/tests/test_stream_edge_transport.py @@ -1,6 +1,6 @@ import asyncio import time -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock from uuid import uuid4 import pytest @@ -106,3 +106,16 @@ async def test_create_call_raises_before_authenticate( ): with pytest.raises(RuntimeError, match="not authenticated"): await stream_edge.create_call(call_id="call-1") + + async def test_close_releases_client_resources(self, stream_edge: StreamEdge): + stream_edge._real_connection = AsyncMock() + real_connection = stream_edge._real_connection + + assert stream_edge.client.client.is_closed is False + assert stream_edge._real_connection is not None + + await stream_edge.close() + + assert stream_edge.client.client.is_closed is True + assert stream_edge._real_connection is None + real_connection.leave.assert_called_once() diff --git a/plugins/getstream/vision_agents/plugins/getstream/stream_edge_transport.py b/plugins/getstream/vision_agents/plugins/getstream/stream_edge_transport.py index e6eb8b79c..bdaf31356 100644 --- a/plugins/getstream/vision_agents/plugins/getstream/stream_edge_transport.py +++ b/plugins/getstream/vision_agents/plugins/getstream/stream_edge_transport.py @@ -418,6 +418,8 @@ async def join( connection = await rtc.join( call, agent.agent_user.id, subscription_config=subscription_config ) + # Store immediately so close() can clean up if join is interrupted + self._real_connection = connection @connection.on("track_added") async def on_track(track_id, track_type, user): @@ -446,7 +448,6 @@ async def on_audio_received(pcm: PcmData): # Start the connection await connection.__aenter__() - self._real_connection = connection self._call = call # Re-publish already published tracks in case somebody is already on the call when we joined. # Otherwise, we won't get the video track from participants joined before us. @@ -496,6 +497,14 @@ def _get_subscription_config(self): ) async def close(self): + if self._real_connection: + try: + await self._real_connection.leave() + except Exception: + logger.exception("Error during connection leave") + self._real_connection = None + if self.client: + await self.client.aclose() self._call = None async def send_custom_event(self, data: dict) -> None: