Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 61 additions & 67 deletions agents-core/vision_agents/core/agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,35 +7,30 @@
from contextlib import asynccontextmanager, contextmanager
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Dict,
Iterator,
List,
Optional,
TypeGuard,
)
from uuid import uuid4

import getstream.models
from aiortc import VideoStreamTrack
from getstream.video.rtc import Call
from getstream.video.rtc.pb.stream.video.sfu.models.models_pb2 import TrackType
from getstream.video.rtc import AudioStreamTrack, PcmData
from opentelemetry import context as otel_context
from opentelemetry import trace
from opentelemetry.context import Token
from opentelemetry.trace import Tracer, set_span_in_context
from opentelemetry.trace.propagation import Context, Span

from ..edge import sfu_events
from ..edge import Call, EdgeTransport
from ..edge.events import (
AudioReceivedEvent,
CallEndedEvent,
TrackAddedEvent,
TrackRemovedEvent,
)
from ..edge.types import OutputAudioTrack, Participant, PcmData, User
from ..edge.types import Connection, Participant, TrackType, User
from ..events.manager import EventManager
from ..instructions import Instructions
from ..llm import events as llm_events
Expand Down Expand Up @@ -78,12 +73,6 @@
from .conversation import Conversation
from .transcript_buffer import TranscriptBuffer

if TYPE_CHECKING:
from vision_agents.plugins.getstream.stream_edge_transport import (
StreamConnection,
StreamEdge,
)

logger = logging.getLogger(__name__)

tracer: Tracer = trace.get_tracer("agents")
Expand Down Expand Up @@ -114,15 +103,15 @@ class Agent:
Note: Don't reuse the agent object. Create a new agent object each time.

