77from contextlib import asynccontextmanager , contextmanager
88from pathlib import Path
99from typing import (
10- TYPE_CHECKING ,
1110 Any ,
1211 AsyncIterator ,
13- Dict ,
1412 Iterator ,
15- List ,
1613 Optional ,
1714 TypeGuard ,
1815)
1916from uuid import uuid4
2017
21- import getstream .models
2218from aiortc import VideoStreamTrack
23- from getstream .video .rtc import Call
24- from getstream .video .rtc .pb .stream .video .sfu .models .models_pb2 import TrackType
19+ from getstream .video .rtc import AudioStreamTrack , PcmData
2520from opentelemetry import context as otel_context
2621from opentelemetry import trace
2722from opentelemetry .context import Token
2823from opentelemetry .trace import Tracer , set_span_in_context
2924from opentelemetry .trace .propagation import Context , Span
3025
31- from ..edge import sfu_events
26+ from ..edge import Call , EdgeTransport
3227from ..edge .events import (
3328 AudioReceivedEvent ,
3429 CallEndedEvent ,
3530 TrackAddedEvent ,
3631 TrackRemovedEvent ,
3732)
38- from ..edge .types import OutputAudioTrack , Participant , PcmData , User
33+ from ..edge .types import Connection , Participant , TrackType , User
3934from ..events .manager import EventManager
4035from ..instructions import Instructions
4136from ..llm import events as llm_events
7873from .conversation import Conversation
7974from .transcript_buffer import TranscriptBuffer
8075
81- if TYPE_CHECKING :
82- from vision_agents .plugins .getstream .stream_edge_transport import (
83- StreamConnection ,
84- StreamEdge ,
85- )
86-
8776logger = logging .getLogger (__name__ )
8877
8978tracer : Tracer = trace .get_tracer ("agents" )
@@ -114,15 +103,15 @@ class Agent:
114103 Note: Don't reuse the agent object. Create a new agent object each time.
115104
116105 Dev guidelines
117- - Small methods so its easy to subclass/change behaviour
106+ - Small methods so it's easy to subclass/change behaviour
118107 """
119108
120109 options : AgentOptions
121110
122111 def __init__ (
123112 self ,
124113 # edge network for video & audio
125- edge : "StreamEdge" ,
114+ edge : EdgeTransport ,
126115 # llm, optionally with sts/realtime capabilities
127116 llm : LLM | AudioLLM | VideoLLM ,
128117 # the agent's user info
@@ -137,9 +126,9 @@ def __init__(
137126 # - roboflow/ yolo typically run continuously
138127 # - often combined with API calls to fetch stats etc
139128 # - state from each processor is passed to the LLM
140- processors : Optional [List [Processor ]] = None ,
129+ processors : Optional [list [Processor ]] = None ,
141130 # MCP servers for external tool and resource access
142- mcp_servers : Optional [List [MCPBaseServer ]] = None ,
131+ mcp_servers : Optional [list [MCPBaseServer ]] = None ,
143132 options : Optional [AgentOptions ] = None ,
144133 tracer : Tracer = trace .get_tracer ("agents" ),
145134 profiler : Optional [Profiler ] = None ,
@@ -178,9 +167,7 @@ def __init__(
178167 self .logger = _AgentLoggerAdapter (logger , {"agent_id" : self .agent_user .id })
179168
180169 self .events = EventManager ()
181- self .events .register_events_from_module (getstream .models , "call." )
182170 self .events .register_events_from_module (events )
183- self .events .register_events_from_module (sfu_events )
184171 self .events .register_events_from_module (llm_events )
185172
186173 self .llm = llm
@@ -205,13 +192,13 @@ def __init__(
205192 self .conversation : Optional [Conversation ] = None
206193
207194 # Track pending transcripts for turn-based response triggering
208- self ._pending_user_transcripts : Dict [str , TranscriptBuffer ] = defaultdict (
195+ self ._pending_user_transcripts : dict [str , TranscriptBuffer ] = defaultdict (
209196 TranscriptBuffer
210197 )
211198
212199 # Merge plugin events BEFORE subscribing to any events
213200 for plugin in [stt , tts , turn_detection , llm , edge , profiler ]:
214- if plugin and hasattr ( plugin , "events" ) :
201+ if plugin is not None :
215202 self .logger .debug (f"Register events from plugin { plugin } " )
216203 self .events .merge (plugin .events )
217204
@@ -224,16 +211,16 @@ def __init__(
224211 self .events .subscribe (self ._on_agent_say )
225212
226213 # Track metadata: track_id -> TrackInfo
227- self ._active_video_tracks : Dict [str , TrackInfo ] = {}
228- self ._video_forwarders : List [VideoForwarder ] = []
229- self ._connection : Optional [StreamConnection ] = None
214+ self ._active_video_tracks : dict [str , TrackInfo ] = {}
215+ self ._video_forwarders : list [VideoForwarder ] = []
216+ self ._connection : Optional [Connection ] = None
230217
231218 # Optional local video track override for debugging.
232219 # This track will play instead of any incoming video track.
233220 self ._video_track_override_path : Optional [str | Path ] = None
234221
235222 # the outgoing audio track
236- self ._audio_track : Optional [OutputAudioTrack ] = None
223+ self ._audio_track : Optional [AudioStreamTrack ] = None
237224
238225 # the outgoing video track
239226 self ._video_track : Optional [VideoStreamTrack ] = None
@@ -323,15 +310,23 @@ async def _on_tts_audio_write_to_output(event: TTSAudioEvent):
323310 # listen to video tracks added/removed
324311 @self .edge .events .subscribe
325312 async def on_video_track_added (event : TrackAddedEvent | TrackRemovedEvent ):
326- if event .track_id is None or event .track_type is None or event .user is None :
313+ if (
314+ event .track_id is None
315+ or event .track_type is None
316+ or event .participant is None
317+ ):
327318 return
328319 if isinstance (event , TrackRemovedEvent ):
329320 asyncio .create_task (
330- self ._on_track_removed (event .track_id , event .track_type , event .user )
321+ self ._on_track_removed (
322+ event .track_id , event .track_type , event .participant
323+ )
331324 )
332325 else :
333326 asyncio .create_task (
334- self ._on_track_added (event .track_id , event .track_type , event .user )
327+ self ._on_track_added (
328+ event .track_id , event .track_type , event .participant
329+ )
335330 )
336331
337332 # audio event for the user talking to the AI
@@ -343,7 +338,7 @@ async def on_audio_received(event: AudioReceivedEvent):
343338 await self ._incoming_audio_queue .put (event .pcm_data )
344339
345340 @self .edge .events .subscribe
346- async def on_call_ended (event : CallEndedEvent ):
341+ async def on_call_ended (_ : CallEndedEvent ):
347342 if self ._call_ended_event is not None :
348343 self ._call_ended_event .set ()
349344
@@ -711,8 +706,14 @@ def _start_tracing(self, call: Call) -> None:
711706 self ._context_token = otel_context .attach (self ._root_ctx )
712707
713708 async def _apply (self , function_name : str , * args , ** kwargs ):
714- subclasses = [self .llm , self .stt , self .tts , self .turn_detection , self .edge ]
715- subclasses .extend (self .processors )
709+ subclasses = [
710+ self .llm ,
711+ self .stt ,
712+ self .tts ,
713+ self .turn_detection ,
714+ self .edge ,
715+ * self .processors ,
716+ ]
716717 for subclass in subclasses :
717718 if (
718719 subclass is not None
@@ -858,9 +859,9 @@ async def create_user(self) -> None:
858859
859860 async def create_call (self , call_type : str , call_id : str ) -> Call :
860861 """Shortcut for creating a call/room etc."""
861- call = self .edge .client . video . call ( call_type , call_id )
862- await call . get_or_create ( data = { "created_by_id" : self .agent_user .id })
863-
862+ call = await self .edge .create_call (
863+ call_id = call_id , agent_user_id = self .agent_user .id , call_type = call_type
864+ )
864865 return call
865866
866867 def _on_rtc_reconnect (self ):
@@ -886,13 +887,18 @@ async def _on_agent_say(self, event: events.AgentSayEvent):
886887 start_time = time .time ()
887888
888889 if self .tts is not None :
889- # Call TTS with user metadata
890- user_metadata = {"user_id" : event .user_id }
891- if event .metadata :
892- user_metadata .update (event .metadata )
890+ # Create participant from event
891+ participant = (
892+ Participant (
893+ original = event .metadata or {},
894+ user_id = event .user_id ,
895+ )
896+ if event .user_id
897+ else None
898+ )
893899
894900 sanitized_text = self ._sanitize_text (event .text )
895- await self .tts .send (sanitized_text , user_metadata )
901+ await self .tts .send (sanitized_text , participant )
896902
897903 # Calculate duration
898904 duration_ms = (time .time () - start_time ) * 1000
@@ -928,7 +934,7 @@ async def say(
928934 self ,
929935 text : str ,
930936 user_id : Optional [str ] = None ,
931- metadata : Optional [Dict [str , Any ]] = None ,
937+ metadata : Optional [dict [str , Any ]] = None ,
932938 ):
933939 """
934940 Make the agent say something using TTS.
@@ -1058,21 +1064,17 @@ async def _track_to_video_processors(self, track: TrackInfo):
10581064 )
10591065
10601066 async def _on_track_removed (
1061- self , track_id : str , track_type : int , participant : Participant
1067+ self , track_id : str , track_type : TrackType , participant : Participant
10621068 ):
10631069 # We only process video tracks (camera video or screenshare)
10641070 if track_type not in (
1065- TrackType .TRACK_TYPE_VIDEO ,
1066- TrackType .TRACK_TYPE_SCREEN_SHARE ,
1071+ TrackType .VIDEO ,
1072+ TrackType .SCREEN_SHARE ,
10671073 ):
10681074 return
1069- track_type_name = (
1070- "SCREEN_SHARE"
1071- if track_type == TrackType .TRACK_TYPE_SCREEN_SHARE
1072- else "VIDEO"
1073- )
1075+
10741076 self .logger .info (
1075- f"📺 Track removed: { track_type_name } from { participant .user_id } "
1077+ f"📺 Track removed: { track_type . name } from { participant .user_id } "
10761078 )
10771079
10781080 track = self ._active_video_tracks .pop (track_id , None )
@@ -1081,7 +1083,7 @@ async def _on_track_removed(
10811083 track .track .stop ()
10821084 await self ._on_track_change (track_id )
10831085
1084- async def _on_track_change (self , track_id : str ):
1086+ async def _on_track_change (self , _ : str ):
10851087 # shared logic between track remove and added
10861088 # Select a track. Prioritize screenshare over regular
10871089 # This is the track without processing
@@ -1120,24 +1122,20 @@ async def _on_track_change(self, track_id: str):
11201122 )
11211123
11221124 async def _on_track_added (
1123- self , track_id : str , track_type : int , participant : Participant
1125+ self , track_id : str , track_type : TrackType , participant : Participant
11241126 ):
11251127 # We only process video tracks (camera video or screenshare)
11261128 if track_type not in (
1127- TrackType .TRACK_TYPE_VIDEO ,
1128- TrackType .TRACK_TYPE_SCREEN_SHARE ,
1129+ TrackType .VIDEO ,
1130+ TrackType .SCREEN_SHARE ,
11291131 ):
11301132 return
11311133
1132- track_type_name = (
1133- "SCREEN_SHARE"
1134- if track_type == TrackType .TRACK_TYPE_SCREEN_SHARE
1135- else "VIDEO"
1136- )
11371134 self .logger .info (
1138- f"📺 Track added: { track_type_name } from { participant .user_id } "
1135+ f"📺 Track added: { track_type . name } from { participant .user_id } "
11391136 )
11401137
1138+ track : VideoStreamTrack | None
11411139 if self ._video_track_override_path is not None :
11421140 # If local video track is set, we override all other video tracks with it.
11431141 # We override tracks instead of simply playing one in order to keep the same lifecycle within the call.
@@ -1163,7 +1161,7 @@ async def _on_track_added(
11631161 processor = "" ,
11641162 track = track ,
11651163 participant = participant ,
1166- priority = 1 if track_type == TrackType .TRACK_TYPE_SCREEN_SHARE else 0 ,
1164+ priority = 1 if track_type == TrackType .SCREEN_SHARE else 0 ,
11671165 forwarder = forwarder ,
11681166 )
11691167
@@ -1401,11 +1399,7 @@ def _prepare_rtc(self):
14011399 # Variables are now initialized in __init__
14021400
14031401 if self .publish_audio :
1404- framerate = 48000
1405- stereo = True
1406- self ._audio_track = self .edge .create_audio_track (
1407- framerate = framerate , stereo = stereo
1408- )
1402+ self ._audio_track = self .edge .create_audio_track ()
14091403
14101404 @self .events .subscribe
14111405 async def forward_audio (event : RealtimeAudioOutputEvent ):
@@ -1426,7 +1420,7 @@ async def forward_audio(event: RealtimeAudioOutputEvent):
14261420 )
14271421 self ._active_video_tracks [self ._video_track .id ] = TrackInfo (
14281422 id = self ._video_track .id ,
1429- type = TrackType .TRACK_TYPE_VIDEO ,
1423+ type = TrackType .VIDEO . value ,
14301424 processor = video_publisher .name ,
14311425 track = self ._video_track ,
14321426 participant = None ,
0 commit comments