Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions agents-core/vision_agents/core/agents/agent_session.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import asyncio
import contextvars
import typing

from vision_agents.core.agents import Agent
if typing.TYPE_CHECKING:
from .agents import Agent


class AgentSessionContextManager:
Expand All @@ -27,7 +29,7 @@ class AgentSessionContextManager:
returned by the edge transport (kept open during the context).
"""

def __init__(self, agent: Agent, connection_cm=None):
def __init__(self, agent: "Agent", connection_cm=None):
self.agent = agent
self._connection_cm = connection_cm

Expand Down
15 changes: 1 addition & 14 deletions agents-core/vision_agents/core/agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,12 @@
from opentelemetry import trace, context as otel_context
from opentelemetry.trace import Tracer
from opentelemetry.context import Token
from .agent_session import AgentSessionContextManager


if TYPE_CHECKING:
from vision_agents.plugins.getstream.stream_edge_transport import StreamEdge

from .agent_session import AgentSessionContextManager

logger = logging.getLogger(__name__)

tracer: Tracer = trace.get_tracer("agents")
Expand Down Expand Up @@ -546,18 +545,6 @@ async def join(
with self.span("edge.publish_tracks"):
await self.edge.publish_tracks(audio_track, video_track)

connection._connection._coordinator_ws_client.on_wildcard(
"*",
lambda event_name, event: self.events.send(event),
)

connection._connection._ws_client.on_wildcard(
"*",
lambda event_name, event: self.events.send(event),
)

from .agent_session import AgentSessionContextManager

# wait for conversation creation coro at the very end of the join flow
self.conversation = await create_conversation_coro
# Provide conversation to the LLM so it can access the chat history.
Expand Down
2 changes: 1 addition & 1 deletion plugins/getstream/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ requires-python = ">=3.10"
license = "MIT"
dependencies = [
"vision-agents",
"getstream[webrtc,telemetry]>=2.5.0",
"getstream[webrtc,telemetry]>=2.5.18",
]

[project.urls]
Expand Down
9 changes: 7 additions & 2 deletions plugins/getstream/tests/test_stream_conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
from unittest.mock import Mock, AsyncMock
from dotenv import load_dotenv

from getstream.models import MessageRequest, ChannelInput, MessagePaginationParams
from getstream.models import (
Message,
MessageRequest,
ChannelInput,
MessagePaginationParams,
)
from getstream import AsyncStream

from vision_agents.plugins.getstream.stream_conversation import StreamConversation
Expand Down Expand Up @@ -48,7 +53,7 @@ def mock_channel(self):
def stream_conversation(self, mock_channel):
"""Create a StreamConversation instance with mocked dependencies."""
instructions = "You are a helpful assistant."
messages = []
messages: list[Message] = []
conversation = StreamConversation(
instructions=instructions, messages=messages, channel=mock_channel
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ def __init__(self, connection: ConnectionManager):
def participants(self) -> ParticipantsState:
return self._connection.participants_state

async def close(self):
async def close(self, timeout: float = 2.0):
try:
await asyncio.wait_for(self._connection.leave(), timeout=2.0)
await asyncio.wait_for(self._connection.leave(), timeout=timeout)
except asyncio.TimeoutError:
logger.warning("Connection leave timed out during close")
except RuntimeError as e:
Expand Down Expand Up @@ -83,11 +83,19 @@ def __init__(self, **kwargs):
# track_id -> (user_id, session_id, webrtc_type_string)
self._pending_tracks: dict = {}

self._real_connection: Optional[ConnectionManager] = None

# Register event handlers
self.events.subscribe(self._on_track_published)
self.events.subscribe(self._on_track_removed)
self.events.subscribe(self._on_call_ended)

@property
def _connection(self) -> ConnectionManager:
if self._real_connection is None:
raise ValueError("Edge connection is not set")
return self._real_connection

def _get_webrtc_kind(self, track_type_int: int) -> str:
"""Get the expected WebRTC kind (audio/video) for a SFU track type."""
# Map SFU track types to WebRTC kinds
Expand All @@ -105,96 +113,6 @@ def _get_webrtc_kind(self, track_type_int: int) -> str:
# Default to video for unknown types
return "video"

async def _subscribe_to_existing_tracks(
self, connection: ConnectionManager
) -> None:
"""Subscribe to tracks from participants who joined before the agent."""
from vision_agents.core.edge.sfu_events import Participant as SfuParticipant

participants = connection.participants_state.get_participants()
subscription_manager = connection._subscription_manager
tracks_to_subscribe = []

for participant in participants:
if participant.user_id == self.agent_user_id:
continue

for track_type_int in participant.published_tracks:
# Create a mock event for the subscription manager
class MockTrackPublishedEvent:
def __init__(self, p, track_type):
self.user_id = p.user_id
self.session_id = p.session_id
self.type = track_type
self.participant = p

mock_event = MockTrackPublishedEvent(participant, track_type_int)

try:
await subscription_manager.handle_track_published(mock_event)
tracks_to_subscribe.append((participant, track_type_int))
except Exception as e:
logger.error(f"Failed to subscribe to existing track: {e}")

# Poll for WebRTC tracks to arrive after subscription
for participant, track_type_int in tracks_to_subscribe:
expected_kind = self._get_webrtc_kind(track_type_int)
track_key = (
participant.user_id,
participant.session_id,
track_type_int,
)

if track_key in self._track_map:
continue

# Poll for WebRTC track ID with timeout (same pattern as _on_track_published)
track_id = None
timeout = 10.0
poll_interval = 0.01
elapsed = 0.0

while elapsed < timeout:
for tid, (pending_user, pending_session, pending_kind) in list(
self._pending_tracks.items()
):
if (
pending_user == participant.user_id
and pending_session == participant.session_id
and pending_kind == expected_kind
):
track_id = tid
del self._pending_tracks[tid]
break

if track_id:
break

await asyncio.sleep(poll_interval)
elapsed += poll_interval

if track_id:
self._track_map[track_key] = {
"track_id": track_id,
"published": True,
}
sfu_participant = SfuParticipant.from_proto(participant)

self.events.send(
events.TrackAddedEvent(
plugin_name="getstream",
track_id=track_id,
track_type=track_type_int,
user=sfu_participant,
participant=sfu_participant,
)
)
else:
logger.warning(
f"No pending track for existing participant: "
f"user={participant.user_id}, type={TrackType.Name(track_type_int)}"
)

async def _on_track_published(self, event: sfu_events.TrackPublishedEvent):
"""Handle track published events from SFU - spawn TrackAddedEvent with correct type."""
if not event.payload:
Expand Down Expand Up @@ -366,13 +284,10 @@ async def join(self, agent: "Agent", call: Call) -> StreamConnection:
This function
- initializes the chat channel
- has the agent.agent_user join the call
- connect incoming audio/video to the agent
- connects incoming audio/video to the agent
- connecting agent's outgoing audio/video to the call

TODO:
- process track flow

"""

# Traditional mode - use WebRTC connection
# Configure subscription for audio and video
subscription_config = SubscriptionConfig(
Expand Down Expand Up @@ -401,13 +316,20 @@ async def on_audio_received(pcm: PcmData):
)
)

await (
connection.__aenter__()
) # TODO: weird API? there should be a manual version
self._connection = connection

# Subscribe to tracks from participants who joined before the agent
await self._subscribe_to_existing_tracks(connection)
# Re-emit certain events from the underlying RTC stack
# for the Agent to subscribe.
connection.on("participant_joined", self.events.send)
connection.on("participant_left", self.events.send)
connection.on("track_published", self.events.send)
connection.on("track_unpublished", self.events.send)
connection.on("call_ended", self.events.send)

# Start the connection
await connection.__aenter__()
# Re-publish already published tracks in case somebody is already on the call when we joined.
# Otherwise, we won't get the video track from participants joined before us.
await connection.republish_tracks()
self._real_connection = connection

standardize_connection = StreamConnection(connection)
return standardize_connection
Expand Down
8 changes: 4 additions & 4 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading