From 51572c089239013dd63986fe9cd7ab209a23caf4 Mon Sep 17 00:00:00 2001 From: Bas Nijholt Date: Tue, 8 Jul 2025 20:12:47 -0700 Subject: [PATCH 1/8] Add REFACTORING_PLAN.md --- REFACTORING_PLAN.md | 90 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 REFACTORING_PLAN.md diff --git a/REFACTORING_PLAN.md b/REFACTORING_PLAN.md new file mode 100644 index 000000000..32867abbb --- /dev/null +++ b/REFACTORING_PLAN.md @@ -0,0 +1,90 @@ +# Agent-CLI Refactoring Plan + +## 1. Goal + +The primary goal of this refactoring is to improve the overall organization of the `agent-cli` package. This involves restructuring the project to better separate concerns, reduce cross-dependencies between modules, and make the codebase more intuitive, maintainable, and extensible. + +## 2. Proposed File Structure + +The new architecture will introduce `core` and `services` packages to logically group related functionality. + +``` +agent_cli/ +├── __init__.py +├── cli.py +├── constants.py +├── py.typed +├── agents/ +│ ├── __init__.py +│ ├── _cli_options.py +│ ├── _tts_common.py +│ ├── _voice_agent_common.py +│ ├── assistant.py +│ ├── autocorrect.py +│ ├── chat.py +│ ├── speak.py +│ ├── transcribe.py +│ └── voice_edit.py +├── config.py # New unified config module +├── core/ # New package for core logic +│ ├── __init__.py +│ ├── audio.py # For audio device I/O +│ ├── process.py # For process management +│ └── utils.py # For generic utilities +└── services/ # New package for external service integrations + ├── __init__.py + ├── base.py # Abstract base classes for services + ├── factory.py # Factory to get the correct service + ├── local.py # Implementations for local services (Wyoming/Ollama) + └── openai.py # Implementations for OpenAI services +``` + +## 3. Detailed Migration Plan + +### Step 1: Consolidate Configuration + +- **Action:** Create a new `agent_cli/config.py` file. +- **Source Logic:** Merge the contents of `agent_cli/config_loader.py` and `agent_cli/agents/config.py`. +- **Content:** + - **Loading Logic:** `load_config()`, `_replace_dashed_keys()` from `config_loader.py`. + - **Pydantic Models:** All configuration models (`ProviderSelection`, `Ollama`, `OpenAILLM`, `AudioInput`, `WyomingASR`, `OpenAIASR`, `AudioOutput`, `WyomingTTS`, `OpenAITTS`, `WakeWord`, `General`, `History`). +- **Cleanup:** Delete `agent_cli/config_loader.py` and `agent_cli/agents/config.py`. + +### Step 2: Create `core` Package + +- **Action:** Create a new directory `agent_cli/core/`. +- **`agent_cli/core/audio.py`**: + - **Action:** Move `agent_cli/audio.py` to `agent_cli/core/audio.py`. + - **Content:** All PyAudio device management and streaming logic. +- **`agent_cli/core/process.py`**: + - **Action:** Move `agent_cli/process_manager.py` to `agent_cli/core/process.py`. + - **Content:** All PID file and process management functions. +- **`agent_cli/core/utils.py`**: + - **Action:** Create `agent_cli/core/utils.py` and move generic helpers from `agent_cli/utils.py`. + - **Content:** `console`, `InteractiveStopEvent`, `signal_handling_context`, `live_timer`, `print_*_panel`, `get_clipboard_text`. + +### Step 3: Create `services` Package + +- **Action:** Create a new directory `agent_cli/services/`. +- **`agent_cli/services/base.py`** (New File): + - **Content:** Define Abstract Base Classes (ABCs) for `ASRService`, `LLMService`, and `TTSService`. +- **`agent_cli/services/local.py`** (New File): + - **Content:** Implementations for all local services. + - **Wyoming ASR:** Logic from `asr.py`. + - **Wyoming TTS:** Logic from `tts.py`. + - **Wyoming Wake Word:** Logic from `wake_word.py`. + - **Ollama LLM:** Logic from `llm.py`. + - **Wyoming Utils:** `wyoming_client_context` from `wyoming_utils.py`. +- **`agent_cli/services/openai.py`** (New File): + - **Content:** Implementations for all OpenAI services. + - **OpenAI ASR:** Logic from `services.py` and `asr.py`. + - **OpenAI TTS:** Logic from `services.py` and `tts.py`. + - **OpenAI LLM:** Logic from `llm.py`. +- **`agent_cli/services/factory.py`** (New File): + - **Content:** Factory functions (`get_asr_service`, `get_llm_service`, `get_tts_service`) that return the correct service implementation based on the user's configuration. + +### Step 4: Refactor and Cleanup + +- **Action:** Update all imports across the project to reflect the new structure. +- **Action:** Delete the old, now-empty files: `asr.py`, `llm.py`, `tts.py`, `wake_word.py`, `services.py`, `process_manager.py`, `config_loader.py`, `wyoming_utils.py`, and `agents/config.py`. +- **Action:** Refactor `agent_cli/utils.py` to remove the functions that were moved to `core/utils.py`. From bf8762c5a707a914788407ac248134438388b943 Mon Sep 17 00:00:00 2001 From: Bas Nijholt Date: Tue, 8 Jul 2025 20:17:29 -0700 Subject: [PATCH 2/8] refactor: Move core modules to agent_cli/core --- agent_cli/agents/assistant.py | 4 +++- agent_cli/agents/chat.py | 3 ++- agent_cli/agents/speak.py | 2 +- agent_cli/agents/transcribe.py | 3 ++- agent_cli/agents/voice_edit.py | 3 ++- agent_cli/asr.py | 1 + agent_cli/core/__init__.py | 1 + agent_cli/{ => core}/audio.py | 0 agent_cli/{process_manager.py => core/process.py} | 0 agent_cli/{ => core}/utils.py | 0 agent_cli/services/__init__.py | 1 + agent_cli/tts.py | 1 + 12 files changed, 14 insertions(+), 5 deletions(-) create mode 100644 agent_cli/core/__init__.py rename agent_cli/{ => core}/audio.py (100%) rename agent_cli/{process_manager.py => core/process.py} (100%) rename agent_cli/{ => core}/utils.py (100%) create mode 100644 agent_cli/services/__init__.py diff --git a/agent_cli/agents/assistant.py b/agent_cli/agents/assistant.py index b0d4be36d..c4ff5aa56 100644 --- a/agent_cli/agents/assistant.py +++ b/agent_cli/agents/assistant.py @@ -34,7 +34,9 @@ from typing import TYPE_CHECKING import agent_cli.agents._cli_options as opts -from agent_cli import asr, audio, process_manager, wake_word +from agent_cli import asr, wake_word +from agent_cli.core import audio +from agent_cli.core import process as process_manager from agent_cli.agents import config from agent_cli.agents._voice_agent_common import ( get_instruction_from_audio, diff --git a/agent_cli/agents/chat.py b/agent_cli/agents/chat.py index 741355cd1..b7491631a 100644 --- a/agent_cli/agents/chat.py +++ b/agent_cli/agents/chat.py @@ -25,7 +25,8 @@ import typer import agent_cli.agents._cli_options as opts -from agent_cli import asr, process_manager +from agent_cli import asr +from agent_cli.core import process as process_manager from agent_cli.agents import config from agent_cli.agents._tts_common import handle_tts_playback from agent_cli.audio import pyaudio_context, setup_devices diff --git a/agent_cli/agents/speak.py b/agent_cli/agents/speak.py index 6dfc7ed15..f3ee35293 100644 --- a/agent_cli/agents/speak.py +++ b/agent_cli/agents/speak.py @@ -10,7 +10,7 @@ import typer import agent_cli.agents._cli_options as opts -from agent_cli import process_manager +from agent_cli.core import process as process_manager from agent_cli.agents import config from agent_cli.agents._tts_common import handle_tts_playback from agent_cli.audio import pyaudio_context, setup_devices diff --git a/agent_cli/agents/transcribe.py b/agent_cli/agents/transcribe.py index bca448c7e..35f414466 100644 --- a/agent_cli/agents/transcribe.py +++ b/agent_cli/agents/transcribe.py @@ -11,7 +11,8 @@ import pyperclip import agent_cli.agents._cli_options as opts -from agent_cli import asr, process_manager +from agent_cli import asr +from agent_cli.core import process as process_manager from agent_cli.agents import config from agent_cli.audio import pyaudio_context, setup_devices from agent_cli.cli import app, setup_logging diff --git a/agent_cli/agents/voice_edit.py b/agent_cli/agents/voice_edit.py index 283dade5a..5868c199b 100644 --- a/agent_cli/agents/voice_edit.py +++ b/agent_cli/agents/voice_edit.py @@ -39,7 +39,8 @@ from pathlib import Path # noqa: TC003 import agent_cli.agents._cli_options as opts -from agent_cli import asr, process_manager +from agent_cli import asr +from agent_cli.core import process as process_manager from agent_cli.agents import config from agent_cli.agents._voice_agent_common import ( get_instruction_from_audio, diff --git a/agent_cli/asr.py b/agent_cli/asr.py index 78974b6e1..c28d45eae 100644 --- a/agent_cli/asr.py +++ b/agent_cli/asr.py @@ -17,6 +17,7 @@ read_from_queue, setup_input_stream, ) + from agent_cli.services import transcribe_audio_openai from agent_cli.wyoming_utils import manage_send_receive_tasks, wyoming_client_context diff --git a/agent_cli/core/__init__.py b/agent_cli/core/__init__.py new file mode 100644 index 000000000..c5895c9f5 --- /dev/null +++ b/agent_cli/core/__init__.py @@ -0,0 +1 @@ +"""Core package for agent-cli.""" diff --git a/agent_cli/audio.py b/agent_cli/core/audio.py similarity index 100% rename from agent_cli/audio.py rename to agent_cli/core/audio.py diff --git a/agent_cli/process_manager.py b/agent_cli/core/process.py similarity index 100% rename from agent_cli/process_manager.py rename to agent_cli/core/process.py diff --git a/agent_cli/utils.py b/agent_cli/core/utils.py similarity index 100% rename from agent_cli/utils.py rename to agent_cli/core/utils.py diff --git a/agent_cli/services/__init__.py b/agent_cli/services/__init__.py new file mode 100644 index 000000000..67ed73c1c --- /dev/null +++ b/agent_cli/services/__init__.py @@ -0,0 +1 @@ +"""Services package for agent-cli.""" diff --git a/agent_cli/tts.py b/agent_cli/tts.py index cded7a074..1fb42e182 100644 --- a/agent_cli/tts.py +++ b/agent_cli/tts.py @@ -14,6 +14,7 @@ from agent_cli import constants from agent_cli.audio import open_pyaudio_stream, pyaudio_context, setup_output_stream + from agent_cli.services import synthesize_speech_openai from agent_cli.utils import InteractiveStopEvent, live_timer, print_error_message, print_with_style from agent_cli.wyoming_utils import manage_send_receive_tasks, wyoming_client_context From dc0013921b78d9aa2c09b1bf8b0d424e56cd8b5f Mon Sep 17 00:00:00 2001 From: Bas Nijholt Date: Tue, 8 Jul 2025 20:19:36 -0700 Subject: [PATCH 3/8] refactor: Consolidate configuration into agent_cli/config.py --- agent_cli/agents/_tts_common.py | 2 +- agent_cli/agents/_voice_agent_common.py | 2 +- agent_cli/agents/assistant.py | 2 +- agent_cli/agents/autocorrect.py | 2 +- agent_cli/agents/chat.py | 2 +- agent_cli/agents/speak.py | 2 +- agent_cli/agents/transcribe.py | 2 +- agent_cli/agents/voice_edit.py | 2 +- agent_cli/{agents => }/config.py | 47 +++++++++++++++++++++++-- agent_cli/config_loader.py | 43 ---------------------- 10 files changed, 53 insertions(+), 53 deletions(-) rename agent_cli/{agents => }/config.py (67%) delete mode 100644 agent_cli/config_loader.py diff --git a/agent_cli/agents/_tts_common.py b/agent_cli/agents/_tts_common.py index 0f8cea732..d8d01b937 100644 --- a/agent_cli/agents/_tts_common.py +++ b/agent_cli/agents/_tts_common.py @@ -14,7 +14,7 @@ from rich.live import Live - from agent_cli.agents import config + from agent_cli import config async def _save_audio_file( diff --git a/agent_cli/agents/_voice_agent_common.py b/agent_cli/agents/_voice_agent_common.py index 01d293a3a..58e2a6053 100644 --- a/agent_cli/agents/_voice_agent_common.py +++ b/agent_cli/agents/_voice_agent_common.py @@ -16,7 +16,7 @@ if TYPE_CHECKING: from rich.live import Live - from agent_cli.agents import config + from agent_cli import config LOGGER = logging.getLogger() diff --git a/agent_cli/agents/assistant.py b/agent_cli/agents/assistant.py index c4ff5aa56..2eba108bf 100644 --- a/agent_cli/agents/assistant.py +++ b/agent_cli/agents/assistant.py @@ -37,7 +37,7 @@ from agent_cli import asr, wake_word from agent_cli.core import audio from agent_cli.core import process as process_manager -from agent_cli.agents import config +from agent_cli import config from agent_cli.agents._voice_agent_common import ( get_instruction_from_audio, process_instruction_and_respond, diff --git a/agent_cli/agents/autocorrect.py b/agent_cli/agents/autocorrect.py index a740001a6..8d51b36e1 100644 --- a/agent_cli/agents/autocorrect.py +++ b/agent_cli/agents/autocorrect.py @@ -12,7 +12,7 @@ import typer import agent_cli.agents._cli_options as opts -from agent_cli.agents import config +from agent_cli import config from agent_cli.cli import app, setup_logging from agent_cli.llm import build_agent from agent_cli.utils import ( diff --git a/agent_cli/agents/chat.py b/agent_cli/agents/chat.py index b7491631a..7cd8940fc 100644 --- a/agent_cli/agents/chat.py +++ b/agent_cli/agents/chat.py @@ -27,7 +27,7 @@ import agent_cli.agents._cli_options as opts from agent_cli import asr from agent_cli.core import process as process_manager -from agent_cli.agents import config +from agent_cli import config from agent_cli.agents._tts_common import handle_tts_playback from agent_cli.audio import pyaudio_context, setup_devices from agent_cli.cli import app, setup_logging diff --git a/agent_cli/agents/speak.py b/agent_cli/agents/speak.py index f3ee35293..3e601c209 100644 --- a/agent_cli/agents/speak.py +++ b/agent_cli/agents/speak.py @@ -11,7 +11,7 @@ import agent_cli.agents._cli_options as opts from agent_cli.core import process as process_manager -from agent_cli.agents import config +from agent_cli import config from agent_cli.agents._tts_common import handle_tts_playback from agent_cli.audio import pyaudio_context, setup_devices from agent_cli.cli import app, setup_logging diff --git a/agent_cli/agents/transcribe.py b/agent_cli/agents/transcribe.py index 35f414466..a94e20509 100644 --- a/agent_cli/agents/transcribe.py +++ b/agent_cli/agents/transcribe.py @@ -13,7 +13,7 @@ import agent_cli.agents._cli_options as opts from agent_cli import asr from agent_cli.core import process as process_manager -from agent_cli.agents import config +from agent_cli import config from agent_cli.audio import pyaudio_context, setup_devices from agent_cli.cli import app, setup_logging from agent_cli.llm import process_and_update_clipboard diff --git a/agent_cli/agents/voice_edit.py b/agent_cli/agents/voice_edit.py index 5868c199b..eff459b8e 100644 --- a/agent_cli/agents/voice_edit.py +++ b/agent_cli/agents/voice_edit.py @@ -41,7 +41,7 @@ import agent_cli.agents._cli_options as opts from agent_cli import asr from agent_cli.core import process as process_manager -from agent_cli.agents import config +from agent_cli import config from agent_cli.agents._voice_agent_common import ( get_instruction_from_audio, process_instruction_and_respond, diff --git a/agent_cli/agents/config.py b/agent_cli/config.py similarity index 67% rename from agent_cli/agents/config.py rename to agent_cli/config.py index 452727ad8..37ee761fd 100644 --- a/agent_cli/agents/config.py +++ b/agent_cli/config.py @@ -1,11 +1,54 @@ -"""Pydantic models for agent configurations, aligned with CLI option groups.""" +"""Pydantic models for agent configurations and config file loading.""" from __future__ import annotations +import tomllib from pathlib import Path -from typing import Literal +from typing import Any, Literal from pydantic import BaseModel, field_validator +from rich.console import Console + +console = Console() + +# --- Config File Loading --- + +CONFIG_PATH = Path.home() / ".config" / "agent-cli" / "config.toml" +CONFIG_PATH_2 = Path("agent-cli-config.toml") + + +def _replace_dashed_keys(cfg: dict[str, Any]) -> dict[str, Any]: + """Replace dashed keys with underscores in the config options.""" + return {k.replace("-", "_"): v for k, v in cfg.items()} + + +def load_config(config_path_str: str | None = None) -> dict[str, Any]: + """Load the TOML configuration file and process it for nested structures.""" + # Determine which config path to use + if config_path_str: + config_path = Path(config_path_str) + elif CONFIG_PATH.exists(): + config_path = CONFIG_PATH + elif CONFIG_PATH_2.exists(): + config_path = CONFIG_PATH_2 + else: + return {} + + # Try to load and process the config + if config_path.exists(): + with config_path.open("rb") as f: + cfg = tomllib.load(f) + return {k: _replace_dashed_keys(v) for k, v in cfg.items()} + + # Report error only if an explicit path was given + if config_path_str: + console.print( + f"[bold red]Config file not found at {config_path_str}[/bold red]", + ) + return {} + + +# --- Pydantic Models for Configuration --- # --- Panel: Provider Selection --- diff --git a/agent_cli/config_loader.py b/agent_cli/config_loader.py deleted file mode 100644 index 826271f50..000000000 --- a/agent_cli/config_loader.py +++ /dev/null @@ -1,43 +0,0 @@ -"""Handles loading and parsing of the agent-cli configuration file.""" - -from __future__ import annotations - -import tomllib -from pathlib import Path -from typing import Any - -from .utils import console - -CONFIG_PATH = Path.home() / ".config" / "agent-cli" / "config.toml" -CONFIG_PATH_2 = Path("agent-cli-config.toml") - - -def _replace_dashed_keys(cfg: dict[str, Any]) -> dict[str, Any]: - """Replace dashed keys with underscores in the config options.""" - return {k.replace("-", "_"): v for k, v in cfg.items()} - - -def load_config(config_path_str: str | None = None) -> dict[str, Any]: - """Load the TOML configuration file and process it for nested structures.""" - # Determine which config path to use - if config_path_str: - config_path = Path(config_path_str) - elif CONFIG_PATH.exists(): - config_path = CONFIG_PATH - elif CONFIG_PATH_2.exists(): - config_path = CONFIG_PATH_2 - else: - return {} - - # Try to load and process the config - if config_path.exists(): - with config_path.open("rb") as f: - cfg = tomllib.load(f) - return {k: _replace_dashed_keys(v) for k, v in cfg.items()} - - # Report error only if an explicit path was given - if config_path_str: - console.print( - f"[bold red]Config file not found at {config_path_str}[/bold red]", - ) - return {} From beaf50623a65b3ce98e5d86a7a8675e0339382b9 Mon Sep 17 00:00:00 2001 From: Bas Nijholt Date: Tue, 8 Jul 2025 20:22:18 -0700 Subject: [PATCH 4/8] refactor: Create services package --- agent_cli/services/base.py | 56 +++++++ agent_cli/services/factory.py | 76 +++++++++ agent_cli/services/local.py | 301 ++++++++++++++++++++++++++++++++++ agent_cli/services/openai.py | 125 ++++++++++++++ 4 files changed, 558 insertions(+) create mode 100644 agent_cli/services/base.py create mode 100644 agent_cli/services/factory.py create mode 100644 agent_cli/services/local.py create mode 100644 agent_cli/services/openai.py diff --git a/agent_cli/services/base.py b/agent_cli/services/base.py new file mode 100644 index 000000000..d9890a842 --- /dev/null +++ b/agent_cli/services/base.py @@ -0,0 +1,56 @@ +"""Base classes for external services.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import logging + + from rich.live import Live + + from agent_cli import config + + +class ASRService(ABC): + """Abstract base class for Automatic Speech Recognition services.""" + + @abstractmethod + async def transcribe(self, audio_data: bytes) -> str: + """Transcribe audio data to text.""" + pass + + +class LLMService(ABC): + """Abstract base class for Language Model services.""" + + @abstractmethod + async def get_response( + self, + *, + system_prompt: str, + agent_instructions: str, + user_input: str, + tools: list | None = None, + ) -> str | None: + """Get a response from the language model.""" + pass + + +class TTSService(ABC): + """Abstract base class for Text-to-Speech services.""" + + @abstractmethod + async def synthesize(self, text: str) -> bytes | None: + """Synthesize text to speech audio data.""" + pass + + +class WakeWordService(ABC): + """Abstract base class for Wake Word detection services.""" + + @abstractmethod + async def detect(self) -> str | None: + """Detect the wake word.""" + pass diff --git a/agent_cli/services/factory.py b/agent_cli/services/factory.py new file mode 100644 index 000000000..264097f24 --- /dev/null +++ b/agent_cli/services/factory.py @@ -0,0 +1,76 @@ +"""Factory functions for creating service instances.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from agent_cli.services.base import ASRService, LLMService, TTSService +from agent_cli.services.local import ( + OllamaLLMService, + WyomingTranscriptionService, + WyomingTTSService, +) +from agent_cli.services.openai import ( + OpenAILLMService, + OpenAITranscriptionService, + OpenAITTSService, +) + +if TYPE_CHECKING: + import logging + + from agent_cli import config + + +def get_asr_service( + provider_config: config.ProviderSelection, + wyoming_asr_config: config.WyomingASR, + openai_asr_config: config.OpenAIASR, + openai_llm_config: config.OpenAILLM, + logger: logging.Logger, + *, + quiet: bool = False, +) -> ASRService: + """Get the appropriate ASR service based on the provider.""" + if provider_config.asr_provider == "openai": + return OpenAITranscriptionService( + openai_asr_config=openai_asr_config, + openai_llm_config=openai_llm_config, + logger=logger, + ) + return WyomingTranscriptionService( + wyoming_asr_config=wyoming_asr_config, logger=logger, quiet=quiet + ) + + +def get_llm_service( + provider_config: config.ProviderSelection, + ollama_config: config.Ollama, + openai_llm_config: config.OpenAILLM, + logger: logging.Logger, +) -> LLMService: + """Get the appropriate LLM service based on the provider.""" + if provider_config.llm_provider == "openai": + return OpenAILLMService(openai_llm_config=openai_llm_config, logger=logger) + return OllamaLLMService(ollama_config=ollama_config, logger=logger) + + +def get_tts_service( + provider_config: config.ProviderSelection, + wyoming_tts_config: config.WyomingTTS, + openai_tts_config: config.OpenAITTS, + openai_llm_config: config.OpenAILLM, + logger: logging.Logger, + *, + quiet: bool = False, +) -> TTSService: + """Get the appropriate TTS service based on the provider.""" + if provider_config.tts_provider == "openai": + return OpenAITTSService( + openai_tts_config=openai_tts_config, + openai_llm_config=openai_llm_config, + logger=logger, + ) + return WyomingTTSService( + wyoming_tts_config=wyoming_tts_config, logger=logger, quiet=quiet + ) diff --git a/agent_cli/services/local.py b/agent_cli/services/local.py new file mode 100644 index 000000000..dfcbd963a --- /dev/null +++ b/agent_cli/services/local.py @@ -0,0 +1,301 @@ +"""Module for interacting with local services like Wyoming and Ollama.""" + +from __future__ import annotations + +import asyncio +import io +from contextlib import asynccontextmanager +from typing import TYPE_CHECKING + +from wyoming.asr import Transcribe, Transcript, TranscriptChunk, TranscriptStart, TranscriptStop +from wyoming.audio import AudioChunk, AudioStart, AudioStop +from wyoming.client import AsyncClient +from wyoming.tts import Synthesize, SynthesizeVoice +from wyoming.wake import Detect, Detection, NotDetected + +from agent_cli import constants +from agent_cli.services.base import ASRService, LLMService, TTSService, WakeWordService + +if TYPE_CHECKING: + import logging + from collections.abc import AsyncGenerator, Awaitable, Callable + + from rich.live import Live + + from agent_cli import config + + +@asynccontextmanager +async def wyoming_client_context( + host: str, + port: int, + service_name: str, + logger: logging.Logger, + *, + quiet: bool = False, +) -> AsyncGenerator[AsyncClient, None]: + """Connect to a Wyoming server.""" + if not quiet: + logger.info("Connecting to %s server at %s:%s", service_name, host, port) + try: + async with AsyncClient(host, port) as client: + yield client + except ConnectionRefusedError: + logger.error("Connection refused to %s server at %s:%s", service_name, host, port) + raise + + +async def manage_send_receive_tasks( + send_task: Awaitable, + recv_task: Awaitable, + return_when: str = asyncio.ALL_COMPLETED, +) -> tuple[asyncio.Task, asyncio.Task]: + """Manage send and receive tasks for a Wyoming client.""" + send = asyncio.create_task(send_task) + recv = asyncio.create_task(recv_task) + done, pending = await asyncio.wait({send, recv}, return_when=return_when) + for task in pending: + task.cancel() + return send, recv + + +class WyomingTranscriptionService(ASRService): + """Transcription service using a Wyoming ASR server.""" + + def __init__( + self, + wyoming_asr_config: config.WyomingASR, + logger: logging.Logger, + *, + quiet: bool = False, + ): + self.wyoming_asr_config = wyoming_asr_config + self.logger = logger + self.quiet = quiet + + async def transcribe(self, audio_data: bytes) -> str: + """Transcribe audio using a Wyoming ASR server.""" + async with wyoming_client_context( + self.wyoming_asr_config.wyoming_asr_ip, + self.wyoming_asr_config.wyoming_asr_port, + "ASR", + self.logger, + quiet=self.quiet, + ) as client: + await client.write_event(Transcribe().event()) + await client.write_event(AudioStart(**constants.WYOMING_AUDIO_CONFIG).event()) + + chunk_size = constants.PYAUDIO_CHUNK_SIZE * 2 + for i in range(0, len(audio_data), chunk_size): + chunk = audio_data[i : i + chunk_size] + await client.write_event( + AudioChunk(audio=chunk, **constants.WYOMING_AUDIO_CONFIG).event() + ) + await client.write_event(AudioStop().event()) + return await self._receive_transcript(client) + + async def _receive_transcript(self, client: AsyncClient) -> str: + """Receive transcription events and return the final transcript.""" + transcript_text = "" + while True: + event = await client.read_event() + if event is None: + self.logger.warning("Connection to ASR server lost.") + break + if Transcript.is_type(event.type): + transcript = Transcript.from_event(event) + transcript_text = transcript.text + break + return transcript_text + + +class OllamaLLMService(LLMService): + """LLM service using an Ollama server.""" + + def __init__( + self, + ollama_config: config.Ollama, + logger: logging.Logger, + ): + from pydantic_ai.models.openai import OpenAIModel # noqa: PLC0415 + from pydantic_ai.providers.openai import OpenAIProvider # noqa: PLC0415 + + self.ollama_config = ollama_config + self.logger = logger + provider = OpenAIProvider(base_url=f"{self.ollama_config.ollama_host}/v1") + self.model = OpenAIModel( + model_name=self.ollama_config.ollama_model, provider=provider + ) + + async def get_response( + self, + *, + system_prompt: str, + agent_instructions: str, + user_input: str, + tools: list | None = None, + ) -> str | None: + """Get a response from the language model.""" + from pydantic_ai import Agent # noqa: PLC0415 + + agent = Agent( + model=self.model, + system_prompt=system_prompt or (), + instructions=agent_instructions, + tools=tools or [], + ) + result = await agent.run(user_input) + return result.output + + +class WyomingTTSService(TTSService): + """TTS service using a Wyoming TTS server.""" + + def __init__( + self, + wyoming_tts_config: config.WyomingTTS, + logger: logging.Logger, + *, + quiet: bool = False, + ): + self.wyoming_tts_config = wyoming_tts_config + self.logger = logger + self.quiet = quiet + + async def synthesize(self, text: str) -> bytes | None: + """Synthesize speech using a Wyoming TTS server.""" + async with wyoming_client_context( + self.wyoming_tts_config.wyoming_tts_ip, + self.wyoming_tts_config.wyoming_tts_port, + "TTS", + self.logger, + quiet=self.quiet, + ) as client: + synthesize_event = self._create_synthesis_request(text) + _send_task, recv_task = await manage_send_receive_tasks( + client.write_event(synthesize_event.event()), + self._process_audio_events(client), + ) + audio_data, sample_rate, sample_width, channels = recv_task.result() + if sample_rate and sample_width and channels and audio_data: + return self._create_wav_data(audio_data, sample_rate, sample_width, channels) + return None + + def _create_synthesis_request(self, text: str) -> Synthesize: + """Create a synthesis request with optional voice parameters.""" + synthesize_event = Synthesize(text=text) + if ( + self.wyoming_tts_config.wyoming_voice + or self.wyoming_tts_config.wyoming_tts_language + or self.wyoming_tts_config.wyoming_speaker + ): + synthesize_event.voice = SynthesizeVoice( + name=self.wyoming_tts_config.wyoming_voice, + language=self.wyoming_tts_config.wyoming_tts_language, + speaker=self.wyoming_tts_config.wyoming_speaker, + ) + return synthesize_event + + async def _process_audio_events( + self, client: AsyncClient + ) -> tuple[bytes, int | None, int | None, int | None]: + """Process audio events from TTS server and return audio data with metadata.""" + audio_data = io.BytesIO() + sample_rate, sample_width, channels = None, None, None + while True: + event = await client.read_event() + if event is None: + break + if AudioStart.is_type(event.type): + audio_start = AudioStart.from_event(event) + sample_rate, sample_width, channels = ( + audio_start.rate, + audio_start.width, + audio_start.channels, + ) + elif AudioChunk.is_type(event.type): + audio_data.write(AudioChunk.from_event(event).audio) + elif AudioStop.is_type(event.type): + break + return audio_data.getvalue(), sample_rate, sample_width, channels + + def _create_wav_data( + self, audio_data: bytes, sample_rate: int, sample_width: int, channels: int + ) -> bytes: + """Convert raw audio data to WAV format.""" + import wave + + wav_data = io.BytesIO() + with wave.open(wav_data, "wb") as wav_file: + wav_file.setnchannels(channels) + wav_file.setsampwidth(sample_width) + wav_file.setframerate(sample_rate) + wav_file.writeframes(audio_data) + return wav_data.getvalue() + + +class WyomingWakeWordService(WakeWordService): + """Wake word detection service using a Wyoming wake word server.""" + + def __init__( + self, + wake_word_config: config.WakeWord, + logger: logging.Logger, + queue: asyncio.Queue, + *, + live: Live | None = None, + quiet: bool = False, + ): + self.wake_word_config = wake_word_config + self.logger = logger + self.queue = queue + self.live = live + self.quiet = quiet + + async def detect(self) -> str | None: + """Detect the wake word.""" + async with wyoming_client_context( + self.wake_word_config.wake_server_ip, + self.wake_word_config.wake_server_port, + "wake word", + self.logger, + quiet=self.quiet, + ) as client: + await client.write_event(Detect(names=[self.wake_word_config.wake_word_name]).event()) + _send_task, recv_task = await manage_send_receive_tasks( + self._send_audio_from_queue_for_wake_detection(client), + self._receive_wake_detection(client), + return_when=asyncio.FIRST_COMPLETED, + ) + if recv_task.done() and not recv_task.cancelled(): + return recv_task.result() + return None + + async def _send_audio_from_queue_for_wake_detection(self, client: AsyncClient) -> None: + """Read from a queue and send to Wyoming wake word server.""" + from agent_cli.core.audio import read_from_queue + + await client.write_event(AudioStart(**constants.WYOMING_AUDIO_CONFIG).event()) + try: + await read_from_queue( + queue=self.queue, + chunk_handler=lambda chunk: client.write_event( + AudioChunk(audio=chunk, **constants.WYOMING_AUDIO_CONFIG).event() + ), + logger=self.logger, + ) + finally: + if client._writer is not None: + await client.write_event(AudioStop().event()) + + async def _receive_wake_detection(self, client: AsyncClient) -> str | None: + """Receive wake word detection events.""" + while True: + event = await client.read_event() + if event is None: + break + if Detection.is_type(event.type): + return Detection.from_event(event).name + if NotDetected.is_type(event.type): + break + return None diff --git a/agent_cli/services/openai.py b/agent_cli/services/openai.py new file mode 100644 index 000000000..4d1c911eb --- /dev/null +++ b/agent_cli/services/openai.py @@ -0,0 +1,125 @@ +"""Module for interacting with OpenAI services.""" + +from __future__ import annotations + +import io +from typing import TYPE_CHECKING + +from agent_cli.services.base import ASRService, LLMService, TTSService + +if TYPE_CHECKING: + import logging + + from openai import AsyncOpenAI + + from agent_cli import config + + +def _get_openai_client(api_key: str) -> AsyncOpenAI: + """Get an OpenAI client instance.""" + from openai import AsyncOpenAI # noqa: PLC0415 + + if not api_key: + msg = "OpenAI API key is not set." + raise ValueError(msg) + return AsyncOpenAI(api_key=api_key) + + +class OpenAITranscriptionService(ASRService): + """Transcription service using OpenAI's Whisper API.""" + + def __init__( + self, + openai_asr_config: config.OpenAIASR, + openai_llm_config: config.OpenAILLM, + logger: logging.Logger, + ): + self.openai_asr_config = openai_asr_config + self.openai_llm_config = openai_llm_config + self.logger = logger + if not self.openai_llm_config.openai_api_key: + msg = "OpenAI API key is not set." + raise ValueError(msg) + self.client = _get_openai_client(api_key=self.openai_llm_config.openai_api_key) + + async def transcribe(self, audio_data: bytes) -> str: + """Transcribe audio using OpenAI's Whisper API.""" + self.logger.info("Transcribing audio with OpenAI Whisper...") + audio_file = io.BytesIO(audio_data) + audio_file.name = "audio.wav" + response = await self.client.audio.transcriptions.create( + model=self.openai_asr_config.openai_asr_model, + file=audio_file, + ) + return response.text + + +class OpenAILLMService(LLMService): + """LLM service using OpenAI's API.""" + + def __init__( + self, + openai_llm_config: config.OpenAILLM, + logger: logging.Logger, + ): + from pydantic_ai.models.openai import OpenAIModel # noqa: PLC0415 + from pydantic_ai.providers.openai import OpenAIProvider # noqa: PLC0415 + + self.openai_llm_config = openai_llm_config + self.logger = logger + if not self.openai_llm_config.openai_api_key: + msg = "OpenAI API key is not set." + raise ValueError(msg) + provider = OpenAIProvider(api_key=self.openai_llm_config.openai_api_key) + self.model = OpenAIModel( + model_name=self.openai_llm_config.openai_llm_model, provider=provider + ) + + async def get_response( + self, + *, + system_prompt: str, + agent_instructions: str, + user_input: str, + tools: list | None = None, + ) -> str | None: + """Get a response from the language model.""" + from pydantic_ai import Agent # noqa: PLC0415 + + agent = Agent( + model=self.model, + system_prompt=system_prompt or (), + instructions=agent_instructions, + tools=tools or [], + ) + result = await agent.run(user_input) + return result.output + + +class OpenAITTSService(TTSService): + """TTS service using OpenAI's API.""" + + def __init__( + self, + openai_tts_config: config.OpenAITTS, + openai_llm_config: config.OpenAILLM, + logger: logging.Logger, + ): + self.openai_tts_config = openai_tts_config + self.openai_llm_config = openai_llm_config + self.logger = logger + if not self.openai_llm_config.openai_api_key: + msg = "OpenAI API key is not set." + raise ValueError(msg) + self.client = _get_openai_client(api_key=self.openai_llm_config.openai_api_key) + + async def synthesize(self, text: str) -> bytes: + """Synthesize speech using OpenAI's TTS API.""" + self.logger.info("Synthesizing speech with OpenAI TTS...") + response = await self.client.audio.speech.create( + model=self.openai_tts_config.openai_tts_model, + voice=self.openai_tts_config.openai_tts_voice, + input=text, + response_format="wav", + ) + return response.content From 22e5c1fda4c8b9b26df360e3708a2266dc580f7d Mon Sep 17 00:00:00 2001 From: Bas Nijholt Date: Tue, 8 Jul 2025 20:26:41 -0700 Subject: [PATCH 5/8] refactor: Update modules to use service factory --- agent_cli/asr.py | 149 ++++++++++++++++++++++++----------------------- agent_cli/llm.py | 54 +++++------------ agent_cli/tts.py | 54 ++++++++--------- 3 files changed, 115 insertions(+), 142 deletions(-) diff --git a/agent_cli/asr.py b/agent_cli/asr.py index c28d45eae..2aee41564 100644 --- a/agent_cli/asr.py +++ b/agent_cli/asr.py @@ -18,8 +18,19 @@ setup_input_stream, ) -from agent_cli.services import transcribe_audio_openai -from agent_cli.wyoming_utils import manage_send_receive_tasks, wyoming_client_context +from agent_cli.core.audio import ( + open_pyaudio_stream, + read_audio_stream, + read_from_queue, + setup_input_stream, +) +from agent_cli.services.factory import get_asr_service +from agent_cli.services.local import ( + WyomingTranscriptionService, + manage_send_receive_tasks, + wyoming_client_context, +) +from agent_cli.services.openai import OpenAITranscriptionService if TYPE_CHECKING: import logging @@ -39,29 +50,56 @@ def get_transcriber( wyoming_asr_config: config.WyomingASR, openai_asr_config: config.OpenAIASR, openai_llm_config: config.OpenAILLM, + logger: logging.Logger, + *, + quiet: bool = False, ) -> Callable[..., Awaitable[str | None]]: """Return the appropriate transcriber for live audio based on the provider.""" - if provider_config.asr_provider == "openai": - return partial( - transcribe_live_audio_openai, - audio_input_config=audio_input_config, - openai_asr_config=openai_asr_config, - openai_llm_config=openai_llm_config, - ) - return partial( - transcribe_live_audio_wyoming, - audio_input_config=audio_input_config, - wyoming_asr_config=wyoming_asr_config, + asr_service = get_asr_service( + provider_config, + wyoming_asr_config, + openai_asr_config, + openai_llm_config, + logger, + quiet=quiet, ) + async def transcribe_live_audio(p: pyaudio.PyAudio, stop_event: InteractiveStopEvent, live: Live) -> str | None: + """Record and transcribe live audio.""" + audio_data = await record_audio_with_manual_stop( + p, + audio_input_config.input_device_index, + stop_event, + logger, + quiet=quiet, + live=live, + ) + if not audio_data: + return None + return await asr_service.transcribe(audio_data) + + return transcribe_live_audio + def get_recorded_audio_transcriber( provider_config: config.ProviderSelection, -) -> Callable[..., Awaitable[str]]: + wyoming_asr_config: config.WyomingASR, + openai_asr_config: config.OpenAIASR, + openai_llm_config: config.OpenAILLM, + logger: logging.Logger, + *, + quiet: bool = False, +) -> Callable[[bytes], Awaitable[str]]: """Return the appropriate transcriber for recorded audio based on the provider.""" - if provider_config.asr_provider == "openai": - return transcribe_audio_openai - return transcribe_recorded_audio_wyoming + asr_service = get_asr_service( + provider_config, + wyoming_asr_config, + openai_asr_config, + openai_llm_config, + logger, + quiet=quiet, + ) + return asr_service.transcribe async def _send_audio( @@ -188,31 +226,8 @@ async def transcribe_recorded_audio_wyoming( **_kwargs: object, ) -> str: """Process pre-recorded audio data with Wyoming ASR server.""" - try: - async with wyoming_client_context( - wyoming_asr_config.wyoming_asr_ip, - wyoming_asr_config.wyoming_asr_port, - "ASR", - logger, - quiet=quiet, - ) as client: - await client.write_event(Transcribe().event()) - await client.write_event(AudioStart(**constants.WYOMING_AUDIO_CONFIG).event()) - - chunk_size = constants.PYAUDIO_CHUNK_SIZE * 2 - for i in range(0, len(audio_data), chunk_size): - chunk = audio_data[i : i + chunk_size] - await client.write_event( - AudioChunk(audio=chunk, **constants.WYOMING_AUDIO_CONFIG).event(), - ) - logger.debug("Sent %d byte(s) of audio", len(chunk)) - - await client.write_event(AudioStop().event()) - logger.debug("Sent AudioStop") - - return await _receive_transcript(client, logger) - except (ConnectionRefusedError, Exception): - return "" + service = WyomingTranscriptionService(wyoming_asr_config, logger, quiet=quiet) + return await service.transcribe(audio_data) async def transcribe_live_audio_wyoming( @@ -229,29 +244,23 @@ async def transcribe_live_audio_wyoming( **_kwargs: object, ) -> str | None: """Unified ASR transcription function.""" - try: - async with wyoming_client_context( - wyoming_asr_config.wyoming_asr_ip, - wyoming_asr_config.wyoming_asr_port, - "ASR", - logger, - quiet=quiet, - ) as client: - stream_config = setup_input_stream(audio_input_config.input_device_index) - with open_pyaudio_stream(p, **stream_config) as stream: - _, recv_task = await manage_send_receive_tasks( - _send_audio(client, stream, stop_event, logger, live=live, quiet=quiet), - _receive_transcript( - client, - logger, - chunk_callback=chunk_callback, - final_callback=final_callback, - ), - return_when=asyncio.ALL_COMPLETED, - ) - return recv_task.result() - except (ConnectionRefusedError, Exception): + service = WyomingTranscriptionService(wyoming_asr_config, logger, quiet=quiet) + audio_data = await record_audio_with_manual_stop( + p, + audio_input_config.input_device_index, + stop_event, + logger, + quiet=quiet, + live=live, + ) + if not audio_data: return None + transcript = await service.transcribe(audio_data) + if chunk_callback: + chunk_callback(transcript) + if final_callback: + final_callback(transcript) + return transcript async def transcribe_live_audio_openai( @@ -267,6 +276,7 @@ async def transcribe_live_audio_openai( **_kwargs: object, ) -> str | None: """Record and transcribe live audio using OpenAI Whisper.""" + service = OpenAITranscriptionService(openai_asr_config, openai_llm_config, logger) audio_data = await record_audio_with_manual_stop( p, audio_input_config.input_device_index, @@ -277,13 +287,4 @@ async def transcribe_live_audio_openai( ) if not audio_data: return None - try: - return await transcribe_audio_openai( - audio_data, - openai_asr_config, - openai_llm_config, - logger, - ) - except Exception: - logger.exception("Error during transcription") - return "" + return await service.transcribe(audio_data) diff --git a/agent_cli/llm.py b/agent_cli/llm.py index 7b5bbe76f..34121ae99 100644 --- a/agent_cli/llm.py +++ b/agent_cli/llm.py @@ -9,7 +9,8 @@ import pyperclip from rich.live import Live -from agent_cli.utils import console, live_timer, print_error_message, print_output_panel +from agent_cli.core.utils import console, live_timer, print_error_message, print_output_panel +from agent_cli.services.factory import get_llm_service if TYPE_CHECKING: import logging @@ -20,37 +21,6 @@ from agent_cli.agents import config -def build_agent( - provider_config: config.ProviderSelection, - ollama_config: config.Ollama, - openai_config: config.OpenAILLM, - *, - system_prompt: str | None = None, - instructions: str | None = None, - tools: list[Tool] | None = None, -) -> Agent: - """Construct and return a PydanticAI agent.""" - from pydantic_ai import Agent # noqa: PLC0415 - from pydantic_ai.models.openai import OpenAIModel # noqa: PLC0415 - from pydantic_ai.providers.openai import OpenAIProvider # noqa: PLC0415 - - if provider_config.llm_provider == "openai": - if not openai_config.openai_api_key: - msg = "OpenAI API key is not set." - raise ValueError(msg) - provider = OpenAIProvider(api_key=openai_config.openai_api_key) - model_name = openai_config.openai_llm_model - else: - provider = OpenAIProvider(base_url=f"{ollama_config.ollama_host}/v1") - model_name = ollama_config.ollama_model - - llm_model = OpenAIModel(model_name=model_name, provider=provider) - return Agent( - model=llm_model, - system_prompt=system_prompt or (), - instructions=instructions, - tools=tools or [], - ) # --- LLM (Editing) Logic --- @@ -83,13 +53,8 @@ async def get_llm_response( exit_on_error: bool = False, ) -> str | None: """Get a response from the LLM with optional clipboard and output handling.""" - agent = build_agent( - provider_config=provider_config, - ollama_config=ollama_config, - openai_config=openai_config, - system_prompt=system_prompt, - instructions=agent_instructions, - tools=tools, + llm_service = get_llm_service( + provider_config, ollama_config, openai_config, logger ) start_time = time.monotonic() @@ -107,10 +72,17 @@ async def get_llm_response( style="bold yellow", quiet=quiet, ): - result = await agent.run(user_input) + result_text = await llm_service.get_response( + system_prompt=system_prompt, + agent_instructions=agent_instructions, + user_input=user_input, + tools=tools, + ) elapsed = time.monotonic() - start_time - result_text = result.output + + if not result_text: + return None if clipboard: pyperclip.copy(result_text) diff --git a/agent_cli/tts.py b/agent_cli/tts.py index 1fb42e182..01046f184 100644 --- a/agent_cli/tts.py +++ b/agent_cli/tts.py @@ -15,9 +15,15 @@ from agent_cli import constants from agent_cli.audio import open_pyaudio_stream, pyaudio_context, setup_output_stream -from agent_cli.services import synthesize_speech_openai -from agent_cli.utils import InteractiveStopEvent, live_timer, print_error_message, print_with_style -from agent_cli.wyoming_utils import manage_send_receive_tasks, wyoming_client_context +from agent_cli.core.audio import open_pyaudio_stream, pyaudio_context, setup_output_stream +from agent_cli.core.utils import ( + InteractiveStopEvent, + live_timer, + print_error_message, + print_with_style, +) +from agent_cli.services.factory import get_tts_service +from agent_cli.services.local import manage_send_receive_tasks, wyoming_client_context if TYPE_CHECKING: import logging @@ -37,17 +43,28 @@ def get_synthesizer( wyoming_tts_config: config.WyomingTTS, openai_tts_config: config.OpenAITTS, openai_llm_config: config.OpenAILLM, + logger: logging.Logger, + *, + quiet: bool = False, ) -> Callable[..., Awaitable[bytes | None]]: """Return the appropriate synthesizer based on the config.""" if not audio_output_config.enable_tts: return _dummy_synthesizer - if provider_config.tts_provider == "openai": - return partial( - _synthesize_speech_openai, - openai_tts_config=openai_tts_config, - openai_llm_config=openai_llm_config, - ) - return partial(_synthesize_speech_wyoming, wyoming_tts_config=wyoming_tts_config) + + tts_service = get_tts_service( + provider_config, + wyoming_tts_config, + openai_tts_config, + openai_llm_config, + logger, + quiet=quiet, + ) + + async def synthesize(text: str) -> bytes | None: + """Synthesize speech from text.""" + return await tts_service.synthesize(text) + + return synthesize def _create_synthesis_request( @@ -134,23 +151,6 @@ async def _dummy_synthesizer(**_kwargs: object) -> bytes | None: return None -async def _synthesize_speech_openai( - *, - text: str, - openai_tts_config: config.OpenAITTS, - openai_llm_config: config.OpenAILLM, - logger: logging.Logger, - **_kwargs: object, -) -> bytes | None: - """Synthesize speech from text using OpenAI TTS server.""" - return await synthesize_speech_openai( - text=text, - openai_tts_config=openai_tts_config, - openai_llm_config=openai_llm_config, - logger=logger, - ) - - async def _synthesize_speech_wyoming( *, text: str, From 4bf2126a1d92f33f8877e5b9361480663c572bad Mon Sep 17 00:00:00 2001 From: Bas Nijholt Date: Tue, 8 Jul 2025 20:27:33 -0700 Subject: [PATCH 6/8] refactor: Remove old service modules --- agent_cli/asr.py | 290 ---------------------------------- agent_cli/services.py | 65 -------- agent_cli/tts.py | 313 ------------------------------------- agent_cli/wake_word.py | 136 ---------------- agent_cli/wyoming_utils.py | 95 ----------- 5 files changed, 899 deletions(-) delete mode 100644 agent_cli/asr.py delete mode 100644 agent_cli/services.py delete mode 100644 agent_cli/tts.py delete mode 100644 agent_cli/wake_word.py delete mode 100644 agent_cli/wyoming_utils.py diff --git a/agent_cli/asr.py b/agent_cli/asr.py deleted file mode 100644 index 2aee41564..000000000 --- a/agent_cli/asr.py +++ /dev/null @@ -1,290 +0,0 @@ -"""Module for Automatic Speech Recognition using Wyoming or OpenAI.""" - -from __future__ import annotations - -import asyncio -import io -from functools import partial -from typing import TYPE_CHECKING - -from wyoming.asr import Transcribe, Transcript, TranscriptChunk, TranscriptStart, TranscriptStop -from wyoming.audio import AudioChunk, AudioStart, AudioStop - -from agent_cli import constants -from agent_cli.audio import ( - open_pyaudio_stream, - read_audio_stream, - read_from_queue, - setup_input_stream, -) - -from agent_cli.core.audio import ( - open_pyaudio_stream, - read_audio_stream, - read_from_queue, - setup_input_stream, -) -from agent_cli.services.factory import get_asr_service -from agent_cli.services.local import ( - WyomingTranscriptionService, - manage_send_receive_tasks, - wyoming_client_context, -) -from agent_cli.services.openai import OpenAITranscriptionService - -if TYPE_CHECKING: - import logging - from collections.abc import Awaitable, Callable - - import pyaudio - from rich.live import Live - from wyoming.client import AsyncClient - - from agent_cli.agents import config - from agent_cli.utils import InteractiveStopEvent - - -def get_transcriber( - provider_config: config.ProviderSelection, - audio_input_config: config.AudioInput, - wyoming_asr_config: config.WyomingASR, - openai_asr_config: config.OpenAIASR, - openai_llm_config: config.OpenAILLM, - logger: logging.Logger, - *, - quiet: bool = False, -) -> Callable[..., Awaitable[str | None]]: - """Return the appropriate transcriber for live audio based on the provider.""" - asr_service = get_asr_service( - provider_config, - wyoming_asr_config, - openai_asr_config, - openai_llm_config, - logger, - quiet=quiet, - ) - - async def transcribe_live_audio(p: pyaudio.PyAudio, stop_event: InteractiveStopEvent, live: Live) -> str | None: - """Record and transcribe live audio.""" - audio_data = await record_audio_with_manual_stop( - p, - audio_input_config.input_device_index, - stop_event, - logger, - quiet=quiet, - live=live, - ) - if not audio_data: - return None - return await asr_service.transcribe(audio_data) - - return transcribe_live_audio - - -def get_recorded_audio_transcriber( - provider_config: config.ProviderSelection, - wyoming_asr_config: config.WyomingASR, - openai_asr_config: config.OpenAIASR, - openai_llm_config: config.OpenAILLM, - logger: logging.Logger, - *, - quiet: bool = False, -) -> Callable[[bytes], Awaitable[str]]: - """Return the appropriate transcriber for recorded audio based on the provider.""" - asr_service = get_asr_service( - provider_config, - wyoming_asr_config, - openai_asr_config, - openai_llm_config, - logger, - quiet=quiet, - ) - return asr_service.transcribe - - -async def _send_audio( - client: AsyncClient, - stream: pyaudio.Stream, - stop_event: InteractiveStopEvent, - logger: logging.Logger, - *, - live: Live, - quiet: bool = False, -) -> None: - """Read from mic and send to Wyoming server.""" - await client.write_event(Transcribe().event()) - await client.write_event(AudioStart(**constants.WYOMING_AUDIO_CONFIG).event()) - - async def send_chunk(chunk: bytes) -> None: - """Send audio chunk to ASR server.""" - await client.write_event(AudioChunk(audio=chunk, **constants.WYOMING_AUDIO_CONFIG).event()) - - try: - await read_audio_stream( - stream=stream, - stop_event=stop_event, - chunk_handler=send_chunk, - logger=logger, - live=live, - quiet=quiet, - progress_message="Listening", - progress_style="blue", - ) - finally: - await client.write_event(AudioStop().event()) - logger.debug("Sent AudioStop") - - -async def record_audio_to_buffer( - queue: asyncio.Queue, - logger: logging.Logger, -) -> bytes: - """Record audio from a queue to a buffer.""" - audio_buffer = io.BytesIO() - - def buffer_chunk(chunk: bytes) -> None: - """Buffer audio chunk.""" - audio_buffer.write(chunk) - - await read_from_queue(queue=queue, chunk_handler=buffer_chunk, logger=logger) - - return audio_buffer.getvalue() - - -async def _receive_transcript( - client: AsyncClient, - logger: logging.Logger, - *, - chunk_callback: Callable[[str], None] | None = None, - final_callback: Callable[[str], None] | None = None, -) -> str: - """Receive transcription events and return the final transcript.""" - transcript_text = "" - while True: - event = await client.read_event() - if event is None: - logger.warning("Connection to ASR server lost.") - break - - if Transcript.is_type(event.type): - transcript = Transcript.from_event(event) - transcript_text = transcript.text - logger.info("Final transcript: %s", transcript_text) - if final_callback: - final_callback(transcript_text) - break - if TranscriptChunk.is_type(event.type): - chunk = TranscriptChunk.from_event(event) - logger.debug("Transcript chunk: %s", chunk.text) - if chunk_callback: - chunk_callback(chunk.text) - elif TranscriptStart.is_type(event.type) or TranscriptStop.is_type(event.type): - logger.debug("Received %s", event.type) - else: - logger.debug("Ignoring event type: %s", event.type) - - return transcript_text - - -async def record_audio_with_manual_stop( - p: pyaudio.PyAudio, - input_device_index: int | None, - stop_event: InteractiveStopEvent, - logger: logging.Logger, - *, - quiet: bool = False, - live: Live | None = None, -) -> bytes: - """Record audio to a buffer using a manual stop signal.""" - audio_buffer = io.BytesIO() - - def buffer_chunk(chunk: bytes) -> None: - """Buffer audio chunk.""" - audio_buffer.write(chunk) - - stream_config = setup_input_stream(input_device_index) - with open_pyaudio_stream(p, **stream_config) as stream: - await read_audio_stream( - stream=stream, - stop_event=stop_event, - chunk_handler=buffer_chunk, - logger=logger, - live=live, - quiet=quiet, - progress_message="Recording", - progress_style="green", - ) - return audio_buffer.getvalue() - - -async def transcribe_recorded_audio_wyoming( - *, - audio_data: bytes, - wyoming_asr_config: config.WyomingASR, - logger: logging.Logger, - quiet: bool = False, - **_kwargs: object, -) -> str: - """Process pre-recorded audio data with Wyoming ASR server.""" - service = WyomingTranscriptionService(wyoming_asr_config, logger, quiet=quiet) - return await service.transcribe(audio_data) - - -async def transcribe_live_audio_wyoming( - *, - audio_input_config: config.AudioInput, - wyoming_asr_config: config.WyomingASR, - logger: logging.Logger, - p: pyaudio.PyAudio, - stop_event: InteractiveStopEvent, - live: Live, - quiet: bool = False, - chunk_callback: Callable[[str], None] | None = None, - final_callback: Callable[[str], None] | None = None, - **_kwargs: object, -) -> str | None: - """Unified ASR transcription function.""" - service = WyomingTranscriptionService(wyoming_asr_config, logger, quiet=quiet) - audio_data = await record_audio_with_manual_stop( - p, - audio_input_config.input_device_index, - stop_event, - logger, - quiet=quiet, - live=live, - ) - if not audio_data: - return None - transcript = await service.transcribe(audio_data) - if chunk_callback: - chunk_callback(transcript) - if final_callback: - final_callback(transcript) - return transcript - - -async def transcribe_live_audio_openai( - *, - audio_input_config: config.AudioInput, - openai_asr_config: config.OpenAIASR, - openai_llm_config: config.OpenAILLM, - logger: logging.Logger, - p: pyaudio.PyAudio, - stop_event: InteractiveStopEvent, - live: Live, - quiet: bool = False, - **_kwargs: object, -) -> str | None: - """Record and transcribe live audio using OpenAI Whisper.""" - service = OpenAITranscriptionService(openai_asr_config, openai_llm_config, logger) - audio_data = await record_audio_with_manual_stop( - p, - audio_input_config.input_device_index, - stop_event, - logger, - quiet=quiet, - live=live, - ) - if not audio_data: - return None - return await service.transcribe(audio_data) diff --git a/agent_cli/services.py b/agent_cli/services.py deleted file mode 100644 index d33cfaa65..000000000 --- a/agent_cli/services.py +++ /dev/null @@ -1,65 +0,0 @@ -"""Module for interacting with online services like OpenAI.""" - -from __future__ import annotations - -import io -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - import logging - - from openai import AsyncOpenAI - - from agent_cli.agents import config - - -def _get_openai_client(api_key: str) -> AsyncOpenAI: - """Get an OpenAI client instance.""" - from openai import AsyncOpenAI # noqa: PLC0415 - - if not api_key: - msg = "OpenAI API key is not set." - raise ValueError(msg) - return AsyncOpenAI(api_key=api_key) - - -async def transcribe_audio_openai( - audio_data: bytes, - openai_asr_config: config.OpenAIASR, - openai_llm_config: config.OpenAILLM, - logger: logging.Logger, -) -> str: - """Transcribe audio using OpenAI's Whisper API.""" - logger.info("Transcribing audio with OpenAI Whisper...") - if not openai_llm_config.openai_api_key: - msg = "OpenAI API key is not set." - raise ValueError(msg) - client = _get_openai_client(api_key=openai_llm_config.openai_api_key) - audio_file = io.BytesIO(audio_data) - audio_file.name = "audio.wav" - response = await client.audio.transcriptions.create( - model=openai_asr_config.openai_asr_model, - file=audio_file, - ) - return response.text - - -async def synthesize_speech_openai( - text: str, - openai_tts_config: config.OpenAITTS, - openai_llm_config: config.OpenAILLM, - logger: logging.Logger, -) -> bytes: - """Synthesize speech using OpenAI's TTS API.""" - logger.info("Synthesizing speech with OpenAI TTS...") - if not openai_llm_config.openai_api_key: - msg = "OpenAI API key is not set." - raise ValueError(msg) - client = _get_openai_client(api_key=openai_llm_config.openai_api_key) - response = await client.audio.speech.create( - model=openai_tts_config.openai_tts_model, - voice=openai_tts_config.openai_tts_voice, - input=text, - response_format="wav", - ) - return response.content diff --git a/agent_cli/tts.py b/agent_cli/tts.py deleted file mode 100644 index 01046f184..000000000 --- a/agent_cli/tts.py +++ /dev/null @@ -1,313 +0,0 @@ -"""Module for Text-to-Speech using Wyoming or OpenAI.""" - -from __future__ import annotations - -import asyncio -import importlib.util -import io -import wave -from functools import partial -from typing import TYPE_CHECKING - -from wyoming.audio import AudioChunk, AudioStart, AudioStop -from wyoming.tts import Synthesize, SynthesizeVoice - -from agent_cli import constants -from agent_cli.audio import open_pyaudio_stream, pyaudio_context, setup_output_stream - -from agent_cli.core.audio import open_pyaudio_stream, pyaudio_context, setup_output_stream -from agent_cli.core.utils import ( - InteractiveStopEvent, - live_timer, - print_error_message, - print_with_style, -) -from agent_cli.services.factory import get_tts_service -from agent_cli.services.local import manage_send_receive_tasks, wyoming_client_context - -if TYPE_CHECKING: - import logging - from collections.abc import Awaitable, Callable - - from rich.live import Live - from wyoming.client import AsyncClient - - from agent_cli.agents import config - -has_audiostretchy = importlib.util.find_spec("audiostretchy") is not None - - -def get_synthesizer( - provider_config: config.ProviderSelection, - audio_output_config: config.AudioOutput, - wyoming_tts_config: config.WyomingTTS, - openai_tts_config: config.OpenAITTS, - openai_llm_config: config.OpenAILLM, - logger: logging.Logger, - *, - quiet: bool = False, -) -> Callable[..., Awaitable[bytes | None]]: - """Return the appropriate synthesizer based on the config.""" - if not audio_output_config.enable_tts: - return _dummy_synthesizer - - tts_service = get_tts_service( - provider_config, - wyoming_tts_config, - openai_tts_config, - openai_llm_config, - logger, - quiet=quiet, - ) - - async def synthesize(text: str) -> bytes | None: - """Synthesize speech from text.""" - return await tts_service.synthesize(text) - - return synthesize - - -def _create_synthesis_request( - text: str, - *, - voice_name: str | None = None, - language: str | None = None, - speaker: str | None = None, -) -> Synthesize: - """Create a synthesis request with optional voice parameters.""" - synthesize_event = Synthesize(text=text) - - # Add voice parameters if specified - if voice_name or language or speaker: - synthesize_event.voice = SynthesizeVoice( - name=voice_name, - language=language, - speaker=speaker, - ) - - return synthesize_event - - -async def _process_audio_events( - client: AsyncClient, - logger: logging.Logger, -) -> tuple[bytes, int | None, int | None, int | None]: - """Process audio events from TTS server and return audio data with metadata.""" - audio_data = io.BytesIO() - sample_rate = None - sample_width = None - channels = None - - while True: - event = await client.read_event() - if event is None: - logger.warning("Connection to TTS server lost.") - break - - if AudioStart.is_type(event.type): - audio_start = AudioStart.from_event(event) - sample_rate = audio_start.rate - sample_width = audio_start.width - channels = audio_start.channels - logger.debug( - "Audio stream started: %dHz, %d channels, %d bytes/sample", - sample_rate, - channels, - sample_width, - ) - - elif AudioChunk.is_type(event.type): - chunk = AudioChunk.from_event(event) - audio_data.write(chunk.audio) - logger.debug("Received %d bytes of audio", len(chunk.audio)) - - elif AudioStop.is_type(event.type): - logger.debug("Audio stream completed") - break - else: - logger.debug("Ignoring event type: %s", event.type) - - return audio_data.getvalue(), sample_rate, sample_width, channels - - -def _create_wav_data( - audio_data: bytes, - sample_rate: int, - sample_width: int, - channels: int, -) -> bytes: - """Convert raw audio data to WAV format.""" - wav_data = io.BytesIO() - with wave.open(wav_data, "wb") as wav_file: - wav_file.setnchannels(channels) - wav_file.setsampwidth(sample_width) - wav_file.setframerate(sample_rate) - wav_file.writeframes(audio_data) - return wav_data.getvalue() - - -async def _dummy_synthesizer(**_kwargs: object) -> bytes | None: - """A dummy synthesizer that does nothing.""" - return None - - -async def _synthesize_speech_wyoming( - *, - text: str, - wyoming_tts_config: config.WyomingTTS, - logger: logging.Logger, - quiet: bool = False, - live: Live, - **_kwargs: object, -) -> bytes | None: - """Synthesize speech from text using Wyoming TTS server.""" - try: - async with wyoming_client_context( - wyoming_tts_config.wyoming_tts_ip, - wyoming_tts_config.wyoming_tts_port, - "TTS", - logger, - quiet=quiet, - ) as client: - async with live_timer(live, "🔊 Synthesizing text", style="blue", quiet=quiet): - synthesize_event = _create_synthesis_request( - text, - voice_name=wyoming_tts_config.wyoming_voice, - language=wyoming_tts_config.wyoming_tts_language, - speaker=wyoming_tts_config.wyoming_speaker, - ) - _send_task, recv_task = await manage_send_receive_tasks( - client.write_event(synthesize_event.event()), - _process_audio_events(client, logger), - ) - audio_data, sample_rate, sample_width, channels = recv_task.result() - if sample_rate and sample_width and channels and audio_data: - wav_data = _create_wav_data(audio_data, sample_rate, sample_width, channels) - logger.info("Speech synthesis completed: %d bytes", len(wav_data)) - return wav_data - logger.warning("No audio data received from TTS server") - return None - except (ConnectionRefusedError, Exception): - return None - - -def _apply_speed_adjustment( - audio_data: io.BytesIO, - speed: float, -) -> tuple[io.BytesIO, bool]: - """Apply speed adjustment to audio data.""" - if speed == 1.0 or not has_audiostretchy: - return audio_data, False - from audiostretchy.stretch import AudioStretch # noqa: PLC0415 - - audio_data.seek(0) - input_copy = io.BytesIO(audio_data.read()) - audio_stretch = AudioStretch() - audio_stretch.open(file=input_copy, format="wav") - audio_stretch.stretch(ratio=1 / speed) - out = io.BytesIO() - audio_stretch.save_wav(out, close=False) - out.seek(0) - return out, True - - -async def play_audio( - audio_data: bytes, - logger: logging.Logger, - *, - audio_output_config: config.AudioOutput, - quiet: bool = False, - stop_event: InteractiveStopEvent | None = None, - live: Live, -) -> None: - """Play WAV audio data using PyAudio.""" - try: - wav_io = io.BytesIO(audio_data) - speed = audio_output_config.tts_speed - wav_io, speed_changed = _apply_speed_adjustment(wav_io, speed) - with wave.open(wav_io, "rb") as wav_file: - sample_rate = wav_file.getframerate() - channels = wav_file.getnchannels() - sample_width = wav_file.getsampwidth() - frames = wav_file.readframes(wav_file.getnframes()) - if not speed_changed: - sample_rate = int(sample_rate * speed) - base_msg = f"🔊 Playing audio at {speed}x speed" if speed != 1.0 else "🔊 Playing audio" - async with live_timer(live, base_msg, style="blue", quiet=quiet): - with pyaudio_context() as p: - stream_config = setup_output_stream( - audio_output_config.output_device_index, - sample_rate=sample_rate, - sample_width=sample_width, - channels=channels, - ) - with open_pyaudio_stream(p, **stream_config) as stream: - chunk_size = constants.PYAUDIO_CHUNK_SIZE - for i in range(0, len(frames), chunk_size): - if stop_event and stop_event.is_set(): - logger.info("Audio playback interrupted") - if not quiet: - print_with_style("⏹️ Audio playback interrupted", style="yellow") - break - chunk = frames[i : i + chunk_size] - stream.write(chunk) - await asyncio.sleep(0) - if not (stop_event and stop_event.is_set()): - logger.info("Audio playback completed (speed: %.1fx)", speed) - if not quiet: - print_with_style("✅ Audio playback finished") - except Exception as e: - logger.exception("Error during audio playback") - if not quiet: - print_error_message(f"Playback error: {e}") - - -async def speak_text( - *, - text: str, - provider_config: config.ProviderSelection, - audio_output_config: config.AudioOutput, - wyoming_tts_config: config.WyomingTTS, - openai_tts_config: config.OpenAITTS, - openai_llm_config: config.OpenAILLM, - logger: logging.Logger, - quiet: bool = False, - play_audio_flag: bool = True, - stop_event: InteractiveStopEvent | None = None, - live: Live, -) -> bytes | None: - """Synthesize and optionally play speech from text.""" - synthesizer = get_synthesizer( - provider_config, - audio_output_config, - wyoming_tts_config, - openai_tts_config, - openai_llm_config, - ) - audio_data = None - try: - async with live_timer(live, "🔊 Synthesizing text", style="blue", quiet=quiet): - audio_data = await synthesizer( - text=text, - wyoming_tts_config=wyoming_tts_config, - openai_tts_config=openai_tts_config, - openai_llm_config=openai_llm_config, - logger=logger, - quiet=quiet, - live=live, - ) - except Exception: - logger.exception("Error during speech synthesis") - return None - - if audio_data and play_audio_flag: - await play_audio( - audio_data, - logger, - audio_output_config=audio_output_config, - quiet=quiet, - stop_event=stop_event, - live=live, - ) - - return audio_data diff --git a/agent_cli/wake_word.py b/agent_cli/wake_word.py deleted file mode 100644 index 8015ea69e..000000000 --- a/agent_cli/wake_word.py +++ /dev/null @@ -1,136 +0,0 @@ -"""Module for Wake Word Detection using Wyoming.""" - -from __future__ import annotations - -import asyncio -from typing import TYPE_CHECKING - -from wyoming.audio import AudioChunk, AudioStart, AudioStop -from wyoming.wake import Detect, Detection, NotDetected - -from agent_cli import constants -from agent_cli.audio import read_from_queue -from agent_cli.wyoming_utils import manage_send_receive_tasks, wyoming_client_context - -if TYPE_CHECKING: - import logging - from collections.abc import Callable - - from rich.live import Live - from wyoming.client import AsyncClient - - -async def _send_audio_from_queue_for_wake_detection( - client: AsyncClient, - queue: asyncio.Queue, - logger: logging.Logger, - live: Live | None, - quiet: bool, - progress_message: str, -) -> None: - """Read from a queue and send to Wyoming wake word server.""" - await client.write_event(AudioStart(**constants.WYOMING_AUDIO_CONFIG).event()) - seconds_streamed = 0.0 - - async def send_chunk(chunk: bytes) -> None: - nonlocal seconds_streamed - """Send audio chunk to wake word server.""" - await client.write_event( - AudioChunk(audio=chunk, **constants.WYOMING_AUDIO_CONFIG).event(), - ) - seconds_streamed += len(chunk) / (constants.PYAUDIO_RATE * constants.PYAUDIO_CHANNELS * 2) - if live and not quiet: - live.update(f"{progress_message}... ({seconds_streamed:.1f}s)") - - try: - await read_from_queue( - queue=queue, - chunk_handler=send_chunk, - logger=logger, - ) - finally: - if client._writer is not None: - await client.write_event(AudioStop().event()) - logger.debug("Sent AudioStop for wake detection") - - -async def _receive_wake_detection( - client: AsyncClient, - logger: logging.Logger, - *, - detection_callback: Callable[[str], None] | None = None, -) -> str | None: - """Receive wake word detection events. - - Args: - client: Wyoming client connection - logger: Logger instance - detection_callback: Optional callback for when wake word is detected - - Returns: - Name of detected wake word or None if no detection - - """ - while True: - event = await client.read_event() - if event is None: - logger.warning("Connection to wake word server lost.") - break - - if Detection.is_type(event.type): - detection = Detection.from_event(event) - wake_word_name = detection.name or "unknown" - logger.info("Wake word detected: %s", wake_word_name) - if detection_callback: - detection_callback(wake_word_name) - return wake_word_name - if NotDetected.is_type(event.type): - logger.debug("No wake word detected") - break - logger.debug("Ignoring event type: %s", event.type) - - return None - - -async def detect_wake_word_from_queue( - wake_server_ip: str, - wake_server_port: int, - wake_word_name: str, - logger: logging.Logger, - queue: asyncio.Queue, - *, - live: Live | None = None, - detection_callback: Callable[[str], None] | None = None, - quiet: bool = False, - progress_message: str = "Listening for wake word", -) -> str | None: - """Detect wake word from an audio queue.""" - try: - async with wyoming_client_context( - wake_server_ip, - wake_server_port, - "wake word", - logger, - quiet=quiet, - ) as client: - await client.write_event(Detect(names=[wake_word_name]).event()) - - _send_task, recv_task = await manage_send_receive_tasks( - _send_audio_from_queue_for_wake_detection( - client, - queue, - logger, - live, - quiet, - progress_message, - ), - _receive_wake_detection(client, logger, detection_callback=detection_callback), - return_when=asyncio.FIRST_COMPLETED, - ) - - if recv_task.done() and not recv_task.cancelled(): - return recv_task.result() - - return None - except (ConnectionRefusedError, asyncio.CancelledError, Exception): - return None diff --git a/agent_cli/wyoming_utils.py b/agent_cli/wyoming_utils.py deleted file mode 100644 index edbe31a78..000000000 --- a/agent_cli/wyoming_utils.py +++ /dev/null @@ -1,95 +0,0 @@ -"""Utility functions for Wyoming protocol interactions to eliminate code duplication.""" - -from __future__ import annotations - -import asyncio -from contextlib import asynccontextmanager -from typing import TYPE_CHECKING - -from wyoming.client import AsyncClient - -from agent_cli.utils import print_error_message - -if TYPE_CHECKING: - import logging - from collections.abc import AsyncGenerator, Coroutine - - -@asynccontextmanager -async def wyoming_client_context( - server_ip: str, - server_port: int, - server_type: str, - logger: logging.Logger, - *, - quiet: bool = False, -) -> AsyncGenerator[AsyncClient, None]: - """Context manager for Wyoming client connections with unified error handling. - - Args: - server_ip: Wyoming server IP - server_port: Wyoming server port - server_type: Type of server (e.g., "ASR", "TTS", "wake word") - logger: Logger instance - quiet: If True, suppress console error messages - - Yields: - Connected Wyoming client - - Raises: - ConnectionRefusedError: If connection fails - Exception: For other connection errors - - """ - uri = f"tcp://{server_ip}:{server_port}" - logger.info("Connecting to Wyoming %s server at %s", server_type, uri) - - try: - async with AsyncClient.from_uri(uri) as client: - logger.info("%s connection established", server_type) - yield client - except ConnectionRefusedError: - logger.exception("%s connection refused.", server_type) - if not quiet: - print_error_message( - f"{server_type} connection refused.", - f"Is the Wyoming {server_type.lower()} server running at {uri}?", - ) - raise - except Exception as e: - logger.exception("An error occurred during %s connection", server_type.lower()) - if not quiet: - print_error_message(f"{server_type} error: {e}") - raise - - -async def manage_send_receive_tasks( - send_task_coro: Coroutine, - receive_task_coro: Coroutine, - *, - return_when: str = asyncio.ALL_COMPLETED, -) -> tuple[asyncio.Task, asyncio.Task]: - """Manage send and receive tasks with proper cancellation. - - Args: - send_task_coro: Send task coroutine - receive_task_coro: Receive task coroutine - return_when: When to return (e.g., asyncio.ALL_COMPLETED) - - Returns: - Tuple of (send_task, receive_task) - both completed or cancelled - - """ - send_task = asyncio.create_task(send_task_coro) - recv_task = asyncio.create_task(receive_task_coro) - - done, pending = await asyncio.wait( - [send_task, recv_task], - return_when=return_when, - ) - - # Cancel any pending tasks - for task in pending: - task.cancel() - - return send_task, recv_task From 636ee7a3c02be0f37882638ab4bf4585fa96cbcc Mon Sep 17 00:00:00 2001 From: Bas Nijholt Date: Tue, 8 Jul 2025 20:44:01 -0700 Subject: [PATCH 7/8] Last Cline changes --- agent_cli/agents/_tts_common.py | 2 +- agent_cli/agents/_voice_agent_common.py | 23 +- agent_cli/agents/assistant.py | 39 ++- agent_cli/agents/autocorrect.py | 4 +- agent_cli/agents/chat.py | 18 +- agent_cli/agents/speak.py | 6 +- agent_cli/agents/transcribe.py | 18 +- agent_cli/agents/voice_edit.py | 9 +- agent_cli/asr.py | 290 ++++++++++++++++++++++ agent_cli/cli.py | 2 +- agent_cli/core/audio.py | 4 +- agent_cli/llm.py | 10 +- agent_cli/services/base.py | 12 - agent_cli/services/factory.py | 10 +- agent_cli/services/local.py | 53 ++-- agent_cli/services/openai.py | 12 +- agent_cli/tts.py | 313 ++++++++++++++++++++++++ tests/agents/test_interactive.py | 4 +- tests/agents/test_interactive_extra.py | 4 +- tests/agents/test_speak.py | 2 +- tests/agents/test_speak_e2e.py | 2 +- tests/agents/test_transcribe.py | 2 +- tests/agents/test_transcribe_e2e.py | 2 +- tests/agents/test_tts_common.py | 2 +- tests/agents/test_tts_common_extra.py | 2 +- tests/agents/test_voice_agent_common.py | 2 +- tests/agents/test_voice_edit_e2e.py | 2 +- tests/test_llm.py | 2 +- tests/test_services.py | 3 +- tests/test_tts.py | 2 +- tests/test_wake_word.py | 2 +- 31 files changed, 729 insertions(+), 129 deletions(-) create mode 100644 agent_cli/asr.py create mode 100644 agent_cli/tts.py diff --git a/agent_cli/agents/_tts_common.py b/agent_cli/agents/_tts_common.py index d8d01b937..110fd2d68 100644 --- a/agent_cli/agents/_tts_common.py +++ b/agent_cli/agents/_tts_common.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING from agent_cli import tts -from agent_cli.utils import InteractiveStopEvent, print_with_style +from agent_cli.core.utils import InteractiveStopEvent, print_with_style if TYPE_CHECKING: import logging diff --git a/agent_cli/agents/_voice_agent_common.py b/agent_cli/agents/_voice_agent_common.py index 58e2a6053..751ff6ba7 100644 --- a/agent_cli/agents/_voice_agent_common.py +++ b/agent_cli/agents/_voice_agent_common.py @@ -8,10 +8,10 @@ import pyperclip -from agent_cli import asr from agent_cli.agents._tts_common import handle_tts_playback +from agent_cli.core.utils import print_input_panel, print_with_style from agent_cli.llm import process_and_update_clipboard -from agent_cli.utils import print_input_panel, print_with_style +from agent_cli.services.factory import get_asr_service if TYPE_CHECKING: from rich.live import Live @@ -25,10 +25,8 @@ async def get_instruction_from_audio( *, audio_data: bytes, provider_config: config.ProviderSelection, - audio_input_config: config.AudioInput, wyoming_asr_config: config.WyomingASR, openai_asr_config: config.OpenAIASR, - ollama_config: config.Ollama, openai_llm_config: config.OpenAILLM, logger: logging.Logger, quiet: bool, @@ -36,18 +34,15 @@ async def get_instruction_from_audio( """Transcribe audio data and return the instruction.""" try: start_time = time.monotonic() - transcriber = asr.get_recorded_audio_transcriber(provider_config) - instruction = await transcriber( - audio_data=audio_data, - provider_config=provider_config, - audio_input_config=audio_input_config, - wyoming_asr_config=wyoming_asr_config, - openai_asr_config=openai_asr_config, - ollama_config=ollama_config, - openai_llm_config=openai_llm_config, - logger=logger, + transcriber = get_asr_service( + provider_config, + wyoming_asr_config, + openai_asr_config, + openai_llm_config, + logger, quiet=quiet, ) + instruction = await transcriber.transcribe(audio_data) elapsed = time.monotonic() - start_time if not instruction or not instruction.strip(): diff --git a/agent_cli/agents/assistant.py b/agent_cli/agents/assistant.py index 2eba108bf..393a44fc2 100644 --- a/agent_cli/agents/assistant.py +++ b/agent_cli/agents/assistant.py @@ -34,23 +34,23 @@ from typing import TYPE_CHECKING import agent_cli.agents._cli_options as opts -from agent_cli import asr, wake_word -from agent_cli.core import audio -from agent_cli.core import process as process_manager from agent_cli import config from agent_cli.agents._voice_agent_common import ( get_instruction_from_audio, process_instruction_and_respond, ) -from agent_cli.audio import pyaudio_context, setup_devices from agent_cli.cli import app, setup_logging -from agent_cli.utils import ( +from agent_cli.core import audio +from agent_cli.core import process as process_manager +from agent_cli.core.audio import pyaudio_context, setup_devices +from agent_cli.core.utils import ( InteractiveStopEvent, maybe_live, print_with_style, signal_handling_context, stop_or_status_or_toggle, ) +from agent_cli.services.local import WyomingWakeWordService if TYPE_CHECKING: import pyaudio @@ -112,15 +112,14 @@ async def _record_audio_with_wake_word( # Create a queue for wake word detection wake_queue = await tee.add_queue() - detected_word = await wake_word.detect_wake_word_from_queue( - wake_server_ip=wake_word_config.server_ip, - wake_server_port=wake_word_config.server_port, - wake_word_name=wake_word_config.wake_word_name, - logger=logger, - queue=wake_queue, - quiet=quiet, + wake_word_service = WyomingWakeWordService( + wake_word_config, + logger, + wake_queue, live=live, + quiet=quiet, ) + detected_word = await wake_word_service.detect() if not detected_word or stop_event.is_set(): # Clean up the queue if we exit early @@ -135,19 +134,17 @@ async def _record_audio_with_wake_word( # Add a new queue for recording record_queue = await tee.add_queue() - record_task = asyncio.create_task(asr.record_audio_to_buffer(record_queue, logger)) + record_task = asyncio.create_task(audio.record_audio_to_buffer(record_queue, logger)) # Use the same wake_queue for stop-word detection - stop_detected_word = await wake_word.detect_wake_word_from_queue( - wake_server_ip=wake_word_config.server_ip, - wake_server_port=wake_word_config.server_port, - wake_word_name=wake_word_config.wake_word_name, - logger=logger, - queue=wake_queue, - quiet=quiet, + wake_word_service = WyomingWakeWordService( + wake_word_config, + logger, + wake_queue, live=live, - progress_message="Recording... (say wake word to stop)", + quiet=quiet, ) + stop_detected_word = await wake_word_service.detect() # Stop the recording task by removing its queue await tee.remove_queue(record_queue) diff --git a/agent_cli/agents/autocorrect.py b/agent_cli/agents/autocorrect.py index 8d51b36e1..f99c2d944 100644 --- a/agent_cli/agents/autocorrect.py +++ b/agent_cli/agents/autocorrect.py @@ -14,8 +14,7 @@ import agent_cli.agents._cli_options as opts from agent_cli import config from agent_cli.cli import app, setup_logging -from agent_cli.llm import build_agent -from agent_cli.utils import ( +from agent_cli.core.utils import ( create_status, get_clipboard_text, print_error_message, @@ -23,6 +22,7 @@ print_output_panel, print_with_style, ) +from agent_cli.llm import build_agent if TYPE_CHECKING: from rich.status import Status diff --git a/agent_cli/agents/chat.py b/agent_cli/agents/chat.py index 7cd8940fc..6a782d33c 100644 --- a/agent_cli/agents/chat.py +++ b/agent_cli/agents/chat.py @@ -25,14 +25,12 @@ import typer import agent_cli.agents._cli_options as opts -from agent_cli import asr -from agent_cli.core import process as process_manager from agent_cli import config from agent_cli.agents._tts_common import handle_tts_playback -from agent_cli.audio import pyaudio_context, setup_devices from agent_cli.cli import app, setup_logging -from agent_cli.llm import get_llm_response -from agent_cli.utils import ( +from agent_cli.core import process as process_manager +from agent_cli.core.audio import pyaudio_context, setup_devices +from agent_cli.core.utils import ( InteractiveStopEvent, console, format_timedelta_to_ago, @@ -44,6 +42,8 @@ signal_handling_context, stop_or_status_or_toggle, ) +from agent_cli.llm import get_llm_response +from agent_cli.services.factory import get_asr_service if TYPE_CHECKING: import pyaudio @@ -151,7 +151,6 @@ async def _handle_conversation_turn( provider_cfg: config.ProviderSelection, general_cfg: config.General, history_cfg: config.History, - audio_in_cfg: config.AudioInput, wyoming_asr_cfg: config.WyomingASR, openai_asr_cfg: config.OpenAIASR, ollama_cfg: config.Ollama, @@ -177,14 +176,15 @@ async def _handle_conversation_turn( # 1. Transcribe user's command start_time = time.monotonic() - transcriber = asr.get_transcriber( + transcriber = get_asr_service( provider_cfg, - audio_in_cfg, wyoming_asr_cfg, openai_asr_cfg, openai_llm_cfg, + LOGGER, + quiet=general_cfg.quiet, ) - instruction = await transcriber( + instruction = await transcriber.transcribe( p=p, stop_event=stop_event, quiet=general_cfg.quiet, diff --git a/agent_cli/agents/speak.py b/agent_cli/agents/speak.py index 3e601c209..b1617340f 100644 --- a/agent_cli/agents/speak.py +++ b/agent_cli/agents/speak.py @@ -10,12 +10,12 @@ import typer import agent_cli.agents._cli_options as opts -from agent_cli.core import process as process_manager from agent_cli import config from agent_cli.agents._tts_common import handle_tts_playback -from agent_cli.audio import pyaudio_context, setup_devices from agent_cli.cli import app, setup_logging -from agent_cli.utils import ( +from agent_cli.core import process as process_manager +from agent_cli.core.audio import pyaudio_context, setup_devices +from agent_cli.core.utils import ( get_clipboard_text, maybe_live, print_input_panel, diff --git a/agent_cli/agents/transcribe.py b/agent_cli/agents/transcribe.py index a94e20509..0d3cee5e7 100644 --- a/agent_cli/agents/transcribe.py +++ b/agent_cli/agents/transcribe.py @@ -11,13 +11,11 @@ import pyperclip import agent_cli.agents._cli_options as opts -from agent_cli import asr -from agent_cli.core import process as process_manager from agent_cli import config -from agent_cli.audio import pyaudio_context, setup_devices from agent_cli.cli import app, setup_logging -from agent_cli.llm import process_and_update_clipboard -from agent_cli.utils import ( +from agent_cli.core import process as process_manager +from agent_cli.core.audio import pyaudio_context, setup_devices +from agent_cli.core.utils import ( maybe_live, print_input_panel, print_output_panel, @@ -25,6 +23,8 @@ signal_handling_context, stop_or_status_or_toggle, ) +from agent_cli.llm import process_and_update_clipboard +from agent_cli.services.factory import get_asr_service if TYPE_CHECKING: import pyaudio @@ -71,7 +71,6 @@ async def _async_main( *, provider_cfg: config.ProviderSelection, general_cfg: config.General, - audio_in_cfg: config.AudioInput, wyoming_asr_cfg: config.WyomingASR, openai_asr_cfg: config.OpenAIASR, ollama_cfg: config.Ollama, @@ -83,14 +82,15 @@ async def _async_main( start_time = time.monotonic() with maybe_live(not general_cfg.quiet) as live: with signal_handling_context(LOGGER, general_cfg.quiet) as stop_event: - transcriber = asr.get_transcriber( + transcriber = get_asr_service( provider_cfg, - audio_in_cfg, wyoming_asr_cfg, openai_asr_cfg, openai_llm_cfg, + LOGGER, + quiet=general_cfg.quiet, ) - transcript = await transcriber( + transcript = await transcriber.transcribe( logger=LOGGER, p=p, stop_event=stop_event, diff --git a/agent_cli/agents/voice_edit.py b/agent_cli/agents/voice_edit.py index eff459b8e..0b25d074f 100644 --- a/agent_cli/agents/voice_edit.py +++ b/agent_cli/agents/voice_edit.py @@ -39,16 +39,15 @@ from pathlib import Path # noqa: TC003 import agent_cli.agents._cli_options as opts -from agent_cli import asr -from agent_cli.core import process as process_manager from agent_cli import config from agent_cli.agents._voice_agent_common import ( get_instruction_from_audio, process_instruction_and_respond, ) -from agent_cli.audio import pyaudio_context, setup_devices from agent_cli.cli import app, setup_logging -from agent_cli.utils import ( +from agent_cli.core import process as process_manager +from agent_cli.core.audio import pyaudio_context, record_audio_with_manual_stop, setup_devices +from agent_cli.core.utils import ( get_clipboard_text, maybe_live, print_input_panel, @@ -119,7 +118,7 @@ async def _async_main( signal_handling_context(LOGGER, general_cfg.quiet) as stop_event, maybe_live(not general_cfg.quiet) as live, ): - audio_data = await asr.record_audio_with_manual_stop( + audio_data = await record_audio_with_manual_stop( p, input_device_index, stop_event, diff --git a/agent_cli/asr.py b/agent_cli/asr.py new file mode 100644 index 000000000..2b1165d10 --- /dev/null +++ b/agent_cli/asr.py @@ -0,0 +1,290 @@ +"""Module for Automatic Speech Recognition using Wyoming or OpenAI.""" + +from __future__ import annotations + +import asyncio +import io +from functools import partial +from typing import TYPE_CHECKING + +from wyoming.asr import Transcribe, Transcript, TranscriptChunk, TranscriptStart, TranscriptStop +from wyoming.audio import AudioChunk, AudioStart, AudioStop + +from agent_cli import constants +from agent_cli.audio import ( + open_pyaudio_stream, + read_audio_stream, + read_from_queue, + setup_input_stream, +) + +from agent_cli.core.audio import ( + open_pyaudio_stream, + read_audio_stream, + read_from_queue, + setup_input_stream, +) +from agent_cli.services.factory import get_asr_service +from agent_cli.services.local import ( + WyomingTranscriptionService, + manage_send_receive_tasks, + wyoming_client_context, +) +from agent_cli.services.openai import OpenAITranscriptionService + +if TYPE_CHECKING: + import logging + from collections.abc import Awaitable, Callable + + import pyaudio + from rich.live import Live + from wyoming.client import AsyncClient + + from agent_cli import config + from agent_cli.core.utils import InteractiveStopEvent + + +def get_transcriber( + provider_config: config.ProviderSelection, + audio_input_config: config.AudioInput, + wyoming_asr_config: config.WyomingASR, + openai_asr_config: config.OpenAIASR, + openai_llm_config: config.OpenAILLM, + logger: logging.Logger, + *, + quiet: bool = False, +) -> Callable[..., Awaitable[str | None]]: + """Return the appropriate transcriber for live audio based on the provider.""" + asr_service = get_asr_service( + provider_config, + wyoming_asr_config, + openai_asr_config, + openai_llm_config, + logger, + quiet=quiet, + ) + + async def transcribe_live_audio(p: pyaudio.PyAudio, stop_event: InteractiveStopEvent, live: Live) -> str | None: + """Record and transcribe live audio.""" + audio_data = await record_audio_with_manual_stop( + p, + audio_input_config.input_device_index, + stop_event, + logger, + quiet=quiet, + live=live, + ) + if not audio_data: + return None + return await asr_service.transcribe(audio_data) + + return transcribe_live_audio + + +def get_recorded_audio_transcriber( + provider_config: config.ProviderSelection, + wyoming_asr_config: config.WyomingASR, + openai_asr_config: config.OpenAIASR, + openai_llm_config: config.OpenAILLM, + logger: logging.Logger, + *, + quiet: bool = False, +) -> Callable[[bytes], Awaitable[str]]: + """Return the appropriate transcriber for recorded audio based on the provider.""" + asr_service = get_asr_service( + provider_config, + wyoming_asr_config, + openai_asr_config, + openai_llm_config, + logger, + quiet=quiet, + ) + return asr_service.transcribe + + +async def _send_audio( + client: AsyncClient, + stream: pyaudio.Stream, + stop_event: InteractiveStopEvent, + logger: logging.Logger, + *, + live: Live, + quiet: bool = False, +) -> None: + """Read from mic and send to Wyoming server.""" + await client.write_event(Transcribe().event()) + await client.write_event(AudioStart(**constants.WYOMING_AUDIO_CONFIG).event()) + + async def send_chunk(chunk: bytes) -> None: + """Send audio chunk to ASR server.""" + await client.write_event(AudioChunk(audio=chunk, **constants.WYOMING_AUDIO_CONFIG).event()) + + try: + await read_audio_stream( + stream=stream, + stop_event=stop_event, + chunk_handler=send_chunk, + logger=logger, + live=live, + quiet=quiet, + progress_message="Listening", + progress_style="blue", + ) + finally: + await client.write_event(AudioStop().event()) + logger.debug("Sent AudioStop") + + +async def record_audio_to_buffer( + queue: asyncio.Queue, + logger: logging.Logger, +) -> bytes: + """Record audio from a queue to a buffer.""" + audio_buffer = io.BytesIO() + + def buffer_chunk(chunk: bytes) -> None: + """Buffer audio chunk.""" + audio_buffer.write(chunk) + + await read_from_queue(queue=queue, chunk_handler=buffer_chunk, logger=logger) + + return audio_buffer.getvalue() + + +async def _receive_transcript( + client: AsyncClient, + logger: logging.Logger, + *, + chunk_callback: Callable[[str], None] | None = None, + final_callback: Callable[[str], None] | None = None, +) -> str: + """Receive transcription events and return the final transcript.""" + transcript_text = "" + while True: + event = await client.read_event() + if event is None: + logger.warning("Connection to ASR server lost.") + break + + if Transcript.is_type(event.type): + transcript = Transcript.from_event(event) + transcript_text = transcript.text + logger.info("Final transcript: %s", transcript_text) + if final_callback: + final_callback(transcript_text) + break + if TranscriptChunk.is_type(event.type): + chunk = TranscriptChunk.from_event(event) + logger.debug("Transcript chunk: %s", chunk.text) + if chunk_callback: + chunk_callback(chunk.text) + elif TranscriptStart.is_type(event.type) or TranscriptStop.is_type(event.type): + logger.debug("Received %s", event.type) + else: + logger.debug("Ignoring event type: %s", event.type) + + return transcript_text + + +async def record_audio_with_manual_stop( + p: pyaudio.PyAudio, + input_device_index: int | None, + stop_event: InteractiveStopEvent, + logger: logging.Logger, + *, + quiet: bool = False, + live: Live | None = None, +) -> bytes: + """Record audio to a buffer using a manual stop signal.""" + audio_buffer = io.BytesIO() + + def buffer_chunk(chunk: bytes) -> None: + """Buffer audio chunk.""" + audio_buffer.write(chunk) + + stream_config = setup_input_stream(input_device_index) + with open_pyaudio_stream(p, **stream_config) as stream: + await read_audio_stream( + stream=stream, + stop_event=stop_event, + chunk_handler=buffer_chunk, + logger=logger, + live=live, + quiet=quiet, + progress_message="Recording", + progress_style="green", + ) + return audio_buffer.getvalue() + + +async def transcribe_recorded_audio_wyoming( + *, + audio_data: bytes, + wyoming_asr_config: config.WyomingASR, + logger: logging.Logger, + quiet: bool = False, + **_kwargs: object, +) -> str: + """Process pre-recorded audio data with Wyoming ASR server.""" + service = WyomingTranscriptionService(wyoming_asr_config, logger, quiet=quiet) + return await service.transcribe(audio_data) + + +async def transcribe_live_audio_wyoming( + *, + audio_input_config: config.AudioInput, + wyoming_asr_config: config.WyomingASR, + logger: logging.Logger, + p: pyaudio.PyAudio, + stop_event: InteractiveStopEvent, + live: Live, + quiet: bool = False, + chunk_callback: Callable[[str], None] | None = None, + final_callback: Callable[[str], None] | None = None, + **_kwargs: object, +) -> str | None: + """Unified ASR transcription function.""" + service = WyomingTranscriptionService(wyoming_asr_config, logger, quiet=quiet) + audio_data = await record_audio_with_manual_stop( + p, + audio_input_config.input_device_index, + stop_event, + logger, + quiet=quiet, + live=live, + ) + if not audio_data: + return None + transcript = await service.transcribe(audio_data) + if chunk_callback: + chunk_callback(transcript) + if final_callback: + final_callback(transcript) + return transcript + + +async def transcribe_live_audio_openai( + *, + audio_input_config: config.AudioInput, + openai_asr_config: config.OpenAIASR, + openai_llm_config: config.OpenAILLM, + logger: logging.Logger, + p: pyaudio.PyAudio, + stop_event: InteractiveStopEvent, + live: Live, + quiet: bool = False, + **_kwargs: object, +) -> str | None: + """Record and transcribe live audio using OpenAI Whisper.""" + service = OpenAITranscriptionService(openai_asr_config, openai_llm_config, logger) + audio_data = await record_audio_with_manual_stop( + p, + audio_input_config.input_device_index, + stop_event, + logger, + quiet=quiet, + live=live, + ) + if not audio_data: + return None + return await service.transcribe(audio_data) diff --git a/agent_cli/cli.py b/agent_cli/cli.py index 7b8803ce0..1b861ee80 100644 --- a/agent_cli/cli.py +++ b/agent_cli/cli.py @@ -7,7 +7,7 @@ import typer -from .config_loader import load_config +from .config import load_config from .utils import console if TYPE_CHECKING: diff --git a/agent_cli/core/audio.py b/agent_cli/core/audio.py index 186cc8785..99eee2e67 100644 --- a/agent_cli/core/audio.py +++ b/agent_cli/core/audio.py @@ -12,7 +12,7 @@ from rich.text import Text from agent_cli import constants -from agent_cli.utils import InteractiveStopEvent, console, print_device_index, print_with_style +from agent_cli.core.utils import InteractiveStopEvent, console, print_device_index, print_with_style if TYPE_CHECKING: import logging @@ -20,7 +20,7 @@ from rich.live import Live - from agent_cli.agents import config + from agent_cli import config class _AudioTee: diff --git a/agent_cli/llm.py b/agent_cli/llm.py index 34121ae99..ea083104b 100644 --- a/agent_cli/llm.py +++ b/agent_cli/llm.py @@ -15,12 +15,9 @@ if TYPE_CHECKING: import logging - from pydantic_ai import Agent from pydantic_ai.tools import Tool - from agent_cli.agents import config - - + from agent_cli import config # --- LLM (Editing) Logic --- @@ -54,7 +51,10 @@ async def get_llm_response( ) -> str | None: """Get a response from the LLM with optional clipboard and output handling.""" llm_service = get_llm_service( - provider_config, ollama_config, openai_config, logger + provider_config, + ollama_config, + openai_config, + logger, ) start_time = time.monotonic() diff --git a/agent_cli/services/base.py b/agent_cli/services/base.py index d9890a842..49a4483fc 100644 --- a/agent_cli/services/base.py +++ b/agent_cli/services/base.py @@ -3,14 +3,6 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - import logging - - from rich.live import Live - - from agent_cli import config class ASRService(ABC): @@ -19,7 +11,6 @@ class ASRService(ABC): @abstractmethod async def transcribe(self, audio_data: bytes) -> str: """Transcribe audio data to text.""" - pass class LLMService(ABC): @@ -35,7 +26,6 @@ async def get_response( tools: list | None = None, ) -> str | None: """Get a response from the language model.""" - pass class TTSService(ABC): @@ -44,7 +34,6 @@ class TTSService(ABC): @abstractmethod async def synthesize(self, text: str) -> bytes | None: """Synthesize text to speech audio data.""" - pass class WakeWordService(ABC): @@ -53,4 +42,3 @@ class WakeWordService(ABC): @abstractmethod async def detect(self) -> str | None: """Detect the wake word.""" - pass diff --git a/agent_cli/services/factory.py b/agent_cli/services/factory.py index 264097f24..be5370bfc 100644 --- a/agent_cli/services/factory.py +++ b/agent_cli/services/factory.py @@ -4,7 +4,6 @@ from typing import TYPE_CHECKING -from agent_cli.services.base import ASRService, LLMService, TTSService from agent_cli.services.local import ( OllamaLLMService, WyomingTranscriptionService, @@ -20,6 +19,7 @@ import logging from agent_cli import config + from agent_cli.services.base import ASRService, LLMService, TTSService def get_asr_service( @@ -39,7 +39,9 @@ def get_asr_service( logger=logger, ) return WyomingTranscriptionService( - wyoming_asr_config=wyoming_asr_config, logger=logger, quiet=quiet + wyoming_asr_config=wyoming_asr_config, + logger=logger, + quiet=quiet, ) @@ -72,5 +74,7 @@ def get_tts_service( logger=logger, ) return WyomingTTSService( - wyoming_tts_config=wyoming_tts_config, logger=logger, quiet=quiet + wyoming_tts_config=wyoming_tts_config, + logger=logger, + quiet=quiet, ) diff --git a/agent_cli/services/local.py b/agent_cli/services/local.py index dfcbd963a..4731e59a1 100644 --- a/agent_cli/services/local.py +++ b/agent_cli/services/local.py @@ -4,21 +4,26 @@ import asyncio import io +import wave from contextlib import asynccontextmanager from typing import TYPE_CHECKING -from wyoming.asr import Transcribe, Transcript, TranscriptChunk, TranscriptStart, TranscriptStop +from pydantic_ai import Agent +from pydantic_ai.models.openai import OpenAIModel +from pydantic_ai.providers.openai import OpenAIProvider +from wyoming.asr import Transcribe, Transcript from wyoming.audio import AudioChunk, AudioStart, AudioStop from wyoming.client import AsyncClient from wyoming.tts import Synthesize, SynthesizeVoice from wyoming.wake import Detect, Detection, NotDetected from agent_cli import constants +from agent_cli.core.audio import read_from_queue from agent_cli.services.base import ASRService, LLMService, TTSService, WakeWordService if TYPE_CHECKING: import logging - from collections.abc import AsyncGenerator, Awaitable, Callable + from collections.abc import AsyncGenerator, Awaitable from rich.live import Live @@ -41,7 +46,12 @@ async def wyoming_client_context( async with AsyncClient(host, port) as client: yield client except ConnectionRefusedError: - logger.error("Connection refused to %s server at %s:%s", service_name, host, port) + logger.exception( + "Connection refused to %s server at %s:%s", + service_name, + host, + port, + ) raise @@ -68,7 +78,8 @@ def __init__( logger: logging.Logger, *, quiet: bool = False, - ): + ) -> None: + """Initialize the WyomingTranscriptionService.""" self.wyoming_asr_config = wyoming_asr_config self.logger = logger self.quiet = quiet @@ -89,7 +100,7 @@ async def transcribe(self, audio_data: bytes) -> str: for i in range(0, len(audio_data), chunk_size): chunk = audio_data[i : i + chunk_size] await client.write_event( - AudioChunk(audio=chunk, **constants.WYOMING_AUDIO_CONFIG).event() + AudioChunk(audio=chunk, **constants.WYOMING_AUDIO_CONFIG).event(), ) await client.write_event(AudioStop().event()) return await self._receive_transcript(client) @@ -116,15 +127,14 @@ def __init__( self, ollama_config: config.Ollama, logger: logging.Logger, - ): - from pydantic_ai.models.openai import OpenAIModel # noqa: PLC0415 - from pydantic_ai.providers.openai import OpenAIProvider # noqa: PLC0415 - + ) -> None: + """Initialize the OllamaLLMService.""" self.ollama_config = ollama_config self.logger = logger provider = OpenAIProvider(base_url=f"{self.ollama_config.ollama_host}/v1") self.model = OpenAIModel( - model_name=self.ollama_config.ollama_model, provider=provider + model_name=self.ollama_config.ollama_model, + provider=provider, ) async def get_response( @@ -136,8 +146,6 @@ async def get_response( tools: list | None = None, ) -> str | None: """Get a response from the language model.""" - from pydantic_ai import Agent # noqa: PLC0415 - agent = Agent( model=self.model, system_prompt=system_prompt or (), @@ -157,7 +165,8 @@ def __init__( logger: logging.Logger, *, quiet: bool = False, - ): + ) -> None: + """Initialize the WyomingTTSService.""" self.wyoming_tts_config = wyoming_tts_config self.logger = logger self.quiet = quiet @@ -197,7 +206,8 @@ def _create_synthesis_request(self, text: str) -> Synthesize: return synthesize_event async def _process_audio_events( - self, client: AsyncClient + self, + client: AsyncClient, ) -> tuple[bytes, int | None, int | None, int | None]: """Process audio events from TTS server and return audio data with metadata.""" audio_data = io.BytesIO() @@ -220,11 +230,13 @@ async def _process_audio_events( return audio_data.getvalue(), sample_rate, sample_width, channels def _create_wav_data( - self, audio_data: bytes, sample_rate: int, sample_width: int, channels: int + self, + audio_data: bytes, + sample_rate: int, + sample_width: int, + channels: int, ) -> bytes: """Convert raw audio data to WAV format.""" - import wave - wav_data = io.BytesIO() with wave.open(wav_data, "wb") as wav_file: wav_file.setnchannels(channels) @@ -245,7 +257,8 @@ def __init__( *, live: Live | None = None, quiet: bool = False, - ): + ) -> None: + """Initialize the WyomingWakeWordService.""" self.wake_word_config = wake_word_config self.logger = logger self.queue = queue @@ -273,14 +286,12 @@ async def detect(self) -> str | None: async def _send_audio_from_queue_for_wake_detection(self, client: AsyncClient) -> None: """Read from a queue and send to Wyoming wake word server.""" - from agent_cli.core.audio import read_from_queue - await client.write_event(AudioStart(**constants.WYOMING_AUDIO_CONFIG).event()) try: await read_from_queue( queue=self.queue, chunk_handler=lambda chunk: client.write_event( - AudioChunk(audio=chunk, **constants.WYOMING_AUDIO_CONFIG).event() + AudioChunk(audio=chunk, **constants.WYOMING_AUDIO_CONFIG).event(), ), logger=self.logger, ) diff --git a/agent_cli/services/openai.py b/agent_cli/services/openai.py index 4d1c911eb..c372a8fa4 100644 --- a/agent_cli/services/openai.py +++ b/agent_cli/services/openai.py @@ -33,7 +33,8 @@ def __init__( openai_asr_config: config.OpenAIASR, openai_llm_config: config.OpenAILLM, logger: logging.Logger, - ): + ) -> None: + """Initialize the OpenAITranscriptionService.""" self.openai_asr_config = openai_asr_config self.openai_llm_config = openai_llm_config self.logger = logger @@ -61,7 +62,8 @@ def __init__( self, openai_llm_config: config.OpenAILLM, logger: logging.Logger, - ): + ) -> None: + """Initialize the OpenAILLMService.""" from pydantic_ai.models.openai import OpenAIModel # noqa: PLC0415 from pydantic_ai.providers.openai import OpenAIProvider # noqa: PLC0415 @@ -72,7 +74,8 @@ def __init__( raise ValueError(msg) provider = OpenAIProvider(api_key=self.openai_llm_config.openai_api_key) self.model = OpenAIModel( - model_name=self.openai_llm_config.openai_llm_model, provider=provider + model_name=self.openai_llm_config.openai_llm_model, + provider=provider, ) async def get_response( @@ -104,7 +107,8 @@ def __init__( openai_tts_config: config.OpenAITTS, openai_llm_config: config.OpenAILLM, logger: logging.Logger, - ): + ) -> None: + """Initialize the OpenAITTSService.""" self.openai_tts_config = openai_tts_config self.openai_llm_config = openai_llm_config self.logger = logger diff --git a/agent_cli/tts.py b/agent_cli/tts.py new file mode 100644 index 000000000..e556c87ba --- /dev/null +++ b/agent_cli/tts.py @@ -0,0 +1,313 @@ +"""Module for Text-to-Speech using Wyoming or OpenAI.""" + +from __future__ import annotations + +import asyncio +import importlib.util +import io +import wave +from functools import partial +from typing import TYPE_CHECKING + +from wyoming.audio import AudioChunk, AudioStart, AudioStop +from wyoming.tts import Synthesize, SynthesizeVoice + +from agent_cli import constants +from agent_cli.audio import open_pyaudio_stream, pyaudio_context, setup_output_stream + +from agent_cli.core.audio import open_pyaudio_stream, pyaudio_context, setup_output_stream +from agent_cli.core.utils import ( + InteractiveStopEvent, + live_timer, + print_error_message, + print_with_style, +) +from agent_cli.services.factory import get_tts_service +from agent_cli.services.local import manage_send_receive_tasks, wyoming_client_context + +if TYPE_CHECKING: + import logging + from collections.abc import Awaitable, Callable + + from rich.live import Live + from wyoming.client import AsyncClient + + from agent_cli import config + +has_audiostretchy = importlib.util.find_spec("audiostretchy") is not None + + +def get_synthesizer( + provider_config: config.ProviderSelection, + audio_output_config: config.AudioOutput, + wyoming_tts_config: config.WyomingTTS, + openai_tts_config: config.OpenAITTS, + openai_llm_config: config.OpenAILLM, + logger: logging.Logger, + *, + quiet: bool = False, +) -> Callable[..., Awaitable[bytes | None]]: + """Return the appropriate synthesizer based on the config.""" + if not audio_output_config.enable_tts: + return _dummy_synthesizer + + tts_service = get_tts_service( + provider_config, + wyoming_tts_config, + openai_tts_config, + openai_llm_config, + logger, + quiet=quiet, + ) + + async def synthesize(text: str) -> bytes | None: + """Synthesize speech from text.""" + return await tts_service.synthesize(text) + + return synthesize + + +def _create_synthesis_request( + text: str, + *, + voice_name: str | None = None, + language: str | None = None, + speaker: str | None = None, +) -> Synthesize: + """Create a synthesis request with optional voice parameters.""" + synthesize_event = Synthesize(text=text) + + # Add voice parameters if specified + if voice_name or language or speaker: + synthesize_event.voice = SynthesizeVoice( + name=voice_name, + language=language, + speaker=speaker, + ) + + return synthesize_event + + +async def _process_audio_events( + client: AsyncClient, + logger: logging.Logger, +) -> tuple[bytes, int | None, int | None, int | None]: + """Process audio events from TTS server and return audio data with metadata.""" + audio_data = io.BytesIO() + sample_rate = None + sample_width = None + channels = None + + while True: + event = await client.read_event() + if event is None: + logger.warning("Connection to TTS server lost.") + break + + if AudioStart.is_type(event.type): + audio_start = AudioStart.from_event(event) + sample_rate = audio_start.rate + sample_width = audio_start.width + channels = audio_start.channels + logger.debug( + "Audio stream started: %dHz, %d channels, %d bytes/sample", + sample_rate, + channels, + sample_width, + ) + + elif AudioChunk.is_type(event.type): + chunk = AudioChunk.from_event(event) + audio_data.write(chunk.audio) + logger.debug("Received %d bytes of audio", len(chunk.audio)) + + elif AudioStop.is_type(event.type): + logger.debug("Audio stream completed") + break + else: + logger.debug("Ignoring event type: %s", event.type) + + return audio_data.getvalue(), sample_rate, sample_width, channels + + +def _create_wav_data( + audio_data: bytes, + sample_rate: int, + sample_width: int, + channels: int, +) -> bytes: + """Convert raw audio data to WAV format.""" + wav_data = io.BytesIO() + with wave.open(wav_data, "wb") as wav_file: + wav_file.setnchannels(channels) + wav_file.setsampwidth(sample_width) + wav_file.setframerate(sample_rate) + wav_file.writeframes(audio_data) + return wav_data.getvalue() + + +async def _dummy_synthesizer(**_kwargs: object) -> bytes | None: + """A dummy synthesizer that does nothing.""" + return None + + +async def _synthesize_speech_wyoming( + *, + text: str, + wyoming_tts_config: config.WyomingTTS, + logger: logging.Logger, + quiet: bool = False, + live: Live, + **_kwargs: object, +) -> bytes | None: + """Synthesize speech from text using Wyoming TTS server.""" + try: + async with wyoming_client_context( + wyoming_tts_config.wyoming_tts_ip, + wyoming_tts_config.wyoming_tts_port, + "TTS", + logger, + quiet=quiet, + ) as client: + async with live_timer(live, "🔊 Synthesizing text", style="blue", quiet=quiet): + synthesize_event = _create_synthesis_request( + text, + voice_name=wyoming_tts_config.wyoming_voice, + language=wyoming_tts_config.wyoming_tts_language, + speaker=wyoming_tts_config.wyoming_speaker, + ) + _send_task, recv_task = await manage_send_receive_tasks( + client.write_event(synthesize_event.event()), + _process_audio_events(client, logger), + ) + audio_data, sample_rate, sample_width, channels = recv_task.result() + if sample_rate and sample_width and channels and audio_data: + wav_data = _create_wav_data(audio_data, sample_rate, sample_width, channels) + logger.info("Speech synthesis completed: %d bytes", len(wav_data)) + return wav_data + logger.warning("No audio data received from TTS server") + return None + except (ConnectionRefusedError, Exception): + return None + + +def _apply_speed_adjustment( + audio_data: io.BytesIO, + speed: float, +) -> tuple[io.BytesIO, bool]: + """Apply speed adjustment to audio data.""" + if speed == 1.0 or not has_audiostretchy: + return audio_data, False + from audiostretchy.stretch import AudioStretch # noqa: PLC0415 + + audio_data.seek(0) + input_copy = io.BytesIO(audio_data.read()) + audio_stretch = AudioStretch() + audio_stretch.open(file=input_copy, format="wav") + audio_stretch.stretch(ratio=1 / speed) + out = io.BytesIO() + audio_stretch.save_wav(out, close=False) + out.seek(0) + return out, True + + +async def play_audio( + audio_data: bytes, + logger: logging.Logger, + *, + audio_output_config: config.AudioOutput, + quiet: bool = False, + stop_event: InteractiveStopEvent | None = None, + live: Live, +) -> None: + """Play WAV audio data using PyAudio.""" + try: + wav_io = io.BytesIO(audio_data) + speed = audio_output_config.tts_speed + wav_io, speed_changed = _apply_speed_adjustment(wav_io, speed) + with wave.open(wav_io, "rb") as wav_file: + sample_rate = wav_file.getframerate() + channels = wav_file.getnchannels() + sample_width = wav_file.getsampwidth() + frames = wav_file.readframes(wav_file.getnframes()) + if not speed_changed: + sample_rate = int(sample_rate * speed) + base_msg = f"🔊 Playing audio at {speed}x speed" if speed != 1.0 else "🔊 Playing audio" + async with live_timer(live, base_msg, style="blue", quiet=quiet): + with pyaudio_context() as p: + stream_config = setup_output_stream( + audio_output_config.output_device_index, + sample_rate=sample_rate, + sample_width=sample_width, + channels=channels, + ) + with open_pyaudio_stream(p, **stream_config) as stream: + chunk_size = constants.PYAUDIO_CHUNK_SIZE + for i in range(0, len(frames), chunk_size): + if stop_event and stop_event.is_set(): + logger.info("Audio playback interrupted") + if not quiet: + print_with_style("⏹️ Audio playback interrupted", style="yellow") + break + chunk = frames[i : i + chunk_size] + stream.write(chunk) + await asyncio.sleep(0) + if not (stop_event and stop_event.is_set()): + logger.info("Audio playback completed (speed: %.1fx)", speed) + if not quiet: + print_with_style("✅ Audio playback finished") + except Exception as e: + logger.exception("Error during audio playback") + if not quiet: + print_error_message(f"Playback error: {e}") + + +async def speak_text( + *, + text: str, + provider_config: config.ProviderSelection, + audio_output_config: config.AudioOutput, + wyoming_tts_config: config.WyomingTTS, + openai_tts_config: config.OpenAITTS, + openai_llm_config: config.OpenAILLM, + logger: logging.Logger, + quiet: bool = False, + play_audio_flag: bool = True, + stop_event: InteractiveStopEvent | None = None, + live: Live, +) -> bytes | None: + """Synthesize and optionally play speech from text.""" + synthesizer = get_synthesizer( + provider_config, + audio_output_config, + wyoming_tts_config, + openai_tts_config, + openai_llm_config, + ) + audio_data = None + try: + async with live_timer(live, "🔊 Synthesizing text", style="blue", quiet=quiet): + audio_data = await synthesizer( + text=text, + wyoming_tts_config=wyoming_tts_config, + openai_tts_config=openai_tts_config, + openai_llm_config=openai_llm_config, + logger=logger, + quiet=quiet, + live=live, + ) + except Exception: + logger.exception("Error during speech synthesis") + return None + + if audio_data and play_audio_flag: + await play_audio( + audio_data, + logger, + audio_output_config=audio_output_config, + quiet=quiet, + stop_event=stop_event, + live=live, + ) + + return audio_data diff --git a/tests/agents/test_interactive.py b/tests/agents/test_interactive.py index c75dfe6bc..716387ac0 100644 --- a/tests/agents/test_interactive.py +++ b/tests/agents/test_interactive.py @@ -9,7 +9,7 @@ import pytest -from agent_cli.agents import config +from agent_cli import config from agent_cli.agents.chat import ( ConversationEntry, _async_main, @@ -17,7 +17,7 @@ _load_conversation_history, _save_conversation_history, ) -from agent_cli.utils import InteractiveStopEvent +from agent_cli.core.utils import InteractiveStopEvent if TYPE_CHECKING: from pathlib import Path diff --git a/tests/agents/test_interactive_extra.py b/tests/agents/test_interactive_extra.py index bc2a188f5..de2082a87 100644 --- a/tests/agents/test_interactive_extra.py +++ b/tests/agents/test_interactive_extra.py @@ -5,13 +5,13 @@ import pytest from typer.testing import CliRunner -from agent_cli.agents import config +from agent_cli import config from agent_cli.agents.chat import ( _async_main, _handle_conversation_turn, ) from agent_cli.cli import app -from agent_cli.utils import InteractiveStopEvent +from agent_cli.core.utils import InteractiveStopEvent @pytest.mark.asyncio diff --git a/tests/agents/test_speak.py b/tests/agents/test_speak.py index 57c4aebfd..9acf51ca2 100644 --- a/tests/agents/test_speak.py +++ b/tests/agents/test_speak.py @@ -7,7 +7,7 @@ import pytest from typer.testing import CliRunner -from agent_cli.agents import config +from agent_cli import config from agent_cli.agents.speak import _async_main from agent_cli.cli import app diff --git a/tests/agents/test_speak_e2e.py b/tests/agents/test_speak_e2e.py index 0a63eac00..80d2d33b1 100644 --- a/tests/agents/test_speak_e2e.py +++ b/tests/agents/test_speak_e2e.py @@ -6,7 +6,7 @@ import pytest -from agent_cli.agents import config +from agent_cli import config from agent_cli.agents.speak import _async_main from tests.mocks.audio import MockPyAudio from tests.mocks.wyoming import MockTTSClient diff --git a/tests/agents/test_transcribe.py b/tests/agents/test_transcribe.py index d49361e87..a8f11c24c 100644 --- a/tests/agents/test_transcribe.py +++ b/tests/agents/test_transcribe.py @@ -8,7 +8,7 @@ import pytest -from agent_cli.agents import config, transcribe +from agent_cli import config, transcribe from tests.mocks.wyoming import MockASRClient diff --git a/tests/agents/test_transcribe_e2e.py b/tests/agents/test_transcribe_e2e.py index 8df4e2930..c33ac751b 100644 --- a/tests/agents/test_transcribe_e2e.py +++ b/tests/agents/test_transcribe_e2e.py @@ -8,7 +8,7 @@ import pytest -from agent_cli.agents import config +from agent_cli import config from agent_cli.agents.transcribe import _async_main from tests.mocks.audio import MockPyAudio from tests.mocks.wyoming import MockASRClient diff --git a/tests/agents/test_tts_common.py b/tests/agents/test_tts_common.py index 8a0686c4e..18f69a0a6 100644 --- a/tests/agents/test_tts_common.py +++ b/tests/agents/test_tts_common.py @@ -7,7 +7,7 @@ import pytest -from agent_cli.agents import config +from agent_cli import config from agent_cli.agents._tts_common import handle_tts_playback if TYPE_CHECKING: diff --git a/tests/agents/test_tts_common_extra.py b/tests/agents/test_tts_common_extra.py index 4749d20b5..f481c2eaf 100644 --- a/tests/agents/test_tts_common_extra.py +++ b/tests/agents/test_tts_common_extra.py @@ -7,7 +7,7 @@ import pytest -from agent_cli.agents import config +from agent_cli import config from agent_cli.agents._tts_common import _save_audio_file, handle_tts_playback diff --git a/tests/agents/test_voice_agent_common.py b/tests/agents/test_voice_agent_common.py index a3168610c..5eab0ba48 100644 --- a/tests/agents/test_voice_agent_common.py +++ b/tests/agents/test_voice_agent_common.py @@ -6,7 +6,7 @@ import pytest -from agent_cli.agents import config +from agent_cli import config from agent_cli.agents._voice_agent_common import ( get_instruction_from_audio, process_instruction_and_respond, diff --git a/tests/agents/test_voice_edit_e2e.py b/tests/agents/test_voice_edit_e2e.py index e7d345828..d0187015c 100644 --- a/tests/agents/test_voice_edit_e2e.py +++ b/tests/agents/test_voice_edit_e2e.py @@ -6,7 +6,7 @@ import pytest -from agent_cli.agents import config +from agent_cli import config from agent_cli.agents.voice_edit import ( AGENT_INSTRUCTIONS, SYSTEM_PROMPT, diff --git a/tests/test_llm.py b/tests/test_llm.py index 579994259..7a768bd0c 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -7,7 +7,7 @@ import pytest -from agent_cli.agents import config +from agent_cli import config from agent_cli.llm import build_agent, get_llm_response, process_and_update_clipboard diff --git a/tests/test_services.py b/tests/test_services.py index 11d868baf..998fc2a82 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -6,8 +6,7 @@ import pytest -from agent_cli import asr, tts -from agent_cli.agents import config +from agent_cli import asr, config, tts from agent_cli.services import synthesize_speech_openai, transcribe_audio_openai diff --git a/tests/test_tts.py b/tests/test_tts.py index c7919cae7..900b63236 100644 --- a/tests/test_tts.py +++ b/tests/test_tts.py @@ -8,7 +8,7 @@ import pytest -from agent_cli.agents import config +from agent_cli import config from agent_cli.tts import _apply_speed_adjustment, get_synthesizer, speak_text diff --git a/tests/test_wake_word.py b/tests/test_wake_word.py index 38bc49fa1..df65b83d2 100644 --- a/tests/test_wake_word.py +++ b/tests/test_wake_word.py @@ -7,7 +7,7 @@ from rich.live import Live from agent_cli import wake_word -from agent_cli.utils import InteractiveStopEvent +from agent_cli.core.utils import InteractiveStopEvent @pytest.fixture From 3d0c3b972f22dc705e9e6f83fa457faf39064025 Mon Sep 17 00:00:00 2001 From: Bas Nijholt Date: Tue, 8 Jul 2025 21:01:32 -0700 Subject: [PATCH 8/8] wip --- agent_cli/agents/autocorrect.py | 25 ++++-- agent_cli/agents/voice_edit.py | 3 +- agent_cli/asr.py | 18 ++-- agent_cli/cli.py | 2 +- agent_cli/core/utils.py | 12 +-- agent_cli/services/local.py | 12 +-- agent_cli/tts.py | 4 +- tests/agents/test_transcribe.py | 3 +- tests/test_config.py | 6 +- tests/test_wake_word.py | 146 -------------------------------- 10 files changed, 43 insertions(+), 188 deletions(-) delete mode 100644 tests/test_wake_word.py diff --git a/agent_cli/agents/autocorrect.py b/agent_cli/agents/autocorrect.py index f99c2d944..1a0925192 100644 --- a/agent_cli/agents/autocorrect.py +++ b/agent_cli/agents/autocorrect.py @@ -4,6 +4,7 @@ import asyncio import contextlib +import logging import sys import time from typing import TYPE_CHECKING @@ -22,7 +23,7 @@ print_output_panel, print_with_style, ) -from agent_cli.llm import build_agent +from agent_cli.services.factory import get_llm_service if TYPE_CHECKING: from rich.status import Status @@ -76,23 +77,27 @@ async def _process_text( provider_cfg: config.ProviderSelection, ollama_cfg: config.Ollama, openai_llm_cfg: config.OpenAILLM, + logger, ) -> tuple[str, float]: """Process text with the LLM and return the corrected text and elapsed time.""" - agent = build_agent( - provider_config=provider_cfg, - ollama_config=ollama_cfg, - openai_config=openai_llm_cfg, - system_prompt=SYSTEM_PROMPT, - instructions=AGENT_INSTRUCTIONS, + llm_service = get_llm_service( + provider_cfg, + ollama_cfg, + openai_llm_cfg, + logger, ) # Format the input using the template to clearly separate text from instructions formatted_input = INPUT_TEMPLATE.format(text=text) start_time = time.monotonic() - result = await agent.run(formatted_input) + result = await llm_service.get_response( + system_prompt=SYSTEM_PROMPT, + agent_instructions=AGENT_INSTRUCTIONS, + user_input=formatted_input, + ) elapsed = time.monotonic() - start_time - return result.output, elapsed + return result or "", elapsed def _display_original_text(original_text: str, quiet: bool) -> None: @@ -160,11 +165,13 @@ async def _async_autocorrect( try: with _maybe_status(provider_cfg, ollama_cfg, openai_llm_cfg, general_cfg.quiet): + logger = logging.getLogger(__name__) corrected_text, elapsed = await _process_text( original_text, provider_cfg, ollama_cfg, openai_llm_cfg, + logger, ) _display_result(corrected_text, original_text, elapsed, simple_output=general_cfg.quiet) diff --git a/agent_cli/agents/voice_edit.py b/agent_cli/agents/voice_edit.py index 0b25d074f..956728688 100644 --- a/agent_cli/agents/voice_edit.py +++ b/agent_cli/agents/voice_edit.py @@ -44,9 +44,10 @@ get_instruction_from_audio, process_instruction_and_respond, ) +from agent_cli.asr import record_audio_with_manual_stop from agent_cli.cli import app, setup_logging from agent_cli.core import process as process_manager -from agent_cli.core.audio import pyaudio_context, record_audio_with_manual_stop, setup_devices +from agent_cli.core.audio import pyaudio_context, setup_devices from agent_cli.core.utils import ( get_clipboard_text, maybe_live, diff --git a/agent_cli/asr.py b/agent_cli/asr.py index 2b1165d10..2148c9854 100644 --- a/agent_cli/asr.py +++ b/agent_cli/asr.py @@ -2,22 +2,13 @@ from __future__ import annotations -import asyncio import io -from functools import partial from typing import TYPE_CHECKING from wyoming.asr import Transcribe, Transcript, TranscriptChunk, TranscriptStart, TranscriptStop from wyoming.audio import AudioChunk, AudioStart, AudioStop from agent_cli import constants -from agent_cli.audio import ( - open_pyaudio_stream, - read_audio_stream, - read_from_queue, - setup_input_stream, -) - from agent_cli.core.audio import ( open_pyaudio_stream, read_audio_stream, @@ -27,12 +18,11 @@ from agent_cli.services.factory import get_asr_service from agent_cli.services.local import ( WyomingTranscriptionService, - manage_send_receive_tasks, - wyoming_client_context, ) from agent_cli.services.openai import OpenAITranscriptionService if TYPE_CHECKING: + import asyncio import logging from collections.abc import Awaitable, Callable @@ -64,7 +54,11 @@ def get_transcriber( quiet=quiet, ) - async def transcribe_live_audio(p: pyaudio.PyAudio, stop_event: InteractiveStopEvent, live: Live) -> str | None: + async def transcribe_live_audio( + p: pyaudio.PyAudio, + stop_event: InteractiveStopEvent, + live: Live, + ) -> str | None: """Record and transcribe live audio.""" audio_data = await record_audio_with_manual_stop( p, diff --git a/agent_cli/cli.py b/agent_cli/cli.py index 1b861ee80..131814afc 100644 --- a/agent_cli/cli.py +++ b/agent_cli/cli.py @@ -8,7 +8,7 @@ import typer from .config import load_config -from .utils import console +from .core.utils import console if TYPE_CHECKING: from logging import Handler diff --git a/agent_cli/core/utils.py b/agent_cli/core/utils.py index 7367c343f..1e1abbfcc 100644 --- a/agent_cli/core/utils.py +++ b/agent_cli/core/utils.py @@ -23,7 +23,7 @@ from rich.status import Status from rich.text import Text -from agent_cli import process_manager +from agent_cli.core.process import is_process_running, kill_process, read_pid_file if TYPE_CHECKING: import logging @@ -207,7 +207,7 @@ def stop_or_status_or_toggle( ) -> bool: """Handle process control for a given process name.""" if stop: - if process_manager.kill_process(process_name): + if kill_process(process_name): if not quiet: print_with_style(f"✅ {which.capitalize()} stopped.") elif not quiet: @@ -215,8 +215,8 @@ def stop_or_status_or_toggle( return True if status: - if process_manager.is_process_running(process_name): - pid = process_manager.read_pid_file(process_name) + if is_process_running(process_name): + pid = read_pid_file(process_name) if not quiet: print_with_style(f"✅ {which.capitalize()} is running (PID: {pid}).") elif not quiet: @@ -224,8 +224,8 @@ def stop_or_status_or_toggle( return True if toggle: - if process_manager.is_process_running(process_name): - if process_manager.kill_process(process_name) and not quiet: + if is_process_running(process_name): + if kill_process(process_name) and not quiet: print_with_style(f"✅ {which.capitalize()} stopped.") return True if not quiet: diff --git a/agent_cli/services/local.py b/agent_cli/services/local.py index 4731e59a1..0703c50e2 100644 --- a/agent_cli/services/local.py +++ b/agent_cli/services/local.py @@ -6,7 +6,7 @@ import io import wave from contextlib import asynccontextmanager -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from pydantic_ai import Agent from pydantic_ai.models.openai import OpenAIModel @@ -23,7 +23,7 @@ if TYPE_CHECKING: import logging - from collections.abc import AsyncGenerator, Awaitable + from collections.abc import AsyncGenerator, Coroutine from rich.live import Live @@ -56,13 +56,13 @@ async def wyoming_client_context( async def manage_send_receive_tasks( - send_task: Awaitable, - recv_task: Awaitable, + send_task: Coroutine[Any, Any, Any], + recv_task: Coroutine[Any, Any, Any], return_when: str = asyncio.ALL_COMPLETED, ) -> tuple[asyncio.Task, asyncio.Task]: """Manage send and receive tasks for a Wyoming client.""" - send = asyncio.create_task(send_task) - recv = asyncio.create_task(recv_task) + send: asyncio.Task = asyncio.create_task(send_task) + recv: asyncio.Task = asyncio.create_task(recv_task) done, pending = await asyncio.wait({send, recv}, return_when=return_when) for task in pending: task.cancel() diff --git a/agent_cli/tts.py b/agent_cli/tts.py index e556c87ba..cf7443187 100644 --- a/agent_cli/tts.py +++ b/agent_cli/tts.py @@ -6,15 +6,12 @@ import importlib.util import io import wave -from functools import partial from typing import TYPE_CHECKING from wyoming.audio import AudioChunk, AudioStart, AudioStop from wyoming.tts import Synthesize, SynthesizeVoice from agent_cli import constants -from agent_cli.audio import open_pyaudio_stream, pyaudio_context, setup_output_stream - from agent_cli.core.audio import open_pyaudio_stream, pyaudio_context, setup_output_stream from agent_cli.core.utils import ( InteractiveStopEvent, @@ -283,6 +280,7 @@ async def speak_text( wyoming_tts_config, openai_tts_config, openai_llm_config, + logger, ) audio_data = None try: diff --git a/tests/agents/test_transcribe.py b/tests/agents/test_transcribe.py index a8f11c24c..e069e65ac 100644 --- a/tests/agents/test_transcribe.py +++ b/tests/agents/test_transcribe.py @@ -8,7 +8,8 @@ import pytest -from agent_cli import config, transcribe +from agent_cli import config +from agent_cli.agents import transcribe from tests.mocks.wyoming import MockASRClient diff --git a/tests/test_config.py b/tests/test_config.py index 67f70b418..04bf3328d 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -11,7 +11,7 @@ from typer.testing import CliRunner from agent_cli.cli import set_config_defaults -from agent_cli.config_loader import load_config +from agent_cli.config import load_config if TYPE_CHECKING: from pathlib import Path @@ -89,8 +89,8 @@ def test_set_config_defaults(config_file: Path) -> None: assert ctx.default_map == expected_merged_defaults -@patch("agent_cli.config_loader.CONFIG_PATH") -@patch("agent_cli.config_loader.CONFIG_PATH_2") +@patch("agent_cli.config.CONFIG_PATH") +@patch("agent_cli.config.CONFIG_PATH_2") def test_default_config_paths( mock_path2: MagicMock, mock_path1: MagicMock, diff --git a/tests/test_wake_word.py b/tests/test_wake_word.py deleted file mode 100644 index df65b83d2..000000000 --- a/tests/test_wake_word.py +++ /dev/null @@ -1,146 +0,0 @@ -"""Tests for the wake word detection module.""" - -import asyncio -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest -from rich.live import Live - -from agent_cli import wake_word -from agent_cli.core.utils import InteractiveStopEvent - - -@pytest.fixture -def mock_pyaudio(): - """Mock PyAudio instance.""" - return MagicMock() - - -@pytest.fixture -def mock_logger(): - """Mock logger instance.""" - return MagicMock() - - -@pytest.fixture -def mock_stop_event(): - """Mock stop event.""" - stop_event = MagicMock(spec=InteractiveStopEvent) - stop_event.is_set.return_value = False - stop_event.ctrl_c_pressed = False - return stop_event - - -@pytest.fixture -def mock_live(): - """Mock Rich Live instance.""" - return MagicMock(spec=Live) - - -class TestReceiveWakeDetection: - """Tests for _receive_wake_detection function.""" - - @pytest.mark.asyncio - async def test_returns_detected_wake_word(self, mock_logger: MagicMock) -> None: - """Test detection of wake word.""" - mock_client = AsyncMock() - - # Mock detection event - mock_event = MagicMock() - mock_event.type = "detection" - - # Mock Detection.is_type and Detection.from_event - with ( - patch("agent_cli.wake_word.Detection.is_type", return_value=True), - patch("agent_cli.wake_word.Detection.from_event") as mock_from_event, - ): - mock_detection = MagicMock() - mock_detection.name = "test_wake_word" - mock_from_event.return_value = mock_detection - - mock_client.read_event.return_value = mock_event - - result = await wake_word._receive_wake_detection(mock_client, mock_logger) - - assert result == "test_wake_word" - mock_logger.info.assert_called_with("Wake word detected: %s", "test_wake_word") - - @pytest.mark.asyncio - async def test_calls_detection_callback(self, mock_logger: MagicMock) -> None: - """Test that detection callback is called.""" - mock_client = AsyncMock() - mock_callback = MagicMock() - - # Mock detection event - mock_event = MagicMock() - mock_event.type = "detection" - - with ( - patch("agent_cli.wake_word.Detection.is_type", return_value=True), - patch("agent_cli.wake_word.Detection.from_event") as mock_from_event, - ): - mock_detection = MagicMock() - mock_detection.name = "test_wake_word" - mock_from_event.return_value = mock_detection - - mock_client.read_event.return_value = mock_event - - result = await wake_word._receive_wake_detection( - mock_client, - mock_logger, - detection_callback=mock_callback, - ) - - assert result == "test_wake_word" - mock_callback.assert_called_once_with("test_wake_word") - - @pytest.mark.asyncio - async def test_handles_not_detected_event(self, mock_logger: MagicMock) -> None: - """Test handling of not-detected event.""" - mock_client = AsyncMock() - - # Mock not-detected event - mock_event = MagicMock() - mock_event.type = "not-detected" - - with ( - patch("agent_cli.wake_word.Detection.is_type", return_value=False), - patch("agent_cli.wake_word.NotDetected.is_type", return_value=True), - ): - mock_client.read_event.return_value = mock_event - - result = await wake_word._receive_wake_detection(mock_client, mock_logger) - - assert result is None - mock_logger.debug.assert_called_with("No wake word detected") - - @pytest.mark.asyncio - async def test_handles_connection_loss(self, mock_logger: MagicMock) -> None: - """Test handling of lost connection.""" - mock_client = AsyncMock() - mock_client.read_event.return_value = None - - result = await wake_word._receive_wake_detection(mock_client, mock_logger) - - assert result is None - mock_logger.warning.assert_called_with("Connection to wake word server lost.") - - -@pytest.mark.asyncio -@patch("agent_cli.wake_word.wyoming_client_context", side_effect=ConnectionRefusedError) -async def test_detect_wake_word_from_queue_connection_error( - mock_wyoming_client_context: MagicMock, - mock_logger: MagicMock, - mock_live: MagicMock, -): - """Test that detect_wake_word_from_queue handles ConnectionRefusedError.""" - result = await wake_word.detect_wake_word_from_queue( - "localhost", - 1234, - "test_word", - mock_logger, - asyncio.Queue(), - live=mock_live, - ) - assert result is None - mock_wyoming_client_context.assert_called_once()