Skip to content

Commit e919787

Browse files
authored
Decouple vision agents from getstream (#330)
1 parent 633f464 commit e919787

29 files changed

Lines changed: 1070 additions & 825 deletions

File tree

agents-core/vision_agents/core/agents/agents.py

Lines changed: 61 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -7,35 +7,30 @@
77
from contextlib import asynccontextmanager, contextmanager
88
from pathlib import Path
99
from typing import (
10-
TYPE_CHECKING,
1110
Any,
1211
AsyncIterator,
13-
Dict,
1412
Iterator,
15-
List,
1613
Optional,
1714
TypeGuard,
1815
)
1916
from uuid import uuid4
2017

21-
import getstream.models
2218
from aiortc import VideoStreamTrack
23-
from getstream.video.rtc import Call
24-
from getstream.video.rtc.pb.stream.video.sfu.models.models_pb2 import TrackType
19+
from getstream.video.rtc import AudioStreamTrack, PcmData
2520
from opentelemetry import context as otel_context
2621
from opentelemetry import trace
2722
from opentelemetry.context import Token
2823
from opentelemetry.trace import Tracer, set_span_in_context
2924
from opentelemetry.trace.propagation import Context, Span
3025

31-
from ..edge import sfu_events
26+
from ..edge import Call, EdgeTransport
3227
from ..edge.events import (
3328
AudioReceivedEvent,
3429
CallEndedEvent,
3530
TrackAddedEvent,
3631
TrackRemovedEvent,
3732
)
38-
from ..edge.types import OutputAudioTrack, Participant, PcmData, User
33+
from ..edge.types import Connection, Participant, TrackType, User
3934
from ..events.manager import EventManager
4035
from ..instructions import Instructions
4136
from ..llm import events as llm_events
@@ -78,12 +73,6 @@
7873
from .conversation import Conversation
7974
from .transcript_buffer import TranscriptBuffer
8075

81-
if TYPE_CHECKING:
82-
from vision_agents.plugins.getstream.stream_edge_transport import (
83-
StreamConnection,
84-
StreamEdge,
85-
)
86-
8776
logger = logging.getLogger(__name__)
8877

8978
tracer: Tracer = trace.get_tracer("agents")
@@ -114,15 +103,15 @@ class Agent:
114103
Note: Don't reuse the agent object. Create a new agent object each time.
115104
116105
Dev guidelines
117-
- Small methods so its easy to subclass/change behaviour
106+
- Small methods so it's easy to subclass/change behaviour
118107
"""
119108

120109
options: AgentOptions
121110

122111
def __init__(
123112
self,
124113
# edge network for video & audio
125-
edge: "StreamEdge",
114+
edge: EdgeTransport,
126115
# llm, optionally with sts/realtime capabilities
127116
llm: LLM | AudioLLM | VideoLLM,
128117
# the agent's user info
@@ -137,9 +126,9 @@ def __init__(
137126
# - roboflow/ yolo typically run continuously
138127
# - often combined with API calls to fetch stats etc
139128
# - state from each processor is passed to the LLM
140-
processors: Optional[List[Processor]] = None,
129+
processors: Optional[list[Processor]] = None,
141130
# MCP servers for external tool and resource access
142-
mcp_servers: Optional[List[MCPBaseServer]] = None,
131+
mcp_servers: Optional[list[MCPBaseServer]] = None,
143132
options: Optional[AgentOptions] = None,
144133
tracer: Tracer = trace.get_tracer("agents"),
145134
profiler: Optional[Profiler] = None,
@@ -178,9 +167,7 @@ def __init__(
178167
self.logger = _AgentLoggerAdapter(logger, {"agent_id": self.agent_user.id})
179168

180169
self.events = EventManager()
181-
self.events.register_events_from_module(getstream.models, "call.")
182170
self.events.register_events_from_module(events)
183-
self.events.register_events_from_module(sfu_events)
184171
self.events.register_events_from_module(llm_events)
185172

186173
self.llm = llm
@@ -205,13 +192,13 @@ def __init__(
205192
self.conversation: Optional[Conversation] = None
206193

207194
# Track pending transcripts for turn-based response triggering
208-
self._pending_user_transcripts: Dict[str, TranscriptBuffer] = defaultdict(
195+
self._pending_user_transcripts: dict[str, TranscriptBuffer] = defaultdict(
209196
TranscriptBuffer
210197
)
211198

212199
# Merge plugin events BEFORE subscribing to any events
213200
for plugin in [stt, tts, turn_detection, llm, edge, profiler]:
214-
if plugin and hasattr(plugin, "events"):
201+
if plugin is not None:
215202
self.logger.debug(f"Register events from plugin {plugin}")
216203
self.events.merge(plugin.events)
217204

@@ -224,16 +211,16 @@ def __init__(
224211
self.events.subscribe(self._on_agent_say)
225212

226213
# Track metadata: track_id -> TrackInfo
227-
self._active_video_tracks: Dict[str, TrackInfo] = {}
228-
self._video_forwarders: List[VideoForwarder] = []
229-
self._connection: Optional[StreamConnection] = None
214+
self._active_video_tracks: dict[str, TrackInfo] = {}
215+
self._video_forwarders: list[VideoForwarder] = []
216+
self._connection: Optional[Connection] = None
230217

231218
# Optional local video track override for debugging.
232219
# This track will play instead of any incoming video track.
233220
self._video_track_override_path: Optional[str | Path] = None
234221

235222
# the outgoing audio track
236-
self._audio_track: Optional[OutputAudioTrack] = None
223+
self._audio_track: Optional[AudioStreamTrack] = None
237224

238225
# the outgoing video track
239226
self._video_track: Optional[VideoStreamTrack] = None
@@ -323,15 +310,23 @@ async def _on_tts_audio_write_to_output(event: TTSAudioEvent):
323310
# listen to video tracks added/removed
324311
@self.edge.events.subscribe
325312
async def on_video_track_added(event: TrackAddedEvent | TrackRemovedEvent):
326-
if event.track_id is None or event.track_type is None or event.user is None:
313+
if (
314+
event.track_id is None
315+
or event.track_type is None
316+
or event.participant is None
317+
):
327318
return
328319
if isinstance(event, TrackRemovedEvent):
329320
asyncio.create_task(
330-
self._on_track_removed(event.track_id, event.track_type, event.user)
321+
self._on_track_removed(
322+
event.track_id, event.track_type, event.participant
323+
)
331324
)
332325
else:
333326
asyncio.create_task(
334-
self._on_track_added(event.track_id, event.track_type, event.user)
327+
self._on_track_added(
328+
event.track_id, event.track_type, event.participant
329+
)
335330
)
336331

337332
# audio event for the user talking to the AI
@@ -343,7 +338,7 @@ async def on_audio_received(event: AudioReceivedEvent):
343338
await self._incoming_audio_queue.put(event.pcm_data)
344339

345340
@self.edge.events.subscribe
346-
async def on_call_ended(event: CallEndedEvent):
341+
async def on_call_ended(_: CallEndedEvent):
347342
if self._call_ended_event is not None:
348343
self._call_ended_event.set()
349344

@@ -711,8 +706,14 @@ def _start_tracing(self, call: Call) -> None:
711706
self._context_token = otel_context.attach(self._root_ctx)
712707

713708
async def _apply(self, function_name: str, *args, **kwargs):
714-
subclasses = [self.llm, self.stt, self.tts, self.turn_detection, self.edge]
715-
subclasses.extend(self.processors)
709+
subclasses = [
710+
self.llm,
711+
self.stt,
712+
self.tts,
713+
self.turn_detection,
714+
self.edge,
715+
*self.processors,
716+
]
716717
for subclass in subclasses:
717718
if (
718719
subclass is not None
@@ -858,9 +859,9 @@ async def create_user(self) -> None:
858859

859860
async def create_call(self, call_type: str, call_id: str) -> Call:
860861
"""Shortcut for creating a call/room etc."""
861-
call = self.edge.client.video.call(call_type, call_id)
862-
await call.get_or_create(data={"created_by_id": self.agent_user.id})
863-
862+
call = await self.edge.create_call(
863+
call_id=call_id, agent_user_id=self.agent_user.id, call_type=call_type
864+
)
864865
return call
865866

866867
def _on_rtc_reconnect(self):
@@ -886,13 +887,18 @@ async def _on_agent_say(self, event: events.AgentSayEvent):
886887
start_time = time.time()
887888

888889
if self.tts is not None:
889-
# Call TTS with user metadata
890-
user_metadata = {"user_id": event.user_id}
891-
if event.metadata:
892-
user_metadata.update(event.metadata)
890+
# Create participant from event
891+
participant = (
892+
Participant(
893+
original=event.metadata or {},
894+
user_id=event.user_id,
895+
)
896+
if event.user_id
897+
else None
898+
)
893899

894900
sanitized_text = self._sanitize_text(event.text)
895-
await self.tts.send(sanitized_text, user_metadata)
901+
await self.tts.send(sanitized_text, participant)
896902

897903
# Calculate duration
898904
duration_ms = (time.time() - start_time) * 1000
@@ -928,7 +934,7 @@ async def say(
928934
self,
929935
text: str,
930936
user_id: Optional[str] = None,
931-
metadata: Optional[Dict[str, Any]] = None,
937+
metadata: Optional[dict[str, Any]] = None,
932938
):
933939
"""
934940
Make the agent say something using TTS.
@@ -1058,21 +1064,17 @@ async def _track_to_video_processors(self, track: TrackInfo):
10581064
)
10591065

