diff --git a/plugins/aws/README.md b/plugins/aws/README.md index d24761d8f..e2a11e5d9 100644 --- a/plugins/aws/README.md +++ b/plugins/aws/README.md @@ -1,6 +1,6 @@ # AWS Plugin for Vision Agents -AWS (Bedrock) integration for Vision Agents framework with support for standard LLM, realtime with Nova Sonic, and text-to-speech with automatic session resumption. +AWS integration for Vision Agents framework with support for standard LLM (Bedrock), realtime with Nova Sonic, text-to-speech (Polly), and streaming speech-to-text (Transcribe). ## Installation @@ -80,17 +80,18 @@ See `example/aws_realtime_nova_example.py` for a complete example. ### Text-to-Speech (TTS) -AWS Polly TTS is available for converting text to speech: +AWS Polly synthesises speech from text and streams the resulting audio. Supports both standard and neural engines, plain-text or SSML input, and Polly lexicons for pronunciation overrides. ```python from vision_agents.plugins import aws tts = aws.TTS( region_name="us-east-1", - voice_id="Joanna", # AWS Polly voice ID - engine="neural", # 'standard' or 'neural' - text_type="text", # 'text' or 'ssml' - language_code="en-US" + voice_id="Joanna", # any Polly voice ID + engine="neural", # "standard" | "neural" + text_type="text", # "text" | "ssml" + language_code="en-US", + lexicon_names=None, # optional list of Polly lexicons ) # Use in agent @@ -101,6 +102,35 @@ agent = Agent( ) ``` +Credentials follow the standard boto3 chain (env vars, `~/.aws/credentials`, SSO, instance profile, etc.). Pass `aws_access_key_id` + `aws_secret_access_key` (both required together, plus `aws_session_token` for temporary credentials from STS / SSO / assumed roles) or `aws_profile` to override. You may also inject a pre-built boto3 Polly client via `client=...`. `region_name` falls back to `AWS_REGION` / `AWS_DEFAULT_REGION` and finally `us-east-1`. + +### Speech-to-Text (STT) + +AWS Transcribe streaming STT converts audio to text in realtime. The connection auto-reconnects with exponential backoff on idle timeouts, audio-length limits, and transient errors. + +```python +from vision_agents.plugins import aws + +stt = aws.STT( + language_code="en-US", + region_name="us-east-1", + show_speaker_label=False, + enable_partial_results_stabilization=False, + partial_results_stability=None, # "high" | "medium" | "low" +) + +# Use in agent +agent = Agent( + llm=aws.LLM(model="qwen.qwen3-32b-v1:0"), + stt=stt, + # ... other components +) +``` + +Credentials follow the standard boto3 chain (env vars, `~/.aws/credentials`, SSO, instance profile, etc.). Pass `aws_access_key_id` + `aws_secret_access_key` (both required together, plus `aws_session_token` for temporary credentials from STS / SSO / assumed roles) or `aws_profile` to override. + +See `example/aws_pipeline_example.py` for a complete STT - LLM - TTS pipeline using only AWS components. + ## Function Calling ### Standard LLM (aws.LLM) diff --git a/plugins/aws/example/aws_pipeline_example.py b/plugins/aws/example/aws_pipeline_example.py new file mode 100644 index 000000000..1b38e5545 --- /dev/null +++ b/plugins/aws/example/aws_pipeline_example.py @@ -0,0 +1,48 @@ +""" +AWS STT - LLM - TTS Pipeline Example + +Voice agent built entirely from AWS components: +- STT: AWS Transcribe streaming +- LLM: AWS Bedrock (Qwen) +- TTS: AWS Polly +""" + +import asyncio +import logging + +from dotenv import load_dotenv +from vision_agents.core import Agent, Runner, User +from vision_agents.core.agents import AgentLauncher +from vision_agents.plugins import aws, getstream + +logger = logging.getLogger(__name__) + +load_dotenv() + + +async def create_agent(**kwargs) -> Agent: + agent = Agent( + edge=getstream.Edge(), + agent_user=User(name="AWS Voice Agent", id="agent"), + instructions="You are a voice agent. Keep replies short and " + "conversational. Do not use special characters or formatting.", + llm=aws.LLM(model="qwen.qwen3-32b-v1:0", region_name="us-east-1"), + stt=aws.STT(language_code="en-US", region_name="us-east-1"), + tts=aws.TTS(region_name="us-east-1", voice_id="Joanna", engine="neural"), + ) + + return agent + + +async def join_call(agent: Agent, call_type: str, call_id: str, **kwargs) -> None: + call = await agent.create_call(call_type, call_id) + + async with agent.join(call): + await asyncio.sleep(5) + await agent.simple_response(text="Ask the user about their day.") + + await agent.finish() + + +if __name__ == "__main__": + Runner(AgentLauncher(create_agent=create_agent, join_call=join_call)).cli() diff --git a/plugins/aws/pyproject.toml b/plugins/aws/pyproject.toml index deea1383f..aa47eb1fc 100644 --- a/plugins/aws/pyproject.toml +++ b/plugins/aws/pyproject.toml @@ -5,9 +5,9 @@ build-backend = "hatchling.build" [project] name = "vision-agents-plugins-aws" dynamic = ["version"] -description = "AWS (Bedrock) LLM integration for Vision Agents" +description = "AWS (Bedrock LLM, Transcribe STT, Polly TTS) integration for Vision Agents" readme = "README.md" -keywords = ["aws", "bedrock", "LLM", "AI", "voice agents", "agents"] +keywords = ["aws", "bedrock", "transcribe", "polly", "STT", "TTS", "LLM", "AI", "voice agents", "agents"] requires-python = ">=3.12" license = "MIT" dependencies = [ @@ -15,6 +15,7 @@ dependencies = [ "onnxruntime>=1.16.1,<2", "boto3>=1.42.65,<2", "aws-sdk-bedrock-runtime>=0.4.0,<1", + "aws-sdk-transcribe-streaming>=0.5.0,<1", ] [project.urls] diff --git a/plugins/aws/tests/test_aws_stt.py b/plugins/aws/tests/test_aws_stt.py new file mode 100644 index 000000000..9e6b28d32 --- /dev/null +++ b/plugins/aws/tests/test_aws_stt.py @@ -0,0 +1,170 @@ +import asyncio + +import pytest +from aws_sdk_transcribe_streaming.models import ( + Alternative, + Result, + Transcript, + TranscriptEvent, +) +from dotenv import load_dotenv +from vision_agents.core.turn_detection import TurnEndedEvent, TurnStartedEvent +from vision_agents.plugins import aws + +from conftest import STTSession + +load_dotenv() + + +class TestTranscribeSTT: + @pytest.fixture + def transcript_event_factory(self): + def factory( + text: str, *, is_partial: bool, start_time: float + ) -> TranscriptEvent: + return TranscriptEvent( + transcript=Transcript( + results=[ + Result( + result_id="r", + start_time=start_time, + end_time=start_time + 1.0, + is_partial=is_partial, + alternatives=[Alternative(transcript=text, items=[])], + ) + ] + ) + ) + + return factory + + async def test_partial_result_emits_partial_transcript_and_turn_started( + self, participant, transcript_event_factory + ): + stt = aws.STT(language_code="en-US") + stt._current_participant = participant + session = STTSession(stt) + turn_started: list[TurnStartedEvent] = [] + + @stt.events.subscribe + async def on_turn_started(event: TurnStartedEvent): + turn_started.append(event) + + stt._handle_transcript_event( + transcript_event_factory("hello", is_partial=True, start_time=0.0) + ) + await asyncio.sleep(0.05) + + assert [e.text for e in session.partial_transcripts] == ["hello"] + assert not session.transcripts + assert len(turn_started) == 1 + assert turn_started[0].participant == participant + + async def test_final_result_emits_transcript_and_turn_ended( + self, participant, transcript_event_factory + ): + stt = aws.STT(language_code="en-US") + stt._current_participant = participant + session = STTSession(stt) + turn_ended: list[TurnEndedEvent] = [] + + @stt.events.subscribe + async def on_turn_ended(event: TurnEndedEvent): + turn_ended.append(event) + + stt._handle_transcript_event( + transcript_event_factory("hello world", is_partial=False, start_time=0.0) + ) + await asyncio.sleep(0.05) + + assert [e.text for e in session.transcripts] == ["hello world"] + assert len(turn_ended) == 1 + assert turn_ended[0].participant == participant + + def test_partial_static_credentials_rejected(self): + with pytest.raises(ValueError, match="provided together"): + aws.STT(aws_access_key_id="AKIA...") + with pytest.raises(ValueError, match="provided together"): + aws.STT(aws_secret_access_key="secret") + + async def test_clear_drops_results_before_watermark( + self, participant, transcript_event_factory + ): + stt = aws.STT(language_code="en-US") + stt._audio_sent_seconds = 5.0 + await stt.clear() + + stt._current_participant = participant + session = STTSession(stt) + + stt._handle_transcript_event( + transcript_event_factory("stale", is_partial=False, start_time=2.0) + ) + stt._handle_transcript_event( + transcript_event_factory("fresh", is_partial=False, start_time=6.0) + ) + await asyncio.sleep(0.05) + + assert [e.text for e in session.transcripts] == ["fresh"] + + +@pytest.mark.integration +class TestTranscribeSTTIntegration: + @pytest.fixture + async def stt(self): + stt = aws.STT(language_code="en-US") + try: + await stt.start() + yield stt + finally: + await stt.close() + + async def test_transcribe_mia_audio_16khz( + self, stt, mia_audio_16khz_chunked, participant + ): + session = STTSession(stt) + + for chunk in mia_audio_16khz_chunked: + await stt.process_audio(chunk, participant=participant) + + await session.wait_for_result(timeout=30.0) + assert not session.errors, f"Errors occurred: {session.errors}" + + full_transcript = session.get_full_transcript().lower() + assert any( + word in full_transcript for word in ["village", "quiet", "mia", "treasures"] + ), f"Transcript did not match expected content: {full_transcript!r}" + + async def test_partial_transcripts_emitted( + self, stt, mia_audio_16khz_chunked, participant + ): + session = STTSession(stt) + + for chunk in mia_audio_16khz_chunked: + await stt.process_audio(chunk, participant=participant) + + await session.wait_for_result(timeout=30.0) + assert session.partial_transcripts, "No partial transcripts received" + + async def test_turn_events_emitted(self, stt, mia_audio_16khz_chunked, participant): + session = STTSession(stt) + turn_started: list[TurnStartedEvent] = [] + turn_ended: list[TurnEndedEvent] = [] + + @stt.events.subscribe + async def on_turn_started(event: TurnStartedEvent): + turn_started.append(event) + + @stt.events.subscribe + async def on_turn_ended(event: TurnEndedEvent): + turn_ended.append(event) + + for chunk in mia_audio_16khz_chunked: + await stt.process_audio(chunk, participant=participant) + + await session.wait_for_result(timeout=30.0) + + assert turn_started, "No TurnStartedEvent received" + assert turn_ended, "No TurnEndedEvent received" + assert turn_started[0].participant == participant + assert turn_ended[0].participant == participant diff --git a/plugins/aws/tests/test_tts.py b/plugins/aws/tests/test_tts.py index eeb20d7c7..8901d1e69 100644 --- a/plugins/aws/tests/test_tts.py +++ b/plugins/aws/tests/test_tts.py @@ -30,6 +30,14 @@ async def tts(self) -> aws.TTS: # type: ignore[name-defined] return aws.TTS(voice_id=os.environ.get("AWS_POLLY_VOICE", "Joanna")) +class TestAWSPollyTTS: + def test_partial_static_credentials_rejected(self): + with pytest.raises(ValueError, match="provided together"): + aws.TTS(aws_access_key_id="AKIA...") + with pytest.raises(ValueError, match="provided together"): + aws.TTS(aws_secret_access_key="secret") + + @pytest.mark.skip() @pytest.mark.integration class TestAWSPollyTTSIntegration: diff --git a/plugins/aws/vision_agents/plugins/aws/__init__.py b/plugins/aws/vision_agents/plugins/aws/__init__.py index aafbcf679..1b9a16718 100644 --- a/plugins/aws/vision_agents/plugins/aws/__init__.py +++ b/plugins/aws/vision_agents/plugins/aws/__init__.py @@ -1,5 +1,6 @@ from .aws_llm import BedrockLLM as LLM from .aws_realtime import Realtime +from .stt import TranscribeSTT as STT from .tts import TTS -__all__ = ["LLM", "Realtime", "TTS"] +__all__ = ["LLM", "Realtime", "STT", "TTS"] diff --git a/plugins/aws/vision_agents/plugins/aws/_credentials.py b/plugins/aws/vision_agents/plugins/aws/_credentials.py new file mode 100644 index 000000000..9c4ad4f9f --- /dev/null +++ b/plugins/aws/vision_agents/plugins/aws/_credentials.py @@ -0,0 +1,43 @@ +import asyncio +from typing import Any, Optional + +import boto3 +from smithy_aws_core.identity.components import ( + AWSCredentialsIdentity, + AWSIdentityProperties, +) +from smithy_core.aio.interfaces.identity import IdentityResolver + + +class Boto3CredentialsResolver( + IdentityResolver[AWSCredentialsIdentity, AWSIdentityProperties] +): + """IdentityResolver that delegates to boto3.Session for credential resolution. + + Supports the full boto3 credential chain: env vars, shared credentials files, + AWS profiles, SSO, EC2 instance profiles, etc. + """ + + def __init__(self, profile_name: Optional[str] = None) -> None: + self._session = boto3.Session(profile_name=profile_name) + + async def get_identity( + self, *, properties: AWSIdentityProperties, **kwargs: Any + ) -> AWSCredentialsIdentity: + # Both calls can block: get_credentials() walks the provider chain + # (file I/O, IMDS, SSO, STS) on first access, and get_frozen_credentials() + # triggers refresh on RefreshableCredentials. + credentials = await asyncio.to_thread(self._session.get_credentials) + if not credentials: + raise ValueError("Unable to load AWS credentials via boto3") + + creds = await asyncio.to_thread(credentials.get_frozen_credentials) + if not creds.access_key or not creds.secret_key: + raise ValueError("AWS credentials are incomplete") + + return AWSCredentialsIdentity( + access_key_id=creds.access_key, + secret_access_key=creds.secret_key, + session_token=creds.token or None, + expiration=None, + ) diff --git a/plugins/aws/vision_agents/plugins/aws/aws_realtime.py b/plugins/aws/vision_agents/plugins/aws/aws_realtime.py index 59fab23d9..a2799d958 100644 --- a/plugins/aws/vision_agents/plugins/aws/aws_realtime.py +++ b/plugins/aws/vision_agents/plugins/aws/aws_realtime.py @@ -7,7 +7,6 @@ from typing import Any, Dict, List, Optional import aiortc -import boto3 from aws_sdk_bedrock_runtime.client import ( BedrockRuntimeClient, InvokeModelWithBidirectionalStreamOperationInput, @@ -19,11 +18,6 @@ ) from getstream.video.rtc import PcmData from getstream.video.rtc.audio_track import AudioStreamTrack -from smithy_aws_core.identity.components import ( - AWSCredentialsIdentity, - AWSIdentityProperties, -) -from smithy_core.aio.interfaces.identity import IdentityResolver from vision_agents.core.agents.agent_types import AgentOptions from vision_agents.core.edge.types import Participant from vision_agents.core.llm import realtime @@ -32,6 +26,8 @@ from vision_agents.core.vad.silero import SileroVADSession, SileroVADSessionPool from vision_agents.core.warmup import Warmable +from ._credentials import Boto3CredentialsResolver + logger = logging.getLogger(__name__) @@ -40,44 +36,6 @@ FORCE_RECONNECT_IN_MINUTES = 7.0 -class Boto3CredentialsResolver( - IdentityResolver[AWSCredentialsIdentity, AWSIdentityProperties] -): - """IdentityResolver that delegates to boto3.Session for credential resolution. - - Supports the full boto3 credential chain: env vars, shared credentials files, - AWS profiles, SSO, EC2 instance profiles, etc. - """ - - def __init__(self, profile_name: Optional[str] = None) -> None: - self._session = boto3.Session(profile_name=profile_name) - self._cached: Optional[AWSCredentialsIdentity] = None - - async def get_identity( - self, *, properties: AWSIdentityProperties, **kwargs: Any - ) -> AWSCredentialsIdentity: - if self._cached is not None: - return self._cached - - credentials = self._session.get_credentials() - if not credentials: - raise ValueError("Unable to load AWS credentials via boto3") - - creds = credentials.get_frozen_credentials() - if not creds.access_key or not creds.secret_key: - raise ValueError("AWS credentials are incomplete") - - expiry = getattr(credentials, "_expiry_time", None) - - self._cached = AWSCredentialsIdentity( - access_key_id=creds.access_key, - secret_access_key=creds.secret_key, - session_token=creds.token or None, - expiration=expiry, - ) - return self._cached - - class RealtimeConnection: """Encapsulates a single AWS Bedrock bidirectional stream connection. diff --git a/plugins/aws/vision_agents/plugins/aws/stt.py b/plugins/aws/vision_agents/plugins/aws/stt.py new file mode 100644 index 000000000..2b526b718 --- /dev/null +++ b/plugins/aws/vision_agents/plugins/aws/stt.py @@ -0,0 +1,445 @@ +import asyncio +import logging +import time +from typing import Any, Literal, Optional + +from aws_sdk_transcribe_streaming.client import ( + Config, + StartStreamTranscriptionInput, + StartStreamTranscriptionOutput, + TranscribeStreamingClient, +) +from aws_sdk_transcribe_streaming.models import ( + AudioEvent, + AudioStream, + AudioStreamAudioEvent, + Item, + Result, + TranscriptEvent, + TranscriptResultStream, + TranscriptResultStreamInternalFailureException, + TranscriptResultStreamServiceUnavailableException, + TranscriptResultStreamTranscriptEvent, +) +from getstream.video.rtc import PcmData +from smithy_core.aio.eventstream import DuplexEventStream +from smithy_core.aio.interfaces.eventstream import EventPublisher, EventReceiver +from vision_agents.core import stt +from vision_agents.core.edge.types import Participant +from vision_agents.core.stt import TranscriptResponse +from vision_agents.core.utils.utils import cancel_and_wait + +from ._credentials import Boto3CredentialsResolver + +logger = logging.getLogger(__name__) + +_RETRIABLE_STREAM_ERRORS = ( + TranscriptResultStreamInternalFailureException, + TranscriptResultStreamServiceUnavailableException, +) + + +class TranscribeSTT(stt.STT): + """ + AWS Transcribe streaming Speech-to-Text implementation. + + Uses the smithy-based ``aws-sdk-transcribe-streaming`` client. Each + "natural speech segment" detected by AWS is mapped to a turn: + partials carry ``is_partial=True`` and the finalised segment arrives + with ``is_partial=False``. + + Docs: + - https://docs.aws.amazon.com/transcribe/latest/dg/streaming.html + - https://docs.aws.amazon.com/transcribe/latest/dg/streaming-partial-results.html + """ + + turn_detection: bool = True + + def __init__( + self, + language_code: str = "en-US", + region_name: str = "us-east-1", + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + aws_session_token: Optional[str] = None, + aws_profile: Optional[str] = None, + show_speaker_label: bool = False, + enable_partial_results_stabilization: bool = False, + partial_results_stability: Optional[Literal["high", "medium", "low"]] = None, + max_reconnect_backoff_seconds: float = 30.0, + ): + """ + Initialize AWS Transcribe streaming STT. + + Args: + language_code: BCP-47 language code, e.g. ``"en-US"``. + region_name: AWS region for the streaming endpoint. + aws_access_key_id: Optional explicit access key. + aws_secret_access_key: Optional explicit secret key. + aws_session_token: Optional session token (for temporary creds). + aws_profile: Optional named profile from ``~/.aws/credentials``. + Resolved via boto3 to a static key/secret pair. + show_speaker_label: Enable speaker diarization labels on items. + enable_partial_results_stabilization: Stabilise the trailing words + of partial transcripts to reduce flicker. + partial_results_stability: ``"high"``, ``"medium"`` or ``"low"``. + Only meaningful when stabilization is enabled. + max_reconnect_backoff_seconds: Cap on the exponential backoff + between reconnect attempts after a transient error. The + sequence starts at 1s and doubles up to this cap, then + stays there for subsequent attempts. Reconnects are + unlimited. + """ + if bool(aws_access_key_id) != bool(aws_secret_access_key): + raise ValueError( + "aws_access_key_id and aws_secret_access_key must be provided together" + ) + + super().__init__(provider_name="aws") + + self.language_code = language_code + self.region_name = region_name + self._aws_access_key_id = aws_access_key_id + self._aws_secret_access_key = aws_secret_access_key + self._aws_session_token = aws_session_token + self._aws_profile = aws_profile + self.show_speaker_label = show_speaker_label + self.enable_partial_results_stabilization = enable_partial_results_stabilization + self.partial_results_stability = partial_results_stability + self.max_reconnect_backoff_seconds = max_reconnect_backoff_seconds + # AWS Transcribe accepts 8000 (telephony) or 16000 (high quality). + # The bytes we send must match this declared rate exactly. + self._sample_rate = 16_000 + + self._client: Optional[TranscribeStreamingClient] = None + self._stream: Optional[ + DuplexEventStream[ + AudioStream, TranscriptResultStream, StartStreamTranscriptionOutput + ] + ] = None + self._input_stream: Optional[EventPublisher[AudioStream]] = None + self._output_stream: Optional[EventReceiver[TranscriptResultStream]] = None + self._recv_task: Optional[asyncio.Task[None]] = None + self._supervisor_task: Optional[asyncio.Task[None]] = None + self._reconnect_event = asyncio.Event() + self._current_participant: Optional[Participant] = None + self._turn_in_progress = False + self._audio_start_time: Optional[float] = None + # Media-time watermarks, in seconds. Total audio fed into the stream + # so far, and the cutoff snapshotted by clear(). Results whose + # start_time precedes the watermark are suppressed. The lock + # serialises increment-on-send with clear()'s snapshot so a chunk + # mid-flight cannot escape the cutoff. + self._audio_sent_seconds: float = 0.0 + self._start_time_watermark: float = 0.0 + self._watermark_lock = asyncio.Lock() + + async def start(self): + await super().start() + try: + await self._open_stream() + self._supervisor_task = asyncio.create_task(self._supervisor_loop()) + except BaseException: + self.started = False + if self._supervisor_task is not None: + await cancel_and_wait(self._supervisor_task) + self._supervisor_task = None + await self._close_streams() + raise + logger.info( + "AWS Transcribe streaming connection established (region=%s, lang=%s)", + self.region_name, + self.language_code, + ) + + async def clear(self): + # AWS Transcribe has no native way to drop in-flight events. Snapshot + # a media-time watermark so any result whose segment began before + # this point is suppressed in _handle_transcript_event. + await super().clear() + async with self._watermark_lock: + self._start_time_watermark = self._audio_sent_seconds + self._audio_start_time = None + self._turn_in_progress = False + self._current_participant = None + + async def close(self): + await super().close() + # Wake the supervisor so it observes self.closed and exits. + self._reconnect_event.set() + if self._supervisor_task is not None: + await cancel_and_wait(self._supervisor_task) + self._supervisor_task = None + self._audio_start_time = None + await self._close_streams() + + async def process_audio( + self, + pcm_data: PcmData, + participant: Optional[Participant] = None, + ): + resampled = pcm_data.resample(self._sample_rate, 1) + self._current_participant = participant + if self._audio_start_time is None: + self._audio_start_time = time.perf_counter() + + async with self._watermark_lock: + if self.closed or self._input_stream is None: + return + await self._input_stream.send( + AudioStreamAudioEvent( + value=AudioEvent(audio_chunk=resampled.samples.tobytes()) + ) + ) + self._audio_sent_seconds += resampled.duration + + async def _open_stream(self, timeout: float = 10.0): + client = TranscribeStreamingClient(config=await self._build_config()) + + async def _connect(): + _stream = await client.start_stream_transcription( + input=self._build_transcription_input() + ) + try: + _, _output_stream = await _stream.await_output() + return _stream, _output_stream + except asyncio.CancelledError: + await _stream.close() + raise + except Exception: + await _stream.close() + raise + + stream, output_stream = await asyncio.wait_for(_connect(), timeout=timeout) + + # New stream restarts AWS's media-time clock at 0. Reset counters + # and publish the new input stream atomically so process_audio() + # never observes a half-built stream and chunks mid-flight cannot + # escape the cutoff. + async with self._watermark_lock: + self._client = client + self._stream = stream + self._output_stream = output_stream + self._audio_sent_seconds = 0.0 + self._start_time_watermark = 0.0 + self._input_stream = stream.input_stream + self._recv_task = asyncio.create_task(self._recv_loop()) + + async def _build_config(self) -> Config: + kwargs: dict[str, Any] = { + "region": self.region_name, + "endpoint_uri": ( + f"https://transcribestreaming.{self.region_name}.amazonaws.com" + ), + } + + if self._aws_access_key_id and self._aws_secret_access_key: + kwargs["aws_access_key_id"] = self._aws_access_key_id + kwargs["aws_secret_access_key"] = self._aws_secret_access_key + if self._aws_session_token: + kwargs["aws_session_token"] = self._aws_session_token + else: + kwargs["aws_credentials_identity_resolver"] = await asyncio.to_thread( + Boto3CredentialsResolver, profile_name=self._aws_profile + ) + + return Config(**kwargs) + + def _build_transcription_input(self) -> StartStreamTranscriptionInput: + kwargs: dict[str, Any] = { + "language_code": self.language_code, + "media_sample_rate_hertz": self._sample_rate, + "media_encoding": "pcm", + } + if self.show_speaker_label: + kwargs["show_speaker_label"] = True + if self.enable_partial_results_stabilization: + kwargs["enable_partial_results_stabilization"] = True + if self.partial_results_stability is not None: + kwargs["partial_results_stability"] = self.partial_results_stability + return StartStreamTranscriptionInput(**kwargs) + + async def _recv_loop(self): + if self._output_stream is None: + return + + try: + async for event in self._output_stream: + if isinstance(event, TranscriptResultStreamTranscriptEvent): + self._handle_transcript_event(event.value) + elif isinstance(event, _RETRIABLE_STREAM_ERRORS): + logger.warning( + "Retriable AWS Transcribe error, will reconnect: %r", + event, + ) + if not self.closed: + # Stop accepting audio immediately; the supervisor may + # not run for up to max_reconnect_backoff_seconds. + self._input_stream = None + self._reconnect_event.set() + return + else: + logger.error("Permanent AWS Transcribe error: %r", event) + self._emit_error_event( + RuntimeError(f"AWS Transcribe error: {event!r}"), + participant=self._current_participant, + context="aws_transcribe", + ) + # Stop accepting audio; supervisor stays idle (no reconnect). + self._input_stream = None + return + # Stream ended cleanly. AWS closes on idle and on audio-length + # limits; treat that as retriable. + if not self.closed: + logger.info("AWS Transcribe stream ended, will reconnect") + self._input_stream = None + self._reconnect_event.set() + except asyncio.CancelledError: + raise + except Exception: + if self.closed: + return + logger.exception("AWS Transcribe receive loop failed, will reconnect") + self._input_stream = None + self._reconnect_event.set() + + def _handle_transcript_event(self, event: TranscriptEvent): + if event.transcript is None or not event.transcript.results: + return + + participant = self._current_participant + if participant is None: + logger.warning("Received transcript but no participant set") + return + + for result in event.transcript.results: + if result.start_time < self._start_time_watermark: + continue + if not result.alternatives: + continue + text = (result.alternatives[0].transcript or "").strip() + if not text: + continue + + response = self._result_to_response(result) + + if result.is_partial: + if not self._turn_in_progress: + self._turn_in_progress = True + self._emit_turn_started_event(participant) + self._emit_partial_transcript_event(text, participant, response) + else: + self._emit_transcript_event(text, participant, response) + self._audio_start_time = None + self._turn_in_progress = False + self._emit_turn_ended_event(participant) + + def _result_to_response(self, result: Result) -> TranscriptResponse: + items: list[Item] = [] + if result.alternatives: + items = result.alternatives[0].items or [] + scores = [i.confidence for i in items if i.confidence is not None] + confidence = sum(scores) / len(scores) if scores else None + + other: dict[str, Any] = { + "result_id": result.result_id, + "start_time": result.start_time, + "end_time": result.end_time, + } + if items: + other["items"] = [ + { + "type": item.type, + "content": item.content, + "start_time": item.start_time, + "end_time": item.end_time, + "speaker": item.speaker, + "confidence": item.confidence, + "stable": item.stable, + } + for item in items + ] + if result.channel_id: + other["channel_id"] = result.channel_id + + processing_time_ms: Optional[float] = None + if self._audio_start_time is not None: + processing_time_ms = (time.perf_counter() - self._audio_start_time) * 1000 + + return TranscriptResponse( + confidence=confidence, + language=self.language_code, + model_name="aws-transcribe-streaming", + other=other, + processing_time_ms=processing_time_ms, + ) + + async def _supervisor_loop(self): + """Wait for retriable failures and rebuild the stream. + + Triggered by ``_recv_loop`` setting ``_reconnect_event``. Each + attempt sleeps ``min(2**n, max_reconnect_backoff_seconds)`` and + then tears down the old stream and opens a new one. Retries are + unlimited; the counter resets after a successful reconnect. + """ + attempt = 0 + while not self.closed: + await self._reconnect_event.wait() + self._reconnect_event.clear() + if self.closed: + return + backoff = min(2.0**attempt, self.max_reconnect_backoff_seconds) + attempt += 1 + logger.info( + "Reconnecting to AWS Transcribe in %.1fs (attempt %d)", + backoff, + attempt, + ) + await asyncio.sleep(backoff) + try: + if self._turn_in_progress and self._current_participant is not None: + self._emit_turn_ended_event(self._current_participant) + self._turn_in_progress = False + self._audio_start_time = None + # Slow close+open runs outside the lock so concurrent + # process_audio calls observe _input_stream is None and + # drop audio immediately instead of stalling on the lock + # for the whole reconnect. _open_stream takes the lock + # briefly to swap in the new stream and reset watermarks + # atomically. + await self._close_streams() + await self._open_stream() + attempt = 0 + logger.info("AWS Transcribe reconnected") + except asyncio.CancelledError: + raise + except Exception: + logger.exception("AWS Transcribe reconnect failed") + self._reconnect_event.set() + + async def _close_streams(self, timeout=5.0): + # Close the input first so AWS sees END_STREAM and closes the stream. + if self._input_stream is not None: + try: + await self._input_stream.close() + except Exception: + logger.warning("Error closing input stream", exc_info=True) + if self._recv_task is not None: + try: + # Here the _recv_task is expected to exit + # when the input stream is closed. + # The cancel below is only a fallback for the + # case where AWS doesn't drain quickly enough. + await asyncio.wait_for(asyncio.shield(self._recv_task), timeout=timeout) + except Exception: + await cancel_and_wait(self._recv_task) + self._recv_task = None + if self._stream is not None: + try: + await self._stream.close() + except Exception: + logger.warning("Error closing stream", exc_info=True) + self._stream = None + self._input_stream = None + self._output_stream = None + self._client = None diff --git a/plugins/aws/vision_agents/plugins/aws/tts.py b/plugins/aws/vision_agents/plugins/aws/tts.py index 4685b01fd..8b4fb0b8b 100644 --- a/plugins/aws/vision_agents/plugins/aws/tts.py +++ b/plugins/aws/vision_agents/plugins/aws/tts.py @@ -17,7 +17,10 @@ class TTS(BaseTTS): - TextType can be 'text' or 'ssml' (auto-detected unless overridden) - Optional Engine ('standard' or 'neural'), LanguageCode, LexiconNames - Credentials are resolved via standard AWS SDK chain (env vars, profiles, roles). + Credentials are resolved via the standard boto3 chain (env vars, profiles, + SSO, instance profile, etc.). Override with explicit access key + secret + (plus optional session token) or with a named profile, or inject a + pre-built boto3 Polly client via ``client=...``. """ def __init__( @@ -29,8 +32,17 @@ def __init__( engine: Optional[str] = None, # 'standard' | 'neural' language_code: Optional[str] = None, lexicon_names: Optional[List[str]] = None, + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + aws_session_token: Optional[str] = None, + aws_profile: Optional[str] = None, client: Optional[Any] = None, ) -> None: + if bool(aws_access_key_id) != bool(aws_secret_access_key): + raise ValueError( + "aws_access_key_id and aws_secret_access_key must be provided together" + ) + super().__init__(provider_name="aws_polly") self.region_name = ( region_name @@ -49,13 +61,24 @@ def __init__( self.engine = engine self.language_code = language_code self.lexicon_names = lexicon_names + + session_kwargs: dict[str, Any] = {"region_name": self.region_name} + if aws_profile: + session_kwargs["profile_name"] = aws_profile + if aws_access_key_id: + session_kwargs["aws_access_key_id"] = aws_access_key_id + if aws_secret_access_key: + session_kwargs["aws_secret_access_key"] = aws_secret_access_key + if aws_session_token: + session_kwargs["aws_session_token"] = aws_session_token + self._session_kwargs = session_kwargs self._client = client @property async def client(self): if self._client is None: self._client = await asyncio.to_thread( - lambda: boto3.client("polly", region_name=self.region_name) + lambda: boto3.Session(**self._session_kwargs).client("polly") ) return self._client diff --git a/uv.lock b/uv.lock index 992a756a5..d97af9bfe 100644 --- a/uv.lock +++ b/uv.lock @@ -567,6 +567,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e3/f3/cb2cd5177bb6c21d339b96079ece8e30ffe9954540494bae73a6498677ac/aws_sdk_signers-0.2.0-py3-none-any.whl", hash = "sha256:e770dcc390e18093840ef4ce1ac70fc419fe6ac8737809ffce9a9fab56d8ac2d", size = 21621, upload-time = "2026-04-07T19:24:08.57Z" }, ] +[[package]] +name = "aws-sdk-transcribe-streaming" +version = "0.5.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "smithy-aws-core", extra = ["eventstream", "json"] }, + { name = "smithy-core" }, + { name = "smithy-http", extra = ["awscrt"] }, +] +sdist = { url = "https://files.pythonhosted.org/packages/fa/a9/98b060633ef1457520860825a3331943c256c7f31983150b35fb39e71add/aws_sdk_transcribe_streaming-0.5.0.tar.gz", hash = "sha256:87753ab869c79cb260028971898fe82aa3c623b9b3167b22e0f7fbccd9395556", size = 456922, upload-time = "2026-04-07T19:24:22.454Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ef/9f/dfa16632c4330b2386e5567cb2424ba57c49760f4935f4b925d1a7b97f35/aws_sdk_transcribe_streaming-0.5.0-py3-none-any.whl", hash = "sha256:41f5a0a045b143b3a034043a0245bebc3a11fa6e57b2d92d7b2388f4fd60bd60", size = 60965, upload-time = "2026-04-07T19:24:23.577Z" }, +] + [[package]] name = "awscrt" version = "0.32.1" @@ -7397,6 +7411,7 @@ name = "vision-agents-plugins-aws" source = { editable = "plugins/aws" } dependencies = [ { name = "aws-sdk-bedrock-runtime" }, + { name = "aws-sdk-transcribe-streaming" }, { name = "boto3" }, { name = "onnxruntime" }, { name = "vision-agents" }, @@ -7411,6 +7426,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "aws-sdk-bedrock-runtime", specifier = ">=0.4.0,<1" }, + { name = "aws-sdk-transcribe-streaming", specifier = ">=0.5.0,<1" }, { name = "boto3", specifier = ">=1.42.65,<2" }, { name = "onnxruntime", specifier = ">=1.16.1,<2" }, { name = "vision-agents", editable = "agents-core" },