Skip to content

Commit 64b2d43

Browse files
authored
Fix SFU events handling inside Agent (#245)
1 parent a0042ea commit 64b2d43

6 files changed

Lines changed: 43 additions & 127 deletions

File tree

agents-core/vision_agents/core/agents/agent_session.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import asyncio
22
import contextvars
3+
import typing
34

4-
from vision_agents.core.agents import Agent
5+
if typing.TYPE_CHECKING:
6+
from .agents import Agent
57

68

79
class AgentSessionContextManager:
@@ -27,7 +29,7 @@ class AgentSessionContextManager:
2729
returned by the edge transport (kept open during the context).
2830
"""
2931

30-
def __init__(self, agent: Agent, connection_cm=None):
32+
def __init__(self, agent: "Agent", connection_cm=None):
3133
self.agent = agent
3234
self._connection_cm = connection_cm
3335

agents-core/vision_agents/core/agents/agents.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,12 @@
6262
from opentelemetry import trace, context as otel_context
6363
from opentelemetry.trace import Tracer
6464
from opentelemetry.context import Token
65+
from .agent_session import AgentSessionContextManager
6566

6667

6768
if TYPE_CHECKING:
6869
from vision_agents.plugins.getstream.stream_edge_transport import StreamEdge
6970

70-
from .agent_session import AgentSessionContextManager
71-
7271
logger = logging.getLogger(__name__)
7372

7473
tracer: Tracer = trace.get_tracer("agents")
@@ -546,18 +545,6 @@ async def join(
546545
with self.span("edge.publish_tracks"):
547546
await self.edge.publish_tracks(audio_track, video_track)
548547

549-
connection._connection._coordinator_ws_client.on_wildcard(
550-
"*",
551-
lambda event_name, event: self.events.send(event),
552-
)
553-
554-
connection._connection._ws_client.on_wildcard(
555-
"*",
556-
lambda event_name, event: self.events.send(event),
557-
)
558-
559-
from .agent_session import AgentSessionContextManager
560-
561548
# wait for conversation creation coro at the very end of the join flow
562549
self.conversation = await create_conversation_coro
563550
# Provide conversation to the LLM so it can access the chat history.

plugins/getstream/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ requires-python = ">=3.10"
1212
license = "MIT"
1313
dependencies = [
1414
"vision-agents",
15-
"getstream[webrtc,telemetry]>=2.5.0",
15+
"getstream[webrtc,telemetry]>=2.5.18",
1616
]
1717

1818
[project.urls]

plugins/getstream/tests/test_stream_conversation.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,12 @@
77
from unittest.mock import Mock, AsyncMock
88
from dotenv import load_dotenv
99

10-
from getstream.models import MessageRequest, ChannelInput, MessagePaginationParams
10+
from getstream.models import (
11+
Message,
12+
MessageRequest,
13+
ChannelInput,
14+
MessagePaginationParams,
15+
)
1116
from getstream import AsyncStream
1217

1318
from vision_agents.plugins.getstream.stream_conversation import StreamConversation
@@ -48,7 +53,7 @@ def mock_channel(self):
4853
def stream_conversation(self, mock_channel):
4954
"""Create a StreamConversation instance with mocked dependencies."""
5055
instructions = "You are a helpful assistant."
51-
messages = []
56+
messages: list[Message] = []
5257
conversation = StreamConversation(
5358
instructions=instructions, messages=messages, channel=mock_channel
5459
)

plugins/getstream/vision_agents/plugins/getstream/stream_edge_transport.py

Lines changed: 26 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ def __init__(self, connection: ConnectionManager):
4343
def participants(self) -> ParticipantsState:
4444
return self._connection.participants_state
4545

46-
async def close(self):
46+
async def close(self, timeout: float = 2.0):
4747
try:
48-
await asyncio.wait_for(self._connection.leave(), timeout=2.0)
48+
await asyncio.wait_for(self._connection.leave(), timeout=timeout)
4949
except asyncio.TimeoutError:
5050
logger.warning("Connection leave timed out during close")
5151
except RuntimeError as e:
@@ -83,11 +83,19 @@ def __init__(self, **kwargs):
8383
# track_id -> (user_id, session_id, webrtc_type_string)
8484
self._pending_tracks: dict = {}
8585

86+
self._real_connection: Optional[ConnectionManager] = None
87+
8688
# Register event handlers
8789
self.events.subscribe(self._on_track_published)
8890
self.events.subscribe(self._on_track_removed)
8991
self.events.subscribe(self._on_call_ended)
9092

93+
@property
94+
def _connection(self) -> ConnectionManager:
95+
if self._real_connection is None:
96+
raise ValueError("Edge connection is not set")
97+
return self._real_connection
98+
9199
def _get_webrtc_kind(self, track_type_int: int) -> str:
92100
"""Get the expected WebRTC kind (audio/video) for a SFU track type."""
93101
# Map SFU track types to WebRTC kinds
@@ -105,96 +113,6 @@ def _get_webrtc_kind(self, track_type_int: int) -> str:
105113
# Default to video for unknown types
106114
return "video"
107115

108-
async def _subscribe_to_existing_tracks(
109-
self, connection: ConnectionManager
110-
) -> None:
111-
"""Subscribe to tracks from participants who joined before the agent."""
112-
from vision_agents.core.edge.sfu_events import Participant as SfuParticipant
113-
114-
participants = connection.participants_state.get_participants()
115-
subscription_manager = connection._subscription_manager
116-
tracks_to_subscribe = []
117-
118-
for participant in participants:
119-
if participant.user_id == self.agent_user_id:
120-
continue
121-
122-
for track_type_int in participant.published_tracks:
123-
# Create a mock event for the subscription manager
124-
class MockTrackPublishedEvent:
125-
def __init__(self, p, track_type):
126-
self.user_id = p.user_id
127-
self.session_id = p.session_id
128-
self.type = track_type
129-
self.participant = p
130-
131-
mock_event = MockTrackPublishedEvent(participant, track_type_int)
132-
133-
try:
134-
await subscription_manager.handle_track_published(mock_event)
135-
tracks_to_subscribe.append((participant, track_type_int))
136-
except Exception as e:
137-
logger.error(f"Failed to subscribe to existing track: {e}")
138-
139-
# Poll for WebRTC tracks to arrive after subscription
140-
for participant, track_type_int in tracks_to_subscribe:
141-
expected_kind = self._get_webrtc_kind(track_type_int)
142-
track_key = (
143-
participant.user_id,
144-
participant.session_id,
145-
track_type_int,
146-
)
147-
148-
if track_key in self._track_map:
149-
continue
150-
151-
# Poll for WebRTC track ID with timeout (same pattern as _on_track_published)
152-
track_id = None
153-
timeout = 10.0
154-
poll_interval = 0.01
155-
elapsed = 0.0
156-
157-
while elapsed < timeout:
158-
for tid, (pending_user, pending_session, pending_kind) in list(
159-
self._pending_tracks.items()
160-
):
161-
if (
162-
pending_user == participant.user_id
163-
and pending_session == participant.session_id
164-
and pending_kind == expected_kind
165-
):
166-
track_id = tid
167-
del self._pending_tracks[tid]
168-
break
169-
170-
if track_id:
171-
break
172-
173-
await asyncio.sleep(poll_interval)
174-
elapsed += poll_interval
175-
176-
if track_id:
177-
self._track_map[track_key] = {
178-
"track_id": track_id,
179-
"published": True,
180-
}
181-
sfu_participant = SfuParticipant.from_proto(participant)
182-
183-
self.events.send(
184-
events.TrackAddedEvent(
185-
plugin_name="getstream",
186-
track_id=track_id,
187-
track_type=track_type_int,
188-
user=sfu_participant,
189-
participant=sfu_participant,
190-
)
191-
)
192-
else:
193-
logger.warning(
194-
f"No pending track for existing participant: "
195-
f"user={participant.user_id}, type={TrackType.Name(track_type_int)}"
196-
)
197-
198116
async def _on_track_published(self, event: sfu_events.TrackPublishedEvent):
199117
"""Handle track published events from SFU - spawn TrackAddedEvent with correct type."""
200118
if not event.payload:
@@ -366,13 +284,10 @@ async def join(self, agent: "Agent", call: Call) -> StreamConnection:
366284
This function
367285
- initializes the chat channel
368286
- has the agent.agent_user join the call
369-
- connect incoming audio/video to the agent
287+
- connects incoming audio/video to the agent
370288
- connecting agent's outgoing audio/video to the call
371-
372-
TODO:
373-
- process track flow
374-
375289
"""
290+
376291
# Traditional mode - use WebRTC connection
377292
# Configure subscription for audio and video
378293
subscription_config = SubscriptionConfig(
@@ -401,13 +316,20 @@ async def on_audio_received(pcm: PcmData):
401316
)
402317
)
403318

404-
await (
405-
connection.__aenter__()
406-
) # TODO: weird API? there should be a manual version
407-
self._connection = connection
408-
409-
# Subscribe to tracks from participants who joined before the agent
410-
await self._subscribe_to_existing_tracks(connection)
319+
# Re-emit certain events from the underlying RTC stack
320+
# for the Agent to subscribe.
321+
connection.on("participant_joined", self.events.send)
322+
connection.on("participant_left", self.events.send)
323+
connection.on("track_published", self.events.send)
324+
connection.on("track_unpublished", self.events.send)
325+
connection.on("call_ended", self.events.send)
326+
327+
# Start the connection
328+
await connection.__aenter__()
329+
# Re-publish already published tracks in case somebody is already on the call when we joined.
330+
# Otherwise, we won't get the video track from participants joined before us.
331+
await connection.republish_tracks()
332+
self._real_connection = connection
411333

412334
standardize_connection = StreamConnection(connection)
413335
return standardize_connection

uv.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)