11import asyncio
22import logging
33import weakref
4- from typing import TYPE_CHECKING , Callable , Coroutine , Optional
5- from uuid import uuid4
4+ from asyncio import Task
5+ from dataclasses import dataclass , field
6+ from datetime import datetime , timezone
7+ from typing import (
8+ TYPE_CHECKING ,
9+ Any ,
10+ AsyncIterator ,
11+ Callable ,
12+ Coroutine ,
13+ Optional ,
14+ cast ,
15+ )
616
717from vision_agents .core .utils .utils import await_or_run , cancel_and_wait
818from vision_agents .core .warmup import Warmable , WarmupCache
1323logger = logging .getLogger (__name__ )
1424
1525
16- class AgentNotFoundError (Exception ): ...
26+ class SessionNotFoundError (Exception ): ...
1727
1828
29+ @dataclass
30+ class AgentSession :
31+ agent : "Agent"
32+ call_id : str
33+ started_at : datetime
34+ task : asyncio .Task
35+ config : dict = field (default_factory = dict )
36+
37+ @property
38+ def finished (self ) -> bool :
39+ return self .task .done ()
40+
41+ @property
42+ def id (self ) -> str :
43+ return self .agent .id
44+
45+ async def wait (self ):
46+ """
47+ Wait for the session task to finish running.
48+ """
49+ return await self .task
50+
51+
52+ # TODO: Rename to `AgentManager`.
1953class AgentLauncher :
2054 """
2155 Agent launcher that handles warmup and lifecycle management.
@@ -26,8 +60,8 @@ class AgentLauncher:
2660
2761 def __init__ (
2862 self ,
29- create_agent : Callable [..., "Agent" | Coroutine ["Agent" , ..., ... ]],
30- join_call : Callable [["Agent" , ..., ... ], Coroutine [ None , ..., ...] ],
63+ create_agent : Callable [..., "Agent" | Coroutine [Any , Any , "Agent" ]],
64+ join_call : Callable [["Agent" , str , str ], Coroutine ],
3165 agent_idle_timeout : float = 60.0 ,
3266 agent_idle_cleanup_interval : float = 5.0 ,
3367 ):
@@ -59,7 +93,7 @@ def __init__(
5993 self ._running = False
6094 self ._cleanup_task : Optional [asyncio .Task ] = None
6195 self ._warmed_up : bool = False
62- self ._call_tasks : dict [str , asyncio . Task ] = {}
96+ self ._sessions : dict [str , AgentSession ] = {}
6397
6498 async def start (self ):
6599 if self ._running :
@@ -76,10 +110,10 @@ async def stop(self):
76110 if self ._cleanup_task :
77111 await cancel_and_wait (self ._cleanup_task )
78112
79- coros = [cancel_and_wait (t ) for t in self ._call_tasks .values ()]
80- async for result in asyncio .as_completed (coros ):
113+ coros = [cancel_and_wait (s . task ) for s in self ._sessions .values ()]
114+ async for result in cast ( AsyncIterator [ Task ], asyncio .as_completed (coros ) ):
81115 if result .done () and not result .cancelled () and result .exception ():
82- logger .error (f"Failed to cancel the call task: { result .exception ()} " )
116+ logger .error (f"Failed to cancel the agent task: { result .exception ()} " )
83117
84118 logger .debug ("AgentLauncher stopped" )
85119
@@ -108,6 +142,10 @@ async def warmup(self) -> None:
108142 def warmed_up (self ) -> bool :
109143 return self ._warmed_up
110144
145+ @property
146+ def running (self ) -> bool :
147+ return self ._running
148+
111149 async def launch (self , ** kwargs ) -> "Agent" :
112150 """
113151 Launch the agent.
@@ -123,33 +161,51 @@ async def launch(self, **kwargs) -> "Agent":
123161 self ._active_agents .add (agent )
124162 return agent
125163
126- # TODO: Typing
127- async def join (self , call_id : str , call_type : str = "default" ):
128- agent : "Agent" = await await_or_run (self ._create_agent )
129- await self ._warmup_agent (agent )
130- self ._active_agents .add (agent )
164+ async def start_session (
165+ self ,
166+ call_id : str ,
167+ call_type : str = "default" ,
168+ video_track_override_path : Optional [str ] = None ,
169+ ) -> AgentSession :
170+ agent : "Agent" = await self .launch ()
171+ if video_track_override_path :
172+ agent .set_video_track_override_path (video_track_override_path )
131173
132- agent_id = str (uuid4 ())
133174 task = asyncio .create_task (
134- self ._join_call (agent , call_type , call_id ), name = f"agent-{ agent_id } "
175+ self ._join_call (agent , call_type , call_id ), name = f"agent-{ agent . id } "
135176 )
136- self ._call_tasks [agent_id ] = task
137177
138- # Remove the task reference when it's done
139- task .add_done_callback (
140- lambda t , agent_id_ = agent_id : self ._call_tasks .pop (agent_id_ , None )
178+ # Remove the session when the task is done
179+ def _done_cb (_ , agent_id_ = agent .id ):
180+ self ._sessions .pop (agent_id_ , None )
181+
182+ task .add_done_callback (_done_cb )
183+ session = AgentSession (
184+ agent = agent ,
185+ task = task ,
186+ started_at = datetime .now (timezone .utc ),
187+ call_id = call_id ,
141188 )
142- return agent_id
189+ self ._sessions [agent .id ] = session
190+ return session
143191
144- async def close_agent (self , agent_id : str , wait : bool = False ) -> None :
145- task = self ._call_tasks .pop (agent_id , None )
146- if task is None :
147- raise AgentNotFoundError (f"Agent with id { agent_id } not found" )
192+ async def close_session (self , session_id : str , wait : bool = False ) -> None :
193+ # TODO: Test
194+ session = self ._sessions .pop (session_id , None )
195+ if session is None :
196+ raise SessionNotFoundError (f"Session with id { session_id } not found" )
148197
149198 if wait :
150- await cancel_and_wait (task )
199+ await cancel_and_wait (session . task )
151200 else :
152- task .cancel ()
201+ session .task .cancel ()
202+
203+ def get_session (self , session_id : str ) -> AgentSession :
204+ # TODO: Test
205+ session = self ._sessions .get (session_id )
206+ if session is None :
207+ raise SessionNotFoundError (f"Session with id { session_id } not found" )
208+ return session
153209
154210 async def _warmup_agent (self , agent : "Agent" ) -> None :
155211 """
0 commit comments