From f5e98c02654b424184b701dd185776954a1df29e Mon Sep 17 00:00:00 2001 From: Daniil Gusev Date: Thu, 30 Apr 2026 22:57:11 +0200 Subject: [PATCH 01/14] Implement AWS Transcribe STT plugin --- plugins/aws/pyproject.toml | 5 +- plugins/aws/tests/test_aws_stt.py | 77 +++++ .../aws/vision_agents/plugins/aws/__init__.py | 3 +- .../vision_agents/plugins/aws/_credentials.py | 46 +++ plugins/aws/vision_agents/plugins/aws/stt.py | 322 ++++++++++++++++++ uv.lock | 16 + 6 files changed, 466 insertions(+), 3 deletions(-) create mode 100644 plugins/aws/tests/test_aws_stt.py create mode 100644 plugins/aws/vision_agents/plugins/aws/_credentials.py create mode 100644 plugins/aws/vision_agents/plugins/aws/stt.py diff --git a/plugins/aws/pyproject.toml b/plugins/aws/pyproject.toml index deea1383f..aa47eb1fc 100644 --- a/plugins/aws/pyproject.toml +++ b/plugins/aws/pyproject.toml @@ -5,9 +5,9 @@ build-backend = "hatchling.build" [project] name = "vision-agents-plugins-aws" dynamic = ["version"] -description = "AWS (Bedrock) LLM integration for Vision Agents" +description = "AWS (Bedrock LLM, Transcribe STT, Polly TTS) integration for Vision Agents" readme = "README.md" -keywords = ["aws", "bedrock", "LLM", "AI", "voice agents", "agents"] +keywords = ["aws", "bedrock", "transcribe", "polly", "STT", "TTS", "LLM", "AI", "voice agents", "agents"] requires-python = ">=3.12" license = "MIT" dependencies = [ @@ -15,6 +15,7 @@ dependencies = [ "onnxruntime>=1.16.1,<2", "boto3>=1.42.65,<2", "aws-sdk-bedrock-runtime>=0.4.0,<1", + "aws-sdk-transcribe-streaming>=0.5.0,<1", ] [project.urls] diff --git a/plugins/aws/tests/test_aws_stt.py b/plugins/aws/tests/test_aws_stt.py new file mode 100644 index 000000000..44b7452ad --- /dev/null +++ b/plugins/aws/tests/test_aws_stt.py @@ -0,0 +1,77 @@ +"""Integration tests for AWS Transcribe STT.""" + +import pytest +from dotenv import load_dotenv +from vision_agents.core.turn_detection import TurnEndedEvent, TurnStartedEvent +from vision_agents.plugins import aws + +from conftest import STTSession + +load_dotenv() + + +@pytest.mark.integration +class TestTranscribeSTT: + """Integration tests against real AWS Transcribe streaming.""" + + @pytest.fixture + async def stt(self): + stt = aws.STT(language_code="en-US") + try: + await stt.start() + yield stt + finally: + await stt.close() + + async def test_transcribe_mia_audio_16khz( + self, stt, mia_audio_16khz_chunked, participant + ): + session = STTSession(stt) + + for chunk in mia_audio_16khz_chunked: + await stt.process_audio(chunk, participant=participant) + + await session.wait_for_result(timeout=30.0) + assert not session.errors, f"Errors occurred: {session.errors}" + + full_transcript = session.get_full_transcript().lower() + assert any( + word in full_transcript + for word in ["village", "quiet", "mia", "treasures"] + ), f"Transcript did not match expected content: {full_transcript!r}" + + async def test_partial_transcripts_emitted( + self, stt, mia_audio_16khz_chunked, participant + ): + session = STTSession(stt) + + for chunk in mia_audio_16khz_chunked: + await stt.process_audio(chunk, participant=participant) + + await session.wait_for_result(timeout=30.0) + assert session.partial_transcripts, "No partial transcripts received" + + async def test_turn_events_emitted( + self, stt, mia_audio_16khz_chunked, participant + ): + session = STTSession(stt) + turn_started: list[TurnStartedEvent] = [] + turn_ended: list[TurnEndedEvent] = [] + + @stt.events.subscribe + async def on_turn_started(event: TurnStartedEvent): + turn_started.append(event) + + @stt.events.subscribe + async def on_turn_ended(event: TurnEndedEvent): + turn_ended.append(event) + + for chunk in mia_audio_16khz_chunked: + await stt.process_audio(chunk, participant=participant) + + await session.wait_for_result(timeout=30.0) + + assert turn_started, "No TurnStartedEvent received" + assert turn_ended, "No TurnEndedEvent received" + assert turn_started[0].participant == participant + assert turn_ended[0].participant == participant diff --git a/plugins/aws/vision_agents/plugins/aws/__init__.py b/plugins/aws/vision_agents/plugins/aws/__init__.py index aafbcf679..1b9a16718 100644 --- a/plugins/aws/vision_agents/plugins/aws/__init__.py +++ b/plugins/aws/vision_agents/plugins/aws/__init__.py @@ -1,5 +1,6 @@ from .aws_llm import BedrockLLM as LLM from .aws_realtime import Realtime +from .stt import TranscribeSTT as STT from .tts import TTS -__all__ = ["LLM", "Realtime", "TTS"] +__all__ = ["LLM", "Realtime", "STT", "TTS"] diff --git a/plugins/aws/vision_agents/plugins/aws/_credentials.py b/plugins/aws/vision_agents/plugins/aws/_credentials.py new file mode 100644 index 000000000..40e358307 --- /dev/null +++ b/plugins/aws/vision_agents/plugins/aws/_credentials.py @@ -0,0 +1,46 @@ +from typing import Any, Optional + +import boto3 +from smithy_aws_core.identity.components import ( + AWSCredentialsIdentity, + AWSIdentityProperties, +) +from smithy_core.aio.interfaces.identity import IdentityResolver + + +class Boto3CredentialsResolver( + IdentityResolver[AWSCredentialsIdentity, AWSIdentityProperties] +): + """IdentityResolver that delegates to boto3.Session for credential resolution. + + Supports the full boto3 credential chain: env vars, shared credentials files, + AWS profiles, SSO, EC2 instance profiles, etc. + """ + + def __init__(self, profile_name: Optional[str] = None) -> None: + self._session = boto3.Session(profile_name=profile_name) + self._cached: Optional[AWSCredentialsIdentity] = None + + async def get_identity( + self, *, properties: AWSIdentityProperties, **kwargs: Any + ) -> AWSCredentialsIdentity: + if self._cached is not None: + return self._cached + + credentials = self._session.get_credentials() + if not credentials: + raise ValueError("Unable to load AWS credentials via boto3") + + creds = credentials.get_frozen_credentials() + if not creds.access_key or not creds.secret_key: + raise ValueError("AWS credentials are incomplete") + + expiry = getattr(credentials, "_expiry_time", None) + + self._cached = AWSCredentialsIdentity( + access_key_id=creds.access_key, + secret_access_key=creds.secret_key, + session_token=creds.token or None, + expiration=expiry, + ) + return self._cached diff --git a/plugins/aws/vision_agents/plugins/aws/stt.py b/plugins/aws/vision_agents/plugins/aws/stt.py new file mode 100644 index 000000000..02cf95782 --- /dev/null +++ b/plugins/aws/vision_agents/plugins/aws/stt.py @@ -0,0 +1,322 @@ +import asyncio +import logging +import time +from typing import Any, Optional + +from aws_sdk_transcribe_streaming.client import ( + Config, + StartStreamTranscriptionInput, + StartStreamTranscriptionOutput, + TranscribeStreamingClient, +) +from aws_sdk_transcribe_streaming.models import ( + AudioEvent, + AudioStream, + AudioStreamAudioEvent, + Result, + TranscriptEvent, + TranscriptResultStream, + TranscriptResultStreamTranscriptEvent, +) +from getstream.video.rtc.track_util import PcmData +from smithy_core.aio.eventstream import DuplexEventStream +from smithy_core.aio.interfaces.eventstream import EventPublisher, EventReceiver +from vision_agents.core import stt +from vision_agents.core.edge.types import Participant +from vision_agents.core.stt import TranscriptResponse +from vision_agents.core.utils.utils import cancel_and_wait + +from ._credentials import Boto3CredentialsResolver + +# TODO(stt): reconnection on transport errors is deferred (see plan). + +logger = logging.getLogger(__name__) + + +class TranscribeSTT(stt.STT): + """ + AWS Transcribe streaming Speech-to-Text implementation. + + Uses the smithy-based ``aws-sdk-transcribe-streaming`` client. Each + "natural speech segment" detected by AWS is mapped to a turn: + partials carry ``is_partial=True`` and the finalised segment arrives + with ``is_partial=False``. + + Docs: + - https://docs.aws.amazon.com/transcribe/latest/dg/streaming.html + - https://docs.aws.amazon.com/transcribe/latest/dg/streaming-partial-results.html + """ + + turn_detection: bool = True + + def __init__( + self, + language_code: str = "en-US", + region_name: str = "us-east-1", + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + aws_session_token: Optional[str] = None, + aws_profile: Optional[str] = None, + show_speaker_label: bool = False, + enable_partial_results_stabilization: bool = False, + partial_results_stability: Optional[str] = None, + ): + """ + Initialize AWS Transcribe streaming STT. + + Args: + language_code: BCP-47 language code, e.g. ``"en-US"``. + region_name: AWS region for the streaming endpoint. + aws_access_key_id: Optional explicit access key. + aws_secret_access_key: Optional explicit secret key. + aws_session_token: Optional session token (for temporary creds). + aws_profile: Optional named profile from ``~/.aws/credentials``. + Resolved via boto3 to a static key/secret pair. + show_speaker_label: Enable speaker diarization labels on items. + enable_partial_results_stabilization: Stabilise the trailing words + of partial transcripts to reduce flicker. + partial_results_stability: ``"high"``, ``"medium"`` or ``"low"``. + Only meaningful when stabilization is enabled. + """ + super().__init__(provider_name="aws") + + self.language_code = language_code + self.region_name = region_name + self._aws_access_key_id = aws_access_key_id + self._aws_secret_access_key = aws_secret_access_key + self._aws_session_token = aws_session_token + self._aws_profile = aws_profile + self.show_speaker_label = show_speaker_label + self.enable_partial_results_stabilization = ( + enable_partial_results_stabilization + ) + self.partial_results_stability = partial_results_stability + + self._client: Optional[TranscribeStreamingClient] = None + self._stream: Optional[ + DuplexEventStream[ + AudioStream, TranscriptResultStream, StartStreamTranscriptionOutput + ] + ] = None + self._input_stream: Optional[EventPublisher[AudioStream]] = None + self._output_stream: Optional[EventReceiver[TranscriptResultStream]] = None + self._recv_task: Optional[asyncio.Task[None]] = None + self._current_participant: Optional[Participant] = None + self._turn_in_progress = False + self._audio_start_time: Optional[float] = None + # Media-time watermarks, in seconds. Total audio fed into the stream + # so far, and the cutoff snapshotted by clear(). Results whose + # start_time precedes the watermark are suppressed. The lock + # serialises increment-on-send with clear()'s snapshot so a chunk + # mid-flight cannot escape the cutoff. + self._audio_sent_seconds: float = 0.0 + self._start_time_watermark: float = 0.0 + self._watermark_lock = asyncio.Lock() + + async def start(self): + await super().start() + + self._client = TranscribeStreamingClient(config=await self._build_config()) + + stream = await asyncio.wait_for( + self._client.start_stream_transcription( + input=self._build_transcription_input() + ), + timeout=10.0, + ) + self._stream = stream + self._input_stream = stream.input_stream + _, self._output_stream = await stream.await_output() + self._recv_task = asyncio.create_task(self._recv_loop()) + + logger.info( + "AWS Transcribe streaming connection established (region=%s, lang=%s)", + self.region_name, + self.language_code, + ) + + async def _build_config(self) -> Config: + kwargs: dict[str, Any] = { + "region": self.region_name, + "endpoint_uri": ( + f"https://transcribestreaming.{self.region_name}.amazonaws.com" + ), + } + + if self._aws_access_key_id and self._aws_secret_access_key: + kwargs["aws_access_key_id"] = self._aws_access_key_id + kwargs["aws_secret_access_key"] = self._aws_secret_access_key + if self._aws_session_token: + kwargs["aws_session_token"] = self._aws_session_token + else: + # boto3.Session() reads ~/.aws/credentials synchronously, so the + # resolver has to be constructed off the event loop. + kwargs["aws_credentials_identity_resolver"] = await asyncio.to_thread( + Boto3CredentialsResolver, profile_name=self._aws_profile + ) + + return Config(**kwargs) + + def _build_transcription_input(self) -> StartStreamTranscriptionInput: + kwargs: dict[str, Any] = { + "language_code": self.language_code, + "media_sample_rate_hertz": 16000, + "media_encoding": "pcm", + } + if self.show_speaker_label: + kwargs["show_speaker_label"] = True + if self.enable_partial_results_stabilization: + kwargs["enable_partial_results_stabilization"] = True + if self.partial_results_stability is not None: + kwargs["partial_results_stability"] = ( + self.partial_results_stability + ) + return StartStreamTranscriptionInput(**kwargs) + + async def process_audio( + self, + pcm_data: PcmData, + participant: Optional[Participant] = None, + ): + if self.closed or self._input_stream is None: + return + + resampled = pcm_data.resample(16_000, 1) + self._current_participant = participant + if self._audio_start_time is None: + self._audio_start_time = time.perf_counter() + + async with self._watermark_lock: + await self._input_stream.send( + AudioStreamAudioEvent( + value=AudioEvent(audio_chunk=resampled.samples.tobytes()) + ) + ) + self._audio_sent_seconds += resampled.duration + + async def _recv_loop(self): + if self._output_stream is None: + return + try: + async for event in self._output_stream: + if isinstance(event, TranscriptResultStreamTranscriptEvent): + self._handle_transcript_event(event.value) + except asyncio.CancelledError: + raise + except Exception: + if not self.closed: + logger.exception("AWS Transcribe receive loop failed") + + def _handle_transcript_event(self, event: TranscriptEvent): + if event.transcript is None or not event.transcript.results: + return + + participant = self._current_participant + if participant is None: + raise ValueError( + "No participant set - audio must be processed with a participant" + ) + + for result in event.transcript.results: + if result.start_time < self._start_time_watermark: + continue + if not result.alternatives: + continue + text = (result.alternatives[0].transcript or "").strip() + if not text: + continue + + response = self._result_to_response(result) + + if result.is_partial: + if not self._turn_in_progress: + self._turn_in_progress = True + self._emit_turn_started_event(participant) + self._emit_partial_transcript_event(text, participant, response) + else: + self._emit_transcript_event(text, participant, response) + self._audio_start_time = None + self._turn_in_progress = False + self._emit_turn_ended_event(participant) + + def _result_to_response(self, result: Result) -> TranscriptResponse: + items = result.alternatives[0].items or [] + scores = [i.confidence for i in items if i.confidence is not None] + confidence = sum(scores) / len(scores) if scores else None + + other: dict[str, Any] = { + "result_id": result.result_id, + "start_time": result.start_time, + "end_time": result.end_time, + } + if items: + other["items"] = [ + { + "type": item.type, + "content": item.content, + "start_time": item.start_time, + "end_time": item.end_time, + "speaker": item.speaker, + "confidence": item.confidence, + "stable": item.stable, + } + for item in items + ] + if result.channel_id: + other["channel_id"] = result.channel_id + + processing_time_ms: Optional[float] = None + if self._audio_start_time is not None: + processing_time_ms = ( + time.perf_counter() - self._audio_start_time + ) * 1000 + + return TranscriptResponse( + confidence=confidence, + language=self.language_code, + model_name="aws-transcribe-streaming", + other=other, + processing_time_ms=processing_time_ms, + ) + + async def clear(self): + # AWS Transcribe has no native way to drop in-flight events. Snapshot + # a media-time watermark so any result whose segment began before + # this point is suppressed in _handle_transcript_event. + await super().clear() + async with self._watermark_lock: + self._start_time_watermark = self._audio_sent_seconds + self._audio_start_time = None + self._turn_in_progress = False + self._current_participant = None + + async def close(self): + await super().close() + + self._audio_start_time = None + + if self._input_stream is not None: + try: + await self._input_stream.close() + except Exception: + logger.warning( + "Error closing AWS Transcribe input stream", exc_info=True + ) + + if self._recv_task is not None: + await cancel_and_wait(self._recv_task) + self._recv_task = None + + if self._stream is not None: + try: + await self._stream.close() + except Exception: + logger.warning( + "Error closing AWS Transcribe stream", exc_info=True + ) + finally: + self._stream = None + + self._input_stream = None + self._output_stream = None + self._client = None diff --git a/uv.lock b/uv.lock index 992a756a5..d97af9bfe 100644 --- a/uv.lock +++ b/uv.lock @@ -567,6 +567,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e3/f3/cb2cd5177bb6c21d339b96079ece8e30ffe9954540494bae73a6498677ac/aws_sdk_signers-0.2.0-py3-none-any.whl", hash = "sha256:e770dcc390e18093840ef4ce1ac70fc419fe6ac8737809ffce9a9fab56d8ac2d", size = 21621, upload-time = "2026-04-07T19:24:08.57Z" }, ] +[[package]] +name = "aws-sdk-transcribe-streaming" +version = "0.5.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "smithy-aws-core", extra = ["eventstream", "json"] }, + { name = "smithy-core" }, + { name = "smithy-http", extra = ["awscrt"] }, +] +sdist = { url = "https://files.pythonhosted.org/packages/fa/a9/98b060633ef1457520860825a3331943c256c7f31983150b35fb39e71add/aws_sdk_transcribe_streaming-0.5.0.tar.gz", hash = "sha256:87753ab869c79cb260028971898fe82aa3c623b9b3167b22e0f7fbccd9395556", size = 456922, upload-time = "2026-04-07T19:24:22.454Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ef/9f/dfa16632c4330b2386e5567cb2424ba57c49760f4935f4b925d1a7b97f35/aws_sdk_transcribe_streaming-0.5.0-py3-none-any.whl", hash = "sha256:41f5a0a045b143b3a034043a0245bebc3a11fa6e57b2d92d7b2388f4fd60bd60", size = 60965, upload-time = "2026-04-07T19:24:23.577Z" }, +] + [[package]] name = "awscrt" version = "0.32.1" @@ -7397,6 +7411,7 @@ name = "vision-agents-plugins-aws" source = { editable = "plugins/aws" } dependencies = [ { name = "aws-sdk-bedrock-runtime" }, + { name = "aws-sdk-transcribe-streaming" }, { name = "boto3" }, { name = "onnxruntime" }, { name = "vision-agents" }, @@ -7411,6 +7426,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "aws-sdk-bedrock-runtime", specifier = ">=0.4.0,<1" }, + { name = "aws-sdk-transcribe-streaming", specifier = ">=0.5.0,<1" }, { name = "boto3", specifier = ">=1.42.65,<2" }, { name = "onnxruntime", specifier = ">=1.16.1,<2" }, { name = "vision-agents", editable = "agents-core" }, From 985a4ea7524a9778abf29c434c5b5653e0fe8123 Mon Sep 17 00:00:00 2001 From: Daniil Gusev Date: Thu, 30 Apr 2026 22:58:08 +0200 Subject: [PATCH 02/14] aws: Move Boto3CredentialsResolver to a module to share with STT --- .../vision_agents/plugins/aws/aws_realtime.py | 46 +------------------ 1 file changed, 2 insertions(+), 44 deletions(-) diff --git a/plugins/aws/vision_agents/plugins/aws/aws_realtime.py b/plugins/aws/vision_agents/plugins/aws/aws_realtime.py index 59fab23d9..a2799d958 100644 --- a/plugins/aws/vision_agents/plugins/aws/aws_realtime.py +++ b/plugins/aws/vision_agents/plugins/aws/aws_realtime.py @@ -7,7 +7,6 @@ from typing import Any, Dict, List, Optional import aiortc -import boto3 from aws_sdk_bedrock_runtime.client import ( BedrockRuntimeClient, InvokeModelWithBidirectionalStreamOperationInput, @@ -19,11 +18,6 @@ ) from getstream.video.rtc import PcmData from getstream.video.rtc.audio_track import AudioStreamTrack -from smithy_aws_core.identity.components import ( - AWSCredentialsIdentity, - AWSIdentityProperties, -) -from smithy_core.aio.interfaces.identity import IdentityResolver from vision_agents.core.agents.agent_types import AgentOptions from vision_agents.core.edge.types import Participant from vision_agents.core.llm import realtime @@ -32,6 +26,8 @@ from vision_agents.core.vad.silero import SileroVADSession, SileroVADSessionPool from vision_agents.core.warmup import Warmable +from ._credentials import Boto3CredentialsResolver + logger = logging.getLogger(__name__) @@ -40,44 +36,6 @@ FORCE_RECONNECT_IN_MINUTES = 7.0 -class Boto3CredentialsResolver( - IdentityResolver[AWSCredentialsIdentity, AWSIdentityProperties] -): - """IdentityResolver that delegates to boto3.Session for credential resolution. - - Supports the full boto3 credential chain: env vars, shared credentials files, - AWS profiles, SSO, EC2 instance profiles, etc. - """ - - def __init__(self, profile_name: Optional[str] = None) -> None: - self._session = boto3.Session(profile_name=profile_name) - self._cached: Optional[AWSCredentialsIdentity] = None - - async def get_identity( - self, *, properties: AWSIdentityProperties, **kwargs: Any - ) -> AWSCredentialsIdentity: - if self._cached is not None: - return self._cached - - credentials = self._session.get_credentials() - if not credentials: - raise ValueError("Unable to load AWS credentials via boto3") - - creds = credentials.get_frozen_credentials() - if not creds.access_key or not creds.secret_key: - raise ValueError("AWS credentials are incomplete") - - expiry = getattr(credentials, "_expiry_time", None) - - self._cached = AWSCredentialsIdentity( - access_key_id=creds.access_key, - secret_access_key=creds.secret_key, - session_token=creds.token or None, - expiration=expiry, - ) - return self._cached - - class RealtimeConnection: """Encapsulates a single AWS Bedrock bidirectional stream connection. From 81a046aadd1e2d7556c44305a5a4fa75733594de Mon Sep 17 00:00:00 2001 From: Daniil Gusev Date: Thu, 30 Apr 2026 23:40:09 +0200 Subject: [PATCH 03/14] aws.TranscribeSTT: reconnection handling, cleanup, more tests --- plugins/aws/tests/test_aws_stt.py | 103 ++++++++- plugins/aws/vision_agents/plugins/aws/stt.py | 230 +++++++++++++------ 2 files changed, 250 insertions(+), 83 deletions(-) diff --git a/plugins/aws/tests/test_aws_stt.py b/plugins/aws/tests/test_aws_stt.py index 44b7452ad..6283526bf 100644 --- a/plugins/aws/tests/test_aws_stt.py +++ b/plugins/aws/tests/test_aws_stt.py @@ -1,6 +1,12 @@ -"""Integration tests for AWS Transcribe STT.""" +import asyncio import pytest +from aws_sdk_transcribe_streaming.models import ( + Alternative, + Result, + Transcript, + TranscriptEvent, +) from dotenv import load_dotenv from vision_agents.core.turn_detection import TurnEndedEvent, TurnStartedEvent from vision_agents.plugins import aws @@ -10,10 +16,94 @@ load_dotenv() -@pytest.mark.integration class TestTranscribeSTT: - """Integration tests against real AWS Transcribe streaming.""" + @pytest.fixture + def transcript_event_factory(self): + def factory( + text: str, *, is_partial: bool, start_time: float + ) -> TranscriptEvent: + return TranscriptEvent( + transcript=Transcript( + results=[ + Result( + result_id="r", + start_time=start_time, + end_time=start_time + 1.0, + is_partial=is_partial, + alternatives=[Alternative(transcript=text, items=[])], + ) + ] + ) + ) + + return factory + + async def test_partial_result_emits_partial_transcript_and_turn_started( + self, participant, transcript_event_factory + ): + stt = aws.STT(language_code="en-US") + stt._current_participant = participant + session = STTSession(stt) + turn_started: list[TurnStartedEvent] = [] + + @stt.events.subscribe + async def on_turn_started(event: TurnStartedEvent): + turn_started.append(event) + + stt._handle_transcript_event( + transcript_event_factory("hello", is_partial=True, start_time=0.0) + ) + await asyncio.sleep(0.05) + + assert [e.text for e in session.partial_transcripts] == ["hello"] + assert not session.transcripts + assert len(turn_started) == 1 + assert turn_started[0].participant == participant + + async def test_final_result_emits_transcript_and_turn_ended( + self, participant, transcript_event_factory + ): + stt = aws.STT(language_code="en-US") + stt._current_participant = participant + session = STTSession(stt) + turn_ended: list[TurnEndedEvent] = [] + @stt.events.subscribe + async def on_turn_ended(event: TurnEndedEvent): + turn_ended.append(event) + + stt._handle_transcript_event( + transcript_event_factory("hello world", is_partial=False, start_time=0.0) + ) + await asyncio.sleep(0.05) + + assert [e.text for e in session.transcripts] == ["hello world"] + assert len(turn_ended) == 1 + assert turn_ended[0].participant == participant + + async def test_clear_drops_results_before_watermark( + self, participant, transcript_event_factory + ): + stt = aws.STT(language_code="en-US") + stt._audio_sent_seconds = 5.0 + await stt.clear() + + stt._current_participant = participant + session = STTSession(stt) + + stt._handle_transcript_event( + transcript_event_factory("stale", is_partial=False, start_time=2.0) + ) + stt._handle_transcript_event( + transcript_event_factory("fresh", is_partial=False, start_time=6.0) + ) + await asyncio.sleep(0.05) + + assert [e.text for e in session.transcripts] == ["fresh"] + + +@pytest.mark.integration +class TestTranscribeSTTIntegration: @pytest.fixture async def stt(self): stt = aws.STT(language_code="en-US") @@ -36,8 +126,7 @@ async def test_transcribe_mia_audio_16khz( full_transcript = session.get_full_transcript().lower() assert any( - word in full_transcript - for word in ["village", "quiet", "mia", "treasures"] + word in full_transcript for word in ["village", "quiet", "mia", "treasures"] ), f"Transcript did not match expected content: {full_transcript!r}" async def test_partial_transcripts_emitted( @@ -51,9 +140,7 @@ async def test_partial_transcripts_emitted( await session.wait_for_result(timeout=30.0) assert session.partial_transcripts, "No partial transcripts received" - async def test_turn_events_emitted( - self, stt, mia_audio_16khz_chunked, participant - ): + async def test_turn_events_emitted(self, stt, mia_audio_16khz_chunked, participant): session = STTSession(stt) turn_started: list[TurnStartedEvent] = [] turn_ended: list[TurnEndedEvent] = [] diff --git a/plugins/aws/vision_agents/plugins/aws/stt.py b/plugins/aws/vision_agents/plugins/aws/stt.py index 02cf95782..842cb0cde 100644 --- a/plugins/aws/vision_agents/plugins/aws/stt.py +++ b/plugins/aws/vision_agents/plugins/aws/stt.py @@ -1,7 +1,7 @@ import asyncio import logging import time -from typing import Any, Optional +from typing import Any, Literal, Optional from aws_sdk_transcribe_streaming.client import ( Config, @@ -13,12 +13,15 @@ AudioEvent, AudioStream, AudioStreamAudioEvent, + Item, Result, TranscriptEvent, TranscriptResultStream, + TranscriptResultStreamInternalFailureException, + TranscriptResultStreamServiceUnavailableException, TranscriptResultStreamTranscriptEvent, ) -from getstream.video.rtc.track_util import PcmData +from getstream.video.rtc import PcmData from smithy_core.aio.eventstream import DuplexEventStream from smithy_core.aio.interfaces.eventstream import EventPublisher, EventReceiver from vision_agents.core import stt @@ -28,10 +31,13 @@ from ._credentials import Boto3CredentialsResolver -# TODO(stt): reconnection on transport errors is deferred (see plan). - logger = logging.getLogger(__name__) +_RETRIABLE_STREAM_ERRORS = ( + TranscriptResultStreamInternalFailureException, + TranscriptResultStreamServiceUnavailableException, +) + class TranscribeSTT(stt.STT): """ @@ -59,7 +65,8 @@ def __init__( aws_profile: Optional[str] = None, show_speaker_label: bool = False, enable_partial_results_stabilization: bool = False, - partial_results_stability: Optional[str] = None, + partial_results_stability: Optional[Literal["high", "medium", "low"]] = None, + max_reconnect_backoff_seconds: float = 30.0, ): """ Initialize AWS Transcribe streaming STT. @@ -77,6 +84,11 @@ def __init__( of partial transcripts to reduce flicker. partial_results_stability: ``"high"``, ``"medium"`` or ``"low"``. Only meaningful when stabilization is enabled. + max_reconnect_backoff_seconds: Cap on the exponential backoff + between reconnect attempts after a transient error. The + sequence starts at 1s and doubles up to this cap, then + stays there for subsequent attempts. Reconnects are + unlimited. """ super().__init__(provider_name="aws") @@ -87,10 +99,12 @@ def __init__( self._aws_session_token = aws_session_token self._aws_profile = aws_profile self.show_speaker_label = show_speaker_label - self.enable_partial_results_stabilization = ( - enable_partial_results_stabilization - ) + self.enable_partial_results_stabilization = enable_partial_results_stabilization self.partial_results_stability = partial_results_stability + self.max_reconnect_backoff_seconds = max_reconnect_backoff_seconds + # AWS Transcribe accepts 8000 (telephony) or 16000 (high quality). + # The bytes we send must match this declared rate exactly. + self._sample_rate = 16_000 self._client: Optional[TranscribeStreamingClient] = None self._stream: Optional[ @@ -101,6 +115,8 @@ def __init__( self._input_stream: Optional[EventPublisher[AudioStream]] = None self._output_stream: Optional[EventReceiver[TranscriptResultStream]] = None self._recv_task: Optional[asyncio.Task[None]] = None + self._supervisor_task: Optional[asyncio.Task[None]] = None + self._reconnect_event = asyncio.Event() self._current_participant: Optional[Participant] = None self._turn_in_progress = False self._audio_start_time: Optional[float] = None @@ -115,7 +131,56 @@ def __init__( async def start(self): await super().start() + await self._open_stream() + self._supervisor_task = asyncio.create_task(self._supervisor_loop()) + logger.info( + "AWS Transcribe streaming connection established (region=%s, lang=%s)", + self.region_name, + self.language_code, + ) + + async def clear(self): + # AWS Transcribe has no native way to drop in-flight events. Snapshot + # a media-time watermark so any result whose segment began before + # this point is suppressed in _handle_transcript_event. + await super().clear() + async with self._watermark_lock: + self._start_time_watermark = self._audio_sent_seconds + self._audio_start_time = None + self._turn_in_progress = False + self._current_participant = None + + async def close(self): + await super().close() + # Wake the supervisor so it observes self.closed and exits. + self._reconnect_event.set() + if self._supervisor_task is not None: + await cancel_and_wait(self._supervisor_task) + self._supervisor_task = None + self._audio_start_time = None + await self._close_streams() + + async def process_audio( + self, + pcm_data: PcmData, + participant: Optional[Participant] = None, + ): + resampled = pcm_data.resample(self._sample_rate, 1) + self._current_participant = participant + if self._audio_start_time is None: + self._audio_start_time = time.perf_counter() + async with self._watermark_lock: + if self.closed or self._input_stream is None: + return + await self._input_stream.send( + AudioStreamAudioEvent( + value=AudioEvent(audio_chunk=resampled.samples.tobytes()) + ) + ) + self._audio_sent_seconds += resampled.duration + + async def _open_stream(self): self._client = TranscribeStreamingClient(config=await self._build_config()) stream = await asyncio.wait_for( @@ -129,12 +194,6 @@ async def start(self): _, self._output_stream = await stream.await_output() self._recv_task = asyncio.create_task(self._recv_loop()) - logger.info( - "AWS Transcribe streaming connection established (region=%s, lang=%s)", - self.region_name, - self.language_code, - ) - async def _build_config(self) -> Config: kwargs: dict[str, Any] = { "region": self.region_name, @@ -160,7 +219,7 @@ async def _build_config(self) -> Config: def _build_transcription_input(self) -> StartStreamTranscriptionInput: kwargs: dict[str, Any] = { "language_code": self.language_code, - "media_sample_rate_hertz": 16000, + "media_sample_rate_hertz": self._sample_rate, "media_encoding": "pcm", } if self.show_speaker_label: @@ -168,32 +227,9 @@ def _build_transcription_input(self) -> StartStreamTranscriptionInput: if self.enable_partial_results_stabilization: kwargs["enable_partial_results_stabilization"] = True if self.partial_results_stability is not None: - kwargs["partial_results_stability"] = ( - self.partial_results_stability - ) + kwargs["partial_results_stability"] = self.partial_results_stability return StartStreamTranscriptionInput(**kwargs) - async def process_audio( - self, - pcm_data: PcmData, - participant: Optional[Participant] = None, - ): - if self.closed or self._input_stream is None: - return - - resampled = pcm_data.resample(16_000, 1) - self._current_participant = participant - if self._audio_start_time is None: - self._audio_start_time = time.perf_counter() - - async with self._watermark_lock: - await self._input_stream.send( - AudioStreamAudioEvent( - value=AudioEvent(audio_chunk=resampled.samples.tobytes()) - ) - ) - self._audio_sent_seconds += resampled.duration - async def _recv_loop(self): if self._output_stream is None: return @@ -201,11 +237,36 @@ async def _recv_loop(self): async for event in self._output_stream: if isinstance(event, TranscriptResultStreamTranscriptEvent): self._handle_transcript_event(event.value) + elif isinstance(event, _RETRIABLE_STREAM_ERRORS): + logger.warning( + "Retriable AWS Transcribe error, will reconnect: %r", + event, + ) + if not self.closed: + self._reconnect_event.set() + return + else: + logger.error("Permanent AWS Transcribe error: %r", event) + self._emit_error_event( + RuntimeError(f"AWS Transcribe error: {event!r}"), + participant=self._current_participant, + context="aws_transcribe", + ) + # Stop accepting audio; supervisor stays idle (no reconnect). + self._input_stream = None + return + # Stream ended cleanly. AWS closes on idle and at the 15-minute + # session limit; treat that as retriable. + if not self.closed: + logger.info("AWS Transcribe stream ended, will reconnect") + self._reconnect_event.set() except asyncio.CancelledError: raise except Exception: - if not self.closed: - logger.exception("AWS Transcribe receive loop failed") + if self.closed: + return + logger.exception("AWS Transcribe receive loop failed, will reconnect") + self._reconnect_event.set() def _handle_transcript_event(self, event: TranscriptEvent): if event.transcript is None or not event.transcript.results: @@ -213,9 +274,8 @@ def _handle_transcript_event(self, event: TranscriptEvent): participant = self._current_participant if participant is None: - raise ValueError( - "No participant set - audio must be processed with a participant" - ) + logger.warning("Received transcript but no participant set") + return for result in event.transcript.results: if result.start_time < self._start_time_watermark: @@ -240,7 +300,9 @@ def _handle_transcript_event(self, event: TranscriptEvent): self._emit_turn_ended_event(participant) def _result_to_response(self, result: Result) -> TranscriptResponse: - items = result.alternatives[0].items or [] + items: list[Item] = [] + if result.alternatives: + items = result.alternatives[0].items or [] scores = [i.confidence for i in items if i.confidence is not None] confidence = sum(scores) / len(scores) if scores else None @@ -267,9 +329,7 @@ def _result_to_response(self, result: Result) -> TranscriptResponse: processing_time_ms: Optional[float] = None if self._audio_start_time is not None: - processing_time_ms = ( - time.perf_counter() - self._audio_start_time - ) * 1000 + processing_time_ms = (time.perf_counter() - self._audio_start_time) * 1000 return TranscriptResponse( confidence=confidence, @@ -279,44 +339,64 @@ def _result_to_response(self, result: Result) -> TranscriptResponse: processing_time_ms=processing_time_ms, ) - async def clear(self): - # AWS Transcribe has no native way to drop in-flight events. Snapshot - # a media-time watermark so any result whose segment began before - # this point is suppressed in _handle_transcript_event. - await super().clear() - async with self._watermark_lock: - self._start_time_watermark = self._audio_sent_seconds - self._audio_start_time = None - self._turn_in_progress = False - self._current_participant = None - - async def close(self): - await super().close() - - self._audio_start_time = None + async def _supervisor_loop(self): + """Wait for retriable failures and rebuild the stream. - if self._input_stream is not None: + Triggered by ``_recv_loop`` setting ``_reconnect_event``. Each + attempt sleeps ``min(2**n, max_reconnect_backoff_seconds)`` and + then tears down the old stream and opens a new one. Retries are + unlimited; the counter resets after a successful reconnect. + """ + attempt = 0 + while not self.closed: + await self._reconnect_event.wait() + self._reconnect_event.clear() + if self.closed: + return + backoff = min(2.0**attempt, self.max_reconnect_backoff_seconds) + attempt += 1 + logger.info( + "Reconnecting to AWS Transcribe in %.1fs (attempt %d)", + backoff, + attempt, + ) + await asyncio.sleep(backoff) try: - await self._input_stream.close() + if self._turn_in_progress and self._current_participant is not None: + self._emit_turn_ended_event(self._current_participant) + self._turn_in_progress = False + self._audio_start_time = None + # New stream restarts AWS's media-time clock at 0. Hold the + # lock across close+open so process_audio cannot send into a + # half-torn-down or half-built stream. + async with self._watermark_lock: + self._audio_sent_seconds = 0.0 + self._start_time_watermark = 0.0 + await self._close_streams() + await self._open_stream() + attempt = 0 + logger.info("AWS Transcribe reconnected") + except asyncio.CancelledError: + raise except Exception: - logger.warning( - "Error closing AWS Transcribe input stream", exc_info=True - ) + logger.exception("AWS Transcribe reconnect failed") + self._reconnect_event.set() + async def _close_streams(self): if self._recv_task is not None: await cancel_and_wait(self._recv_task) self._recv_task = None - + if self._input_stream is not None: + try: + await self._input_stream.close() + except Exception: + logger.warning("Error closing input stream", exc_info=True) if self._stream is not None: try: await self._stream.close() except Exception: - logger.warning( - "Error closing AWS Transcribe stream", exc_info=True - ) - finally: - self._stream = None - + logger.warning("Error closing stream", exc_info=True) + self._stream = None self._input_stream = None self._output_stream = None self._client = None From f5a5c5b706e4e038391d649ae64cca544fef49f4 Mon Sep 17 00:00:00 2001 From: Daniil Gusev Date: Tue, 5 May 2026 16:34:22 +0200 Subject: [PATCH 04/14] Fix InvalidStateError when closing aws.TranscribeSTT --- plugins/aws/vision_agents/plugins/aws/stt.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/plugins/aws/vision_agents/plugins/aws/stt.py b/plugins/aws/vision_agents/plugins/aws/stt.py index 842cb0cde..cb7fa7aaa 100644 --- a/plugins/aws/vision_agents/plugins/aws/stt.py +++ b/plugins/aws/vision_agents/plugins/aws/stt.py @@ -233,6 +233,7 @@ def _build_transcription_input(self) -> StartStreamTranscriptionInput: async def _recv_loop(self): if self._output_stream is None: return + try: async for event in self._output_stream: if isinstance(event, TranscriptResultStreamTranscriptEvent): @@ -382,15 +383,23 @@ async def _supervisor_loop(self): logger.exception("AWS Transcribe reconnect failed") self._reconnect_event.set() - async def _close_streams(self): - if self._recv_task is not None: - await cancel_and_wait(self._recv_task) - self._recv_task = None + async def _close_streams(self, timeout=5.0): + # Close the input first so AWS sees END_STREAM and closes the stream. if self._input_stream is not None: try: await self._input_stream.close() except Exception: logger.warning("Error closing input stream", exc_info=True) + if self._recv_task is not None: + try: + # Here the _recv_task is expected to exit + # when the input stream is closed. + # The cancel below is only a fallback for the + # case where AWS doesn't drain quickly enough. + await asyncio.wait_for(asyncio.shield(self._recv_task), timeout=timeout) + except Exception: + await cancel_and_wait(self._recv_task) + self._recv_task = None if self._stream is not None: try: await self._stream.close() From 7c8a762dac7edcc3917c03b6f6c483355e3ed48c Mon Sep 17 00:00:00 2001 From: Daniil Gusev Date: Tue, 5 May 2026 16:54:50 +0200 Subject: [PATCH 05/14] Fix coderabbit --- plugins/aws/tests/test_aws_stt.py | 6 +++++ .../vision_agents/plugins/aws/_credentials.py | 11 ++-------- plugins/aws/vision_agents/plugins/aws/stt.py | 22 +++++++++++++++++-- 3 files changed, 28 insertions(+), 11 deletions(-) diff --git a/plugins/aws/tests/test_aws_stt.py b/plugins/aws/tests/test_aws_stt.py index 6283526bf..9e6b28d32 100644 --- a/plugins/aws/tests/test_aws_stt.py +++ b/plugins/aws/tests/test_aws_stt.py @@ -81,6 +81,12 @@ async def on_turn_ended(event: TurnEndedEvent): assert len(turn_ended) == 1 assert turn_ended[0].participant == participant + def test_partial_static_credentials_rejected(self): + with pytest.raises(ValueError, match="provided together"): + aws.STT(aws_access_key_id="AKIA...") + with pytest.raises(ValueError, match="provided together"): + aws.STT(aws_secret_access_key="secret") + async def test_clear_drops_results_before_watermark( self, participant, transcript_event_factory ): diff --git a/plugins/aws/vision_agents/plugins/aws/_credentials.py b/plugins/aws/vision_agents/plugins/aws/_credentials.py index 40e358307..64a69ed16 100644 --- a/plugins/aws/vision_agents/plugins/aws/_credentials.py +++ b/plugins/aws/vision_agents/plugins/aws/_credentials.py @@ -19,14 +19,10 @@ class Boto3CredentialsResolver( def __init__(self, profile_name: Optional[str] = None) -> None: self._session = boto3.Session(profile_name=profile_name) - self._cached: Optional[AWSCredentialsIdentity] = None async def get_identity( self, *, properties: AWSIdentityProperties, **kwargs: Any ) -> AWSCredentialsIdentity: - if self._cached is not None: - return self._cached - credentials = self._session.get_credentials() if not credentials: raise ValueError("Unable to load AWS credentials via boto3") @@ -35,12 +31,9 @@ async def get_identity( if not creds.access_key or not creds.secret_key: raise ValueError("AWS credentials are incomplete") - expiry = getattr(credentials, "_expiry_time", None) - - self._cached = AWSCredentialsIdentity( + return AWSCredentialsIdentity( access_key_id=creds.access_key, secret_access_key=creds.secret_key, session_token=creds.token or None, - expiration=expiry, + expiration=None, ) - return self._cached diff --git a/plugins/aws/vision_agents/plugins/aws/stt.py b/plugins/aws/vision_agents/plugins/aws/stt.py index cb7fa7aaa..8c29170a0 100644 --- a/plugins/aws/vision_agents/plugins/aws/stt.py +++ b/plugins/aws/vision_agents/plugins/aws/stt.py @@ -90,6 +90,11 @@ def __init__( stays there for subsequent attempts. Reconnects are unlimited. """ + if bool(aws_access_key_id) != bool(aws_secret_access_key): + raise ValueError( + "aws_access_key_id and aws_secret_access_key must be provided together" + ) + super().__init__(provider_name="aws") self.language_code = language_code @@ -131,8 +136,16 @@ def __init__( async def start(self): await super().start() - await self._open_stream() - self._supervisor_task = asyncio.create_task(self._supervisor_loop()) + try: + await self._open_stream() + self._supervisor_task = asyncio.create_task(self._supervisor_loop()) + except BaseException: + self.started = False + if self._supervisor_task is not None: + await cancel_and_wait(self._supervisor_task) + self._supervisor_task = None + await self._close_streams() + raise logger.info( "AWS Transcribe streaming connection established (region=%s, lang=%s)", self.region_name, @@ -244,6 +257,9 @@ async def _recv_loop(self): event, ) if not self.closed: + # Stop accepting audio immediately; the supervisor may + # not run for up to max_reconnect_backoff_seconds. + self._input_stream = None self._reconnect_event.set() return else: @@ -260,6 +276,7 @@ async def _recv_loop(self): # session limit; treat that as retriable. if not self.closed: logger.info("AWS Transcribe stream ended, will reconnect") + self._input_stream = None self._reconnect_event.set() except asyncio.CancelledError: raise @@ -267,6 +284,7 @@ async def _recv_loop(self): if self.closed: return logger.exception("AWS Transcribe receive loop failed, will reconnect") + self._input_stream = None self._reconnect_event.set() def _handle_transcript_event(self, event: TranscriptEvent): From 6967f1037b5b87fefa8e38dac13daae6a7a460a0 Mon Sep 17 00:00:00 2001 From: Daniil Gusev Date: Tue, 5 May 2026 17:05:49 +0200 Subject: [PATCH 06/14] Wrap boto3 calls with asyncio.to_thread() --- plugins/aws/vision_agents/plugins/aws/_credentials.py | 8 ++++++-- plugins/aws/vision_agents/plugins/aws/stt.py | 2 -- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/plugins/aws/vision_agents/plugins/aws/_credentials.py b/plugins/aws/vision_agents/plugins/aws/_credentials.py index 64a69ed16..9c4ad4f9f 100644 --- a/plugins/aws/vision_agents/plugins/aws/_credentials.py +++ b/plugins/aws/vision_agents/plugins/aws/_credentials.py @@ -1,3 +1,4 @@ +import asyncio from typing import Any, Optional import boto3 @@ -23,11 +24,14 @@ def __init__(self, profile_name: Optional[str] = None) -> None: async def get_identity( self, *, properties: AWSIdentityProperties, **kwargs: Any ) -> AWSCredentialsIdentity: - credentials = self._session.get_credentials() + # Both calls can block: get_credentials() walks the provider chain + # (file I/O, IMDS, SSO, STS) on first access, and get_frozen_credentials() + # triggers refresh on RefreshableCredentials. + credentials = await asyncio.to_thread(self._session.get_credentials) if not credentials: raise ValueError("Unable to load AWS credentials via boto3") - creds = credentials.get_frozen_credentials() + creds = await asyncio.to_thread(credentials.get_frozen_credentials) if not creds.access_key or not creds.secret_key: raise ValueError("AWS credentials are incomplete") diff --git a/plugins/aws/vision_agents/plugins/aws/stt.py b/plugins/aws/vision_agents/plugins/aws/stt.py index 8c29170a0..944b30d63 100644 --- a/plugins/aws/vision_agents/plugins/aws/stt.py +++ b/plugins/aws/vision_agents/plugins/aws/stt.py @@ -221,8 +221,6 @@ async def _build_config(self) -> Config: if self._aws_session_token: kwargs["aws_session_token"] = self._aws_session_token else: - # boto3.Session() reads ~/.aws/credentials synchronously, so the - # resolver has to be constructed off the event loop. kwargs["aws_credentials_identity_resolver"] = await asyncio.to_thread( Boto3CredentialsResolver, profile_name=self._aws_profile ) From 4ff775138ea37e62028166ffb74c71dd1f34d9ab Mon Sep 17 00:00:00 2001 From: Daniil Gusev Date: Tue, 5 May 2026 17:13:31 +0200 Subject: [PATCH 07/14] Update aws plugin README --- plugins/aws/README.md | 40 ++++++++++++++++++++++++++++++++++------ 1 file changed, 34 insertions(+), 6 deletions(-) diff --git a/plugins/aws/README.md b/plugins/aws/README.md index d24761d8f..f65e236ed 100644 --- a/plugins/aws/README.md +++ b/plugins/aws/README.md @@ -1,6 +1,6 @@ # AWS Plugin for Vision Agents -AWS (Bedrock) integration for Vision Agents framework with support for standard LLM, realtime with Nova Sonic, and text-to-speech with automatic session resumption. +AWS integration for Vision Agents framework with support for standard LLM (Bedrock), realtime with Nova Sonic, text-to-speech (Polly), and streaming speech-to-text (Transcribe). ## Installation @@ -80,17 +80,18 @@ See `example/aws_realtime_nova_example.py` for a complete example. ### Text-to-Speech (TTS) -AWS Polly TTS is available for converting text to speech: +AWS Polly synthesises speech from text and streams the resulting audio. Supports both standard and neural engines, plain-text or SSML input, and Polly lexicons for pronunciation overrides. ```python from vision_agents.plugins import aws tts = aws.TTS( region_name="us-east-1", - voice_id="Joanna", # AWS Polly voice ID - engine="neural", # 'standard' or 'neural' - text_type="text", # 'text' or 'ssml' - language_code="en-US" + voice_id="Joanna", # any Polly voice ID + engine="neural", # "standard" | "neural" + text_type="text", # "text" | "ssml" + language_code="en-US", + lexicon_names=None, # optional list of Polly lexicons ) # Use in agent @@ -101,6 +102,33 @@ agent = Agent( ) ``` +Credentials follow the standard boto3 chain (env vars, `~/.aws/credentials`, SSO, instance profile, etc.). `region_name` falls back to `AWS_REGION` / `AWS_DEFAULT_REGION` and finally `us-east-1`. + +### Speech-to-Text (STT) + +AWS Transcribe streaming STT converts audio to text in realtime. The connection auto-reconnects with exponential backoff after AWS's 15-minute session limit or transient errors. + +```python +from vision_agents.plugins import aws + +stt = aws.STT( + language_code="en-US", + region_name="us-east-1", + show_speaker_label=False, + enable_partial_results_stabilization=False, + partial_results_stability=None, # "high" | "medium" | "low" +) + +# Use in agent +agent = Agent( + llm=aws.LLM(model="qwen.qwen3-32b-v1:0"), + stt=stt, + # ... other components +) +``` + +Credentials follow the standard boto3 chain (env vars, `~/.aws/credentials`, SSO, instance profile, etc.). Pass `aws_access_key_id` + `aws_secret_access_key` (both required together) or `aws_profile` to override. + ## Function Calling ### Standard LLM (aws.LLM) From fb29cca6617cf558f480d182b3081afa8758aa80 Mon Sep 17 00:00:00 2001 From: Daniil Gusev Date: Tue, 5 May 2026 17:21:41 +0200 Subject: [PATCH 08/14] Add an example with AWS agent pipeline --- plugins/aws/README.md | 2 + plugins/aws/example/aws_pipeline_example.py | 48 +++++++++++++++++++++ 2 files changed, 50 insertions(+) create mode 100644 plugins/aws/example/aws_pipeline_example.py diff --git a/plugins/aws/README.md b/plugins/aws/README.md index f65e236ed..1d4fc0435 100644 --- a/plugins/aws/README.md +++ b/plugins/aws/README.md @@ -129,6 +129,8 @@ agent = Agent( Credentials follow the standard boto3 chain (env vars, `~/.aws/credentials`, SSO, instance profile, etc.). Pass `aws_access_key_id` + `aws_secret_access_key` (both required together) or `aws_profile` to override. +See `example/aws_pipeline_example.py` for a complete STT - LLM - TTS pipeline using only AWS components. + ## Function Calling ### Standard LLM (aws.LLM) diff --git a/plugins/aws/example/aws_pipeline_example.py b/plugins/aws/example/aws_pipeline_example.py new file mode 100644 index 000000000..1b38e5545 --- /dev/null +++ b/plugins/aws/example/aws_pipeline_example.py @@ -0,0 +1,48 @@ +""" +AWS STT - LLM - TTS Pipeline Example + +Voice agent built entirely from AWS components: +- STT: AWS Transcribe streaming +- LLM: AWS Bedrock (Qwen) +- TTS: AWS Polly +""" + +import asyncio +import logging + +from dotenv import load_dotenv +from vision_agents.core import Agent, Runner, User +from vision_agents.core.agents import AgentLauncher +from vision_agents.plugins import aws, getstream + +logger = logging.getLogger(__name__) + +load_dotenv() + + +async def create_agent(**kwargs) -> Agent: + agent = Agent( + edge=getstream.Edge(), + agent_user=User(name="AWS Voice Agent", id="agent"), + instructions="You are a voice agent. Keep replies short and " + "conversational. Do not use special characters or formatting.", + llm=aws.LLM(model="qwen.qwen3-32b-v1:0", region_name="us-east-1"), + stt=aws.STT(language_code="en-US", region_name="us-east-1"), + tts=aws.TTS(region_name="us-east-1", voice_id="Joanna", engine="neural"), + ) + + return agent + + +async def join_call(agent: Agent, call_type: str, call_id: str, **kwargs) -> None: + call = await agent.create_call(call_type, call_id) + + async with agent.join(call): + await asyncio.sleep(5) + await agent.simple_response(text="Ask the user about their day.") + + await agent.finish() + + +if __name__ == "__main__": + Runner(AgentLauncher(create_agent=create_agent, join_call=join_call)).cli() From 77c7f16960ee43d021cd79c8956634c8ff63221a Mon Sep 17 00:00:00 2001 From: Daniil Gusev Date: Tue, 5 May 2026 17:30:44 +0200 Subject: [PATCH 09/14] Coderabbit fixes --- plugins/aws/README.md | 4 ++-- plugins/aws/vision_agents/plugins/aws/stt.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/plugins/aws/README.md b/plugins/aws/README.md index 1d4fc0435..a09c371fc 100644 --- a/plugins/aws/README.md +++ b/plugins/aws/README.md @@ -106,7 +106,7 @@ Credentials follow the standard boto3 chain (env vars, `~/.aws/credentials`, SSO ### Speech-to-Text (STT) -AWS Transcribe streaming STT converts audio to text in realtime. The connection auto-reconnects with exponential backoff after AWS's 15-minute session limit or transient errors. +AWS Transcribe streaming STT converts audio to text in realtime. The connection auto-reconnects with exponential backoff on idle timeouts, audio-length limits, and transient errors. ```python from vision_agents.plugins import aws @@ -127,7 +127,7 @@ agent = Agent( ) ``` -Credentials follow the standard boto3 chain (env vars, `~/.aws/credentials`, SSO, instance profile, etc.). Pass `aws_access_key_id` + `aws_secret_access_key` (both required together) or `aws_profile` to override. +Credentials follow the standard boto3 chain (env vars, `~/.aws/credentials`, SSO, instance profile, etc.). Pass `aws_access_key_id` + `aws_secret_access_key` (both required together, plus `aws_session_token` for temporary credentials from STS / SSO / assumed roles) or `aws_profile` to override. See `example/aws_pipeline_example.py` for a complete STT - LLM - TTS pipeline using only AWS components. diff --git a/plugins/aws/vision_agents/plugins/aws/stt.py b/plugins/aws/vision_agents/plugins/aws/stt.py index 944b30d63..fcb665151 100644 --- a/plugins/aws/vision_agents/plugins/aws/stt.py +++ b/plugins/aws/vision_agents/plugins/aws/stt.py @@ -270,8 +270,8 @@ async def _recv_loop(self): # Stop accepting audio; supervisor stays idle (no reconnect). self._input_stream = None return - # Stream ended cleanly. AWS closes on idle and at the 15-minute - # session limit; treat that as retriable. + # Stream ended cleanly. AWS closes on idle and on audio-length + # limits; treat that as retriable. if not self.closed: logger.info("AWS Transcribe stream ended, will reconnect") self._input_stream = None From cef2d91a5e779910b529cb1406c208ab3a0f828a Mon Sep 17 00:00:00 2001 From: Daniil Gusev Date: Tue, 5 May 2026 17:35:58 +0200 Subject: [PATCH 10/14] Add timeout to TranscribeSTT._open_stream (10s by default) --- plugins/aws/vision_agents/plugins/aws/stt.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/plugins/aws/vision_agents/plugins/aws/stt.py b/plugins/aws/vision_agents/plugins/aws/stt.py index fcb665151..3f1a0b05f 100644 --- a/plugins/aws/vision_agents/plugins/aws/stt.py +++ b/plugins/aws/vision_agents/plugins/aws/stt.py @@ -193,18 +193,20 @@ async def process_audio( ) self._audio_sent_seconds += resampled.duration - async def _open_stream(self): + async def _open_stream(self, timeout: float = 10.0): self._client = TranscribeStreamingClient(config=await self._build_config()) - stream = await asyncio.wait_for( - self._client.start_stream_transcription( + async def _connect(): + _stream = await self._client.start_stream_transcription( input=self._build_transcription_input() - ), - timeout=10.0, - ) + ) + _, _output_stream = await _stream.await_output() + return _stream, _output_stream + + stream, output_stream = await asyncio.wait_for(_connect(), timeout=timeout) self._stream = stream self._input_stream = stream.input_stream - _, self._output_stream = await stream.await_output() + self._output_stream = output_stream self._recv_task = asyncio.create_task(self._recv_loop()) async def _build_config(self) -> Config: From effbfcc8aee49810f017788707173aff4efd6b10 Mon Sep 17 00:00:00 2001 From: Daniil Gusev Date: Tue, 5 May 2026 17:42:13 +0200 Subject: [PATCH 11/14] Fix mypy in TranscribeSTT --- plugins/aws/vision_agents/plugins/aws/stt.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/plugins/aws/vision_agents/plugins/aws/stt.py b/plugins/aws/vision_agents/plugins/aws/stt.py index 3f1a0b05f..2497e5c3d 100644 --- a/plugins/aws/vision_agents/plugins/aws/stt.py +++ b/plugins/aws/vision_agents/plugins/aws/stt.py @@ -194,10 +194,11 @@ async def process_audio( self._audio_sent_seconds += resampled.duration async def _open_stream(self, timeout: float = 10.0): - self._client = TranscribeStreamingClient(config=await self._build_config()) + client = TranscribeStreamingClient(config=await self._build_config()) + self._client = client async def _connect(): - _stream = await self._client.start_stream_transcription( + _stream = await client.start_stream_transcription( input=self._build_transcription_input() ) _, _output_stream = await _stream.await_output() From 9946dfeb5e67f7e84b859b3af8d2d4e8df5f22de Mon Sep 17 00:00:00 2001 From: Daniil Gusev Date: Tue, 5 May 2026 17:46:22 +0200 Subject: [PATCH 12/14] Update Polly TTS to accept the same creds as STT --- plugins/aws/README.md | 2 +- plugins/aws/tests/test_tts.py | 8 ++++++ plugins/aws/vision_agents/plugins/aws/tts.py | 27 ++++++++++++++++++-- 3 files changed, 34 insertions(+), 3 deletions(-) diff --git a/plugins/aws/README.md b/plugins/aws/README.md index a09c371fc..e2a11e5d9 100644 --- a/plugins/aws/README.md +++ b/plugins/aws/README.md @@ -102,7 +102,7 @@ agent = Agent( ) ``` -Credentials follow the standard boto3 chain (env vars, `~/.aws/credentials`, SSO, instance profile, etc.). `region_name` falls back to `AWS_REGION` / `AWS_DEFAULT_REGION` and finally `us-east-1`. +Credentials follow the standard boto3 chain (env vars, `~/.aws/credentials`, SSO, instance profile, etc.). Pass `aws_access_key_id` + `aws_secret_access_key` (both required together, plus `aws_session_token` for temporary credentials from STS / SSO / assumed roles) or `aws_profile` to override. You may also inject a pre-built boto3 Polly client via `client=...`. `region_name` falls back to `AWS_REGION` / `AWS_DEFAULT_REGION` and finally `us-east-1`. ### Speech-to-Text (STT) diff --git a/plugins/aws/tests/test_tts.py b/plugins/aws/tests/test_tts.py index eeb20d7c7..8901d1e69 100644 --- a/plugins/aws/tests/test_tts.py +++ b/plugins/aws/tests/test_tts.py @@ -30,6 +30,14 @@ async def tts(self) -> aws.TTS: # type: ignore[name-defined] return aws.TTS(voice_id=os.environ.get("AWS_POLLY_VOICE", "Joanna")) +class TestAWSPollyTTS: + def test_partial_static_credentials_rejected(self): + with pytest.raises(ValueError, match="provided together"): + aws.TTS(aws_access_key_id="AKIA...") + with pytest.raises(ValueError, match="provided together"): + aws.TTS(aws_secret_access_key="secret") + + @pytest.mark.skip() @pytest.mark.integration class TestAWSPollyTTSIntegration: diff --git a/plugins/aws/vision_agents/plugins/aws/tts.py b/plugins/aws/vision_agents/plugins/aws/tts.py index 4685b01fd..8b4fb0b8b 100644 --- a/plugins/aws/vision_agents/plugins/aws/tts.py +++ b/plugins/aws/vision_agents/plugins/aws/tts.py @@ -17,7 +17,10 @@ class TTS(BaseTTS): - TextType can be 'text' or 'ssml' (auto-detected unless overridden) - Optional Engine ('standard' or 'neural'), LanguageCode, LexiconNames - Credentials are resolved via standard AWS SDK chain (env vars, profiles, roles). + Credentials are resolved via the standard boto3 chain (env vars, profiles, + SSO, instance profile, etc.). Override with explicit access key + secret + (plus optional session token) or with a named profile, or inject a + pre-built boto3 Polly client via ``client=...``. """ def __init__( @@ -29,8 +32,17 @@ def __init__( engine: Optional[str] = None, # 'standard' | 'neural' language_code: Optional[str] = None, lexicon_names: Optional[List[str]] = None, + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + aws_session_token: Optional[str] = None, + aws_profile: Optional[str] = None, client: Optional[Any] = None, ) -> None: + if bool(aws_access_key_id) != bool(aws_secret_access_key): + raise ValueError( + "aws_access_key_id and aws_secret_access_key must be provided together" + ) + super().__init__(provider_name="aws_polly") self.region_name = ( region_name @@ -49,13 +61,24 @@ def __init__( self.engine = engine self.language_code = language_code self.lexicon_names = lexicon_names + + session_kwargs: dict[str, Any] = {"region_name": self.region_name} + if aws_profile: + session_kwargs["profile_name"] = aws_profile + if aws_access_key_id: + session_kwargs["aws_access_key_id"] = aws_access_key_id + if aws_secret_access_key: + session_kwargs["aws_secret_access_key"] = aws_secret_access_key + if aws_session_token: + session_kwargs["aws_session_token"] = aws_session_token + self._session_kwargs = session_kwargs self._client = client @property async def client(self): if self._client is None: self._client = await asyncio.to_thread( - lambda: boto3.client("polly", region_name=self.region_name) + lambda: boto3.Session(**self._session_kwargs).client("polly") ) return self._client From 8df60bd8f67063604fb22f9654811c0addd0ee21 Mon Sep 17 00:00:00 2001 From: Daniil Gusev Date: Tue, 5 May 2026 20:56:13 +0200 Subject: [PATCH 13/14] TranscribeSTT: close stream on connect failure --- plugins/aws/vision_agents/plugins/aws/stt.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/plugins/aws/vision_agents/plugins/aws/stt.py b/plugins/aws/vision_agents/plugins/aws/stt.py index 2497e5c3d..acbf099d9 100644 --- a/plugins/aws/vision_agents/plugins/aws/stt.py +++ b/plugins/aws/vision_agents/plugins/aws/stt.py @@ -201,8 +201,15 @@ async def _connect(): _stream = await client.start_stream_transcription( input=self._build_transcription_input() ) - _, _output_stream = await _stream.await_output() - return _stream, _output_stream + try: + _, _output_stream = await _stream.await_output() + return _stream, _output_stream + except asyncio.CancelledError: + await stream.close() + raise + except Exception: + await stream.close() + raise stream, output_stream = await asyncio.wait_for(_connect(), timeout=timeout) self._stream = stream From da8d175bdc6f7f514052b75fc2dca549d5f200ca Mon Sep 17 00:00:00 2001 From: Daniil Gusev Date: Tue, 5 May 2026 22:55:05 +0200 Subject: [PATCH 14/14] TranscribeSTT: do not hold _watermark_lock across reconnect I/O. --- plugins/aws/vision_agents/plugins/aws/stt.py | 36 ++++++++++++-------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/plugins/aws/vision_agents/plugins/aws/stt.py b/plugins/aws/vision_agents/plugins/aws/stt.py index acbf099d9..2b526b718 100644 --- a/plugins/aws/vision_agents/plugins/aws/stt.py +++ b/plugins/aws/vision_agents/plugins/aws/stt.py @@ -195,7 +195,6 @@ async def process_audio( async def _open_stream(self, timeout: float = 10.0): client = TranscribeStreamingClient(config=await self._build_config()) - self._client = client async def _connect(): _stream = await client.start_stream_transcription( @@ -205,16 +204,25 @@ async def _connect(): _, _output_stream = await _stream.await_output() return _stream, _output_stream except asyncio.CancelledError: - await stream.close() + await _stream.close() raise except Exception: - await stream.close() + await _stream.close() raise stream, output_stream = await asyncio.wait_for(_connect(), timeout=timeout) - self._stream = stream - self._input_stream = stream.input_stream - self._output_stream = output_stream + + # New stream restarts AWS's media-time clock at 0. Reset counters + # and publish the new input stream atomically so process_audio() + # never observes a half-built stream and chunks mid-flight cannot + # escape the cutoff. + async with self._watermark_lock: + self._client = client + self._stream = stream + self._output_stream = output_stream + self._audio_sent_seconds = 0.0 + self._start_time_watermark = 0.0 + self._input_stream = stream.input_stream self._recv_task = asyncio.create_task(self._recv_loop()) async def _build_config(self) -> Config: @@ -393,14 +401,14 @@ async def _supervisor_loop(self): self._emit_turn_ended_event(self._current_participant) self._turn_in_progress = False self._audio_start_time = None - # New stream restarts AWS's media-time clock at 0. Hold the - # lock across close+open so process_audio cannot send into a - # half-torn-down or half-built stream. - async with self._watermark_lock: - self._audio_sent_seconds = 0.0 - self._start_time_watermark = 0.0 - await self._close_streams() - await self._open_stream() + # Slow close+open runs outside the lock so concurrent + # process_audio calls observe _input_stream is None and + # drop audio immediately instead of stalling on the lock + # for the whole reconnect. _open_stream takes the lock + # briefly to swap in the new stream and reset watermarks + # atomically. + await self._close_streams() + await self._open_stream() attempt = 0 logger.info("AWS Transcribe reconnected") except asyncio.CancelledError: