diff --git a/agent_cli/agents/claude_serve.py b/agent_cli/agents/claude_serve.py new file mode 100644 index 000000000..179b7b91f --- /dev/null +++ b/agent_cli/agents/claude_serve.py @@ -0,0 +1,245 @@ +"""Claude Code remote server command for Agent CLI.""" + +from __future__ import annotations + +import ipaddress +import json +from datetime import UTC +from importlib.util import find_spec +from pathlib import Path + +import typer + +from agent_cli import opts +from agent_cli.cli import app +from agent_cli.core.utils import ( + console, + print_command_line_args, + print_error_message, +) + +has_uvicorn = find_spec("uvicorn") is not None +has_fastapi = find_spec("fastapi") is not None +has_claude_sdk = find_spec("claude_agent_sdk") is not None + +# Default paths for SSL certificates +SSL_CERT_DIR = Path.home() / ".config" / "agent-cli" / "ssl" +SSL_CERT_FILE = SSL_CERT_DIR / "cert.pem" +SSL_KEY_FILE = SSL_CERT_DIR / "key.pem" + + +def _generate_self_signed_cert() -> tuple[Path, Path]: + """Generate a self-signed SSL certificate for HTTPS.""" + from datetime import datetime, timedelta # noqa: PLC0415 + + from cryptography import x509 # noqa: PLC0415 + from cryptography.hazmat.primitives import hashes, serialization # noqa: PLC0415 + from cryptography.hazmat.primitives.asymmetric import rsa # noqa: PLC0415 + from cryptography.x509.oid import NameOID # noqa: PLC0415 + + SSL_CERT_DIR.mkdir(parents=True, exist_ok=True) + + # Generate private key + key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + + # Generate certificate + subject = issuer = x509.Name( + [ + x509.NameAttribute(NameOID.COMMON_NAME, "Claude Code Server"), + ], + ) + cert = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.now(UTC)) + .not_valid_after(datetime.now(UTC) + timedelta(days=365)) + .add_extension( + x509.SubjectAlternativeName( + [ + x509.DNSName("localhost"), + x509.IPAddress(ipaddress.IPv4Address("127.0.0.1")), + ], + ), + critical=False, + ) + .sign(key, hashes.SHA256()) + ) + + # Write certificate and key + SSL_KEY_FILE.write_bytes( + key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ), + ) + SSL_CERT_FILE.write_bytes(cert.public_bytes(serialization.Encoding.PEM)) + + return SSL_CERT_FILE, SSL_KEY_FILE + + +def run_claude_server( + host: str = "0.0.0.0", # noqa: S104 + port: int = 8765, + reload: bool = False, + cwd: Path | None = None, + projects: dict[str, str] | None = None, + default_project: str | None = None, + ssl: bool = False, +) -> None: + """Run the Claude Code FastAPI server.""" + import os # noqa: PLC0415 + + import uvicorn # noqa: PLC0415 + + # Set working directory for the API to use + if cwd: + os.environ["CLAUDE_API_CWD"] = str(cwd.resolve()) + + # Pass projects config via environment variable + if projects: + os.environ["CLAUDE_API_PROJECTS"] = json.dumps(projects) + if default_project: + os.environ["CLAUDE_API_DEFAULT_PROJECT"] = default_project + + ssl_keyfile = None + ssl_certfile = None + if ssl: + if not SSL_CERT_FILE.exists() or not SSL_KEY_FILE.exists(): + console.print("[yellow]Generating self-signed SSL certificate...[/yellow]") + _generate_self_signed_cert() + ssl_certfile = str(SSL_CERT_FILE) + ssl_keyfile = str(SSL_KEY_FILE) + + uvicorn.run( + "agent_cli.claude_api:app", + host=host, + port=port, + reload=reload, + log_level="info", + ssl_keyfile=ssl_keyfile, + ssl_certfile=ssl_certfile, + ) + + +@app.command("claude-serve") +def claude_serve( + host: str = typer.Option( + "0.0.0.0", # noqa: S104 + help="Host to bind the server to", + ), + port: int = typer.Option(8880, help="Port to bind the server to"), + cwd: Path = typer.Option( # noqa: B008 + None, + help="Working directory for Claude Code (defaults to current directory)", + ), + reload: bool = typer.Option( + False, # noqa: FBT003 + "--reload", + help="Enable auto-reload for development", + ), + ssl: bool = typer.Option( + False, # noqa: FBT003 + "--ssl", + help="Enable HTTPS with self-signed certificate (required for voice on Safari/iOS)", + ), + config_file: str | None = opts.CONFIG_FILE, + print_args: bool = opts.PRINT_ARGS, +) -> None: + """Start Claude Code remote server for iOS/web access. + + This starts a FastAPI server that exposes Claude Code capabilities via REST and WebSocket + endpoints, allowing remote access from iOS Shortcuts, web interfaces, or any HTTP client. + + Prerequisites: + - Run `claude /login` once to authenticate with your Claude.ai account + - Install dependencies: pip install agent-cli[claude] + + Example usage: + agent-cli claude-serve --port 8765 + + Configure projects in config.toml: + [claude_server] + default_project = "my-project" + + [claude_server.projects] + my-project = "/path/to/project" + dotfiles = "~/.dotfiles" + + Endpoints: + - POST /prompt - Simple prompt with auto project management + - GET /logs - View recent logs + - GET /log/{id} - View log details + - GET /projects - List configured projects + - POST /switch-project - Switch current project + """ + if print_args: + print_command_line_args(locals()) + + if not has_uvicorn or not has_fastapi: + msg = ( + "uvicorn or fastapi is not installed. " + "Please install with: pip install agent-cli[claude]" + ) + print_error_message(msg) + raise typer.Exit(1) + + if not has_claude_sdk: + msg = ( + "claude-agent-sdk is not installed. Please install with: pip install agent-cli[claude]" + ) + print_error_message(msg) + raise typer.Exit(1) + + # Load config for projects + from agent_cli.config import load_config # noqa: PLC0415 + + config = load_config(config_file) + claude_server_config = config.get("claude_server", {}) + projects = claude_server_config.get("projects", {}) + default_project = claude_server_config.get("default_project") + + # Default to current directory if not specified + if cwd is None: + cwd = Path.cwd() + + # If no projects configured, add cwd as default project + if not projects: + projects = {"default": str(cwd.resolve())} + default_project = "default" + + protocol = "https" if ssl else "http" + console.print( + f"[bold green]Starting Claude Code remote server on {protocol}://{host}:{port}[/bold green]", + ) + console.print(f"[dim]Working directory: {cwd.resolve()}[/dim]") + if projects: + console.print(f"[dim]Projects: {', '.join(projects.keys())}[/dim]") + if default_project: + console.print(f"[dim]Default project: {default_project}[/dim]") + console.print() + console.print("[bold]Endpoints:[/bold]") + console.print(f" Chat {protocol}://{host}:{port}/chat") + console.print(f" POST {protocol}://{host}:{port}/prompt") + console.print(f" GET {protocol}://{host}:{port}/logs") + console.print() + + if ssl: + console.print( + "[yellow]HTTPS enabled (self-signed cert) - accept certificate warning in browser[/yellow]", + ) + if reload: + console.print("[yellow]Auto-reload enabled for development[/yellow]") + + run_claude_server( + host=host, + port=port, + reload=reload, + cwd=cwd, + projects=projects, + default_project=default_project, + ssl=ssl, + ) diff --git a/agent_cli/claude_api.py b/agent_cli/claude_api.py new file mode 100644 index 000000000..d5923c35a --- /dev/null +++ b/agent_cli/claude_api.py @@ -0,0 +1,1279 @@ +"""FastAPI web service for remote Claude Code access via Agent SDK.""" + +from __future__ import annotations + +import asyncio +import contextlib +import json +import logging +import os +import re +import uuid +from dataclasses import dataclass +from datetime import UTC, datetime +from pathlib import Path +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + +import httpx +from claude_agent_sdk import ClaudeAgentOptions, query +from claude_agent_sdk.types import ( + AssistantMessage, + ResultMessage, + SystemMessage, + TextBlock, + ThinkingBlock, + ToolResultBlock, + ToolUseBlock, +) +from fastapi import FastAPI, HTTPException, Request, WebSocket, WebSocketDisconnect +from fastapi.responses import HTMLResponse, StreamingResponse +from pydantic import BaseModel + +# Configure logging +logging.basicConfig(level=logging.INFO) +LOGGER = logging.getLogger(__name__) + +# Constants +PROMPT_TRUNCATE_LENGTH = 50 + +# Default tools allowed for Claude Code operations +DEFAULT_ALLOWED_TOOLS = [ + "Read", + "Write", + "Edit", + "Bash", + "Glob", + "Grep", + "WebSearch", + "WebFetch", +] + + +# Pydantic models for request/response +class NewSessionRequest(BaseModel): + """Request model for creating a new session.""" + + cwd: str = "." + + +class NewSessionResponse(BaseModel): + """Response model for session creation.""" + + session_id: str + status: str = "created" + + +class PromptRequest(BaseModel): + """Request model for sending a prompt.""" + + prompt: str + + +class PromptResponse(BaseModel): + """Response model for prompt results.""" + + result: str + success: bool + error: str | None = None + + +class HealthResponse(BaseModel): + """Response model for health check.""" + + status: str + version: str + + +class SimplePromptRequest(BaseModel): + """Request model for simplified prompt endpoint.""" + + prompt: str + project: str | None = None # Optional project name, uses default if not specified + + +class SimplePromptResponse(BaseModel): + """Response model for simplified prompt endpoint.""" + + summary: str + files_changed: list[str] + log_id: str + log_url: str + success: bool + error: str | None = None + + +class ToolCall(BaseModel): + """Represents a tool call made during execution.""" + + name: str + input: dict[str, Any] + file_path: str | None = None + + +@dataclass +class LogEntry: + """A log entry for a Claude Code interaction.""" + + log_id: str + project: str + prompt: str + summary: str + files_changed: list[str] + tool_calls: list[dict[str, Any]] + full_response: str + timestamp: datetime + success: bool + error: str | None = None + + +class LogStore: + """In-memory storage for conversation logs.""" + + def __init__(self, max_entries: int = 100) -> None: + """Initialize the log store.""" + self.entries: dict[str, LogEntry] = {} + self.max_entries = max_entries + + def add(self, entry: LogEntry) -> None: + """Add a log entry.""" + # Remove oldest entries if at capacity + if len(self.entries) >= self.max_entries: + oldest_id = min(self.entries, key=lambda k: self.entries[k].timestamp) + del self.entries[oldest_id] + self.entries[entry.log_id] = entry + + def get(self, log_id: str) -> LogEntry | None: + """Get a log entry by ID.""" + return self.entries.get(log_id) + + def list_recent(self, limit: int = 20) -> list[LogEntry]: + """List recent log entries.""" + sorted_entries = sorted( + self.entries.values(), + key=lambda e: e.timestamp, + reverse=True, + ) + return sorted_entries[:limit] + + +# Global log store +log_store = LogStore() + + +@dataclass +class Session: + """Represents an active Claude Code session.""" + + session_id: str + cwd: Path + project_name: str | None = None + cancelled: bool = False + claude_session_id: str | None = None # The actual Claude SDK session ID + + +class SessionManager: + """Manages active Claude Code sessions.""" + + def __init__(self) -> None: + """Initialize the session manager.""" + self.sessions: dict[str, Session] = {} + self._project_sessions: dict[str, str] = {} # project_name -> session_id + self._cancel_events: dict[str, asyncio.Event] = {} + + def create_session(self, cwd: str = ".", project_name: str | None = None) -> Session: + """Create a new session.""" + session_id = str(uuid.uuid4()) + session = Session( + session_id=session_id, + cwd=Path(cwd).resolve(), + project_name=project_name, + ) + self.sessions[session_id] = session + self._cancel_events[session_id] = asyncio.Event() + if project_name: + self._project_sessions[project_name] = session_id + LOGGER.info( + "Created session %s with cwd=%s project=%s", + session_id, + session.cwd, + project_name, + ) + return session + + def get_session(self, session_id: str) -> Session | None: + """Get a session by ID.""" + return self.sessions.get(session_id) + + def get_or_create_project_session(self, project_name: str, cwd: str) -> Session: + """Get existing session for project or create a new one.""" + if project_name in self._project_sessions: + session_id = self._project_sessions[project_name] + session = self.sessions.get(session_id) + if session: + return session + return self.create_session(cwd=cwd, project_name=project_name) + + def cancel_session(self, session_id: str) -> bool: + """Mark a session as cancelled.""" + if session_id in self.sessions: + self.sessions[session_id].cancelled = True + if session_id in self._cancel_events: + self._cancel_events[session_id].set() + return True + return False + + def remove_session(self, session_id: str) -> None: + """Remove a session.""" + session = self.sessions.pop(session_id, None) + if session and session.project_name: + self._project_sessions.pop(session.project_name, None) + self._cancel_events.pop(session_id, None) + + +class ProjectManager: + """Manages named projects and the current/default project.""" + + def __init__(self) -> None: + """Initialize the project manager.""" + self.projects: dict[str, str] = {} # name -> path + self.default_project: str | None = None + self.current_project: str | None = None # Sticky session support + + def configure(self, projects: dict[str, str], default: str | None = None) -> None: + """Configure projects from settings.""" + self.projects = projects + self.default_project = default + if default and not self.current_project: + self.current_project = default + + def get_project_path(self, project_name: str) -> str | None: + """Get the path for a project.""" + return self.projects.get(project_name) + + def resolve_project( + self, + prompt: str, + explicit_project: str | None = None, + ) -> tuple[str, str, str]: + """Resolve which project to use and clean the prompt. + + Returns: (project_name, project_path, cleaned_prompt) + """ + # 1. Check for explicit project parameter + if explicit_project: + path = self.get_project_path(explicit_project) + if path: + self.current_project = explicit_project + return explicit_project, path, prompt + msg = f"Unknown project: {explicit_project}" + raise ValueError(msg) + + # 2. Check for "in {project}," prefix in prompt + match = re.match(r"^[Ii]n\s+([\w-]+)[,:]?\s*(.*)$", prompt) + if match: + project_name = match.group(1).lower() + cleaned_prompt = match.group(2) + path = self.get_project_path(project_name) + if path: + self.current_project = project_name + return project_name, path, cleaned_prompt + + # 3. Use current/sticky project + if self.current_project: + path = self.get_project_path(self.current_project) + if path: + return self.current_project, path, prompt + + # 4. Use default project + if self.default_project: + path = self.get_project_path(self.default_project) + if path: + self.current_project = self.default_project + return self.default_project, path, prompt + + msg = "No project specified and no default project configured" + raise ValueError(msg) + + def switch_project(self, project_name: str) -> str: + """Switch the current project.""" + path = self.get_project_path(project_name) + if not path: + msg = f"Unknown project: {project_name}" + raise ValueError(msg) + self.current_project = project_name + return path + + +# Global managers +session_manager = SessionManager() +project_manager = ProjectManager() + +# FastAPI app +app = FastAPI( + title="Claude Code Remote API", + description="Remote access to Claude Code via Agent SDK", + version="1.0.0", +) + + +@app.on_event("startup") +async def startup_event() -> None: + """Configure project manager from environment variables on startup.""" + # Read projects from environment (set by CLI) + projects_json = os.environ.get("CLAUDE_API_PROJECTS") + default_project = os.environ.get("CLAUDE_API_DEFAULT_PROJECT") + + if projects_json: + try: + projects = json.loads(projects_json) + project_manager.configure(projects, default_project) + LOGGER.info( + "Configured projects: %s (default: %s)", + list(projects.keys()), + default_project, + ) + except json.JSONDecodeError: + LOGGER.warning("Failed to parse CLAUDE_API_PROJECTS environment variable") + + +def _build_options(session: Session) -> ClaudeAgentOptions: + """Build ClaudeAgentOptions for a session.""" + options = ClaudeAgentOptions( + cwd=str(session.cwd), + permission_mode="bypassPermissions", + allowed_tools=DEFAULT_ALLOWED_TOOLS, + ) + if session.claude_session_id: + options.resume = session.claude_session_id + return options + + +def _extract_file_changes(tool_calls: list[dict[str, Any]]) -> list[str]: + """Extract list of changed files from tool calls.""" + files = set() + for call in tool_calls: + name = call.get("name", "") + input_data = call.get("input", {}) + if name in ("Edit", "Write", "MultiEdit"): + file_path = input_data.get("file_path") + if file_path: + files.add(file_path) + elif name == "Bash": + # Check for common file-modifying commands + cmd = input_data.get("command", "") + if any(op in cmd for op in ["mv ", "cp ", "rm ", "touch ", "mkdir "]): + # Can't reliably extract file paths, but note the command type + pass + return sorted(files) + + +@app.get("/health", response_model=HealthResponse) +async def health_check() -> HealthResponse: + """Health check endpoint.""" + return HealthResponse(status="healthy", version="1.0.0") + + +@app.post("/prompt", response_model=SimplePromptResponse) +async def simple_prompt( # noqa: PLR0912 + request: SimplePromptRequest, + req: Request, +) -> SimplePromptResponse: + """Simplified prompt endpoint with automatic project/session management. + + Supports: + - Explicit project parameter: {"prompt": "...", "project": "my-project"} + - Project prefix in prompt: "In my-project, fix the bug" + - Sticky sessions: remembers last used project + - Default project from config + """ + try: + # Resolve project + project_name, project_path, cleaned_prompt = project_manager.resolve_project( + request.prompt, + request.project, + ) + except ValueError as e: + return SimplePromptResponse( + summary="", + files_changed=[], + log_id="", + log_url="", + success=False, + error=str(e), + ) + + # Get or create session for this project + session = session_manager.get_or_create_project_session(project_name, project_path) + + try: + summary = "" + full_response = "" + tool_calls: list[dict[str, Any]] = [] + session.cancelled = False + options = _build_options(session) + + async for message in query(prompt=cleaned_prompt, options=options): + if session.cancelled: + break + + if isinstance(message, SystemMessage): + if message.subtype == "init": + init_session_id = message.data.get("session_id") + if init_session_id: + session.claude_session_id = init_session_id + + elif isinstance(message, ResultMessage): + if message.result: + summary = message.result + + elif isinstance(message, AssistantMessage): + for block in message.content: + if isinstance(block, TextBlock): + full_response += block.text + elif isinstance(block, ToolUseBlock): + tool_calls.append( + { + "name": block.name, + "input": block.input, + }, + ) + + # Extract file changes + files_changed = _extract_file_changes(tool_calls) + + # Generate log entry + log_id = str(uuid.uuid4())[:8] + base_url = str(req.base_url).rstrip("/") + log_url = f"{base_url}/log/{log_id}" + + # Store log entry + log_entry = LogEntry( + log_id=log_id, + project=project_name, + prompt=request.prompt, + summary=summary or full_response[:200], + files_changed=files_changed, + tool_calls=tool_calls, + full_response=full_response, + timestamp=datetime.now(UTC), + success=True, + ) + log_store.add(log_entry) + + return SimplePromptResponse( + summary=summary or full_response[:200], + files_changed=files_changed, + log_id=log_id, + log_url=log_url, + success=True, + ) + + except Exception as e: + LOGGER.exception("Error during Claude Code query") + return SimplePromptResponse( + summary="", + files_changed=[], + log_id="", + log_url="", + success=False, + error=str(e), + ) + + +@app.get("/log/{log_id}", response_class=HTMLResponse) +async def view_log(log_id: str) -> HTMLResponse: + """View log entry details in a web UI.""" + entry = log_store.get(log_id) + if not entry: + raise HTTPException(status_code=404, detail="Log entry not found") + + # Generate HTML page + files_html = ( + "".join(f"
  • {f}
  • " for f in entry.files_changed) or "
  • No files changed
  • " + ) + tools_html = ( + "".join( + f"
  • {t['name']}: {t.get('input', {}).get('file_path', 'N/A')}
  • " + for t in entry.tool_calls + ) + or "
  • No tool calls
  • " + ) + + html = f""" + + + + Claude Code Log - {log_id} + + + + + +
    +
    +

    🤖 Claude Code Log

    + ← All Logs +
    + +
    + + {"❌ Error" if not entry.success else "✅ Success"} + + {entry.project} + {entry.timestamp.strftime("%Y-%m-%d %H:%M:%S UTC")} +
    + +
    +
    +

    📝 Prompt

    +

    {entry.prompt}

    +
    +
    + +
    +
    +

    💬 Summary

    +

    {entry.summary}

    +
    +
    + +
    +
    +
    +

    📁 Files Changed ({len(entry.files_changed)})

    + +
    +
    +
    +
    +

    🔧 Tool Calls ({len(entry.tool_calls)})

    + +
    +
    +
    + +
    + +
    📄 Full Response
    +
    +
    {entry.full_response or "(No text response)"}
    +
    +
    + + {f'
    ❌ {entry.error}
    ' if entry.error else ""} +
    + + + """ + return HTMLResponse(content=html) + + +def _truncate_prompt(prompt: str) -> str: + """Truncate prompt for display.""" + if len(prompt) > PROMPT_TRUNCATE_LENGTH: + return prompt[:PROMPT_TRUNCATE_LENGTH] + "..." + return prompt + + +@app.get("/logs", response_class=HTMLResponse) +async def list_logs() -> HTMLResponse: + """List recent log entries.""" + entries = log_store.list_recent(20) + + rows = ( + "".join( + f""" + {e.log_id} + {e.project} + {_truncate_prompt(e.prompt)} + {len(e.files_changed)} + {"" if e.success else ""} + {e.timestamp.strftime("%H:%M:%S")} + """ + for e in entries + ) + or "No logs yet" + ) + + html = f""" + + + + Claude Code Logs + + + + + +
    +

    🤖 Claude Code Logs

    + +
    + + + + + + + + + + + + {rows} +
    IDProjectPromptFilesStatusTime
    +
    +
    + + + """ + return HTMLResponse(content=html) + + +@app.get("/chat", response_class=HTMLResponse) +async def chat_page() -> HTMLResponse: + """Interactive chat page with streaming support.""" + current = project_manager.current_project or project_manager.default_project or "" + + html = f""" + + + + Claude Code Chat + + + + + + +
    + + + + +
    +
    Loading...
    +
    + + +
    + +
    + + +
    +
    +
    + + + + + """ + return HTMLResponse(content=html) + + +@app.get("/logs/json") +async def list_logs_json(limit: int = 20) -> list[dict[str, Any]]: + """List recent log entries as JSON for chat history.""" + entries = log_store.list_recent(limit) + return [ + { + "log_id": e.log_id, + "project": e.project, + "prompt": e.prompt, + "summary": e.summary, + "files_changed": e.files_changed, + "success": e.success, + "timestamp": e.timestamp.isoformat(), + } + for e in entries + ] + + +def _render_message( + role: str, + content: str, + files_changed: list[str] | None = None, + log_id: str | None = None, +) -> str: + """Render a single chat message as HTML.""" + chat_class = "chat chat-end" if role == "user" else "chat chat-start" + bubble_class = "chat-bubble-primary" if role == "user" else "" + + meta_html = "" + if files_changed: + meta_html += f'
    📁 {", ".join(files_changed)}
    ' + if log_id: + meta_html += f'View details →' + + return f"""
    +
    {content}{meta_html}
    +
    """ + + +@app.get("/chat/messages", response_class=HTMLResponse) +async def chat_messages() -> HTMLResponse: + """Return chat history as HTML fragments for HTMX.""" + entries = log_store.list_recent(20) + if not entries: + return HTMLResponse(content="") + + # Reverse to show oldest first + html_parts = [] + for entry in reversed(entries): + html_parts.append(_render_message("user", entry.prompt)) + html_parts.append( + _render_message( + "assistant", + entry.summary, + entry.files_changed, + entry.log_id, + ), + ) + return HTMLResponse(content="\n".join(html_parts)) + + +class ChatSendRequest(BaseModel): + """Form data for chat send endpoint.""" + + prompt: str + project: str | None = None + + +@app.post("/chat/send", response_class=HTMLResponse) +async def chat_send(request: Request) -> HTMLResponse: + """Send a message and return HTML fragments for HTMX.""" + # Parse form data + form = await request.form() + prompt = str(form.get("prompt", "")).strip() + project = str(form.get("project", "")) or None + + if not prompt: + return HTMLResponse(content="") + + # Render user message immediately + user_html = _render_message("user", prompt) + + try: + # Resolve project + project_name, project_path, cleaned_prompt = project_manager.resolve_project( + prompt, + project, + ) + except ValueError as e: + error_html = _render_message("assistant", f"❌ Error: {e}") + return HTMLResponse(content=user_html + error_html) + + # Get or create session + session = session_manager.get_or_create_project_session(project_name, project_path) + + try: + summary = "" + full_response = "" + tool_calls: list[dict[str, Any]] = [] + session.cancelled = False + options = _build_options(session) + + async for message in query(prompt=cleaned_prompt, options=options): + if session.cancelled: + break + + if isinstance(message, SystemMessage) and message.subtype == "init": + session.claude_session_id = message.data.get("session_id") + elif isinstance(message, ResultMessage) and message.result: + summary = message.result + elif isinstance(message, AssistantMessage): + for block in message.content: + if isinstance(block, TextBlock): + full_response += block.text + elif isinstance(block, ToolUseBlock): + tool_calls.append({"name": block.name, "input": block.input}) + + # Extract file changes and store log + files_changed = _extract_file_changes(tool_calls) + log_id = str(uuid.uuid4())[:8] + + log_entry = LogEntry( + log_id=log_id, + project=project_name, + prompt=prompt, + summary=summary or full_response[:200], + files_changed=files_changed, + tool_calls=tool_calls, + full_response=full_response, + timestamp=datetime.now(UTC), + success=True, + ) + log_store.add(log_entry) + + assistant_html = _render_message( + "assistant", + summary or full_response[:200], + files_changed, + log_id, + ) + return HTMLResponse(content=user_html + assistant_html) + + except Exception as e: + LOGGER.exception("Error during chat send") + error_html = _render_message("assistant", f"❌ Error: {e}") + return HTMLResponse(content=user_html + error_html) + + +def _sse_event(event: str, data: str) -> str: + """Format an SSE event.""" + return f"event: {event}\ndata: {data}\n\n" + + +@app.post("/chat/stream") +async def chat_stream(request: Request) -> StreamingResponse: + """Stream chat response via Server-Sent Events.""" + # Parse form data + form = await request.form() + prompt = str(form.get("prompt", "")).strip() + project = str(form.get("project", "")) or None + + async def generate() -> AsyncGenerator[str, None]: + if not prompt: + return + + # Send user message first + user_html = _render_message("user", prompt) + yield _sse_event("user", user_html) + + try: + project_name, project_path, cleaned_prompt = project_manager.resolve_project( + prompt, + project, + ) + except ValueError as e: + yield _sse_event("error", _render_message("assistant", f"❌ Error: {e}")) + yield _sse_event("done", "") + return + + session = session_manager.get_or_create_project_session( + project_name, + project_path, + ) + + try: + full_response = "" + summary = "" + tool_calls: list[dict[str, Any]] = [] + session.cancelled = False + options = _build_options(session) + streaming_text = "" + + # Send initial empty assistant bubble that we'll update + yield _sse_event( + "start", + '
    ', + ) + + async for message in query(prompt=cleaned_prompt, options=options): + if session.cancelled: + break + + if isinstance(message, SystemMessage) and message.subtype == "init": + session.claude_session_id = message.data.get("session_id") + elif isinstance(message, ResultMessage) and message.result: + summary = message.result + elif isinstance(message, AssistantMessage): + for block in message.content: + if isinstance(block, TextBlock): + streaming_text += block.text + full_response += block.text + # Stream the text chunk + yield _sse_event("chunk", block.text) + elif isinstance(block, ToolUseBlock): + tool_calls.append({"name": block.name, "input": block.input}) + # Show tool usage + yield _sse_event("tool", f"🔧 {block.name}") + + # Save log entry + files_changed = _extract_file_changes(tool_calls) + log_id = str(uuid.uuid4())[:8] + log_entry = LogEntry( + log_id=log_id, + project=project_name, + prompt=prompt, + summary=summary or full_response[:200], + files_changed=files_changed, + tool_calls=tool_calls, + full_response=full_response, + timestamp=datetime.now(UTC), + success=True, + ) + log_store.add(log_entry) + + # Send final complete message with metadata + final_html = _render_message( + "assistant", + summary or full_response[:500] or "(No response)", + files_changed, + log_id, + ) + yield _sse_event("complete", final_html) + yield _sse_event("done", "") + + except Exception as e: + LOGGER.exception("Error during streaming") + yield _sse_event("error", _render_message("assistant", f"❌ Error: {e}")) + yield _sse_event("done", "") + + return StreamingResponse( + generate(), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, + ) + + +@app.post("/transcribe-proxy") +async def transcribe_proxy(request: Request) -> dict[str, Any]: + """Proxy transcription requests to avoid CORS issues.""" + # Get voice server URL from query param or use default + voice_server = request.query_params.get("voice_server", "http://localhost:61337") + + # Forward the multipart form data + body = await request.body() + content_type = request.headers.get("content-type", "") + + try: + async with httpx.AsyncClient(timeout=60.0) as client: + resp = await client.post( + f"{voice_server}/transcribe", + content=body, + headers={"content-type": content_type}, + ) + return resp.json() + except Exception as e: + LOGGER.exception("Transcription proxy error") + return {"error": str(e), "raw_transcript": "", "cleaned_transcript": ""} + + +@app.post("/switch-project") +async def switch_project(project: str) -> dict[str, str]: + """Switch the current/sticky project.""" + try: + path = project_manager.switch_project(project) + return {"status": "switched", "project": project, "path": path} + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) from e + + +@app.get("/projects") +async def list_projects() -> dict[str, Any]: + """List configured projects.""" + return { + "projects": project_manager.projects, + "default": project_manager.default_project, + "current": project_manager.current_project, + } + + +@app.post("/session/new", response_model=NewSessionResponse) +async def create_session(request: NewSessionRequest) -> NewSessionResponse: + """Create a new Claude Code session.""" + session = session_manager.create_session(cwd=request.cwd) + return NewSessionResponse(session_id=session.session_id, status="created") + + +@app.post("/session/{session_id}/prompt", response_model=PromptResponse) +async def send_prompt(session_id: str, request: PromptRequest) -> PromptResponse: + """Send a prompt to Claude Code and get the result.""" + session = session_manager.get_session(session_id) + if not session: + raise HTTPException(status_code=404, detail="Session not found") + + try: + result_text = "" + session.cancelled = False + options = _build_options(session) + + async for message in query(prompt=request.prompt, options=options): + if session.cancelled: + break + + # Handle different message types with proper isinstance checks + if isinstance(message, SystemMessage): + if message.subtype == "init": + # Extract session_id from data dict + init_session_id = message.data.get("session_id") + if init_session_id: + session.claude_session_id = init_session_id + LOGGER.info( + "Captured Claude session ID: %s", + session.claude_session_id, + ) + + elif isinstance(message, ResultMessage): + if message.result: + result_text = message.result + + elif isinstance(message, AssistantMessage): + for block in message.content: + if isinstance(block, TextBlock): + result_text += block.text + + return PromptResponse(result=result_text, success=True) + + except Exception as e: + LOGGER.exception("Error during Claude Code query") + return PromptResponse(result="", success=False, error=str(e)) + + +@app.post("/session/{session_id}/cancel") +async def cancel_session(session_id: str) -> dict[str, str]: + """Cancel the current operation in a session.""" + if session_manager.cancel_session(session_id): + return {"status": "cancelled"} + raise HTTPException(status_code=404, detail="Session not found") + + +@app.websocket("/session/{session_id}/stream") +async def stream_session(websocket: WebSocket, session_id: str) -> None: + """WebSocket endpoint for streaming Claude Code responses.""" + await websocket.accept() + + session = session_manager.get_session(session_id) + if not session: + await websocket.send_json({"type": "error", "error": "Session not found"}) + await websocket.close(code=4004) + return + + try: + while True: + # Wait for prompt from client + data = await websocket.receive_json() + prompt = data.get("prompt", "") + + if not prompt: + await websocket.send_json({"type": "error", "error": "No prompt provided"}) + continue + + session.cancelled = False + options = _build_options(session) + + try: + async for message in query(prompt=prompt, options=options): + if session.cancelled: + await websocket.send_json({"type": "cancelled"}) + break + + # Convert message to JSON with proper type checks + msg_dict = _message_to_dict( + message, + session, + SystemMessage, + ResultMessage, + AssistantMessage, + TextBlock, + ThinkingBlock, + ToolUseBlock, + ToolResultBlock, + ) + if msg_dict: + await websocket.send_json(msg_dict) + + await websocket.send_json({"type": "done"}) + + except Exception as e: + LOGGER.exception("Error during streaming") + await websocket.send_json({"type": "error", "error": str(e)}) + + except WebSocketDisconnect: + LOGGER.info("WebSocket disconnected for session %s", session_id) + except Exception as e: + LOGGER.exception("WebSocket error") + with contextlib.suppress(Exception): + await websocket.send_json({"type": "error", "error": str(e)}) + + +def _message_to_dict( + message: Any, + session: Session, + system_msg_type: type, + result_msg_type: type, + assistant_msg_type: type, + text_block_type: type, + thinking_block_type: type, + tool_use_block_type: type, + tool_result_block_type: type, +) -> dict[str, Any] | None: + """Convert a Claude SDK message to a JSON-serializable dict.""" + if isinstance(message, system_msg_type): + if message.subtype == "init": # type: ignore[attr-defined] + init_session_id = message.data.get("session_id") # type: ignore[attr-defined] + if init_session_id: + session.claude_session_id = init_session_id + return {"type": "init", "session_id": init_session_id} + return None + + if isinstance(message, result_msg_type): + return { + "type": "result", + "subtype": message.subtype, # type: ignore[attr-defined] + "result": message.result or "", # type: ignore[attr-defined] + } + + if isinstance(message, assistant_msg_type): + blocks = [] + for block in message.content: # type: ignore[attr-defined] + if isinstance(block, text_block_type): + blocks.append({"type": "text", "text": block.text}) # type: ignore[attr-defined] + elif isinstance(block, thinking_block_type): + blocks.append({"type": "thinking", "thinking": block.thinking}) # type: ignore[attr-defined] + elif isinstance(block, tool_use_block_type): + blocks.append( + { + "type": "tool_use", + "id": block.id, # type: ignore[attr-defined] + "name": block.name, # type: ignore[attr-defined] + "input": block.input, # type: ignore[attr-defined] + }, + ) + elif isinstance(block, tool_result_block_type): + blocks.append( + { + "type": "tool_result", + "tool_use_id": block.tool_use_id, # type: ignore[attr-defined] + "content": block.content, # type: ignore[attr-defined] + }, + ) + if blocks: + return {"type": "assistant", "content": blocks} + return None + + return None diff --git a/agent_cli/cli.py b/agent_cli/cli.py index 91ab4e8ea..a6db549cf 100644 --- a/agent_cli/cli.py +++ b/agent_cli/cli.py @@ -75,6 +75,7 @@ def set_config_defaults(ctx: typer.Context, config_file: str | None) -> None: assistant, autocorrect, chat, + claude_serve, memory, rag_proxy, server, diff --git a/agent_cli/config.py b/agent_cli/config.py index 981956143..15c53a333 100644 --- a/agent_cli/config.py +++ b/agent_cli/config.py @@ -171,6 +171,30 @@ class WakeWord(BaseModel): wake_word: str +# --- Panel: Claude Server Options --- + + +class ClaudeServer(BaseModel): + """Configuration for the Claude Code remote server.""" + + host: str = "0.0.0.0" # noqa: S104 + port: int = 8765 + permission_mode: Literal["default", "acceptEdits", "bypassPermissions"] = "bypassPermissions" + allowed_tools: list[str] = [ + "Read", + "Write", + "Edit", + "Bash", + "Glob", + "Grep", + "WebSearch", + "WebFetch", + ] + # Named projects: {"project-name": "/path/to/project"} + projects: dict[str, str] = {} + default_project: str | None = None + + # --- Panel: General Options --- diff --git a/iOS_Shortcut_Guide.md b/iOS_Shortcut_Guide.md index 36c796905..09787fe7c 100644 --- a/iOS_Shortcut_Guide.md +++ b/iOS_Shortcut_Guide.md @@ -241,6 +241,191 @@ extra_instructions: (optional) - **Access Control**: Consider adding authentication to your API - **Firewall**: Only expose necessary ports +--- + +# iOS Shortcut Setup for Claude Code Remote Access + +This section shows how to create an iOS Shortcut that sends prompts to Claude Code running on your server. + +## Prerequisites + +1. **Claude Code Installed and Authenticated**: Run `claude /login` on your server +2. **Agent CLI with Claude extras**: `pip install agent-cli[claude]` +3. **Network Access**: Your iPhone needs network access to reach the server + +## Setup Claude Code Server + +1. Start the server: + ```bash + agent-cli claude-serve --host 0.0.0.0 --port 8765 --cwd /path/to/your/project + ``` + +2. Test the server is working: + ```bash + curl http://your-server-ip:8765/health + ``` + +## Create iOS Shortcut for Claude Code + +### Step 1: Create New Shortcut +- Open the **Shortcuts** app on your iPhone +- Tap the **+** button to create a new shortcut + +### Step 2: Add Actions + +**Action 1: Ask for Input (Optional)** +1. Search for and add **"Ask for Input"** action +2. Configure: + - **Question**: "What would you like Claude to do?" + - **Input Type**: Text + +**Action 2: Create Session** +1. Search for and add **"Get Contents of URL"** action +2. Configure: + - **URL**: `http://YOUR_SERVER_IP:8765/session/new` + - **Method**: POST + - **Request Body**: JSON + - Add field: `cwd` with value `/path/to/project` + +**Action 3: Get Session ID** +1. Add **"Get Dictionary Value"** action +2. Configure: + - **Dictionary**: Output from previous step + - **Get Value for**: `session_id` +3. Add **"Set Variable"** action, name it `SessionID` + +**Action 4: Send Prompt** +1. Add another **"Get Contents of URL"** action +2. Configure: + - **URL**: `http://YOUR_SERVER_IP:8765/session/` then insert SessionID variable then `/prompt` + - **Method**: POST + - **Request Body**: JSON + - Add field: `prompt` with the input from Action 1 + +**Action 5: Get Result** +1. Add **"Get Dictionary Value"** action +2. Configure: + - **Dictionary**: Output from previous step + - **Get Value for**: `result` + +**Action 6: Copy and Notify** +1. Add **"Copy to Clipboard"** action +2. Add **"Show Result"** action to display Claude's response + +### Simple One-Shot Shortcut (Alternative) + +For a simpler shortcut that doesn't maintain sessions: + +1. **Ask for Input**: Get the prompt from user +2. **Get Contents of URL**: POST to `/session/new` with `{"cwd": "."}` +3. **Get Dictionary Value**: Extract `session_id` +4. **Set Variable**: Store as `SessionID` +5. **Get Contents of URL**: POST to `/session/{SessionID}/prompt` with `{"prompt": "..."}` +6. **Get Dictionary Value**: Extract `result` +7. **Show Result**: Display Claude's response + +## Claude Code API Reference + +### POST /session/new + +Create a new Claude Code session. + +**Request:** +```json +{ + "cwd": "/path/to/project" +} +``` + +**Response:** +```json +{ + "session_id": "uuid-here", + "status": "created" +} +``` + +### POST /session/{session_id}/prompt + +Send a prompt and get the result. + +**Request:** +```json +{ + "prompt": "Find and fix bugs in auth.py" +} +``` + +**Response:** +```json +{ + "result": "I analyzed auth.py and found 2 bugs...", + "success": true, + "error": null +} +``` + +### POST /session/{session_id}/cancel + +Cancel the current operation. + +**Response:** +```json +{ + "status": "cancelled" +} +``` + +### WebSocket /session/{session_id}/stream + +Real-time streaming endpoint for watching Claude work. + +**Client sends:** +```json +{"prompt": "Your prompt here"} +``` + +**Server streams:** +```json +{"type": "text", "content": "Looking at the code..."} +{"type": "tool_call", "name": "Read", "input": {...}} +{"type": "result", "result": "Done!"} +{"type": "done"} +``` + +### GET /health + +Health check endpoint. + +**Response:** +```json +{ + "status": "healthy", + "version": "1.0.0" +} +``` + +## Configuration + +Add to your `~/.config/agent-cli/config.toml`: + +```toml +[claude-serve] +host = "0.0.0.0" +port = 8765 +permission-mode = "bypassPermissions" +allowed-tools = ["Read", "Write", "Edit", "Bash", "Glob", "Grep", "WebSearch", "WebFetch"] +``` + +## Security Considerations + +- **Authentication**: The server has no authentication by default - only expose on trusted networks +- **Permission Mode**: `bypassPermissions` allows Claude to run any command - use with caution +- **Network**: Use VPN or SSH tunneling for secure remote access +- **HTTPS**: Set up a reverse proxy with SSL for production use + +--- + ## Next Steps - Set up HTTPS with SSL certificates for production use diff --git a/pyproject.toml b/pyproject.toml index 35a453ae8..6dc7cdd2c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,6 +74,10 @@ dev = [ "notebook", ] speed = ["audiostretchy>=1.3.0"] +claude = [ + "fastapi[standard]", + "claude-agent-sdk>=0.1.0", +] # Duplicate of test+dev optional-dependencies groups [dependency-groups] @@ -165,6 +169,7 @@ ignore = [ ".github/*" = ["INP001"] "example/*" = ["INP001", "D100"] "docs/*" = ["INP001", "E501"] +"agent_cli/claude_api.py" = ["S608"] # HTML templates, not SQL [tool.ruff.lint.mccabe] max-complexity = 18 diff --git a/tests/test_claude_api.py b/tests/test_claude_api.py new file mode 100644 index 000000000..184d5b28c --- /dev/null +++ b/tests/test_claude_api.py @@ -0,0 +1,409 @@ +"""Tests for the Claude Code remote API.""" + +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING, Any +from unittest.mock import MagicMock, patch + +import pytest +from fastapi.testclient import TestClient + +if TYPE_CHECKING: + from collections.abc import Generator + + +@pytest.fixture +def mock_claude_sdk() -> Generator[tuple[MagicMock, MagicMock], None, None]: + """Mock the claude_agent_sdk module.""" + mock_sdk = MagicMock() + mock_sdk.ClaudeAgentOptions = MagicMock + + # Mock types + mock_types = MagicMock() + mock_types.SystemMessage = type("SystemMessage", (), {}) + mock_types.AssistantMessage = type("AssistantMessage", (), {}) + mock_types.ResultMessage = type("ResultMessage", (), {}) + mock_types.TextBlock = type("TextBlock", (), {}) + mock_types.ThinkingBlock = type("ThinkingBlock", (), {}) + mock_types.ToolUseBlock = type("ToolUseBlock", (), {}) + mock_types.ToolResultBlock = type("ToolResultBlock", (), {}) + + with ( + patch.dict("sys.modules", {"claude_agent_sdk": mock_sdk}), + patch.dict("sys.modules", {"claude_agent_sdk.types": mock_types}), + ): + yield mock_sdk, mock_types + + +@pytest.fixture +def client(mock_claude_sdk: Any) -> TestClient: # noqa: ARG001 + """Create a test client for the Claude API app.""" + from agent_cli.claude_api import app # noqa: PLC0415 + + return TestClient(app) + + +class TestHealthEndpoint: + """Tests for the health check endpoint.""" + + def test_health_check(self, client: TestClient) -> None: + """Test the health check endpoint returns healthy status.""" + response = client.get("/health") + assert response.status_code == 200 + data = response.json() + assert data["status"] == "healthy" + assert data["version"] == "1.0.0" + + +class TestSessionManagement: + """Tests for session creation and management.""" + + def test_create_session_default_cwd(self, client: TestClient) -> None: + """Test creating a session with default working directory.""" + response = client.post("/session/new", json={}) + assert response.status_code == 200 + data = response.json() + assert "session_id" in data + assert data["status"] == "created" + + def test_create_session_custom_cwd(self, client: TestClient) -> None: + """Test creating a session with custom working directory.""" + response = client.post("/session/new", json={"cwd": "/tmp"}) # noqa: S108 + assert response.status_code == 200 + data = response.json() + assert "session_id" in data + assert data["status"] == "created" + + def test_cancel_nonexistent_session(self, client: TestClient) -> None: + """Test cancelling a session that doesn't exist.""" + response = client.post("/session/nonexistent-id/cancel") + assert response.status_code == 404 + assert "Session not found" in response.json()["detail"] + + def test_cancel_existing_session(self, client: TestClient) -> None: + """Test cancelling an existing session.""" + # First create a session + create_response = client.post("/session/new", json={}) + session_id = create_response.json()["session_id"] + + # Then cancel it + cancel_response = client.post(f"/session/{session_id}/cancel") + assert cancel_response.status_code == 200 + assert cancel_response.json()["status"] == "cancelled" + + +class TestPromptEndpoint: + """Tests for the prompt endpoint.""" + + def test_prompt_nonexistent_session(self, client: TestClient) -> None: + """Test sending prompt to nonexistent session.""" + response = client.post( + "/session/nonexistent-id/prompt", + json={"prompt": "Hello"}, + ) + assert response.status_code == 404 + assert "Session not found" in response.json()["detail"] + + +class TestSessionManagerUnit: + """Unit tests for the SessionManager class.""" + + def test_create_session(self, mock_claude_sdk: Any) -> None: # noqa: ARG002 + """Test session creation.""" + from agent_cli.claude_api import SessionManager # noqa: PLC0415 + + manager = SessionManager() + session = manager.create_session(cwd="/tmp") # noqa: S108 + + assert session.session_id is not None + assert str(session.cwd) == "/tmp" # noqa: S108 + assert session.cancelled is False + assert session.claude_session_id is None + + def test_get_session(self, mock_claude_sdk: Any) -> None: # noqa: ARG002 + """Test retrieving a session.""" + from agent_cli.claude_api import SessionManager # noqa: PLC0415 + + manager = SessionManager() + created = manager.create_session() + + retrieved = manager.get_session(created.session_id) + assert retrieved is not None + assert retrieved.session_id == created.session_id + + def test_get_nonexistent_session(self, mock_claude_sdk: Any) -> None: # noqa: ARG002 + """Test retrieving a session that doesn't exist.""" + from agent_cli.claude_api import SessionManager # noqa: PLC0415 + + manager = SessionManager() + assert manager.get_session("nonexistent") is None + + def test_cancel_session(self, mock_claude_sdk: Any) -> None: # noqa: ARG002 + """Test cancelling a session.""" + from agent_cli.claude_api import SessionManager # noqa: PLC0415 + + manager = SessionManager() + session = manager.create_session() + + assert manager.cancel_session(session.session_id) is True + assert session.cancelled is True + + def test_cancel_nonexistent_session(self, mock_claude_sdk: Any) -> None: # noqa: ARG002 + """Test cancelling a session that doesn't exist.""" + from agent_cli.claude_api import SessionManager # noqa: PLC0415 + + manager = SessionManager() + assert manager.cancel_session("nonexistent") is False + + def test_remove_session(self, mock_claude_sdk: Any) -> None: # noqa: ARG002 + """Test removing a session.""" + from agent_cli.claude_api import SessionManager # noqa: PLC0415 + + manager = SessionManager() + session = manager.create_session() + session_id = session.session_id + + manager.remove_session(session_id) + assert manager.get_session(session_id) is None + + +class TestBuildOptions: + """Tests for the _build_options helper.""" + + def test_build_options_without_resume(self, mock_claude_sdk: Any) -> None: # noqa: ARG002 + """Test building options for a new session without existing Claude session.""" + from agent_cli.claude_api import Session, _build_options # noqa: PLC0415 + + session = Session(session_id="test-id", cwd=Path("/tmp")) # noqa: S108 + # Session has no claude_session_id, so resume should not be set + assert session.claude_session_id is None + + options = _build_options(session) + assert options.cwd == "/tmp" # noqa: S108 + assert options.permission_mode == "bypassPermissions" + # With MagicMock, we can't easily check resume wasn't set, + # but we verify the logic path via session.claude_session_id being None + + def test_build_options_with_resume(self, mock_claude_sdk: Any) -> None: # noqa: ARG002 + """Test building options for an existing Claude session.""" + from agent_cli.claude_api import Session, _build_options # noqa: PLC0415 + + session = Session( + session_id="test-id", + cwd=Path("/tmp"), # noqa: S108 + claude_session_id="claude-session-123", + ) + options = _build_options(session) + + assert options.resume == "claude-session-123" + + +class TestDefaultAllowedTools: + """Tests for the default allowed tools constant.""" + + def test_default_tools_defined(self, mock_claude_sdk: Any) -> None: # noqa: ARG002 + """Test that default allowed tools are defined.""" + from agent_cli.claude_api import DEFAULT_ALLOWED_TOOLS # noqa: PLC0415 + + assert isinstance(DEFAULT_ALLOWED_TOOLS, list) + assert len(DEFAULT_ALLOWED_TOOLS) > 0 + + def test_default_tools_contains_essentials(self, mock_claude_sdk: Any) -> None: # noqa: ARG002 + """Test that essential tools are in the default list.""" + from agent_cli.claude_api import DEFAULT_ALLOWED_TOOLS # noqa: PLC0415 + + essential_tools = ["Read", "Write", "Edit", "Bash"] + for tool in essential_tools: + assert tool in DEFAULT_ALLOWED_TOOLS + + +class TestProjectManager: + """Tests for the ProjectManager class.""" + + def test_configure_projects(self, mock_claude_sdk: Any) -> None: # noqa: ARG002 + """Test configuring projects.""" + from agent_cli.claude_api import ProjectManager # noqa: PLC0415 + + manager = ProjectManager() + manager.configure( + {"proj1": "/path/to/proj1", "proj2": "/path/to/proj2"}, + default="proj1", + ) + + assert manager.projects == {"proj1": "/path/to/proj1", "proj2": "/path/to/proj2"} + assert manager.default_project == "proj1" + assert manager.current_project == "proj1" + + def test_get_project_path(self, mock_claude_sdk: Any) -> None: # noqa: ARG002 + """Test getting project path.""" + from agent_cli.claude_api import ProjectManager # noqa: PLC0415 + + manager = ProjectManager() + manager.configure({"myproject": "/path/to/myproject"}) + + assert manager.get_project_path("myproject") == "/path/to/myproject" + assert manager.get_project_path("unknown") is None + + def test_resolve_project_explicit(self, mock_claude_sdk: Any) -> None: # noqa: ARG002 + """Test resolving project with explicit parameter.""" + from agent_cli.claude_api import ProjectManager # noqa: PLC0415 + + manager = ProjectManager() + manager.configure({"proj1": "/path/1", "proj2": "/path/2"}, default="proj1") + + name, path, prompt = manager.resolve_project("do something", explicit_project="proj2") + assert name == "proj2" + assert path == "/path/2" + assert prompt == "do something" + + def test_resolve_project_from_prompt_prefix(self, mock_claude_sdk: Any) -> None: # noqa: ARG002 + """Test resolving project from 'in {project},' prefix in prompt.""" + from agent_cli.claude_api import ProjectManager # noqa: PLC0415 + + manager = ProjectManager() + manager.configure({"myproject": "/path/to/myproject"}) + + name, path, prompt = manager.resolve_project("in myproject, fix the bug") + assert name == "myproject" + assert path == "/path/to/myproject" + assert prompt == "fix the bug" + + def test_resolve_project_sticky(self, mock_claude_sdk: Any) -> None: # noqa: ARG002 + """Test sticky project session.""" + from agent_cli.claude_api import ProjectManager # noqa: PLC0415 + + manager = ProjectManager() + manager.configure({"proj1": "/path/1", "proj2": "/path/2"}) + + # First call sets current project + manager.resolve_project("task", explicit_project="proj1") + assert manager.current_project == "proj1" + + # Second call uses sticky project + name, _path, _prompt = manager.resolve_project("another task") + assert name == "proj1" + + def test_switch_project(self, mock_claude_sdk: Any) -> None: # noqa: ARG002 + """Test switching projects.""" + from agent_cli.claude_api import ProjectManager # noqa: PLC0415 + + manager = ProjectManager() + manager.configure({"proj1": "/path/1", "proj2": "/path/2"}, default="proj1") + + path = manager.switch_project("proj2") + assert path == "/path/2" + assert manager.current_project == "proj2" + + +class TestLogStore: + """Tests for the LogStore class.""" + + def test_add_and_get_log(self, mock_claude_sdk: Any) -> None: # noqa: ARG002 + """Test adding and retrieving a log entry.""" + from datetime import UTC, datetime # noqa: PLC0415 + + from agent_cli.claude_api import LogEntry, LogStore # noqa: PLC0415 + + store = LogStore() + entry = LogEntry( + log_id="test123", + project="myproject", + prompt="fix bug", + summary="Fixed the bug", + files_changed=["file.py"], + tool_calls=[], + full_response="I fixed it", + timestamp=datetime.now(UTC), + success=True, + ) + store.add(entry) + + retrieved = store.get("test123") + assert retrieved is not None + assert retrieved.log_id == "test123" + assert retrieved.project == "myproject" + + def test_get_nonexistent_log(self, mock_claude_sdk: Any) -> None: # noqa: ARG002 + """Test getting a log that doesn't exist.""" + from agent_cli.claude_api import LogStore # noqa: PLC0415 + + store = LogStore() + assert store.get("nonexistent") is None + + def test_list_recent(self, mock_claude_sdk: Any) -> None: # noqa: ARG002 + """Test listing recent log entries.""" + from datetime import UTC, datetime, timedelta # noqa: PLC0415 + + from agent_cli.claude_api import LogEntry, LogStore # noqa: PLC0415 + + store = LogStore() + now = datetime.now(UTC) + + for i in range(5): + entry = LogEntry( + log_id=f"log{i}", + project="proj", + prompt=f"prompt {i}", + summary=f"summary {i}", + files_changed=[], + tool_calls=[], + full_response="", + timestamp=now + timedelta(minutes=i), + success=True, + ) + store.add(entry) + + recent = store.list_recent(3) + assert len(recent) == 3 + # Most recent first + assert recent[0].log_id == "log4" + + +class TestHelperFunctions: + """Tests for helper functions.""" + + def test_extract_file_changes(self, mock_claude_sdk: Any) -> None: # noqa: ARG002 + """Test extracting file changes from tool calls.""" + from agent_cli.claude_api import _extract_file_changes # noqa: PLC0415 + + tool_calls = [ + {"name": "Edit", "input": {"file_path": "/path/to/file1.py"}}, + {"name": "Write", "input": {"file_path": "/path/to/file2.py"}}, + {"name": "Read", "input": {"file_path": "/path/to/file3.py"}}, # Not a change + {"name": "Edit", "input": {"file_path": "/path/to/file1.py"}}, # Duplicate + ] + + files = _extract_file_changes(tool_calls) + assert files == ["/path/to/file1.py", "/path/to/file2.py"] + + def test_truncate_prompt(self, mock_claude_sdk: Any) -> None: # noqa: ARG002 + """Test prompt truncation.""" + from agent_cli.claude_api import PROMPT_TRUNCATE_LENGTH, _truncate_prompt # noqa: PLC0415 + + short = "short prompt" + assert _truncate_prompt(short) == short + + long = "x" * (PROMPT_TRUNCATE_LENGTH + 10) + truncated = _truncate_prompt(long) + assert truncated.endswith("...") + assert len(truncated) == PROMPT_TRUNCATE_LENGTH + 3 + + +class TestNewEndpoints: + """Tests for new endpoints.""" + + def test_projects_endpoint(self, client: TestClient, mock_claude_sdk: Any) -> None: # noqa: ARG002 + """Test the /projects endpoint.""" + response = client.get("/projects") + assert response.status_code == 200 + data = response.json() + assert "projects" in data + assert "default" in data + assert "current" in data + + def test_logs_endpoint_empty(self, client: TestClient, mock_claude_sdk: Any) -> None: # noqa: ARG002 + """Test the /logs endpoint when empty.""" + response = client.get("/logs") + assert response.status_code == 200 + assert "No logs yet" in response.text