diff --git a/agents-core/vision_agents/core/agents/agent_session.py b/agents-core/vision_agents/core/agents/agent_session.py index a50a635ec..0d16c4073 100644 --- a/agents-core/vision_agents/core/agents/agent_session.py +++ b/agents-core/vision_agents/core/agents/agent_session.py @@ -1,7 +1,9 @@ import asyncio import contextvars +import typing -from vision_agents.core.agents import Agent +if typing.TYPE_CHECKING: + from .agents import Agent class AgentSessionContextManager: @@ -27,7 +29,7 @@ class AgentSessionContextManager: returned by the edge transport (kept open during the context). """ - def __init__(self, agent: Agent, connection_cm=None): + def __init__(self, agent: "Agent", connection_cm=None): self.agent = agent self._connection_cm = connection_cm diff --git a/agents-core/vision_agents/core/agents/agents.py b/agents-core/vision_agents/core/agents/agents.py index ab2dbd4c8..34d1a7b18 100644 --- a/agents-core/vision_agents/core/agents/agents.py +++ b/agents-core/vision_agents/core/agents/agents.py @@ -62,13 +62,12 @@ from opentelemetry import trace, context as otel_context from opentelemetry.trace import Tracer from opentelemetry.context import Token +from .agent_session import AgentSessionContextManager if TYPE_CHECKING: from vision_agents.plugins.getstream.stream_edge_transport import StreamEdge - from .agent_session import AgentSessionContextManager - logger = logging.getLogger(__name__) tracer: Tracer = trace.get_tracer("agents") @@ -546,18 +545,6 @@ async def join( with self.span("edge.publish_tracks"): await self.edge.publish_tracks(audio_track, video_track) - connection._connection._coordinator_ws_client.on_wildcard( - "*", - lambda event_name, event: self.events.send(event), - ) - - connection._connection._ws_client.on_wildcard( - "*", - lambda event_name, event: self.events.send(event), - ) - - from .agent_session import AgentSessionContextManager - # wait for conversation creation coro at the very end of the join flow self.conversation = await create_conversation_coro # Provide conversation to the LLM so it can access the chat history. diff --git a/plugins/getstream/pyproject.toml b/plugins/getstream/pyproject.toml index 07b405e95..13525a65b 100644 --- a/plugins/getstream/pyproject.toml +++ b/plugins/getstream/pyproject.toml @@ -12,7 +12,7 @@ requires-python = ">=3.10" license = "MIT" dependencies = [ "vision-agents", - "getstream[webrtc,telemetry]>=2.5.0", + "getstream[webrtc,telemetry]>=2.5.18", ] [project.urls] diff --git a/plugins/getstream/tests/test_stream_conversation.py b/plugins/getstream/tests/test_stream_conversation.py index 7818bd1ec..ab53cf630 100644 --- a/plugins/getstream/tests/test_stream_conversation.py +++ b/plugins/getstream/tests/test_stream_conversation.py @@ -7,7 +7,12 @@ from unittest.mock import Mock, AsyncMock from dotenv import load_dotenv -from getstream.models import MessageRequest, ChannelInput, MessagePaginationParams +from getstream.models import ( + Message, + MessageRequest, + ChannelInput, + MessagePaginationParams, +) from getstream import AsyncStream from vision_agents.plugins.getstream.stream_conversation import StreamConversation @@ -48,7 +53,7 @@ def mock_channel(self): def stream_conversation(self, mock_channel): """Create a StreamConversation instance with mocked dependencies.""" instructions = "You are a helpful assistant." - messages = [] + messages: list[Message] = [] conversation = StreamConversation( instructions=instructions, messages=messages, channel=mock_channel ) diff --git a/plugins/getstream/vision_agents/plugins/getstream/stream_edge_transport.py b/plugins/getstream/vision_agents/plugins/getstream/stream_edge_transport.py index e148f8fe4..9af21c0cd 100644 --- a/plugins/getstream/vision_agents/plugins/getstream/stream_edge_transport.py +++ b/plugins/getstream/vision_agents/plugins/getstream/stream_edge_transport.py @@ -43,9 +43,9 @@ def __init__(self, connection: ConnectionManager): def participants(self) -> ParticipantsState: return self._connection.participants_state - async def close(self): + async def close(self, timeout: float = 2.0): try: - await asyncio.wait_for(self._connection.leave(), timeout=2.0) + await asyncio.wait_for(self._connection.leave(), timeout=timeout) except asyncio.TimeoutError: logger.warning("Connection leave timed out during close") except RuntimeError as e: @@ -83,11 +83,19 @@ def __init__(self, **kwargs): # track_id -> (user_id, session_id, webrtc_type_string) self._pending_tracks: dict = {} + self._real_connection: Optional[ConnectionManager] = None + # Register event handlers self.events.subscribe(self._on_track_published) self.events.subscribe(self._on_track_removed) self.events.subscribe(self._on_call_ended) + @property + def _connection(self) -> ConnectionManager: + if self._real_connection is None: + raise ValueError("Edge connection is not set") + return self._real_connection + def _get_webrtc_kind(self, track_type_int: int) -> str: """Get the expected WebRTC kind (audio/video) for a SFU track type.""" # Map SFU track types to WebRTC kinds @@ -105,96 +113,6 @@ def _get_webrtc_kind(self, track_type_int: int) -> str: # Default to video for unknown types return "video" - async def _subscribe_to_existing_tracks( - self, connection: ConnectionManager - ) -> None: - """Subscribe to tracks from participants who joined before the agent.""" - from vision_agents.core.edge.sfu_events import Participant as SfuParticipant - - participants = connection.participants_state.get_participants() - subscription_manager = connection._subscription_manager - tracks_to_subscribe = [] - - for participant in participants: - if participant.user_id == self.agent_user_id: - continue - - for track_type_int in participant.published_tracks: - # Create a mock event for the subscription manager - class MockTrackPublishedEvent: - def __init__(self, p, track_type): - self.user_id = p.user_id - self.session_id = p.session_id - self.type = track_type - self.participant = p - - mock_event = MockTrackPublishedEvent(participant, track_type_int) - - try: - await subscription_manager.handle_track_published(mock_event) - tracks_to_subscribe.append((participant, track_type_int)) - except Exception as e: - logger.error(f"Failed to subscribe to existing track: {e}") - - # Poll for WebRTC tracks to arrive after subscription - for participant, track_type_int in tracks_to_subscribe: - expected_kind = self._get_webrtc_kind(track_type_int) - track_key = ( - participant.user_id, - participant.session_id, - track_type_int, - ) - - if track_key in self._track_map: - continue - - # Poll for WebRTC track ID with timeout (same pattern as _on_track_published) - track_id = None - timeout = 10.0 - poll_interval = 0.01 - elapsed = 0.0 - - while elapsed < timeout: - for tid, (pending_user, pending_session, pending_kind) in list( - self._pending_tracks.items() - ): - if ( - pending_user == participant.user_id - and pending_session == participant.session_id - and pending_kind == expected_kind - ): - track_id = tid - del self._pending_tracks[tid] - break - - if track_id: - break - - await asyncio.sleep(poll_interval) - elapsed += poll_interval - - if track_id: - self._track_map[track_key] = { - "track_id": track_id, - "published": True, - } - sfu_participant = SfuParticipant.from_proto(participant) - - self.events.send( - events.TrackAddedEvent( - plugin_name="getstream", - track_id=track_id, - track_type=track_type_int, - user=sfu_participant, - participant=sfu_participant, - ) - ) - else: - logger.warning( - f"No pending track for existing participant: " - f"user={participant.user_id}, type={TrackType.Name(track_type_int)}" - ) - async def _on_track_published(self, event: sfu_events.TrackPublishedEvent): """Handle track published events from SFU - spawn TrackAddedEvent with correct type.""" if not event.payload: @@ -366,13 +284,10 @@ async def join(self, agent: "Agent", call: Call) -> StreamConnection: This function - initializes the chat channel - has the agent.agent_user join the call - - connect incoming audio/video to the agent + - connects incoming audio/video to the agent - connecting agent's outgoing audio/video to the call - - TODO: - - process track flow - """ + # Traditional mode - use WebRTC connection # Configure subscription for audio and video subscription_config = SubscriptionConfig( @@ -401,13 +316,20 @@ async def on_audio_received(pcm: PcmData): ) ) - await ( - connection.__aenter__() - ) # TODO: weird API? there should be a manual version - self._connection = connection - - # Subscribe to tracks from participants who joined before the agent - await self._subscribe_to_existing_tracks(connection) + # Re-emit certain events from the underlying RTC stack + # for the Agent to subscribe. + connection.on("participant_joined", self.events.send) + connection.on("participant_left", self.events.send) + connection.on("track_published", self.events.send) + connection.on("track_unpublished", self.events.send) + connection.on("call_ended", self.events.send) + + # Start the connection + await connection.__aenter__() + # Re-publish already published tracks in case somebody is already on the call when we joined. + # Otherwise, we won't get the video track from participants joined before us. + await connection.republish_tracks() + self._real_connection = connection standardize_connection = StreamConnection(connection) return standardize_connection diff --git a/uv.lock b/uv.lock index b0b59703b..bf253724f 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.12" resolution-markers = [ "python_full_version >= '3.13' and sys_platform == 'win32'", @@ -1584,7 +1584,7 @@ wheels = [ [[package]] name = "getstream" -version = "2.5.16" +version = "2.5.18" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "dataclasses-json" }, @@ -1599,7 +1599,7 @@ dependencies = [ { name = "twirp" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/ff/b6/0bbb435bc4bb8f8da2bc85a26afa203bf74a08bb55acc15db74cf2dc6ad6/getstream-2.5.16-py3-none-any.whl", hash = "sha256:b51629a2bc838d4f596c46e153d5601600f2bc3259432f1877af8abe9a2ab76b", size = 247581, upload-time = "2025-11-14T19:20:18.281Z" }, + { url = "https://files.pythonhosted.org/packages/e0/04/184b3c72f9c99dc1cf2ae56aca3ba2286637fdd576905c196d7c08832791/getstream-2.5.18-py3-none-any.whl", hash = "sha256:def4763c666494825eb893d55d0d222cd4db2ff03e7aa1d42ce8ee3c9cc8bd77", size = 247981, upload-time = "2025-12-12T14:20:00.031Z" }, ] [package.optional-dependencies] @@ -6494,7 +6494,7 @@ dev = [ [package.metadata] requires-dist = [ - { name = "getstream", extras = ["telemetry", "webrtc"], specifier = ">=2.5.0" }, + { name = "getstream", extras = ["telemetry", "webrtc"], specifier = ">=2.5.18" }, { name = "vision-agents", editable = "agents-core" }, ]