Dev guidelines
- Small methods so its easy to subclass/change behaviour
- Small methods so it's easy to subclass/change behaviour
"""

options: AgentOptions

def __init__(
self,
# edge network for video & audio
edge: "StreamEdge",
edge: EdgeTransport,
# llm, optionally with sts/realtime capabilities
llm: LLM | AudioLLM | VideoLLM,
# the agent's user info
Expand All @@ -137,9 +126,9 @@ def __init__(
# - roboflow/ yolo typically run continuously
# - often combined with API calls to fetch stats etc
# - state from each processor is passed to the LLM
processors: Optional[List[Processor]] = None,
processors: Optional[list[Processor]] = None,
# MCP servers for external tool and resource access
mcp_servers: Optional[List[MCPBaseServer]] = None,
mcp_servers: Optional[list[MCPBaseServer]] = None,
options: Optional[AgentOptions] = None,
tracer: Tracer = trace.get_tracer("agents"),
profiler: Optional[Profiler] = None,
Expand Down Expand Up @@ -178,9 +167,7 @@ def __init__(
self.logger = _AgentLoggerAdapter(logger, {"agent_id": self.agent_user.id})

self.events = EventManager()
self.events.register_events_from_module(getstream.models, "call.")
self.events.register_events_from_module(events)
self.events.register_events_from_module(sfu_events)
self.events.register_events_from_module(llm_events)

self.llm = llm
Expand All @@ -205,13 +192,13 @@ def __init__(
self.conversation: Optional[Conversation] = None

# Track pending transcripts for turn-based response triggering
self._pending_user_transcripts: Dict[str, TranscriptBuffer] = defaultdict(
self._pending_user_transcripts: dict[str, TranscriptBuffer] = defaultdict(
TranscriptBuffer
)

# Merge plugin events BEFORE subscribing to any events
for plugin in [stt, tts, turn_detection, llm, edge, profiler]:
if plugin and hasattr(plugin, "events"):
if plugin is not None:
self.logger.debug(f"Register events from plugin {plugin}")
self.events.merge(plugin.events)

Expand All @@ -224,16 +211,16 @@ def __init__(
self.events.subscribe(self._on_agent_say)

# Track metadata: track_id -> TrackInfo
self._active_video_tracks: Dict[str, TrackInfo] = {}
self._video_forwarders: List[VideoForwarder] = []
self._connection: Optional[StreamConnection] = None
self._active_video_tracks: dict[str, TrackInfo] = {}
self._video_forwarders: list[VideoForwarder] = []
self._connection: Optional[Connection] = None

# Optional local video track override for debugging.
# This track will play instead of any incoming video track.
self._video_track_override_path: Optional[str | Path] = None

# the outgoing audio track
self._audio_track: Optional[OutputAudioTrack] = None
self._audio_track: Optional[AudioStreamTrack] = None

# the outgoing video track
self._video_track: Optional[VideoStreamTrack] = None
Expand Down Expand Up @@ -323,15 +310,23 @@ async def _on_tts_audio_write_to_output(event: TTSAudioEvent):
# listen to video tracks added/removed
@self.edge.events.subscribe
async def on_video_track_added(event: TrackAddedEvent | TrackRemovedEvent):
if event.track_id is None or event.track_type is None or event.user is None:
if (
event.track_id is None
or event.track_type is None
or event.participant is None
):
return
if isinstance(event, TrackRemovedEvent):
asyncio.create_task(
self._on_track_removed(event.track_id, event.track_type, event.user)
self._on_track_removed(
event.track_id, event.track_type, event.participant
)
)
else:
asyncio.create_task(
self._on_track_added(event.track_id, event.track_type, event.user)
self._on_track_added(
event.track_id, event.track_type, event.participant
)
)

# audio event for the user talking to the AI
Expand All @@ -343,7 +338,7 @@ async def on_audio_received(event: AudioReceivedEvent):
await self._incoming_audio_queue.put(event.pcm_data)

@self.edge.events.subscribe
async def on_call_ended(event: CallEndedEvent):
async def on_call_ended(_: CallEndedEvent):
if self._call_ended_event is not None:
self._call_ended_event.set()

Expand Down Expand Up @@ -711,8 +706,14 @@ def _start_tracing(self, call: Call) -> None:
self._context_token = otel_context.attach(self._root_ctx)

async def _apply(self, function_name: str, *args, **kwargs):
subclasses = [self.llm, self.stt, self.tts, self.turn_detection, self.edge]
subclasses.extend(self.processors)
subclasses = [
self.llm,
self.stt,
self.tts,
self.turn_detection,
self.edge,
*self.processors,
]
for subclass in subclasses:
if (
subclass is not None
Expand Down Expand Up @@ -858,9 +859,9 @@ async def create_user(self) -> None:

async def create_call(self, call_type: str, call_id: str) -> Call:
"""Shortcut for creating a call/room etc."""
call = self.edge.client.video.call(call_type, call_id)
await call.get_or_create(data={"created_by_id": self.agent_user.id})

call = await self.edge.create_call(
call_id=call_id, agent_user_id=self.agent_user.id, call_type=call_type
)
return call

def _on_rtc_reconnect(self):
Expand All @@ -886,13 +887,18 @@ async def _on_agent_say(self, event: events.AgentSayEvent):
start_time = time.time()

if self.tts is not None:
# Call TTS with user metadata
user_metadata = {"user_id": event.user_id}
if event.metadata:
user_metadata.update(event.metadata)
# Create participant from event
participant = (
Participant(
original=event.metadata or {},
user_id=event.user_id,
)
if event.user_id
else None
)

sanitized_text = self._sanitize_text(event.text)
await self.tts.send(sanitized_text, user_metadata)
await self.tts.send(sanitized_text, participant)

# Calculate duration
duration_ms = (time.time() - start_time) * 1000
Expand Down Expand Up @@ -928,7 +934,7 @@ async def say(
self,
text: str,
user_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
metadata: Optional[dict[str, Any]] = None,
):
"""
Make the agent say something using TTS.
Expand Down Expand Up @@ -1058,21 +1064,17 @@ async def _track_to_video_processors(self, track: TrackInfo):
)