10601066
async def _on_track_removed(
1061-
self, track_id: str, track_type: int, participant: Participant
1067+
self, track_id: str, track_type: TrackType, participant: Participant
10621068
):
10631069
# We only process video tracks (camera video or screenshare)
10641070
if track_type not in (
1065-
TrackType.TRACK_TYPE_VIDEO,
1066-
TrackType.TRACK_TYPE_SCREEN_SHARE,
1071+
TrackType.VIDEO,
1072+
TrackType.SCREEN_SHARE,
10671073
):
10681074
return
1069-
track_type_name = (
1070-
"SCREEN_SHARE"
1071-
if track_type == TrackType.TRACK_TYPE_SCREEN_SHARE
1072-
else "VIDEO"
1073-
)
1075+
10741076
self.logger.info(
1075-
f"📺 Track removed: {track_type_name} from {participant.user_id}"
1077+
f"📺 Track removed: {track_type.name} from {participant.user_id}"
10761078
)
10771079

10781080
track = self._active_video_tracks.pop(track_id, None)
@@ -1081,7 +1083,7 @@ async def _on_track_removed(
10811083
track.track.stop()
10821084
await self._on_track_change(track_id)
10831085

1084-
async def _on_track_change(self, track_id: str):
1086+
async def _on_track_change(self, _: str):
10851087
# shared logic between track remove and added
10861088
# Select a track. Prioritize screenshare over regular
10871089
# This is the track without processing
@@ -1120,24 +1122,20 @@ async def _on_track_change(self, track_id: str):
11201122
)
11211123

