diff --git a/REFACTORING_PLAN.md b/REFACTORING_PLAN.md new file mode 100644 index 000000000..32867abbb --- /dev/null +++ b/REFACTORING_PLAN.md @@ -0,0 +1,90 @@ +# Agent-CLI Refactoring Plan + +## 1. Goal + +The primary goal of this refactoring is to improve the overall organization of the `agent-cli` package. This involves restructuring the project to better separate concerns, reduce cross-dependencies between modules, and make the codebase more intuitive, maintainable, and extensible. + +## 2. Proposed File Structure + +The new architecture will introduce `core` and `services` packages to logically group related functionality. + +``` +agent_cli/ +├── __init__.py +├── cli.py +├── constants.py +├── py.typed +├── agents/ +│ ├── __init__.py +│ ├── _cli_options.py +│ ├── _tts_common.py +│ ├── _voice_agent_common.py +│ ├── assistant.py +│ ├── autocorrect.py +│ ├── chat.py +│ ├── speak.py +│ ├── transcribe.py +│ └── voice_edit.py +├── config.py # New unified config module +├── core/ # New package for core logic +│ ├── __init__.py +│ ├── audio.py # For audio device I/O +│ ├── process.py # For process management +│ └── utils.py # For generic utilities +└── services/ # New package for external service integrations + ├── __init__.py + ├── base.py # Abstract base classes for services + ├── factory.py # Factory to get the correct service + ├── local.py # Implementations for local services (Wyoming/Ollama) + └── openai.py # Implementations for OpenAI services +``` + +## 3. Detailed Migration Plan + +### Step 1: Consolidate Configuration + +- **Action:** Create a new `agent_cli/config.py` file. +- **Source Logic:** Merge the contents of `agent_cli/config_loader.py` and `agent_cli/agents/config.py`. +- **Content:** + - **Loading Logic:** `load_config()`, `_replace_dashed_keys()` from `config_loader.py`. + - **Pydantic Models:** All configuration models (`ProviderSelection`, `Ollama`, `OpenAILLM`, `AudioInput`, `WyomingASR`, `OpenAIASR`, `AudioOutput`, `WyomingTTS`, `OpenAITTS`, `WakeWord`, `General`, `History`). +- **Cleanup:** Delete `agent_cli/config_loader.py` and `agent_cli/agents/config.py`. + +### Step 2: Create `core` Package + +- **Action:** Create a new directory `agent_cli/core/`. +- **`agent_cli/core/audio.py`**: + - **Action:** Move `agent_cli/audio.py` to `agent_cli/core/audio.py`. + - **Content:** All PyAudio device management and streaming logic. +- **`agent_cli/core/process.py`**: + - **Action:** Move `agent_cli/process_manager.py` to `agent_cli/core/process.py`. + - **Content:** All PID file and process management functions. +- **`agent_cli/core/utils.py`**: + - **Action:** Create `agent_cli/core/utils.py` and move generic helpers from `agent_cli/utils.py`. + - **Content:** `console`, `InteractiveStopEvent`, `signal_handling_context`, `live_timer`, `print_*_panel`, `get_clipboard_text`. + +### Step 3: Create `services` Package + +- **Action:** Create a new directory `agent_cli/services/`. +- **`agent_cli/services/base.py`** (New File): + - **Content:** Define Abstract Base Classes (ABCs) for `ASRService`, `LLMService`, and `TTSService`. +- **`agent_cli/services/local.py`** (New File): + - **Content:** Implementations for all local services. + - **Wyoming ASR:** Logic from `asr.py`. + - **Wyoming TTS:** Logic from `tts.py`. + - **Wyoming Wake Word:** Logic from `wake_word.py`. + - **Ollama LLM:** Logic from `llm.py`. + - **Wyoming Utils:** `wyoming_client_context` from `wyoming_utils.py`. +- **`agent_cli/services/openai.py`** (New File): + - **Content:** Implementations for all OpenAI services. + - **OpenAI ASR:** Logic from `services.py` and `asr.py`. + - **OpenAI TTS:** Logic from `services.py` and `tts.py`. + - **OpenAI LLM:** Logic from `llm.py`. +- **`agent_cli/services/factory.py`** (New File): + - **Content:** Factory functions (`get_asr_service`, `get_llm_service`, `get_tts_service`) that return the correct service implementation based on the user's configuration. + +### Step 4: Refactor and Cleanup + +- **Action:** Update all imports across the project to reflect the new structure. +- **Action:** Delete the old, now-empty files: `asr.py`, `llm.py`, `tts.py`, `wake_word.py`, `services.py`, `process_manager.py`, `config_loader.py`, `wyoming_utils.py`, and `agents/config.py`. +- **Action:** Refactor `agent_cli/utils.py` to remove the functions that were moved to `core/utils.py`. diff --git a/agent_cli/agents/_tts_common.py b/agent_cli/agents/_tts_common.py index 0f8cea732..110fd2d68 100644 --- a/agent_cli/agents/_tts_common.py +++ b/agent_cli/agents/_tts_common.py @@ -7,14 +7,14 @@ from typing import TYPE_CHECKING from agent_cli import tts -from agent_cli.utils import InteractiveStopEvent, print_with_style +from agent_cli.core.utils import InteractiveStopEvent, print_with_style if TYPE_CHECKING: import logging from rich.live import Live - from agent_cli.agents import config + from agent_cli import config async def _save_audio_file( diff --git a/agent_cli/agents/_voice_agent_common.py b/agent_cli/agents/_voice_agent_common.py index 01d293a3a..751ff6ba7 100644 --- a/agent_cli/agents/_voice_agent_common.py +++ b/agent_cli/agents/_voice_agent_common.py @@ -8,15 +8,15 @@ import pyperclip -from agent_cli import asr from agent_cli.agents._tts_common import handle_tts_playback +from agent_cli.core.utils import print_input_panel, print_with_style from agent_cli.llm import process_and_update_clipboard -from agent_cli.utils import print_input_panel, print_with_style +from agent_cli.services.factory import get_asr_service if TYPE_CHECKING: from rich.live import Live - from agent_cli.agents import config + from agent_cli import config LOGGER = logging.getLogger() @@ -25,10 +25,8 @@ async def get_instruction_from_audio( *, audio_data: bytes, provider_config: config.ProviderSelection, - audio_input_config: config.AudioInput, wyoming_asr_config: config.WyomingASR, openai_asr_config: config.OpenAIASR, - ollama_config: config.Ollama, openai_llm_config: config.OpenAILLM, logger: logging.Logger, quiet: bool, @@ -36,18 +34,15 @@ async def get_instruction_from_audio( """Transcribe audio data and return the instruction.""" try: start_time = time.monotonic() - transcriber = asr.get_recorded_audio_transcriber(provider_config) - instruction = await transcriber( - audio_data=audio_data, - provider_config=provider_config, - audio_input_config=audio_input_config, - wyoming_asr_config=wyoming_asr_config, - openai_asr_config=openai_asr_config, - ollama_config=ollama_config, - openai_llm_config=openai_llm_config, - logger=logger, + transcriber = get_asr_service( + provider_config, + wyoming_asr_config, + openai_asr_config, + openai_llm_config, + logger, quiet=quiet, ) + instruction = await transcriber.transcribe(audio_data) elapsed = time.monotonic() - start_time if not instruction or not instruction.strip(): diff --git a/agent_cli/agents/assistant.py b/agent_cli/agents/assistant.py index b0d4be36d..393a44fc2 100644 --- a/agent_cli/agents/assistant.py +++ b/agent_cli/agents/assistant.py @@ -34,21 +34,23 @@ from typing import TYPE_CHECKING import agent_cli.agents._cli_options as opts -from agent_cli import asr, audio, process_manager, wake_word -from agent_cli.agents import config +from agent_cli import config from agent_cli.agents._voice_agent_common import ( get_instruction_from_audio, process_instruction_and_respond, ) -from agent_cli.audio import pyaudio_context, setup_devices from agent_cli.cli import app, setup_logging -from agent_cli.utils import ( +from agent_cli.core import audio +from agent_cli.core import process as process_manager +from agent_cli.core.audio import pyaudio_context, setup_devices +from agent_cli.core.utils import ( InteractiveStopEvent, maybe_live, print_with_style, signal_handling_context, stop_or_status_or_toggle, ) +from agent_cli.services.local import WyomingWakeWordService if TYPE_CHECKING: import pyaudio @@ -110,15 +112,14 @@ async def _record_audio_with_wake_word( # Create a queue for wake word detection wake_queue = await tee.add_queue() - detected_word = await wake_word.detect_wake_word_from_queue( - wake_server_ip=wake_word_config.server_ip, - wake_server_port=wake_word_config.server_port, - wake_word_name=wake_word_config.wake_word_name, - logger=logger, - queue=wake_queue, - quiet=quiet, + wake_word_service = WyomingWakeWordService( + wake_word_config, + logger, + wake_queue, live=live, + quiet=quiet, ) + detected_word = await wake_word_service.detect() if not detected_word or stop_event.is_set(): # Clean up the queue if we exit early @@ -133,19 +134,17 @@ async def _record_audio_with_wake_word( # Add a new queue for recording record_queue = await tee.add_queue() - record_task = asyncio.create_task(asr.record_audio_to_buffer(record_queue, logger)) + record_task = asyncio.create_task(audio.record_audio_to_buffer(record_queue, logger)) # Use the same wake_queue for stop-word detection - stop_detected_word = await wake_word.detect_wake_word_from_queue( - wake_server_ip=wake_word_config.server_ip, - wake_server_port=wake_word_config.server_port, - wake_word_name=wake_word_config.wake_word_name, - logger=logger, - queue=wake_queue, - quiet=quiet, + wake_word_service = WyomingWakeWordService( + wake_word_config, + logger, + wake_queue, live=live, - progress_message="Recording... (say wake word to stop)", + quiet=quiet, ) + stop_detected_word = await wake_word_service.detect() # Stop the recording task by removing its queue await tee.remove_queue(record_queue) diff --git a/agent_cli/agents/autocorrect.py b/agent_cli/agents/autocorrect.py index a740001a6..1a0925192 100644 --- a/agent_cli/agents/autocorrect.py +++ b/agent_cli/agents/autocorrect.py @@ -4,6 +4,7 @@ import asyncio import contextlib +import logging import sys import time from typing import TYPE_CHECKING @@ -12,10 +13,9 @@ import typer import agent_cli.agents._cli_options as opts -from agent_cli.agents import config +from agent_cli import config from agent_cli.cli import app, setup_logging -from agent_cli.llm import build_agent -from agent_cli.utils import ( +from agent_cli.core.utils import ( create_status, get_clipboard_text, print_error_message, @@ -23,6 +23,7 @@ print_output_panel, print_with_style, ) +from agent_cli.services.factory import get_llm_service if TYPE_CHECKING: from rich.status import Status @@ -76,23 +77,27 @@ async def _process_text( provider_cfg: config.ProviderSelection, ollama_cfg: config.Ollama, openai_llm_cfg: config.OpenAILLM, + logger, ) -> tuple[str, float]: """Process text with the LLM and return the corrected text and elapsed time.""" - agent = build_agent( - provider_config=provider_cfg, - ollama_config=ollama_cfg, - openai_config=openai_llm_cfg, - system_prompt=SYSTEM_PROMPT, - instructions=AGENT_INSTRUCTIONS, + llm_service = get_llm_service( + provider_cfg, + ollama_cfg, + openai_llm_cfg, + logger, ) # Format the input using the template to clearly separate text from instructions formatted_input = INPUT_TEMPLATE.format(text=text) start_time = time.monotonic() - result = await agent.run(formatted_input) + result = await llm_service.get_response( + system_prompt=SYSTEM_PROMPT, + agent_instructions=AGENT_INSTRUCTIONS, + user_input=formatted_input, + ) elapsed = time.monotonic() - start_time - return result.output, elapsed + return result or "", elapsed def _display_original_text(original_text: str, quiet: bool) -> None: @@ -160,11 +165,13 @@ async def _async_autocorrect( try: with _maybe_status(provider_cfg, ollama_cfg, openai_llm_cfg, general_cfg.quiet): + logger = logging.getLogger(__name__) corrected_text, elapsed = await _process_text( original_text, provider_cfg, ollama_cfg, openai_llm_cfg, + logger, ) _display_result(corrected_text, original_text, elapsed, simple_output=general_cfg.quiet) diff --git a/agent_cli/agents/chat.py b/agent_cli/agents/chat.py index 741355cd1..6a782d33c 100644 --- a/agent_cli/agents/chat.py +++ b/agent_cli/agents/chat.py @@ -25,13 +25,12 @@ import typer import agent_cli.agents._cli_options as opts -from agent_cli import asr, process_manager -from agent_cli.agents import config +from agent_cli import config from agent_cli.agents._tts_common import handle_tts_playback -from agent_cli.audio import pyaudio_context, setup_devices from agent_cli.cli import app, setup_logging -from agent_cli.llm import get_llm_response -from agent_cli.utils import ( +from agent_cli.core import process as process_manager +from agent_cli.core.audio import pyaudio_context, setup_devices +from agent_cli.core.utils import ( InteractiveStopEvent, console, format_timedelta_to_ago, @@ -43,6 +42,8 @@ signal_handling_context, stop_or_status_or_toggle, ) +from agent_cli.llm import get_llm_response +from agent_cli.services.factory import get_asr_service if TYPE_CHECKING: import pyaudio @@ -150,7 +151,6 @@ async def _handle_conversation_turn( provider_cfg: config.ProviderSelection, general_cfg: config.General, history_cfg: config.History, - audio_in_cfg: config.AudioInput, wyoming_asr_cfg: config.WyomingASR, openai_asr_cfg: config.OpenAIASR, ollama_cfg: config.Ollama, @@ -176,14 +176,15 @@ async def _handle_conversation_turn( # 1. Transcribe user's command start_time = time.monotonic() - transcriber = asr.get_transcriber( + transcriber = get_asr_service( provider_cfg, - audio_in_cfg, wyoming_asr_cfg, openai_asr_cfg, openai_llm_cfg, + LOGGER, + quiet=general_cfg.quiet, ) - instruction = await transcriber( + instruction = await transcriber.transcribe( p=p, stop_event=stop_event, quiet=general_cfg.quiet, diff --git a/agent_cli/agents/speak.py b/agent_cli/agents/speak.py index 6dfc7ed15..b1617340f 100644 --- a/agent_cli/agents/speak.py +++ b/agent_cli/agents/speak.py @@ -10,12 +10,12 @@ import typer import agent_cli.agents._cli_options as opts -from agent_cli import process_manager -from agent_cli.agents import config +from agent_cli import config from agent_cli.agents._tts_common import handle_tts_playback -from agent_cli.audio import pyaudio_context, setup_devices from agent_cli.cli import app, setup_logging -from agent_cli.utils import ( +from agent_cli.core import process as process_manager +from agent_cli.core.audio import pyaudio_context, setup_devices +from agent_cli.core.utils import ( get_clipboard_text, maybe_live, print_input_panel, diff --git a/agent_cli/agents/transcribe.py b/agent_cli/agents/transcribe.py index bca448c7e..0d3cee5e7 100644 --- a/agent_cli/agents/transcribe.py +++ b/agent_cli/agents/transcribe.py @@ -11,12 +11,11 @@ import pyperclip import agent_cli.agents._cli_options as opts -from agent_cli import asr, process_manager -from agent_cli.agents import config -from agent_cli.audio import pyaudio_context, setup_devices +from agent_cli import config from agent_cli.cli import app, setup_logging -from agent_cli.llm import process_and_update_clipboard -from agent_cli.utils import ( +from agent_cli.core import process as process_manager +from agent_cli.core.audio import pyaudio_context, setup_devices +from agent_cli.core.utils import ( maybe_live, print_input_panel, print_output_panel, @@ -24,6 +23,8 @@ signal_handling_context, stop_or_status_or_toggle, ) +from agent_cli.llm import process_and_update_clipboard +from agent_cli.services.factory import get_asr_service if TYPE_CHECKING: import pyaudio @@ -70,7 +71,6 @@ async def _async_main( *, provider_cfg: config.ProviderSelection, general_cfg: config.General, - audio_in_cfg: config.AudioInput, wyoming_asr_cfg: config.WyomingASR, openai_asr_cfg: config.OpenAIASR, ollama_cfg: config.Ollama, @@ -82,14 +82,15 @@ async def _async_main( start_time = time.monotonic() with maybe_live(not general_cfg.quiet) as live: with signal_handling_context(LOGGER, general_cfg.quiet) as stop_event: - transcriber = asr.get_transcriber( + transcriber = get_asr_service( provider_cfg, - audio_in_cfg, wyoming_asr_cfg, openai_asr_cfg, openai_llm_cfg, + LOGGER, + quiet=general_cfg.quiet, ) - transcript = await transcriber( + transcript = await transcriber.transcribe( logger=LOGGER, p=p, stop_event=stop_event, diff --git a/agent_cli/agents/voice_edit.py b/agent_cli/agents/voice_edit.py index 283dade5a..956728688 100644 --- a/agent_cli/agents/voice_edit.py +++ b/agent_cli/agents/voice_edit.py @@ -39,15 +39,16 @@ from pathlib import Path # noqa: TC003 import agent_cli.agents._cli_options as opts -from agent_cli import asr, process_manager -from agent_cli.agents import config +from agent_cli import config from agent_cli.agents._voice_agent_common import ( get_instruction_from_audio, process_instruction_and_respond, ) -from agent_cli.audio import pyaudio_context, setup_devices +from agent_cli.asr import record_audio_with_manual_stop from agent_cli.cli import app, setup_logging -from agent_cli.utils import ( +from agent_cli.core import process as process_manager +from agent_cli.core.audio import pyaudio_context, setup_devices +from agent_cli.core.utils import ( get_clipboard_text, maybe_live, print_input_panel, @@ -118,7 +119,7 @@ async def _async_main( signal_handling_context(LOGGER, general_cfg.quiet) as stop_event, maybe_live(not general_cfg.quiet) as live, ): - audio_data = await asr.record_audio_with_manual_stop( + audio_data = await record_audio_with_manual_stop( p, input_device_index, stop_event, diff --git a/agent_cli/asr.py b/agent_cli/asr.py index 78974b6e1..2148c9854 100644 --- a/agent_cli/asr.py +++ b/agent_cli/asr.py @@ -2,25 +2,27 @@ from __future__ import annotations -import asyncio import io -from functools import partial from typing import TYPE_CHECKING from wyoming.asr import Transcribe, Transcript, TranscriptChunk, TranscriptStart, TranscriptStop from wyoming.audio import AudioChunk, AudioStart, AudioStop from agent_cli import constants -from agent_cli.audio import ( +from agent_cli.core.audio import ( open_pyaudio_stream, read_audio_stream, read_from_queue, setup_input_stream, ) -from agent_cli.services import transcribe_audio_openai -from agent_cli.wyoming_utils import manage_send_receive_tasks, wyoming_client_context +from agent_cli.services.factory import get_asr_service +from agent_cli.services.local import ( + WyomingTranscriptionService, +) +from agent_cli.services.openai import OpenAITranscriptionService if TYPE_CHECKING: + import asyncio import logging from collections.abc import Awaitable, Callable @@ -28,8 +30,8 @@ from rich.live import Live from wyoming.client import AsyncClient - from agent_cli.agents import config - from agent_cli.utils import InteractiveStopEvent + from agent_cli import config + from agent_cli.core.utils import InteractiveStopEvent def get_transcriber( @@ -38,29 +40,60 @@ def get_transcriber( wyoming_asr_config: config.WyomingASR, openai_asr_config: config.OpenAIASR, openai_llm_config: config.OpenAILLM, + logger: logging.Logger, + *, + quiet: bool = False, ) -> Callable[..., Awaitable[str | None]]: """Return the appropriate transcriber for live audio based on the provider.""" - if provider_config.asr_provider == "openai": - return partial( - transcribe_live_audio_openai, - audio_input_config=audio_input_config, - openai_asr_config=openai_asr_config, - openai_llm_config=openai_llm_config, - ) - return partial( - transcribe_live_audio_wyoming, - audio_input_config=audio_input_config, - wyoming_asr_config=wyoming_asr_config, + asr_service = get_asr_service( + provider_config, + wyoming_asr_config, + openai_asr_config, + openai_llm_config, + logger, + quiet=quiet, ) + async def transcribe_live_audio( + p: pyaudio.PyAudio, + stop_event: InteractiveStopEvent, + live: Live, + ) -> str | None: + """Record and transcribe live audio.""" + audio_data = await record_audio_with_manual_stop( + p, + audio_input_config.input_device_index, + stop_event, + logger, + quiet=quiet, + live=live, + ) + if not audio_data: + return None + return await asr_service.transcribe(audio_data) + + return transcribe_live_audio + def get_recorded_audio_transcriber( provider_config: config.ProviderSelection, -) -> Callable[..., Awaitable[str]]: + wyoming_asr_config: config.WyomingASR, + openai_asr_config: config.OpenAIASR, + openai_llm_config: config.OpenAILLM, + logger: logging.Logger, + *, + quiet: bool = False, +) -> Callable[[bytes], Awaitable[str]]: """Return the appropriate transcriber for recorded audio based on the provider.""" - if provider_config.asr_provider == "openai": - return transcribe_audio_openai - return transcribe_recorded_audio_wyoming + asr_service = get_asr_service( + provider_config, + wyoming_asr_config, + openai_asr_config, + openai_llm_config, + logger, + quiet=quiet, + ) + return asr_service.transcribe async def _send_audio( @@ -187,31 +220,8 @@ async def transcribe_recorded_audio_wyoming( **_kwargs: object, ) -> str: """Process pre-recorded audio data with Wyoming ASR server.""" - try: - async with wyoming_client_context( - wyoming_asr_config.wyoming_asr_ip, - wyoming_asr_config.wyoming_asr_port, - "ASR", - logger, - quiet=quiet, - ) as client: - await client.write_event(Transcribe().event()) - await client.write_event(AudioStart(**constants.WYOMING_AUDIO_CONFIG).event()) - - chunk_size = constants.PYAUDIO_CHUNK_SIZE * 2 - for i in range(0, len(audio_data), chunk_size): - chunk = audio_data[i : i + chunk_size] - await client.write_event( - AudioChunk(audio=chunk, **constants.WYOMING_AUDIO_CONFIG).event(), - ) - logger.debug("Sent %d byte(s) of audio", len(chunk)) - - await client.write_event(AudioStop().event()) - logger.debug("Sent AudioStop") - - return await _receive_transcript(client, logger) - except (ConnectionRefusedError, Exception): - return "" + service = WyomingTranscriptionService(wyoming_asr_config, logger, quiet=quiet) + return await service.transcribe(audio_data) async def transcribe_live_audio_wyoming( @@ -228,29 +238,23 @@ async def transcribe_live_audio_wyoming( **_kwargs: object, ) -> str | None: """Unified ASR transcription function.""" - try: - async with wyoming_client_context( - wyoming_asr_config.wyoming_asr_ip, - wyoming_asr_config.wyoming_asr_port, - "ASR", - logger, - quiet=quiet, - ) as client: - stream_config = setup_input_stream(audio_input_config.input_device_index) - with open_pyaudio_stream(p, **stream_config) as stream: - _, recv_task = await manage_send_receive_tasks( - _send_audio(client, stream, stop_event, logger, live=live, quiet=quiet), - _receive_transcript( - client, - logger, - chunk_callback=chunk_callback, - final_callback=final_callback, - ), - return_when=asyncio.ALL_COMPLETED, - ) - return recv_task.result() - except (ConnectionRefusedError, Exception): + service = WyomingTranscriptionService(wyoming_asr_config, logger, quiet=quiet) + audio_data = await record_audio_with_manual_stop( + p, + audio_input_config.input_device_index, + stop_event, + logger, + quiet=quiet, + live=live, + ) + if not audio_data: return None + transcript = await service.transcribe(audio_data) + if chunk_callback: + chunk_callback(transcript) + if final_callback: + final_callback(transcript) + return transcript async def transcribe_live_audio_openai( @@ -266,6 +270,7 @@ async def transcribe_live_audio_openai( **_kwargs: object, ) -> str | None: """Record and transcribe live audio using OpenAI Whisper.""" + service = OpenAITranscriptionService(openai_asr_config, openai_llm_config, logger) audio_data = await record_audio_with_manual_stop( p, audio_input_config.input_device_index, @@ -276,13 +281,4 @@ async def transcribe_live_audio_openai( ) if not audio_data: return None - try: - return await transcribe_audio_openai( - audio_data, - openai_asr_config, - openai_llm_config, - logger, - ) - except Exception: - logger.exception("Error during transcription") - return "" + return await service.transcribe(audio_data) diff --git a/agent_cli/cli.py b/agent_cli/cli.py index 7b8803ce0..131814afc 100644 --- a/agent_cli/cli.py +++ b/agent_cli/cli.py @@ -7,8 +7,8 @@ import typer -from .config_loader import load_config -from .utils import console +from .config import load_config +from .core.utils import console if TYPE_CHECKING: from logging import Handler diff --git a/agent_cli/agents/config.py b/agent_cli/config.py similarity index 67% rename from agent_cli/agents/config.py rename to agent_cli/config.py index 452727ad8..37ee761fd 100644 --- a/agent_cli/agents/config.py +++ b/agent_cli/config.py @@ -1,11 +1,54 @@ -"""Pydantic models for agent configurations, aligned with CLI option groups.""" +"""Pydantic models for agent configurations and config file loading.""" from __future__ import annotations +import tomllib from pathlib import Path -from typing import Literal +from typing import Any, Literal from pydantic import BaseModel, field_validator +from rich.console import Console + +console = Console() + +# --- Config File Loading --- + +CONFIG_PATH = Path.home() / ".config" / "agent-cli" / "config.toml" +CONFIG_PATH_2 = Path("agent-cli-config.toml") + + +def _replace_dashed_keys(cfg: dict[str, Any]) -> dict[str, Any]: + """Replace dashed keys with underscores in the config options.""" + return {k.replace("-", "_"): v for k, v in cfg.items()} + + +def load_config(config_path_str: str | None = None) -> dict[str, Any]: + """Load the TOML configuration file and process it for nested structures.""" + # Determine which config path to use + if config_path_str: + config_path = Path(config_path_str) + elif CONFIG_PATH.exists(): + config_path = CONFIG_PATH + elif CONFIG_PATH_2.exists(): + config_path = CONFIG_PATH_2 + else: + return {} + + # Try to load and process the config + if config_path.exists(): + with config_path.open("rb") as f: + cfg = tomllib.load(f) + return {k: _replace_dashed_keys(v) for k, v in cfg.items()} + + # Report error only if an explicit path was given + if config_path_str: + console.print( + f"[bold red]Config file not found at {config_path_str}[/bold red]", + ) + return {} + + +# --- Pydantic Models for Configuration --- # --- Panel: Provider Selection --- diff --git a/agent_cli/config_loader.py b/agent_cli/config_loader.py deleted file mode 100644 index 826271f50..000000000 --- a/agent_cli/config_loader.py +++ /dev/null @@ -1,43 +0,0 @@ -"""Handles loading and parsing of the agent-cli configuration file.""" - -from __future__ import annotations - -import tomllib -from pathlib import Path -from typing import Any - -from .utils import console - -CONFIG_PATH = Path.home() / ".config" / "agent-cli" / "config.toml" -CONFIG_PATH_2 = Path("agent-cli-config.toml") - - -def _replace_dashed_keys(cfg: dict[str, Any]) -> dict[str, Any]: - """Replace dashed keys with underscores in the config options.""" - return {k.replace("-", "_"): v for k, v in cfg.items()} - - -def load_config(config_path_str: str | None = None) -> dict[str, Any]: - """Load the TOML configuration file and process it for nested structures.""" - # Determine which config path to use - if config_path_str: - config_path = Path(config_path_str) - elif CONFIG_PATH.exists(): - config_path = CONFIG_PATH - elif CONFIG_PATH_2.exists(): - config_path = CONFIG_PATH_2 - else: - return {} - - # Try to load and process the config - if config_path.exists(): - with config_path.open("rb") as f: - cfg = tomllib.load(f) - return {k: _replace_dashed_keys(v) for k, v in cfg.items()} - - # Report error only if an explicit path was given - if config_path_str: - console.print( - f"[bold red]Config file not found at {config_path_str}[/bold red]", - ) - return {} diff --git a/agent_cli/core/__init__.py b/agent_cli/core/__init__.py new file mode 100644 index 000000000..c5895c9f5 --- /dev/null +++ b/agent_cli/core/__init__.py @@ -0,0 +1 @@ +"""Core package for agent-cli.""" diff --git a/agent_cli/audio.py b/agent_cli/core/audio.py similarity index 99% rename from agent_cli/audio.py rename to agent_cli/core/audio.py index 186cc8785..99eee2e67 100644 --- a/agent_cli/audio.py +++ b/agent_cli/core/audio.py @@ -12,7 +12,7 @@ from rich.text import Text from agent_cli import constants -from agent_cli.utils import InteractiveStopEvent, console, print_device_index, print_with_style +from agent_cli.core.utils import InteractiveStopEvent, console, print_device_index, print_with_style if TYPE_CHECKING: import logging @@ -20,7 +20,7 @@ from rich.live import Live - from agent_cli.agents import config + from agent_cli import config class _AudioTee: diff --git a/agent_cli/process_manager.py b/agent_cli/core/process.py similarity index 100% rename from agent_cli/process_manager.py rename to agent_cli/core/process.py diff --git a/agent_cli/utils.py b/agent_cli/core/utils.py similarity index 96% rename from agent_cli/utils.py rename to agent_cli/core/utils.py index 7367c343f..1e1abbfcc 100644 --- a/agent_cli/utils.py +++ b/agent_cli/core/utils.py @@ -23,7 +23,7 @@ from rich.status import Status from rich.text import Text -from agent_cli import process_manager +from agent_cli.core.process import is_process_running, kill_process, read_pid_file if TYPE_CHECKING: import logging @@ -207,7 +207,7 @@ def stop_or_status_or_toggle( ) -> bool: """Handle process control for a given process name.""" if stop: - if process_manager.kill_process(process_name): + if kill_process(process_name): if not quiet: print_with_style(f"✅ {which.capitalize()} stopped.") elif not quiet: @@ -215,8 +215,8 @@ def stop_or_status_or_toggle( return True if status: - if process_manager.is_process_running(process_name): - pid = process_manager.read_pid_file(process_name) + if is_process_running(process_name): + pid = read_pid_file(process_name) if not quiet: print_with_style(f"✅ {which.capitalize()} is running (PID: {pid}).") elif not quiet: @@ -224,8 +224,8 @@ def stop_or_status_or_toggle( return True if toggle: - if process_manager.is_process_running(process_name): - if process_manager.kill_process(process_name) and not quiet: + if is_process_running(process_name): + if kill_process(process_name) and not quiet: print_with_style(f"✅ {which.capitalize()} stopped.") return True if not quiet: diff --git a/agent_cli/llm.py b/agent_cli/llm.py index 7b5bbe76f..ea083104b 100644 --- a/agent_cli/llm.py +++ b/agent_cli/llm.py @@ -9,48 +9,15 @@ import pyperclip from rich.live import Live -from agent_cli.utils import console, live_timer, print_error_message, print_output_panel +from agent_cli.core.utils import console, live_timer, print_error_message, print_output_panel +from agent_cli.services.factory import get_llm_service if TYPE_CHECKING: import logging - from pydantic_ai import Agent from pydantic_ai.tools import Tool - from agent_cli.agents import config - - -def build_agent( - provider_config: config.ProviderSelection, - ollama_config: config.Ollama, - openai_config: config.OpenAILLM, - *, - system_prompt: str | None = None, - instructions: str | None = None, - tools: list[Tool] | None = None, -) -> Agent: - """Construct and return a PydanticAI agent.""" - from pydantic_ai import Agent # noqa: PLC0415 - from pydantic_ai.models.openai import OpenAIModel # noqa: PLC0415 - from pydantic_ai.providers.openai import OpenAIProvider # noqa: PLC0415 - - if provider_config.llm_provider == "openai": - if not openai_config.openai_api_key: - msg = "OpenAI API key is not set." - raise ValueError(msg) - provider = OpenAIProvider(api_key=openai_config.openai_api_key) - model_name = openai_config.openai_llm_model - else: - provider = OpenAIProvider(base_url=f"{ollama_config.ollama_host}/v1") - model_name = ollama_config.ollama_model - - llm_model = OpenAIModel(model_name=model_name, provider=provider) - return Agent( - model=llm_model, - system_prompt=system_prompt or (), - instructions=instructions, - tools=tools or [], - ) + from agent_cli import config # --- LLM (Editing) Logic --- @@ -83,13 +50,11 @@ async def get_llm_response( exit_on_error: bool = False, ) -> str | None: """Get a response from the LLM with optional clipboard and output handling.""" - agent = build_agent( - provider_config=provider_config, - ollama_config=ollama_config, - openai_config=openai_config, - system_prompt=system_prompt, - instructions=agent_instructions, - tools=tools, + llm_service = get_llm_service( + provider_config, + ollama_config, + openai_config, + logger, ) start_time = time.monotonic() @@ -107,10 +72,17 @@ async def get_llm_response( style="bold yellow", quiet=quiet, ): - result = await agent.run(user_input) + result_text = await llm_service.get_response( + system_prompt=system_prompt, + agent_instructions=agent_instructions, + user_input=user_input, + tools=tools, + ) elapsed = time.monotonic() - start_time - result_text = result.output + + if not result_text: + return None if clipboard: pyperclip.copy(result_text) diff --git a/agent_cli/services.py b/agent_cli/services.py deleted file mode 100644 index d33cfaa65..000000000 --- a/agent_cli/services.py +++ /dev/null @@ -1,65 +0,0 @@ -"""Module for interacting with online services like OpenAI.""" - -from __future__ import annotations - -import io -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - import logging - - from openai import AsyncOpenAI - - from agent_cli.agents import config - - -def _get_openai_client(api_key: str) -> AsyncOpenAI: - """Get an OpenAI client instance.""" - from openai import AsyncOpenAI # noqa: PLC0415 - - if not api_key: - msg = "OpenAI API key is not set." - raise ValueError(msg) - return AsyncOpenAI(api_key=api_key) - - -async def transcribe_audio_openai( - audio_data: bytes, - openai_asr_config: config.OpenAIASR, - openai_llm_config: config.OpenAILLM, - logger: logging.Logger, -) -> str: - """Transcribe audio using OpenAI's Whisper API.""" - logger.info("Transcribing audio with OpenAI Whisper...") - if not openai_llm_config.openai_api_key: - msg = "OpenAI API key is not set." - raise ValueError(msg) - client = _get_openai_client(api_key=openai_llm_config.openai_api_key) - audio_file = io.BytesIO(audio_data) - audio_file.name = "audio.wav" - response = await client.audio.transcriptions.create( - model=openai_asr_config.openai_asr_model, - file=audio_file, - ) - return response.text - - -async def synthesize_speech_openai( - text: str, - openai_tts_config: config.OpenAITTS, - openai_llm_config: config.OpenAILLM, - logger: logging.Logger, -) -> bytes: - """Synthesize speech using OpenAI's TTS API.""" - logger.info("Synthesizing speech with OpenAI TTS...") - if not openai_llm_config.openai_api_key: - msg = "OpenAI API key is not set." - raise ValueError(msg) - client = _get_openai_client(api_key=openai_llm_config.openai_api_key) - response = await client.audio.speech.create( - model=openai_tts_config.openai_tts_model, - voice=openai_tts_config.openai_tts_voice, - input=text, - response_format="wav", - ) - return response.content diff --git a/agent_cli/services/__init__.py b/agent_cli/services/__init__.py new file mode 100644 index 000000000..67ed73c1c --- /dev/null +++ b/agent_cli/services/__init__.py @@ -0,0 +1 @@ +"""Services package for agent-cli.""" diff --git a/agent_cli/services/base.py b/agent_cli/services/base.py new file mode 100644 index 000000000..49a4483fc --- /dev/null +++ b/agent_cli/services/base.py @@ -0,0 +1,44 @@ +"""Base classes for external services.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod + + +class ASRService(ABC): + """Abstract base class for Automatic Speech Recognition services.""" + + @abstractmethod + async def transcribe(self, audio_data: bytes) -> str: + """Transcribe audio data to text.""" + + +class LLMService(ABC): + """Abstract base class for Language Model services.""" + + @abstractmethod + async def get_response( + self, + *, + system_prompt: str, + agent_instructions: str, + user_input: str, + tools: list | None = None, + ) -> str | None: + """Get a response from the language model.""" + + +class TTSService(ABC): + """Abstract base class for Text-to-Speech services.""" + + @abstractmethod + async def synthesize(self, text: str) -> bytes | None: + """Synthesize text to speech audio data.""" + + +class WakeWordService(ABC): + """Abstract base class for Wake Word detection services.""" + + @abstractmethod + async def detect(self) -> str | None: + """Detect the wake word.""" diff --git a/agent_cli/services/factory.py b/agent_cli/services/factory.py new file mode 100644 index 000000000..be5370bfc --- /dev/null +++ b/agent_cli/services/factory.py @@ -0,0 +1,80 @@ +"""Factory functions for creating service instances.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from agent_cli.services.local import ( + OllamaLLMService, + WyomingTranscriptionService, + WyomingTTSService, +) +from agent_cli.services.openai import ( + OpenAILLMService, + OpenAITranscriptionService, + OpenAITTSService, +) + +if TYPE_CHECKING: + import logging + + from agent_cli import config + from agent_cli.services.base import ASRService, LLMService, TTSService + + +def get_asr_service( + provider_config: config.ProviderSelection, + wyoming_asr_config: config.WyomingASR, + openai_asr_config: config.OpenAIASR, + openai_llm_config: config.OpenAILLM, + logger: logging.Logger, + *, + quiet: bool = False, +) -> ASRService: + """Get the appropriate ASR service based on the provider.""" + if provider_config.asr_provider == "openai": + return OpenAITranscriptionService( + openai_asr_config=openai_asr_config, + openai_llm_config=openai_llm_config, + logger=logger, + ) + return WyomingTranscriptionService( + wyoming_asr_config=wyoming_asr_config, + logger=logger, + quiet=quiet, + ) + + +def get_llm_service( + provider_config: config.ProviderSelection, + ollama_config: config.Ollama, + openai_llm_config: config.OpenAILLM, + logger: logging.Logger, +) -> LLMService: + """Get the appropriate LLM service based on the provider.""" + if provider_config.llm_provider == "openai": + return OpenAILLMService(openai_llm_config=openai_llm_config, logger=logger) + return OllamaLLMService(ollama_config=ollama_config, logger=logger) + + +def get_tts_service( + provider_config: config.ProviderSelection, + wyoming_tts_config: config.WyomingTTS, + openai_tts_config: config.OpenAITTS, + openai_llm_config: config.OpenAILLM, + logger: logging.Logger, + *, + quiet: bool = False, +) -> TTSService: + """Get the appropriate TTS service based on the provider.""" + if provider_config.tts_provider == "openai": + return OpenAITTSService( + openai_tts_config=openai_tts_config, + openai_llm_config=openai_llm_config, + logger=logger, + ) + return WyomingTTSService( + wyoming_tts_config=wyoming_tts_config, + logger=logger, + quiet=quiet, + ) diff --git a/agent_cli/services/local.py b/agent_cli/services/local.py new file mode 100644 index 000000000..0703c50e2 --- /dev/null +++ b/agent_cli/services/local.py @@ -0,0 +1,312 @@ +"""Module for interacting with local services like Wyoming and Ollama.""" + +from __future__ import annotations + +import asyncio +import io +import wave +from contextlib import asynccontextmanager +from typing import TYPE_CHECKING, Any + +from pydantic_ai import Agent +from pydantic_ai.models.openai import OpenAIModel +from pydantic_ai.providers.openai import OpenAIProvider +from wyoming.asr import Transcribe, Transcript +from wyoming.audio import AudioChunk, AudioStart, AudioStop +from wyoming.client import AsyncClient +from wyoming.tts import Synthesize, SynthesizeVoice +from wyoming.wake import Detect, Detection, NotDetected + +from agent_cli import constants +from agent_cli.core.audio import read_from_queue +from agent_cli.services.base import ASRService, LLMService, TTSService, WakeWordService + +if TYPE_CHECKING: + import logging + from collections.abc import AsyncGenerator, Coroutine + + from rich.live import Live + + from agent_cli import config + + +@asynccontextmanager +async def wyoming_client_context( + host: str, + port: int, + service_name: str, + logger: logging.Logger, + *, + quiet: bool = False, +) -> AsyncGenerator[AsyncClient, None]: + """Connect to a Wyoming server.""" + if not quiet: + logger.info("Connecting to %s server at %s:%s", service_name, host, port) + try: + async with AsyncClient(host, port) as client: + yield client + except ConnectionRefusedError: + logger.exception( + "Connection refused to %s server at %s:%s", + service_name, + host, + port, + ) + raise + + +async def manage_send_receive_tasks( + send_task: Coroutine[Any, Any, Any], + recv_task: Coroutine[Any, Any, Any], + return_when: str = asyncio.ALL_COMPLETED, +) -> tuple[asyncio.Task, asyncio.Task]: + """Manage send and receive tasks for a Wyoming client.""" + send: asyncio.Task = asyncio.create_task(send_task) + recv: asyncio.Task = asyncio.create_task(recv_task) + done, pending = await asyncio.wait({send, recv}, return_when=return_when) + for task in pending: + task.cancel() + return send, recv + + +class WyomingTranscriptionService(ASRService): + """Transcription service using a Wyoming ASR server.""" + + def __init__( + self, + wyoming_asr_config: config.WyomingASR, + logger: logging.Logger, + *, + quiet: bool = False, + ) -> None: + """Initialize the WyomingTranscriptionService.""" + self.wyoming_asr_config = wyoming_asr_config + self.logger = logger + self.quiet = quiet + + async def transcribe(self, audio_data: bytes) -> str: + """Transcribe audio using a Wyoming ASR server.""" + async with wyoming_client_context( + self.wyoming_asr_config.wyoming_asr_ip, + self.wyoming_asr_config.wyoming_asr_port, + "ASR", + self.logger, + quiet=self.quiet, + ) as client: + await client.write_event(Transcribe().event()) + await client.write_event(AudioStart(**constants.WYOMING_AUDIO_CONFIG).event()) + + chunk_size = constants.PYAUDIO_CHUNK_SIZE * 2 + for i in range(0, len(audio_data), chunk_size): + chunk = audio_data[i : i + chunk_size] + await client.write_event( + AudioChunk(audio=chunk, **constants.WYOMING_AUDIO_CONFIG).event(), + ) + await client.write_event(AudioStop().event()) + return await self._receive_transcript(client) + + async def _receive_transcript(self, client: AsyncClient) -> str: + """Receive transcription events and return the final transcript.""" + transcript_text = "" + while True: + event = await client.read_event() + if event is None: + self.logger.warning("Connection to ASR server lost.") + break + if Transcript.is_type(event.type): + transcript = Transcript.from_event(event) + transcript_text = transcript.text + break + return transcript_text + + +class OllamaLLMService(LLMService): + """LLM service using an Ollama server.""" + + def __init__( + self, + ollama_config: config.Ollama, + logger: logging.Logger, + ) -> None: + """Initialize the OllamaLLMService.""" + self.ollama_config = ollama_config + self.logger = logger + provider = OpenAIProvider(base_url=f"{self.ollama_config.ollama_host}/v1") + self.model = OpenAIModel( + model_name=self.ollama_config.ollama_model, + provider=provider, + ) + + async def get_response( + self, + *, + system_prompt: str, + agent_instructions: str, + user_input: str, + tools: list | None = None, + ) -> str | None: + """Get a response from the language model.""" + agent = Agent( + model=self.model, + system_prompt=system_prompt or (), + instructions=agent_instructions, + tools=tools or [], + ) + result = await agent.run(user_input) + return result.output + + +class WyomingTTSService(TTSService): + """TTS service using a Wyoming TTS server.""" + + def __init__( + self, + wyoming_tts_config: config.WyomingTTS, + logger: logging.Logger, + *, + quiet: bool = False, + ) -> None: + """Initialize the WyomingTTSService.""" + self.wyoming_tts_config = wyoming_tts_config + self.logger = logger + self.quiet = quiet + + async def synthesize(self, text: str) -> bytes | None: + """Synthesize speech using a Wyoming TTS server.""" + async with wyoming_client_context( + self.wyoming_tts_config.wyoming_tts_ip, + self.wyoming_tts_config.wyoming_tts_port, + "TTS", + self.logger, + quiet=self.quiet, + ) as client: + synthesize_event = self._create_synthesis_request(text) + _send_task, recv_task = await manage_send_receive_tasks( + client.write_event(synthesize_event.event()), + self._process_audio_events(client), + ) + audio_data, sample_rate, sample_width, channels = recv_task.result() + if sample_rate and sample_width and channels and audio_data: + return self._create_wav_data(audio_data, sample_rate, sample_width, channels) + return None + + def _create_synthesis_request(self, text: str) -> Synthesize: + """Create a synthesis request with optional voice parameters.""" + synthesize_event = Synthesize(text=text) + if ( + self.wyoming_tts_config.wyoming_voice + or self.wyoming_tts_config.wyoming_tts_language + or self.wyoming_tts_config.wyoming_speaker + ): + synthesize_event.voice = SynthesizeVoice( + name=self.wyoming_tts_config.wyoming_voice, + language=self.wyoming_tts_config.wyoming_tts_language, + speaker=self.wyoming_tts_config.wyoming_speaker, + ) + return synthesize_event + + async def _process_audio_events( + self, + client: AsyncClient, + ) -> tuple[bytes, int | None, int | None, int | None]: + """Process audio events from TTS server and return audio data with metadata.""" + audio_data = io.BytesIO() + sample_rate, sample_width, channels = None, None, None + while True: + event = await client.read_event() + if event is None: + break + if AudioStart.is_type(event.type): + audio_start = AudioStart.from_event(event) + sample_rate, sample_width, channels = ( + audio_start.rate, + audio_start.width, + audio_start.channels, + ) + elif AudioChunk.is_type(event.type): + audio_data.write(AudioChunk.from_event(event).audio) + elif AudioStop.is_type(event.type): + break + return audio_data.getvalue(), sample_rate, sample_width, channels + + def _create_wav_data( + self, + audio_data: bytes, + sample_rate: int, + sample_width: int, + channels: int, + ) -> bytes: + """Convert raw audio data to WAV format.""" + wav_data = io.BytesIO() + with wave.open(wav_data, "wb") as wav_file: + wav_file.setnchannels(channels) + wav_file.setsampwidth(sample_width) + wav_file.setframerate(sample_rate) + wav_file.writeframes(audio_data) + return wav_data.getvalue() + + +class WyomingWakeWordService(WakeWordService): + """Wake word detection service using a Wyoming wake word server.""" + + def __init__( + self, + wake_word_config: config.WakeWord, + logger: logging.Logger, + queue: asyncio.Queue, + *, + live: Live | None = None, + quiet: bool = False, + ) -> None: + """Initialize the WyomingWakeWordService.""" + self.wake_word_config = wake_word_config + self.logger = logger + self.queue = queue + self.live = live + self.quiet = quiet + + async def detect(self) -> str | None: + """Detect the wake word.""" + async with wyoming_client_context( + self.wake_word_config.wake_server_ip, + self.wake_word_config.wake_server_port, + "wake word", + self.logger, + quiet=self.quiet, + ) as client: + await client.write_event(Detect(names=[self.wake_word_config.wake_word_name]).event()) + _send_task, recv_task = await manage_send_receive_tasks( + self._send_audio_from_queue_for_wake_detection(client), + self._receive_wake_detection(client), + return_when=asyncio.FIRST_COMPLETED, + ) + if recv_task.done() and not recv_task.cancelled(): + return recv_task.result() + return None + + async def _send_audio_from_queue_for_wake_detection(self, client: AsyncClient) -> None: + """Read from a queue and send to Wyoming wake word server.""" + await client.write_event(AudioStart(**constants.WYOMING_AUDIO_CONFIG).event()) + try: + await read_from_queue( + queue=self.queue, + chunk_handler=lambda chunk: client.write_event( + AudioChunk(audio=chunk, **constants.WYOMING_AUDIO_CONFIG).event(), + ), + logger=self.logger, + ) + finally: + if client._writer is not None: + await client.write_event(AudioStop().event()) + + async def _receive_wake_detection(self, client: AsyncClient) -> str | None: + """Receive wake word detection events.""" + while True: + event = await client.read_event() + if event is None: + break + if Detection.is_type(event.type): + return Detection.from_event(event).name + if NotDetected.is_type(event.type): + break + return None diff --git a/agent_cli/services/openai.py b/agent_cli/services/openai.py new file mode 100644 index 000000000..c372a8fa4 --- /dev/null +++ b/agent_cli/services/openai.py @@ -0,0 +1,129 @@ +"""Module for interacting with OpenAI services.""" + +from __future__ import annotations + +import io +from typing import TYPE_CHECKING + +from agent_cli.services.base import ASRService, LLMService, TTSService + +if TYPE_CHECKING: + import logging + + from openai import AsyncOpenAI + + from agent_cli import config + + +def _get_openai_client(api_key: str) -> AsyncOpenAI: + """Get an OpenAI client instance.""" + from openai import AsyncOpenAI # noqa: PLC0415 + + if not api_key: + msg = "OpenAI API key is not set." + raise ValueError(msg) + return AsyncOpenAI(api_key=api_key) + + +class OpenAITranscriptionService(ASRService): + """Transcription service using OpenAI's Whisper API.""" + + def __init__( + self, + openai_asr_config: config.OpenAIASR, + openai_llm_config: config.OpenAILLM, + logger: logging.Logger, + ) -> None: + """Initialize the OpenAITranscriptionService.""" + self.openai_asr_config = openai_asr_config + self.openai_llm_config = openai_llm_config + self.logger = logger + if not self.openai_llm_config.openai_api_key: + msg = "OpenAI API key is not set." + raise ValueError(msg) + self.client = _get_openai_client(api_key=self.openai_llm_config.openai_api_key) + + async def transcribe(self, audio_data: bytes) -> str: + """Transcribe audio using OpenAI's Whisper API.""" + self.logger.info("Transcribing audio with OpenAI Whisper...") + audio_file = io.BytesIO(audio_data) + audio_file.name = "audio.wav" + response = await self.client.audio.transcriptions.create( + model=self.openai_asr_config.openai_asr_model, + file=audio_file, + ) + return response.text + + +class OpenAILLMService(LLMService): + """LLM service using OpenAI's API.""" + + def __init__( + self, + openai_llm_config: config.OpenAILLM, + logger: logging.Logger, + ) -> None: + """Initialize the OpenAILLMService.""" + from pydantic_ai.models.openai import OpenAIModel # noqa: PLC0415 + from pydantic_ai.providers.openai import OpenAIProvider # noqa: PLC0415 + + self.openai_llm_config = openai_llm_config + self.logger = logger + if not self.openai_llm_config.openai_api_key: + msg = "OpenAI API key is not set." + raise ValueError(msg) + provider = OpenAIProvider(api_key=self.openai_llm_config.openai_api_key) + self.model = OpenAIModel( + model_name=self.openai_llm_config.openai_llm_model, + provider=provider, + ) + + async def get_response( + self, + *, + system_prompt: str, + agent_instructions: str, + user_input: str, + tools: list | None = None, + ) -> str | None: + """Get a response from the language model.""" + from pydantic_ai import Agent # noqa: PLC0415 + + agent = Agent( + model=self.model, + system_prompt=system_prompt or (), + instructions=agent_instructions, + tools=tools or [], + ) + result = await agent.run(user_input) + return result.output + + +class OpenAITTSService(TTSService): + """TTS service using OpenAI's API.""" + + def __init__( + self, + openai_tts_config: config.OpenAITTS, + openai_llm_config: config.OpenAILLM, + logger: logging.Logger, + ) -> None: + """Initialize the OpenAITTSService.""" + self.openai_tts_config = openai_tts_config + self.openai_llm_config = openai_llm_config + self.logger = logger + if not self.openai_llm_config.openai_api_key: + msg = "OpenAI API key is not set." + raise ValueError(msg) + self.client = _get_openai_client(api_key=self.openai_llm_config.openai_api_key) + + async def synthesize(self, text: str) -> bytes: + """Synthesize speech using OpenAI's TTS API.""" + self.logger.info("Synthesizing speech with OpenAI TTS...") + response = await self.client.audio.speech.create( + model=self.openai_tts_config.openai_tts_model, + voice=self.openai_tts_config.openai_tts_voice, + input=text, + response_format="wav", + ) + return response.content diff --git a/agent_cli/tts.py b/agent_cli/tts.py index cded7a074..cf7443187 100644 --- a/agent_cli/tts.py +++ b/agent_cli/tts.py @@ -6,17 +6,21 @@ import importlib.util import io import wave -from functools import partial from typing import TYPE_CHECKING from wyoming.audio import AudioChunk, AudioStart, AudioStop from wyoming.tts import Synthesize, SynthesizeVoice from agent_cli import constants -from agent_cli.audio import open_pyaudio_stream, pyaudio_context, setup_output_stream -from agent_cli.services import synthesize_speech_openai -from agent_cli.utils import InteractiveStopEvent, live_timer, print_error_message, print_with_style -from agent_cli.wyoming_utils import manage_send_receive_tasks, wyoming_client_context +from agent_cli.core.audio import open_pyaudio_stream, pyaudio_context, setup_output_stream +from agent_cli.core.utils import ( + InteractiveStopEvent, + live_timer, + print_error_message, + print_with_style, +) +from agent_cli.services.factory import get_tts_service +from agent_cli.services.local import manage_send_receive_tasks, wyoming_client_context if TYPE_CHECKING: import logging @@ -25,7 +29,7 @@ from rich.live import Live from wyoming.client import AsyncClient - from agent_cli.agents import config + from agent_cli import config has_audiostretchy = importlib.util.find_spec("audiostretchy") is not None @@ -36,17 +40,28 @@ def get_synthesizer( wyoming_tts_config: config.WyomingTTS, openai_tts_config: config.OpenAITTS, openai_llm_config: config.OpenAILLM, + logger: logging.Logger, + *, + quiet: bool = False, ) -> Callable[..., Awaitable[bytes | None]]: """Return the appropriate synthesizer based on the config.""" if not audio_output_config.enable_tts: return _dummy_synthesizer - if provider_config.tts_provider == "openai": - return partial( - _synthesize_speech_openai, - openai_tts_config=openai_tts_config, - openai_llm_config=openai_llm_config, - ) - return partial(_synthesize_speech_wyoming, wyoming_tts_config=wyoming_tts_config) + + tts_service = get_tts_service( + provider_config, + wyoming_tts_config, + openai_tts_config, + openai_llm_config, + logger, + quiet=quiet, + ) + + async def synthesize(text: str) -> bytes | None: + """Synthesize speech from text.""" + return await tts_service.synthesize(text) + + return synthesize def _create_synthesis_request( @@ -133,23 +148,6 @@ async def _dummy_synthesizer(**_kwargs: object) -> bytes | None: return None -async def _synthesize_speech_openai( - *, - text: str, - openai_tts_config: config.OpenAITTS, - openai_llm_config: config.OpenAILLM, - logger: logging.Logger, - **_kwargs: object, -) -> bytes | None: - """Synthesize speech from text using OpenAI TTS server.""" - return await synthesize_speech_openai( - text=text, - openai_tts_config=openai_tts_config, - openai_llm_config=openai_llm_config, - logger=logger, - ) - - async def _synthesize_speech_wyoming( *, text: str, @@ -282,6 +280,7 @@ async def speak_text( wyoming_tts_config, openai_tts_config, openai_llm_config, + logger, ) audio_data = None try: diff --git a/agent_cli/wake_word.py b/agent_cli/wake_word.py deleted file mode 100644 index 8015ea69e..000000000 --- a/agent_cli/wake_word.py +++ /dev/null @@ -1,136 +0,0 @@ -"""Module for Wake Word Detection using Wyoming.""" - -from __future__ import annotations - -import asyncio -from typing import TYPE_CHECKING - -from wyoming.audio import AudioChunk, AudioStart, AudioStop -from wyoming.wake import Detect, Detection, NotDetected - -from agent_cli import constants -from agent_cli.audio import read_from_queue -from agent_cli.wyoming_utils import manage_send_receive_tasks, wyoming_client_context - -if TYPE_CHECKING: - import logging - from collections.abc import Callable - - from rich.live import Live - from wyoming.client import AsyncClient - - -async def _send_audio_from_queue_for_wake_detection( - client: AsyncClient, - queue: asyncio.Queue, - logger: logging.Logger, - live: Live | None, - quiet: bool, - progress_message: str, -) -> None: - """Read from a queue and send to Wyoming wake word server.""" - await client.write_event(AudioStart(**constants.WYOMING_AUDIO_CONFIG).event()) - seconds_streamed = 0.0 - - async def send_chunk(chunk: bytes) -> None: - nonlocal seconds_streamed - """Send audio chunk to wake word server.""" - await client.write_event( - AudioChunk(audio=chunk, **constants.WYOMING_AUDIO_CONFIG).event(), - ) - seconds_streamed += len(chunk) / (constants.PYAUDIO_RATE * constants.PYAUDIO_CHANNELS * 2) - if live and not quiet: - live.update(f"{progress_message}... ({seconds_streamed:.1f}s)") - - try: - await read_from_queue( - queue=queue, - chunk_handler=send_chunk, - logger=logger, - ) - finally: - if client._writer is not None: - await client.write_event(AudioStop().event()) - logger.debug("Sent AudioStop for wake detection") - - -async def _receive_wake_detection( - client: AsyncClient, - logger: logging.Logger, - *, - detection_callback: Callable[[str], None] | None = None, -) -> str | None: - """Receive wake word detection events. - - Args: - client: Wyoming client connection - logger: Logger instance - detection_callback: Optional callback for when wake word is detected - - Returns: - Name of detected wake word or None if no detection - - """ - while True: - event = await client.read_event() - if event is None: - logger.warning("Connection to wake word server lost.") - break - - if Detection.is_type(event.type): - detection = Detection.from_event(event) - wake_word_name = detection.name or "unknown" - logger.info("Wake word detected: %s", wake_word_name) - if detection_callback: - detection_callback(wake_word_name) - return wake_word_name - if NotDetected.is_type(event.type): - logger.debug("No wake word detected") - break - logger.debug("Ignoring event type: %s", event.type) - - return None - - -async def detect_wake_word_from_queue( - wake_server_ip: str, - wake_server_port: int, - wake_word_name: str, - logger: logging.Logger, - queue: asyncio.Queue, - *, - live: Live | None = None, - detection_callback: Callable[[str], None] | None = None, - quiet: bool = False, - progress_message: str = "Listening for wake word", -) -> str | None: - """Detect wake word from an audio queue.""" - try: - async with wyoming_client_context( - wake_server_ip, - wake_server_port, - "wake word", - logger, - quiet=quiet, - ) as client: - await client.write_event(Detect(names=[wake_word_name]).event()) - - _send_task, recv_task = await manage_send_receive_tasks( - _send_audio_from_queue_for_wake_detection( - client, - queue, - logger, - live, - quiet, - progress_message, - ), - _receive_wake_detection(client, logger, detection_callback=detection_callback), - return_when=asyncio.FIRST_COMPLETED, - ) - - if recv_task.done() and not recv_task.cancelled(): - return recv_task.result() - - return None - except (ConnectionRefusedError, asyncio.CancelledError, Exception): - return None diff --git a/agent_cli/wyoming_utils.py b/agent_cli/wyoming_utils.py deleted file mode 100644 index edbe31a78..000000000 --- a/agent_cli/wyoming_utils.py +++ /dev/null @@ -1,95 +0,0 @@ -"""Utility functions for Wyoming protocol interactions to eliminate code duplication.""" - -from __future__ import annotations - -import asyncio -from contextlib import asynccontextmanager -from typing import TYPE_CHECKING - -from wyoming.client import AsyncClient - -from agent_cli.utils import print_error_message - -if TYPE_CHECKING: - import logging - from collections.abc import AsyncGenerator, Coroutine - - -@asynccontextmanager -async def wyoming_client_context( - server_ip: str, - server_port: int, - server_type: str, - logger: logging.Logger, - *, - quiet: bool = False, -) -> AsyncGenerator[AsyncClient, None]: - """Context manager for Wyoming client connections with unified error handling. - - Args: - server_ip: Wyoming server IP - server_port: Wyoming server port - server_type: Type of server (e.g., "ASR", "TTS", "wake word") - logger: Logger instance - quiet: If True, suppress console error messages - - Yields: - Connected Wyoming client - - Raises: - ConnectionRefusedError: If connection fails - Exception: For other connection errors - - """ - uri = f"tcp://{server_ip}:{server_port}" - logger.info("Connecting to Wyoming %s server at %s", server_type, uri) - - try: - async with AsyncClient.from_uri(uri) as client: - logger.info("%s connection established", server_type) - yield client - except ConnectionRefusedError: - logger.exception("%s connection refused.", server_type) - if not quiet: - print_error_message( - f"{server_type} connection refused.", - f"Is the Wyoming {server_type.lower()} server running at {uri}?", - ) - raise - except Exception as e: - logger.exception("An error occurred during %s connection", server_type.lower()) - if not quiet: - print_error_message(f"{server_type} error: {e}") - raise - - -async def manage_send_receive_tasks( - send_task_coro: Coroutine, - receive_task_coro: Coroutine, - *, - return_when: str = asyncio.ALL_COMPLETED, -) -> tuple[asyncio.Task, asyncio.Task]: - """Manage send and receive tasks with proper cancellation. - - Args: - send_task_coro: Send task coroutine - receive_task_coro: Receive task coroutine - return_when: When to return (e.g., asyncio.ALL_COMPLETED) - - Returns: - Tuple of (send_task, receive_task) - both completed or cancelled - - """ - send_task = asyncio.create_task(send_task_coro) - recv_task = asyncio.create_task(receive_task_coro) - - done, pending = await asyncio.wait( - [send_task, recv_task], - return_when=return_when, - ) - - # Cancel any pending tasks - for task in pending: - task.cancel() - - return send_task, recv_task diff --git a/tests/agents/test_interactive.py b/tests/agents/test_interactive.py index c75dfe6bc..716387ac0 100644 --- a/tests/agents/test_interactive.py +++ b/tests/agents/test_interactive.py @@ -9,7 +9,7 @@ import pytest -from agent_cli.agents import config +from agent_cli import config from agent_cli.agents.chat import ( ConversationEntry, _async_main, @@ -17,7 +17,7 @@ _load_conversation_history, _save_conversation_history, ) -from agent_cli.utils import InteractiveStopEvent +from agent_cli.core.utils import InteractiveStopEvent if TYPE_CHECKING: from pathlib import Path diff --git a/tests/agents/test_interactive_extra.py b/tests/agents/test_interactive_extra.py index bc2a188f5..de2082a87 100644 --- a/tests/agents/test_interactive_extra.py +++ b/tests/agents/test_interactive_extra.py @@ -5,13 +5,13 @@ import pytest from typer.testing import CliRunner -from agent_cli.agents import config +from agent_cli import config from agent_cli.agents.chat import ( _async_main, _handle_conversation_turn, ) from agent_cli.cli import app -from agent_cli.utils import InteractiveStopEvent +from agent_cli.core.utils import InteractiveStopEvent @pytest.mark.asyncio diff --git a/tests/agents/test_speak.py b/tests/agents/test_speak.py index 57c4aebfd..9acf51ca2 100644 --- a/tests/agents/test_speak.py +++ b/tests/agents/test_speak.py @@ -7,7 +7,7 @@ import pytest from typer.testing import CliRunner -from agent_cli.agents import config +from agent_cli import config from agent_cli.agents.speak import _async_main from agent_cli.cli import app diff --git a/tests/agents/test_speak_e2e.py b/tests/agents/test_speak_e2e.py index 0a63eac00..80d2d33b1 100644 --- a/tests/agents/test_speak_e2e.py +++ b/tests/agents/test_speak_e2e.py @@ -6,7 +6,7 @@ import pytest -from agent_cli.agents import config +from agent_cli import config from agent_cli.agents.speak import _async_main from tests.mocks.audio import MockPyAudio from tests.mocks.wyoming import MockTTSClient diff --git a/tests/agents/test_transcribe.py b/tests/agents/test_transcribe.py index d49361e87..e069e65ac 100644 --- a/tests/agents/test_transcribe.py +++ b/tests/agents/test_transcribe.py @@ -8,7 +8,8 @@ import pytest -from agent_cli.agents import config, transcribe +from agent_cli import config +from agent_cli.agents import transcribe from tests.mocks.wyoming import MockASRClient diff --git a/tests/agents/test_transcribe_e2e.py b/tests/agents/test_transcribe_e2e.py index 8df4e2930..c33ac751b 100644 --- a/tests/agents/test_transcribe_e2e.py +++ b/tests/agents/test_transcribe_e2e.py @@ -8,7 +8,7 @@ import pytest -from agent_cli.agents import config +from agent_cli import config from agent_cli.agents.transcribe import _async_main from tests.mocks.audio import MockPyAudio from tests.mocks.wyoming import MockASRClient diff --git a/tests/agents/test_tts_common.py b/tests/agents/test_tts_common.py index 8a0686c4e..18f69a0a6 100644 --- a/tests/agents/test_tts_common.py +++ b/tests/agents/test_tts_common.py @@ -7,7 +7,7 @@ import pytest -from agent_cli.agents import config +from agent_cli import config from agent_cli.agents._tts_common import handle_tts_playback if TYPE_CHECKING: diff --git a/tests/agents/test_tts_common_extra.py b/tests/agents/test_tts_common_extra.py index 4749d20b5..f481c2eaf 100644 --- a/tests/agents/test_tts_common_extra.py +++ b/tests/agents/test_tts_common_extra.py @@ -7,7 +7,7 @@ import pytest -from agent_cli.agents import config +from agent_cli import config from agent_cli.agents._tts_common import _save_audio_file, handle_tts_playback diff --git a/tests/agents/test_voice_agent_common.py b/tests/agents/test_voice_agent_common.py index a3168610c..5eab0ba48 100644 --- a/tests/agents/test_voice_agent_common.py +++ b/tests/agents/test_voice_agent_common.py @@ -6,7 +6,7 @@ import pytest -from agent_cli.agents import config +from agent_cli import config from agent_cli.agents._voice_agent_common import ( get_instruction_from_audio, process_instruction_and_respond, diff --git a/tests/agents/test_voice_edit_e2e.py b/tests/agents/test_voice_edit_e2e.py index e7d345828..d0187015c 100644 --- a/tests/agents/test_voice_edit_e2e.py +++ b/tests/agents/test_voice_edit_e2e.py @@ -6,7 +6,7 @@ import pytest -from agent_cli.agents import config +from agent_cli import config from agent_cli.agents.voice_edit import ( AGENT_INSTRUCTIONS, SYSTEM_PROMPT, diff --git a/tests/test_config.py b/tests/test_config.py index 67f70b418..04bf3328d 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -11,7 +11,7 @@ from typer.testing import CliRunner from agent_cli.cli import set_config_defaults -from agent_cli.config_loader import load_config +from agent_cli.config import load_config if TYPE_CHECKING: from pathlib import Path @@ -89,8 +89,8 @@ def test_set_config_defaults(config_file: Path) -> None: assert ctx.default_map == expected_merged_defaults -@patch("agent_cli.config_loader.CONFIG_PATH") -@patch("agent_cli.config_loader.CONFIG_PATH_2") +@patch("agent_cli.config.CONFIG_PATH") +@patch("agent_cli.config.CONFIG_PATH_2") def test_default_config_paths( mock_path2: MagicMock, mock_path1: MagicMock, diff --git a/tests/test_llm.py b/tests/test_llm.py index 579994259..7a768bd0c 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -7,7 +7,7 @@ import pytest -from agent_cli.agents import config +from agent_cli import config from agent_cli.llm import build_agent, get_llm_response, process_and_update_clipboard diff --git a/tests/test_services.py b/tests/test_services.py index 11d868baf..998fc2a82 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -6,8 +6,7 @@ import pytest -from agent_cli import asr, tts -from agent_cli.agents import config +from agent_cli import asr, config, tts from agent_cli.services import synthesize_speech_openai, transcribe_audio_openai diff --git a/tests/test_tts.py b/tests/test_tts.py index c7919cae7..900b63236 100644 --- a/tests/test_tts.py +++ b/tests/test_tts.py @@ -8,7 +8,7 @@ import pytest -from agent_cli.agents import config +from agent_cli import config from agent_cli.tts import _apply_speed_adjustment, get_synthesizer, speak_text diff --git a/tests/test_wake_word.py b/tests/test_wake_word.py deleted file mode 100644 index 38bc49fa1..000000000 --- a/tests/test_wake_word.py +++ /dev/null @@ -1,146 +0,0 @@ -"""Tests for the wake word detection module.""" - -import asyncio -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest -from rich.live import Live - -from agent_cli import wake_word -from agent_cli.utils import InteractiveStopEvent - - -@pytest.fixture -def mock_pyaudio(): - """Mock PyAudio instance.""" - return MagicMock() - - -@pytest.fixture -def mock_logger(): - """Mock logger instance.""" - return MagicMock() - - -@pytest.fixture -def mock_stop_event(): - """Mock stop event.""" - stop_event = MagicMock(spec=InteractiveStopEvent) - stop_event.is_set.return_value = False - stop_event.ctrl_c_pressed = False - return stop_event - - -@pytest.fixture -def mock_live(): - """Mock Rich Live instance.""" - return MagicMock(spec=Live) - - -class TestReceiveWakeDetection: - """Tests for _receive_wake_detection function.""" - - @pytest.mark.asyncio - async def test_returns_detected_wake_word(self, mock_logger: MagicMock) -> None: - """Test detection of wake word.""" - mock_client = AsyncMock() - - # Mock detection event - mock_event = MagicMock() - mock_event.type = "detection" - - # Mock Detection.is_type and Detection.from_event - with ( - patch("agent_cli.wake_word.Detection.is_type", return_value=True), - patch("agent_cli.wake_word.Detection.from_event") as mock_from_event, - ): - mock_detection = MagicMock() - mock_detection.name = "test_wake_word" - mock_from_event.return_value = mock_detection - - mock_client.read_event.return_value = mock_event - - result = await wake_word._receive_wake_detection(mock_client, mock_logger) - - assert result == "test_wake_word" - mock_logger.info.assert_called_with("Wake word detected: %s", "test_wake_word") - - @pytest.mark.asyncio - async def test_calls_detection_callback(self, mock_logger: MagicMock) -> None: - """Test that detection callback is called.""" - mock_client = AsyncMock() - mock_callback = MagicMock() - - # Mock detection event - mock_event = MagicMock() - mock_event.type = "detection" - - with ( - patch("agent_cli.wake_word.Detection.is_type", return_value=True), - patch("agent_cli.wake_word.Detection.from_event") as mock_from_event, - ): - mock_detection = MagicMock() - mock_detection.name = "test_wake_word" - mock_from_event.return_value = mock_detection - - mock_client.read_event.return_value = mock_event - - result = await wake_word._receive_wake_detection( - mock_client, - mock_logger, - detection_callback=mock_callback, - ) - - assert result == "test_wake_word" - mock_callback.assert_called_once_with("test_wake_word") - - @pytest.mark.asyncio - async def test_handles_not_detected_event(self, mock_logger: MagicMock) -> None: - """Test handling of not-detected event.""" - mock_client = AsyncMock() - - # Mock not-detected event - mock_event = MagicMock() - mock_event.type = "not-detected" - - with ( - patch("agent_cli.wake_word.Detection.is_type", return_value=False), - patch("agent_cli.wake_word.NotDetected.is_type", return_value=True), - ): - mock_client.read_event.return_value = mock_event - - result = await wake_word._receive_wake_detection(mock_client, mock_logger) - - assert result is None - mock_logger.debug.assert_called_with("No wake word detected") - - @pytest.mark.asyncio - async def test_handles_connection_loss(self, mock_logger: MagicMock) -> None: - """Test handling of lost connection.""" - mock_client = AsyncMock() - mock_client.read_event.return_value = None - - result = await wake_word._receive_wake_detection(mock_client, mock_logger) - - assert result is None - mock_logger.warning.assert_called_with("Connection to wake word server lost.") - - -@pytest.mark.asyncio -@patch("agent_cli.wake_word.wyoming_client_context", side_effect=ConnectionRefusedError) -async def test_detect_wake_word_from_queue_connection_error( - mock_wyoming_client_context: MagicMock, - mock_logger: MagicMock, - mock_live: MagicMock, -): - """Test that detect_wake_word_from_queue handles ConnectionRefusedError.""" - result = await wake_word.detect_wake_word_from_queue( - "localhost", - 1234, - "test_word", - mock_logger, - asyncio.Queue(), - live=mock_live, - ) - assert result is None - mock_wyoming_client_context.assert_called_once()