async def _on_track_removed(
self, track_id: str, track_type: int, participant: Participant
self, track_id: str, track_type: TrackType, participant: Participant
):
# We only process video tracks (camera video or screenshare)
if track_type not in (
TrackType.TRACK_TYPE_VIDEO,
TrackType.TRACK_TYPE_SCREEN_SHARE,
TrackType.VIDEO,
TrackType.SCREEN_SHARE,
):
return
track_type_name = (
"SCREEN_SHARE"
if track_type == TrackType.TRACK_TYPE_SCREEN_SHARE
else "VIDEO"
)

self.logger.info(
f"📺 Track removed: {track_type_name} from {participant.user_id}"
f"📺 Track removed: {track_type.name} from {participant.user_id}"
)

track = self._active_video_tracks.pop(track_id, None)
Expand All @@ -1081,7 +1083,7 @@ async def _on_track_removed(
track.track.stop()
await self._on_track_change(track_id)

async def _on_track_change(self, track_id: str):
async def _on_track_change(self, _: str):
# shared logic between track remove and added
# Select a track. Prioritize screenshare over regular
# This is the track without processing
Expand Down Expand Up @@ -1120,24 +1122,20 @@ async def _on_track_change(self, track_id: str):
)

async def _on_track_added(
self, track_id: str, track_type: int, participant: Participant
self, track_id: str, track_type: TrackType, participant: Participant
):
# We only process video tracks (camera video or screenshare)
if track_type not in (
TrackType.TRACK_TYPE_VIDEO,
TrackType.TRACK_TYPE_SCREEN_SHARE,
TrackType.VIDEO,
TrackType.SCREEN_SHARE,
):
return

track_type_name = (
"SCREEN_SHARE"
if track_type == TrackType.TRACK_TYPE_SCREEN_SHARE
else "VIDEO"
)
self.logger.info(
f"📺 Track added: {track_type_name} from {participant.user_id}"
f"📺 Track added: {track_type.name} from {participant.user_id}"
)

track: VideoStreamTrack | None
if self._video_track_override_path is not None:
# If local video track is set, we override all other video tracks with it.
# We override tracks instead of simply playing one in order to keep the same lifecycle within the call.
Expand All @@ -1163,7 +1161,7 @@ async def _on_track_added(
processor="",
track=track,
participant=participant,
priority=1 if track_type == TrackType.TRACK_TYPE_SCREEN_SHARE else 0,
priority=1 if track_type == TrackType.SCREEN_SHARE else 0,
forwarder=forwarder,
)

Expand Down Expand Up @@ -1401,11 +1399,7 @@ def _prepare_rtc(self):
# Variables are now initialized in __init__

if self.publish_audio:
framerate = 48000
stereo = True
self._audio_track = self.edge.create_audio_track(
framerate=framerate, stereo=stereo
)
self._audio_track = self.edge.create_audio_track()

@self.events.subscribe
async def forward_audio(event: RealtimeAudioOutputEvent):
Expand All @@ -1426,7 +1420,7 @@ async def forward_audio(event: RealtimeAudioOutputEvent):
)
self._active_video_tracks[self._video_track.id] = TrackInfo(
id=self._video_track.id,
type=TrackType.TRACK_TYPE_VIDEO,
type=TrackType.VIDEO.value,
processor=video_publisher.name,
track=self._video_track,
participant=None,
Expand Down
8 changes: 4 additions & 4 deletions agents-core/vision_agents/core/edge/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""Stream Edge Transport Package.
"""Edge Transport Package.

This package provides edge transport abstraction for Stream Agents.
This package provides edge transport abstraction for vision agents.
"""

from vision_agents.core.edge.call import Call
from vision_agents.core.edge.edge_transport import EdgeTransport
from vision_agents.core.edge import sfu_events

__all__ = ["EdgeTransport", "sfu_events"]
__all__ = ["Call", "EdgeTransport"]
14 changes: 14 additions & 0 deletions agents-core/vision_agents/core/edge/call.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import Protocol


class Call(Protocol):
"""Protocol for call/room abstraction.

Any EdgeTransport implementation must return objects conforming to this protocol
from their create_call or join methods.
"""

@property
def id(self) -> str:
"""The unique identifier of the call."""
...
Loading
Loading