11221124
async def _on_track_added(
1123-
self, track_id: str, track_type: int, participant: Participant
1125+
self, track_id: str, track_type: TrackType, participant: Participant
11241126
):
11251127
# We only process video tracks (camera video or screenshare)
11261128
if track_type not in (
1127-
TrackType.TRACK_TYPE_VIDEO,
1128-
TrackType.TRACK_TYPE_SCREEN_SHARE,
1129+
TrackType.VIDEO,
1130+
TrackType.SCREEN_SHARE,
11291131
):
11301132
return
11311133

1132-
track_type_name = (
1133-
"SCREEN_SHARE"
1134-
if track_type == TrackType.TRACK_TYPE_SCREEN_SHARE
1135-
else "VIDEO"
1136-
)
11371134
self.logger.info(
1138-
f"📺 Track added: {track_type_name} from {participant.user_id}"
1135+
f"📺 Track added: {track_type.name} from {participant.user_id}"
11391136
)
11401137

1138+
track: VideoStreamTrack | None
11411139
if self._video_track_override_path is not None:
11421140
# If local video track is set, we override all other video tracks with it.
11431141
# We override tracks instead of simply playing one in order to keep the same lifecycle within the call.
@@ -1163,7 +1161,7 @@ async def _on_track_added(
11631161
processor="",
11641162
track=track,
11651163
participant=participant,
1166-
priority=1 if track_type == TrackType.TRACK_TYPE_SCREEN_SHARE else 0,
1164+
priority=1 if track_type == TrackType.SCREEN_SHARE else 0,
11671165
forwarder=forwarder,
11681166
)
11691167

@@ -1401,11 +1399,7 @@ def _prepare_rtc(self):
14011399
# Variables are now initialized in __init__
14021400

14031401
if self.publish_audio:
1404-
framerate = 48000
1405-
stereo = True
1406-
self._audio_track = self.edge.create_audio_track(
1407-
framerate=framerate, stereo=stereo
1408-
)
1402+
self._audio_track = self.edge.create_audio_track()
14091403

14101404
@self.events.subscribe
14111405
async def forward_audio(event: RealtimeAudioOutputEvent):
@@ -1426,7 +1420,7 @@ async def forward_audio(event: RealtimeAudioOutputEvent):
14261420
)
14271421
self._active_video_tracks[self._video_track.id] = TrackInfo(
14281422
id=self._video_track.id,
1429-
type=TrackType.TRACK_TYPE_VIDEO,
1423+
type=TrackType.VIDEO.value,
14301424
processor=video_publisher.name,
14311425
track=self._video_track,
14321426
participant=None,
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
"""Stream Edge Transport Package.
1+
"""Edge Transport Package.
22
3-
This package provides edge transport abstraction for Stream Agents.
3+
This package provides edge transport abstraction for vision agents.
44
"""
55

6+
from vision_agents.core.edge.call import Call
67
from vision_agents.core.edge.edge_transport import EdgeTransport
7-
from vision_agents.core.edge import sfu_events
88

9-
__all__ = ["EdgeTransport", "sfu_events"]
9+
__all__ = ["Call", "EdgeTransport"]
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from typing import Protocol
2+
3+
4+
class Call(Protocol):
5+
"""Protocol for call/room abstraction.
6+
7+
Any EdgeTransport implementation must return objects conforming to this protocol
8+
from their create_call or join methods.
9+
"""
10+
11+
@property
12+
def id(self) -> str:
13+
"""The unique identifier of the call."""
14+
...

0 commit comments

Comments
 (0)