11import asyncio
22import logging
33import weakref
4- from typing import TYPE_CHECKING , Awaitable , Callable , Optional
4+ from asyncio import Task
5+ from dataclasses import dataclass , field
6+ from datetime import datetime , timezone
7+ from typing import (
8+ TYPE_CHECKING ,
9+ Any ,
10+ AsyncIterator ,
11+ Callable ,
12+ Coroutine ,
13+ Optional ,
14+ cast ,
15+ )
516
617from vision_agents .core .utils .utils import await_or_run , cancel_and_wait
718from vision_agents .core .warmup import Warmable , WarmupCache
1223logger = logging .getLogger (__name__ )
1324
1425
26+ @dataclass
27+ class AgentSession :
28+ agent : "Agent"
29+ call_id : str
30+ started_at : datetime
31+ task : asyncio .Task
32+ config : dict = field (default_factory = dict )
33+ created_by : Optional [Any ] = None
34+
35+ @property
36+ def finished (self ) -> bool :
37+ return self .task .done ()
38+
39+ @property
40+ def id (self ) -> str :
41+ return self .agent .id
42+
43+ async def wait (self ):
44+ """
45+ Wait for the session task to finish running.
46+ """
47+ return await self .task
48+
49+
50+ # TODO: Rename to `AgentManager`.
1551class AgentLauncher :
1652 """
1753 Agent launcher that handles warmup and lifecycle management.
@@ -22,8 +58,8 @@ class AgentLauncher:
2258
2359 def __init__ (
2460 self ,
25- create_agent : Callable [..., "Agent" | Awaitable [ "Agent" ]],
26- join_call : Callable [..., None | Awaitable [ None ]] | None = None ,
61+ create_agent : Callable [..., "Agent" | Coroutine [ Any , Any , "Agent" ]],
62+ join_call : Callable [[ "Agent" , str , str ], Coroutine ] ,
2763 agent_idle_timeout : float = 60.0 ,
2864 agent_idle_cleanup_interval : float = 5.0 ,
2965 ):
@@ -37,8 +73,8 @@ def __init__(
3773 `0` means idle agents won't leave the call until it's ended.
3874
3975 """
40- self .create_agent = create_agent
41- self .join_call = join_call
76+ self ._create_agent = create_agent
77+ self ._join_call = join_call
4278 self ._warmup_lock = asyncio .Lock ()
4379 self ._warmup_cache = WarmupCache ()
4480
@@ -55,6 +91,7 @@ def __init__(
5591 self ._running = False
5692 self ._cleanup_task : Optional [asyncio .Task ] = None
5793 self ._warmed_up : bool = False
94+ self ._sessions : dict [str , AgentSession ] = {}
5895
5996 async def start (self ):
6097 if self ._running :
@@ -70,6 +107,12 @@ async def stop(self):
70107 self ._running = False
71108 if self ._cleanup_task :
72109 await cancel_and_wait (self ._cleanup_task )
110+
111+ coros = [cancel_and_wait (s .task ) for s in self ._sessions .values ()]
112+ async for result in cast (AsyncIterator [Task ], asyncio .as_completed (coros )):
113+ if result .done () and not result .cancelled () and result .exception ():
114+ logger .error (f"Failed to cancel the agent task: { result .exception ()} " )
115+
73116 logger .debug ("AgentLauncher stopped" )
74117
75118 async def warmup (self ) -> None :
@@ -86,13 +129,25 @@ async def warmup(self) -> None:
86129 logger .info ("Creating agent..." )
87130
88131 # Create a dry-run Agent instance and warmup its components for the first time.
89- agent : "Agent" = await await_or_run (self .create_agent )
132+ agent : "Agent" = await await_or_run (self ._create_agent )
90133 logger .info ("Warming up agent components..." )
91134 await self ._warmup_agent (agent )
92135 self ._warmed_up = True
93136
94137 logger .info ("Agent warmup completed" )
95138
139+ @property
140+ def warmed_up (self ) -> bool :
141+ return self ._warmed_up
142+
143+ @property
144+ def running (self ) -> bool :
145+ return self ._running
146+
147+ @property
148+ def ready (self ) -> bool :
149+ return self .warmed_up and self .running
150+
96151 async def launch (self , ** kwargs ) -> "Agent" :
97152 """
98153 Launch the agent.
@@ -103,11 +158,70 @@ async def launch(self, **kwargs) -> "Agent":
103158 Returns:
104159 The Agent instance
105160 """
106- agent : "Agent" = await await_or_run (self .create_agent , ** kwargs )
161+ agent : "Agent" = await await_or_run (self ._create_agent , ** kwargs )
107162 await self ._warmup_agent (agent )
108163 self ._active_agents .add (agent )
109164 return agent
110165
166+ async def start_session (
167+ self ,
168+ call_id : str ,
169+ call_type : str = "default" ,
170+ created_by : Optional [Any ] = None ,
171+ video_track_override_path : Optional [str ] = None ,
172+ ) -> AgentSession :
173+ agent : "Agent" = await self .launch ()
174+ if video_track_override_path :
175+ agent .set_video_track_override_path (video_track_override_path )
176+
177+ task = asyncio .create_task (
178+ self ._join_call (agent , call_type , call_id ), name = f"agent-{ agent .id } "
179+ )
180+
181+ # Remove the session when the task is done
182+ def _done_cb (_ , agent_id_ = agent .id ):
183+ self ._sessions .pop (agent_id_ , None )
184+
185+ task .add_done_callback (_done_cb )
186+ session = AgentSession (
187+ agent = agent ,
188+ task = task ,
189+ started_at = datetime .now (timezone .utc ),
190+ call_id = call_id ,
191+ created_by = created_by ,
192+ )
193+ self ._sessions [agent .id ] = session
194+ logger .info (f"Start agent session with id { session .id } " )
195+ return session
196+
197+ async def close_session (self , session_id : str , wait : bool = False ) -> bool :
198+ """
199+ Close session with id `session_id`.
200+ Returns `True` if session was found and closed, `False` otherwise.
201+
202+ Args:
203+ session_id: session id
204+ wait: when True, wait for the underlying agent to finish.
205+ Otherwise, just cancel the task and return.
206+
207+ Returns:
208+ `True` if session was found and closed, `False` otherwise.
209+ """
210+ session = self ._sessions .pop (session_id , None )
211+ if session is None :
212+ # The session is either closed or doesn't exist, exit early
213+ return False
214+
215+ logger .info (f"Closing agent session with id { session .id } " )
216+ if wait :
217+ await cancel_and_wait (session .task )
218+ else :
219+ session .task .cancel ()
220+ return True
221+
222+ def get_session (self , session_id : str ) -> Optional [AgentSession ]:
223+ return self ._sessions .get (session_id )
224+
111225 async def _warmup_agent (self , agent : "Agent" ) -> None :
112226 """
113227 Go over the Agent's dependencies and trigger `.warmup()` on them.
0 commit comments