diff --git a/src/bot/builder_handlers.py b/src/bot/builder_handlers.py new file mode 100644 index 000000000..ab7558699 --- /dev/null +++ b/src/bot/builder_handlers.py @@ -0,0 +1,46 @@ +"""Telegram /builder commands for Dex Phase 2 Builder.""" +from __future__ import annotations +import json, os +from pathlib import Path +from telegram import Update +from telegram.ext import ContextTypes + +def _data_dir() -> Path: + return Path(os.environ.get("BUILDER_DATA_DIR", r"C:\Users\odral\data\builder")) + +def _state() -> dict: + p = _data_dir() / "state.json" + return json.loads(p.read_text(encoding="utf-8")) if p.exists() else {"stage": "IDLE"} + +async def handle_builder_status(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + s = _state() + if s.get("stage", "IDLE") == "IDLE": + await update.message.reply_text("Builder idle. No ticket in flight.") + return + await update.message.reply_text( + f"Builder: {s.get('ticket_id')} at {s['stage']} " + f"(tick {s.get('ticks_used',0)}/10, attempts {s.get('implement_attempts',0)})") + +async def handle_builder_kill(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + if not context.args: + await update.message.reply_text("Usage: /builder kill "); return + p = _data_dir() / "state.json" + s = _state(); s["kill_requested"] = context.args[0] + p.write_text(json.dumps(s, indent=2), encoding="utf-8") + await update.message.reply_text(f"Kill requested for {context.args[0]}. Builder stops at next tick; branch preserved.") + +async def handle_builder_queue(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + q = sorted((_data_dir() / "queue").glob("*.md")) + inflight = _state().get("ticket_id") + lines = [f"In flight: {inflight or 'none'}"] + [f"Queued: {p.stem}" for p in q] + await update.message.reply_text("\n".join(lines) if lines else "Queue empty.") + +async def handle_builder(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Router for `/builder ` — sub = status (default) | kill | queue.""" + sub = (context.args[0].lower() if context.args else "status") + if sub == "kill": + context.args = context.args[1:] + return await handle_builder_kill(update, context) + if sub == "queue": + return await handle_builder_queue(update, context) + return await handle_builder_status(update, context) diff --git a/src/bot/dex_handlers.py b/src/bot/dex_handlers.py new file mode 100644 index 000000000..82c6174f0 --- /dev/null +++ b/src/bot/dex_handlers.py @@ -0,0 +1,307 @@ +"""Telegram command handlers for Dex's async decision queue. + +Each command appends a ``## Resolution`` block to the matching +``pending_decisions/.md`` file in John's Obsidian vault. Dex picks up +resolved files on its next scheduled tick. + +The ``fire-now`` variant of ``/yes`` additionally shells out to ``claude -p`` +to invoke the ``update_scheduled_task`` MCP tool so Dex fires within ~1 minute +instead of waiting for the next cron tick. +""" + +from __future__ import annotations + +import asyncio +import os +import shutil +from datetime import datetime +from pathlib import Path + +from telegram import Update +from telegram.ext import ContextTypes + +CUSTOM_VERBS = {"revive", "archive", "fold", "delete", "pause", "resume"} + +# Plain-English meaning + inverse-verb hint for every confirmation reply. +# Keys are the verb passed to ``_compose_echo``. +_VERB_INFO: dict[str, dict[str, str]] = { + "yes_approval": { + "meaning": "KEEP ACTIVE", + "next": "Dex applies next tick.", + "reverse": "/no to snooze 7d · /pause · /archive ", + }, + "yes_dispatch": { + "meaning": "DISPATCH (Writer/Coder/Researcher will run)", + "next": "Dex applies next tick.", + "reverse": "/no to skip", + }, + "yes_dispatch_now": { + "meaning": "DISPATCH IMMEDIATELY (fires within 1 min)", + "next": "Dex fires within 1 min.", + "reverse": "/no to skip (must reply before tick fires)", + }, + "no": { + "meaning": "SKIP (7-day cooldown)", + "next": "Dex applies next tick.", + "reverse": "/yes to reverse", + }, + "pause": { + "meaning": "PAUSE (flip status to paused)", + "next": "Dex applies next tick.", + "reverse": "/yes to keep active · /revive ", + }, + "revive": { + "meaning": "REVIVE (flip status to active)", + "next": "Dex applies next tick.", + "reverse": "/pause to re-pause", + }, + "archive": { + "meaning": "ARCHIVE (move to 11 Archive/)", + "next": "Dex applies next tick.", + "reverse": "/yes to keep active (cannot undo archive easily)", + }, + "fold": { + "meaning": "FOLD (inline + delete source file)", + "next": "Dex applies next tick.", + "reverse": "/no to skip (cannot undo)", + }, + "delete": { + "meaning": "DELETE", + "next": "Dex applies next tick.", + "reverse": "/no to skip (cannot undo)", + }, + "resume": { + "meaning": "RESUME (flip status to active)", + "next": "Dex applies next tick.", + "reverse": "/pause to re-pause", + }, +} + + +def _read_decision_meta(path: Path) -> dict: + """Parse the YAML frontmatter of a decision file. + + Simple line-by-line parser — no external deps. Returns a dict with at + least ``project``, ``agent``, ``type``, and ``task`` keys if present. + Missing fields are simply absent from the returned dict. + """ + meta: dict[str, str] = {} + if not path.exists(): + return meta + try: + text = path.read_text(encoding="utf-8") + except OSError: + return meta + lines = text.splitlines() + # Frontmatter must start at the very first line with ``---``. + if not lines or lines[0].strip() != "---": + return meta + in_frontmatter = True + body_lines: list[str] = [] + for line in lines[1:]: + if in_frontmatter: + if line.strip() == "---": + in_frontmatter = False + continue + if ":" in line: + key, _, value = line.partition(":") + meta[key.strip()] = value.strip() + else: + body_lines.append(line) + # Capture the first non-empty line under ``## Task`` for dispatch context. + for idx, line in enumerate(body_lines): + if line.strip().lower() == "## task": + for follow in body_lines[idx + 1 :]: + if follow.strip(): + meta["task"] = follow.strip()[:80] + break + break + return meta + + +def _decision_context(meta: dict) -> str: + """Return a short human label for the decision (project, or agent + task).""" + dtype = (meta.get("type") or "").lower() + if dtype == "dispatch": + agent = meta.get("agent") or "" + task = meta.get("task") + if task: + return f"{agent}: {task}" + return agent + project = meta.get("project") + if project: + return project + return "" + + +def _compose_echo( + verb: str, decision_id: str, meta: dict, fire_now: bool = False +) -> str: + """Build the confirmation reply for a resolved decision. + + ``verb`` is the bare command verb (e.g. ``"yes"``, ``"no"``, ``"pause"``). + For ``yes`` the dispatch vs approval branch (and ``/now`` variant) is + chosen from ``meta`` and ``fire_now``. + """ + if verb == "yes": + dtype = (meta.get("type") or "").lower() + if dtype == "dispatch": + key = "yes_dispatch_now" if fire_now else "yes_dispatch" + else: + key = "yes_approval" + else: + key = verb + info = _VERB_INFO.get(key) + if info is None: + # Defensive fallback — unknown verb, keep terse but still informative. + return ( + f"{decision_id} {_decision_context(meta)} " + f"→ {verb.upper()}. Dex applies next tick." + ) + context_label = _decision_context(meta) + reverse_hint = info["reverse"].replace("", decision_id) + return ( + f"{decision_id} {context_label} → {info['meaning']}. " + f"{info['next']}\n" + f"↩ If wrong: {reverse_hint}" + ) + + +def _pending_dir() -> Path: + """Return the directory holding pending-decision markdown files. + + Honours the ``DEX_PENDING_DIR`` env var for tests; falls back to the + canonical vault path otherwise. + """ + env = os.environ.get("DEX_PENDING_DIR") + if env: + return Path(env) + return Path( + r"C:\Users\odral\Documents\Obsidian\John Gallardo\pending_decisions" + ) + + +def _decision_path(decision_id: str) -> Path: + return _pending_dir() / f"{decision_id}.md" + + +def _append_resolution( + decision_id: str, status: str, fire_now: bool = False +) -> bool: + """Append a ``## Resolution`` block to the decision file. + + Returns ``True`` on success, ``False`` if the file does not exist. + """ + path = _decision_path(decision_id) + if not path.exists(): + return False + block = [ + "", + "## Resolution", + f"status: {status}", + f"resolved-at: {datetime.now().strftime('%Y-%m-%d %H:%M')}", + ] + if fire_now: + block.append("fire-now: true") + with path.open("a", encoding="utf-8") as f: + f.write("\n".join(block) + "\n") + return True + + +async def _fire_now_dex() -> None: + """Fire Dex within ~1 min via the ``update_scheduled_task`` MCP tool. + + Spawns a one-shot ``claude -p`` process; we don't care about its output, + only that it returns within a reasonable timeout. + """ + claude_cli = shutil.which("claude") or r"C:\Users\odral\.local\bin\claude.exe" + prompt = ( + "Use the update_scheduled_task tool to update the scheduled task " + "named 'dex' so its next fire time is within the next minute. " + "After updating, exit with no further output." + ) + proc = await asyncio.create_subprocess_exec( + claude_cli, + "-p", + prompt, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + try: + await asyncio.wait_for(proc.communicate(), timeout=60) + except asyncio.TimeoutError: + proc.kill() + + +async def handle_yes( + update: Update, context: ContextTypes.DEFAULT_TYPE +) -> None: + """Handle ``/yes [/now]`` — mark a pending decision as resolved-yes.""" + if not context.args: + await update.message.reply_text("Usage: /yes [/now]") + return + decision_id = context.args[0] + fire_now = ( + len(context.args) > 1 and context.args[1].lower() == "/now" + ) + path = _decision_path(decision_id) + meta = _read_decision_meta(path) + ok = _append_resolution(decision_id, "resolved-yes", fire_now=fire_now) + if not ok: + await update.message.reply_text(f"Decision {decision_id} not found.") + return + if fire_now: + await _fire_now_dex() + await update.message.reply_text( + _compose_echo("yes", decision_id, meta, fire_now=fire_now) + ) + + +async def handle_no( + update: Update, context: ContextTypes.DEFAULT_TYPE +) -> None: + """Handle ``/no `` — mark a pending decision as resolved-no.""" + if not context.args: + await update.message.reply_text("Usage: /no ") + return + decision_id = context.args[0] + path = _decision_path(decision_id) + meta = _read_decision_meta(path) + ok = _append_resolution(decision_id, "resolved-no") + if not ok: + await update.message.reply_text(f"Decision {decision_id} not found.") + return + await update.message.reply_text( + _compose_echo("no", decision_id, meta) + ) + + +async def handle_custom_verb( + verb: str, update: Update, context: ContextTypes.DEFAULT_TYPE +) -> None: + """Generic handler for project-lifecycle verbs (revive/archive/etc).""" + if not context.args: + await update.message.reply_text(f"Usage: /{verb} ") + return + decision_id = context.args[0] + path = _decision_path(decision_id) + meta = _read_decision_meta(path) + ok = _append_resolution(decision_id, f"resolved-{verb}") + if not ok: + await update.message.reply_text(f"Decision {decision_id} not found.") + return + await update.message.reply_text( + _compose_echo(verb, decision_id, meta) + ) + + +def make_verb_handler(verb: str): + """Build a CommandHandler-compatible coroutine bound to a specific verb.""" + + async def _h( + update: Update, context: ContextTypes.DEFAULT_TYPE + ) -> None: + await handle_custom_verb(verb, update, context) + + _h.__name__ = f"handle_{verb}" + return _h diff --git a/src/bot/handlers/callback.py b/src/bot/handlers/callback.py index 66dd660c4..e7ed002de 100644 --- a/src/bot/handlers/callback.py +++ b/src/bot/handlers/callback.py @@ -12,6 +12,7 @@ from ...security.audit import AuditLogger from ...security.validators import SecurityValidator from ..utils.html_format import escape_html +from .command import _handle_model_selection logger = structlog.get_logger() @@ -66,6 +67,8 @@ async def handle_callback_query( "conversation": handle_conversation_callback, "git": handle_git_callback, "export": handle_export_callback, + "model": lambda q, p, ctx: _handle_model_selection(q, f"model:{p}", ctx), + "effort": lambda q, p, ctx: _handle_model_selection(q, f"effort:{p}", ctx), } handler = handlers.get(action) diff --git a/src/bot/handlers/command.py b/src/bot/handlers/command.py index 651a08f8c..668e023cd 100644 --- a/src/bot/handlers/command.py +++ b/src/bot/handlers/command.py @@ -7,7 +7,7 @@ from typing import Optional import structlog -from telegram import InlineKeyboardButton, InlineKeyboardMarkup, Update +from telegram import CallbackQuery, InlineKeyboardButton, InlineKeyboardMarkup, Update from telegram.ext import ContextTypes from ...claude.facade import ClaudeIntegration @@ -174,7 +174,8 @@ async def help_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> No "• /status - Show session and usage status\n" "• /export - Export session history\n" "• /actions - Show context-aware quick actions\n" - "• /git - Git repository information\n\n" + "• /git - Git repository information\n" + "• /model [name] - View or switch Claude model\n\n" "Session Behavior:\n" "• Sessions are automatically maintained per project directory\n" "• Switching directories with /cd resumes the session for that project\n" @@ -1232,6 +1233,165 @@ async def git_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> Non logger.error("Error in git_command", error=str(e), user_id=user_id) +# Short CLI aliases passed directly to the Claude CLI, which resolves them to +# the current latest model of each family. No version numbers to maintain here. +# See: https://docs.anthropic.com/en/docs/about-claude/models/overview +_MODEL_FAMILIES = ["opus", "sonnet", "haiku"] + +# Effort levels per model family. Haiku has none; "max" is Opus-only. +# Update here if a future model's effort support changes. +_EFFORT_BY_MODEL = { + "opus": ["low", "medium", "high", "max"], + "sonnet": ["low", "medium", "high"], + "haiku": [], +} + + +def _current_model_label(context: ContextTypes.DEFAULT_TYPE) -> str: + """Return a human-friendly label for the active model + effort.""" + override = context.user_data.get("model_override") # "opus", "sonnet", "haiku", or None + effort = context.user_data.get("effort_override") + if not override: + settings = context.bot_data.get("settings") + server_model = getattr(settings, "claude_model", None) if settings else None + label = f"Default ({server_model or 'CLI default'})" + else: + label = override.capitalize() + return f"{label} | effort={effort}" if effort else label + + +async def model_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Handle /model command - show model selection keyboard.""" + current = _current_model_label(context) + + keyboard = [ + [ + InlineKeyboardButton("Opus", callback_data="model:opus"), + InlineKeyboardButton("Sonnet", callback_data="model:sonnet"), + InlineKeyboardButton("Haiku", callback_data="model:haiku"), + ], + [InlineKeyboardButton("Reset to default", callback_data="model:default")], + ] + + await update.message.reply_text( + f"🤖 Current: {escape_html(current)}\n\n" + "Choose a model:\n" + "⚠️ Switching will start a new session.", + parse_mode="HTML", + reply_markup=InlineKeyboardMarkup(keyboard), + ) + + +async def _handle_model_selection( + query: CallbackQuery, + data: str, + context: ContextTypes.DEFAULT_TYPE, +) -> None: + """Shared logic for model/effort selection (used by both callback routes).""" + if data.startswith("model:"): + choice = data.split(":", 1)[1] + + if choice == "default": + context.user_data.pop("model_override", None) + context.user_data.pop("effort_override", None) + # Note: if PR #165 merges first, change this to context.chat_data + context.user_data["force_new_session"] = True + await query.edit_message_text( + "🤖 Model and effort reset to server defaults.\n" + "Next message starts a fresh session.", + parse_mode="HTML", + ) + logger.info("Model override cleared", user_id=query.from_user.id) + return + + if choice not in _MODEL_FAMILIES: + await query.edit_message_text("Unknown model.") + return + + # Store short CLI alias ("opus"/"sonnet"/"haiku") — the CLI resolves it + # to the current latest model, so no version numbers to maintain. + context.user_data["model_override"] = choice + # Clear stale effort when switching models + context.user_data.pop("effort_override", None) + # Force new session so the model change takes effect immediately + # Note: if PR #165 merges first, change this to context.chat_data + context.user_data["force_new_session"] = True + + logger.info( + "Model override set", + user_id=query.from_user.id, + model=choice, + ) + + # Show effort level selection (if supported by this model) + effort_levels = _EFFORT_BY_MODEL.get(choice, []) + if not effort_levels: + # Model doesn't support effort (e.g. Haiku) + current = _current_model_label(context) + await query.edit_message_text( + f"🤖 {escape_html(current)} — ready.\n" + "New session will start with your next message.", + parse_mode="HTML", + ) + return + + rows = [] + row = [] + for level in effort_levels: + row.append( + InlineKeyboardButton(level.capitalize(), callback_data=f"effort:{level}") + ) + if len(row) == 2: + rows.append(row) + row = [] + if row: + rows.append(row) + rows.append( + [InlineKeyboardButton("Skip (keep current)", callback_data="effort:skip")] + ) + + await query.edit_message_text( + f"🤖 Model set to {escape_html(choice.capitalize())}.\n\n" + "Choose effort level:", + parse_mode="HTML", + reply_markup=InlineKeyboardMarkup(rows), + ) + + elif data.startswith("effort:"): + level = data.split(":", 1)[1] + + if level == "skip": + current = _current_model_label(context) + await query.edit_message_text( + f"🤖 {escape_html(current)} — ready.\n" + "New session will start with your next message.", + parse_mode="HTML", + ) + return + + all_effort_levels = {"low", "medium", "high", "max"} + if level in all_effort_levels: + context.user_data["effort_override"] = level + current = _current_model_label(context) + await query.edit_message_text( + f"🤖 {escape_html(current)} — ready.\n" + "New session will start with your next message.", + parse_mode="HTML", + ) + logger.info( + "Effort override set", + user_id=query.from_user.id, + effort=level, + ) + + +async def model_callback(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Handle model and effort selection callbacks (agentic mode route).""" + query = update.callback_query + await query.answer() + await _handle_model_selection(query, query.data, context) + + async def restart_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: """Handle /restart command - gracefully restart the bot process. diff --git a/src/bot/handlers/message.py b/src/bot/handlers/message.py index bbd240840..f0d87b7ff 100644 --- a/src/bot/handlers/message.py +++ b/src/bot/handlers/message.py @@ -393,6 +393,8 @@ async def stream_handler(update_obj): session_id=session_id, on_stream=stream_handler, force_new=force_new, + model_override=context.user_data.get("model_override"), + effort_override=context.user_data.get("effort_override"), ) # New session created successfully — clear the one-shot flag @@ -818,6 +820,8 @@ async def handle_document(update: Update, context: ContextTypes.DEFAULT_TYPE) -> working_directory=current_dir, user_id=user_id, session_id=session_id, + model_override=context.user_data.get("model_override"), + effort_override=context.user_data.get("effort_override"), ) # Update session ID @@ -945,6 +949,8 @@ async def handle_photo(update: Update, context: ContextTypes.DEFAULT_TYPE) -> No working_directory=current_dir, user_id=user_id, session_id=session_id, + model_override=context.user_data.get("model_override"), + effort_override=context.user_data.get("effort_override"), ) # Update session ID @@ -1073,6 +1079,8 @@ async def handle_voice(update: Update, context: ContextTypes.DEFAULT_TYPE) -> No working_directory=current_dir, user_id=user_id, session_id=session_id, + model_override=context.user_data.get("model_override"), + effort_override=context.user_data.get("effort_override"), ) context.user_data["claude_session_id"] = claude_response.session_id diff --git a/src/bot/orchestrator.py b/src/bot/orchestrator.py index 6d9719f0d..7540a4ec4 100644 --- a/src/bot/orchestrator.py +++ b/src/bot/orchestrator.py @@ -319,6 +319,13 @@ def register_handlers(self, app: Application) -> None: def _register_agentic_handlers(self, app: Application) -> None: """Register agentic handlers: commands + text/file/photo.""" from .handlers import command + from .dex_handlers import ( + CUSTOM_VERBS, + handle_no, + handle_yes, + make_verb_handler, + ) + from .builder_handlers import handle_builder # Commands handlers = [ @@ -327,8 +334,16 @@ def _register_agentic_handlers(self, app: Application) -> None: ("status", self.agentic_status), ("verbose", self.agentic_verbose), ("repo", self.agentic_repo), + ("model", command.model_command), ("restart", command.restart_command), + # Dex async decision queue + ("yes", handle_yes), + ("no", handle_no), + # Dex Phase 2 Builder + ("builder", handle_builder), ] + for verb in CUSTOM_VERBS: + handlers.append((verb, make_verb_handler(verb))) if self.settings.enable_project_threads: handlers.append(("sync_threads", command.sync_threads)) @@ -388,6 +403,14 @@ def _register_agentic_handlers(self, app: Application) -> None: ) ) + # Model/effort selection callbacks + app.add_handler( + CallbackQueryHandler( + self._inject_deps(command.model_callback), + pattern=r"^(model|effort):", + ) + ) + # Only cd: callbacks (for project selection), scoped by pattern app.add_handler( CallbackQueryHandler( @@ -401,6 +424,13 @@ def _register_agentic_handlers(self, app: Application) -> None: def _register_classic_handlers(self, app: Application) -> None: """Register full classic handler set (moved from core.py).""" from .handlers import callback, command, message + from .dex_handlers import ( + CUSTOM_VERBS, + handle_no, + handle_yes, + make_verb_handler, + ) + from .builder_handlers import handle_builder handlers = [ ("start", command.start_command), @@ -416,8 +446,16 @@ def _register_classic_handlers(self, app: Application) -> None: ("export", command.export_session), ("actions", command.quick_actions), ("git", command.git_command), + ("model", command.model_command), ("restart", command.restart_command), + # Dex async decision queue + ("yes", handle_yes), + ("no", handle_no), + # Dex Phase 2 Builder + ("builder", handle_builder), ] + for verb in CUSTOM_VERBS: + handlers.append((verb, make_verb_handler(verb))) if self.settings.enable_project_threads: handlers.append(("sync_threads", command.sync_threads)) @@ -460,7 +498,11 @@ async def get_bot_commands(self) -> list: # type: ignore[type-arg] BotCommand("status", "Show session status"), BotCommand("verbose", "Set output verbosity (0/1/2)"), BotCommand("repo", "List repos / switch workspace"), + BotCommand("model", "Switch Claude model and effort"), BotCommand("restart", "Restart the bot"), + BotCommand("yes", "Approve a pending Dex decision"), + BotCommand("no", "Reject a pending Dex decision"), + BotCommand("builder", "Builder status/kill/queue"), ] if self.settings.enable_project_threads: commands.append(BotCommand("sync_threads", "Sync project topics")) @@ -480,7 +522,11 @@ async def get_bot_commands(self) -> list: # type: ignore[type-arg] BotCommand("export", "Export current session"), BotCommand("actions", "Show quick actions"), BotCommand("git", "Git repository commands"), + BotCommand("model", "Switch Claude model and effort"), BotCommand("restart", "Restart the bot"), + BotCommand("yes", "Approve a pending Dex decision"), + BotCommand("no", "Reject a pending Dex decision"), + BotCommand("builder", "Builder status/kill/queue"), ] if self.settings.enable_project_threads: commands.append(BotCommand("sync_threads", "Sync project topics")) @@ -549,6 +595,10 @@ async def agentic_new( context.user_data["claude_session_id"] = None context.user_data["session_started"] = True context.user_data["force_new_session"] = True + # Reset session cost tracking — new session, fresh tiers + context.user_data["session_total_cost"] = 0.0 + context.user_data["session_turn_count"] = 0 + context.user_data["cost_warning_tiers"] = set() await update.message.reply_text("Session reset. What's next?") @@ -1014,14 +1064,52 @@ async def agentic_text( on_stream=on_stream, force_new=force_new, interrupt_event=interrupt_event, + model_override=context.user_data.get("model_override"), + effort_override=context.user_data.get("effort_override"), ) # New session created successfully — clear the one-shot flag if force_new: context.user_data["force_new_session"] = False + # Reset cost tracking for the brand-new session + context.user_data["session_total_cost"] = 0.0 + context.user_data["session_turn_count"] = 0 + context.user_data["cost_warning_tiers"] = set() context.user_data["claude_session_id"] = claude_response.session_id + # Track cumulative session cost + fire tiered warnings (fire-once-per-tier). + # Resets on /new or model swap (force_new path above). + if self.settings.enable_cost_warnings: + cumulative = ( + context.user_data.get("session_total_cost", 0.0) + + claude_response.cost + ) + turn_count = context.user_data.get("session_turn_count", 0) + 1 + context.user_data["session_total_cost"] = cumulative + context.user_data["session_turn_count"] = turn_count + + fired_tiers = context.user_data.setdefault( + "cost_warning_tiers", set() + ) + # Fire the lowest unfired tier that we've crossed (one warning per turn). + for tier in sorted(self.settings.session_cost_tiers): + if cumulative >= tier and tier not in fired_tiers: + fired_tiers.add(tier) + try: + await update.message.reply_text( + f"⚠️ Session at ${cumulative:.2f} / " + f"{turn_count} turns. Consider /new." + ) + except Exception as e: + logger.warning( + "Failed to send cost warning", + error=str(e), + tier=tier, + cumulative=cumulative, + ) + break + # Track directory changes from .handlers.message import _update_working_directory_from_claude_response @@ -1264,6 +1352,8 @@ async def agentic_document( session_id=session_id, on_stream=on_stream, force_new=force_new, + model_override=context.user_data.get("model_override"), + effort_override=context.user_data.get("effort_override"), ) if force_new: @@ -1474,6 +1564,8 @@ async def _handle_agentic_media_message( on_stream=on_stream, force_new=force_new, images=images, + model_override=context.user_data.get("model_override"), + effort_override=context.user_data.get("effort_override"), ) finally: heartbeat.cancel() diff --git a/src/claude/facade.py b/src/claude/facade.py index b1cafba49..03b416c6a 100644 --- a/src/claude/facade.py +++ b/src/claude/facade.py @@ -40,6 +40,8 @@ async def run_command( force_new: bool = False, interrupt_event: Optional["asyncio.Event"] = None, images: Optional[List[Dict[str, str]]] = None, + model_override: Optional[str] = None, + effort_override: Optional[str] = None, ) -> ClaudeResponse: """Run Claude Code command with full integration.""" logger.info( @@ -49,6 +51,8 @@ async def run_command( session_id=session_id, prompt_length=len(prompt), force_new=force_new, + model_override=model_override, + effort_override=effort_override, ) # If no session_id provided, try to find an existing session for this @@ -90,6 +94,8 @@ async def run_command( stream_callback=on_stream, interrupt_event=interrupt_event, images=images, + model_override=model_override, + effort_override=effort_override, ) except Exception as resume_error: # If resume failed (e.g., session expired/missing on Claude's side), @@ -116,6 +122,8 @@ async def run_command( stream_callback=on_stream, interrupt_event=interrupt_event, images=images, + model_override=model_override, + effort_override=effort_override, ) else: raise @@ -161,6 +169,8 @@ async def _execute( stream_callback: Optional[Callable] = None, interrupt_event: Optional[asyncio.Event] = None, images: Optional[List[Dict[str, str]]] = None, + model_override: Optional[str] = None, + effort_override: Optional[str] = None, ) -> ClaudeResponse: """Execute command via SDK.""" return await self.sdk_manager.execute_command( @@ -171,6 +181,8 @@ async def _execute( stream_callback=stream_callback, interrupt_event=interrupt_event, images=images, + model_override=model_override, + effort_override=effort_override, ) async def _find_resumable_session( diff --git a/src/claude/sdk_integration.py b/src/claude/sdk_integration.py index 5a95f16da..cdf570550 100644 --- a/src/claude/sdk_integration.py +++ b/src/claude/sdk_integration.py @@ -277,6 +277,8 @@ async def execute_command( stream_callback: Optional[Callable[[StreamUpdate], None]] = None, interrupt_event: Optional[asyncio.Event] = None, images: Optional[List[Dict[str, str]]] = None, + model_override: Optional[str] = None, + effort_override: Optional[str] = None, ) -> ClaudeResponse: """Execute Claude Code command via SDK.""" start_time = asyncio.get_event_loop().time() @@ -321,7 +323,7 @@ def _stderr_callback(line: str) -> None: # Build Claude Agent options options = ClaudeAgentOptions( max_turns=self.config.claude_max_turns, - model=self.config.claude_model or None, + model=model_override or self.config.claude_model or None, max_budget_usd=self.config.claude_max_cost_per_request, cwd=str(working_directory), allowed_tools=sdk_allowed_tools, @@ -334,10 +336,18 @@ def _stderr_callback(line: str) -> None: "excludedCommands": self.config.sandbox_excluded_commands or [], }, system_prompt=base_prompt, + effort=effort_override, setting_sources=["project"], stderr=_stderr_callback, ) + if model_override or effort_override: + logger.info( + "Runtime model/effort override active", + model_override=model_override, + effort_override=effort_override, + ) + # Pass MCP server configuration if enabled if self.config.enable_mcp and self.config.mcp_config_path: options.mcp_servers = self._load_mcp_config(self.config.mcp_config_path) @@ -521,6 +531,20 @@ async def _cancel_on_interrupt() -> None: cost = getattr(message, "total_cost_usd", 0.0) or 0.0 claude_session_id = getattr(message, "session_id", None) result_content = getattr(message, "result", None) + # Log actual model returned by Claude (proof of model swap). + # ResultMessage.model_usage is a dict keyed by model name, + # e.g. {"claude-opus-4-7": {...token counts...}}. + model_usage = getattr(message, "model_usage", None) + actual_models = ( + list(model_usage.keys()) + if isinstance(model_usage, dict) and model_usage + else [] + ) + logger.info( + "Turn complete — actual model from Claude", + actual_models=actual_models, + cost_usd=cost, + ) current_time = asyncio.get_event_loop().time() for msg in messages: if isinstance(msg, AssistantMessage): diff --git a/src/config/settings.py b/src/config/settings.py index c4f7cb18b..4003602f9 100644 --- a/src/config/settings.py +++ b/src/config/settings.py @@ -10,10 +10,10 @@ import json from pathlib import Path -from typing import Any, List, Literal, Optional +from typing import Annotated, Any, List, Literal, Optional from pydantic import Field, SecretStr, field_validator, model_validator -from pydantic_settings import BaseSettings, SettingsConfigDict +from pydantic_settings import BaseSettings, NoDecode, SettingsConfigDict from src.utils.constants import ( DEFAULT_CLAUDE_MAX_COST_PER_REQUEST, @@ -330,6 +330,21 @@ class Settings(BaseSettings): ge=0.0, ) + # Session cost threshold warnings — alert user when cumulative session cost + # crosses configured tiers. Fires once per tier per session, resets on /new. + # Helps avoid silent Claude Max quota burn from long-running sessions. + enable_cost_warnings: bool = Field( + True, + description="Enable per-session cumulative cost threshold warnings", + ) + session_cost_tiers: Annotated[List[float], NoDecode] = Field( + default=[5.0, 10.0, 20.0], + description=( + "Cost thresholds (USD) at which to warn per-session. " + "Comma-separated in env: SESSION_COST_TIERS=5,10,20" + ), + ) + model_config = SettingsConfigDict( env_file=".env", env_file_encoding="utf-8", case_sensitive=False, extra="ignore" ) @@ -348,6 +363,20 @@ def parse_int_list(cls, v: Any) -> Optional[List[int]]: return [int(uid) for uid in v] return v # type: ignore[no-any-return] + @field_validator("session_cost_tiers", mode="before") + @classmethod + def parse_float_list(cls, v: Any) -> List[float]: + """Parse comma-separated float lists for cost tiers.""" + if v is None: + return [5.0, 10.0, 20.0] + if isinstance(v, (int, float)): + return [float(v)] + if isinstance(v, str): + return [float(t.strip()) for t in v.split(",") if t.strip()] + if isinstance(v, list): + return [float(t) for t in v] + return v # type: ignore[no-any-return] + @field_validator("claude_allowed_tools", mode="before") @classmethod def parse_claude_allowed_tools(cls, v: Any) -> Optional[List[str]]: diff --git a/tests/unit/test_bot/test_builder_handlers.py b/tests/unit/test_bot/test_builder_handlers.py new file mode 100644 index 000000000..d513d1fb8 --- /dev/null +++ b/tests/unit/test_bot/test_builder_handlers.py @@ -0,0 +1,30 @@ +import pytest +from unittest.mock import AsyncMock, MagicMock +from src.bot.builder_handlers import handle_builder_status, handle_builder_kill, handle_builder_queue + +@pytest.fixture +def builder_dirs(tmp_path, monkeypatch): + monkeypatch.setenv("BUILDER_DATA_DIR", str(tmp_path)) + (tmp_path / "state.json").write_text('{"stage":"IMPLEMENT","ticket_id":"B001","ticks_used":3,"implement_attempts":0}', encoding="utf-8") + (tmp_path / "queue").mkdir() + (tmp_path / "queue" / "B002.md").write_text("---\nid: B002\n---\n", encoding="utf-8") + return tmp_path + +@pytest.mark.asyncio +async def test_status_reports_inflight(builder_dirs): + u = MagicMock(); u.message.reply_text = AsyncMock(); ctx = MagicMock() + await handle_builder_status(u, ctx) + msg = u.message.reply_text.call_args[0][0] + assert "B001" in msg and "IMPLEMENT" in msg + +@pytest.mark.asyncio +async def test_kill_writes_kill_flag(builder_dirs): + u = MagicMock(); u.message.reply_text = AsyncMock(); ctx = MagicMock(); ctx.args = ["B001"] + await handle_builder_kill(u, ctx) + assert "kill_requested" in (builder_dirs / "state.json").read_text(encoding="utf-8") + +@pytest.mark.asyncio +async def test_queue_lists_pending(builder_dirs): + u = MagicMock(); u.message.reply_text = AsyncMock(); ctx = MagicMock() + await handle_builder_queue(u, ctx) + assert "B002" in u.message.reply_text.call_args[0][0] diff --git a/tests/unit/test_bot/test_dex_handlers.py b/tests/unit/test_bot/test_dex_handlers.py new file mode 100644 index 000000000..6159d5aec --- /dev/null +++ b/tests/unit/test_bot/test_dex_handlers.py @@ -0,0 +1,177 @@ +"""Tests for Dex decision-queue command handlers. + +Covers: +- /yes appends a Resolution block with status: resolved-yes +- /no appends a Resolution block with status: resolved-no +- /yes replies with a not-found message +- / appends a Resolution block with status: resolved- +- /yes /now triggers _fire_now_dex (mocked) +""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from src.bot.dex_handlers import handle_custom_verb, handle_no, handle_yes + + +@pytest.fixture +def pending_dir(tmp_path, monkeypatch): + """Point dex_handlers at a tmp dir and seed it with D042.md.""" + monkeypatch.setenv("DEX_PENDING_DIR", str(tmp_path)) + decision = tmp_path / "D042.md" + decision.write_text( + "---\nid: D042\ntype: approval\nstatus: pending\n---\n\n" + "## Question\nTest?\n", + encoding="utf-8", + ) + return tmp_path + + +@pytest.mark.asyncio +async def test_yes_appends_resolution(pending_dir): + update = MagicMock() + update.message.reply_text = AsyncMock() + context = MagicMock() + context.args = ["D042"] + await handle_yes(update, context) + content = (pending_dir / "D042.md").read_text(encoding="utf-8") + assert "## Resolution" in content + assert "status: resolved-yes" in content + update.message.reply_text.assert_called_once() + + +@pytest.mark.asyncio +async def test_no_appends_resolution(pending_dir): + update = MagicMock() + update.message.reply_text = AsyncMock() + context = MagicMock() + context.args = ["D042"] + await handle_no(update, context) + content = (pending_dir / "D042.md").read_text(encoding="utf-8") + assert "status: resolved-no" in content + + +@pytest.mark.asyncio +async def test_yes_unknown_id_replies_error(pending_dir): + update = MagicMock() + update.message.reply_text = AsyncMock() + context = MagicMock() + context.args = ["D999"] + await handle_yes(update, context) + update.message.reply_text.assert_called_once() + msg = update.message.reply_text.call_args[0][0] + assert "not found" in msg.lower() + + +@pytest.mark.asyncio +async def test_revive_custom_verb(pending_dir): + (pending_dir / "D050.md").write_text( + "---\nid: D050\ntype: approval\nstatus: pending\n---\n", + encoding="utf-8", + ) + update = MagicMock() + update.message.reply_text = AsyncMock() + context = MagicMock() + context.args = ["D050"] + await handle_custom_verb("revive", update, context) + content = (pending_dir / "D050.md").read_text(encoding="utf-8") + assert "status: resolved-revive" in content + + +@pytest.mark.asyncio +async def test_yes_now_invokes_fire_now(pending_dir, monkeypatch): + calls = [] + + async def fake_fire(): + calls.append("fired") + + monkeypatch.setattr( + "src.bot.dex_handlers._fire_now_dex", fake_fire + ) + update = MagicMock() + update.message.reply_text = AsyncMock() + context = MagicMock() + context.args = ["D042", "/now"] + await handle_yes(update, context) + assert calls == ["fired"] + content = (pending_dir / "D042.md").read_text(encoding="utf-8") + assert "fire-now: true" in content + + +@pytest.mark.asyncio +async def test_yes_approval_echo_shows_project(pending_dir): + (pending_dir / "D042.md").write_text( + "---\nid: D042\ntype: approval\nproject: The Grove\nstatus: pending\n---\n", + encoding="utf-8", + ) + update = MagicMock() + update.message.reply_text = AsyncMock() + context = MagicMock() + context.args = ["D042"] + await handle_yes(update, context) + msg = update.message.reply_text.call_args[0][0] + assert "D042" in msg + assert "The Grove" in msg + assert "KEEP ACTIVE" in msg + + +@pytest.mark.asyncio +async def test_pause_echo_shows_meaning(pending_dir): + (pending_dir / "D050.md").write_text( + "---\nid: D050\ntype: approval\nproject: JAVLab\nstatus: pending\n---\n", + encoding="utf-8", + ) + update = MagicMock() + update.message.reply_text = AsyncMock() + context = MagicMock() + context.args = ["D050"] + await handle_custom_verb("pause", update, context) + msg = update.message.reply_text.call_args[0][0] + assert "JAVLab" in msg + assert "PAUSE" in msg + assert "/yes D050" in msg # reverse hint present + + +@pytest.mark.asyncio +async def test_yes_dispatch_echo_shows_agent(pending_dir): + (pending_dir / "D100.md").write_text( + "---\nid: D100\ntype: dispatch\nagent: writer\nstatus: pending\n---\n\n## Task\nWrite jcode synthesis note from session log.\n", + encoding="utf-8", + ) + update = MagicMock() + update.message.reply_text = AsyncMock() + context = MagicMock() + context.args = ["D100"] + await handle_yes(update, context) + msg = update.message.reply_text.call_args[0][0] + assert "D100" in msg + assert "DISPATCH" in msg + assert "writer" in msg.lower() or "Writer" in msg + + +@pytest.mark.asyncio +async def test_yes_now_echo_says_immediately(pending_dir, monkeypatch): + monkeypatch.setattr("src.bot.dex_handlers._fire_now_dex", AsyncMock()) + (pending_dir / "D200.md").write_text( + "---\nid: D200\ntype: dispatch\nagent: writer\nstatus: pending\n---\n", + encoding="utf-8", + ) + update = MagicMock() + update.message.reply_text = AsyncMock() + context = MagicMock() + context.args = ["D200", "/now"] + await handle_yes(update, context) + msg = update.message.reply_text.call_args[0][0] + assert "IMMEDIATELY" in msg or "1 min" in msg + + +@pytest.mark.asyncio +async def test_no_echo_says_skip(pending_dir): + update = MagicMock() + update.message.reply_text = AsyncMock() + context = MagicMock() + context.args = ["D042"] + await handle_no(update, context) + msg = update.message.reply_text.call_args[0][0] + assert "SKIP" in msg or "skip" in msg diff --git a/tests/unit/test_bot/test_model_command.py b/tests/unit/test_bot/test_model_command.py new file mode 100644 index 000000000..681669f99 --- /dev/null +++ b/tests/unit/test_bot/test_model_command.py @@ -0,0 +1,296 @@ +"""Tests for the /model command — runtime model and effort switching. + +Covers: +- /model shows inline keyboard with model choices +- Model selection sets model_override and force_new_session +- Effort selection sets effort_override +- "default" clears all overrides +- Haiku skips effort keyboard (not supported) +- Opus shows "max" effort, Sonnet does not +- _current_model_label returns correct labels +- Regression: model: callback data prefix is never rewritten as effort: +- force_new_session is always set on model change +""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest +from telegram import InlineKeyboardMarkup + +from src.bot.handlers.command import ( + _EFFORT_BY_MODEL, + _MODEL_FAMILIES, + _current_model_label, + _handle_model_selection, + model_command, +) +from src.config.settings import Settings + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def settings(tmp_path): + return Settings( + telegram_bot_token="test:token", + telegram_bot_username="testbot", + approved_directory=tmp_path, + ) + + +@pytest.fixture +def context(settings): + ctx = MagicMock() + ctx.user_data = {} + ctx.bot_data = {"settings": settings} + ctx.args = None + return ctx + + +@pytest.fixture +def update(context): + upd = MagicMock() + upd.message = AsyncMock() + upd.effective_user.id = 12345 + return upd + + +@pytest.fixture +def callback_query(): + query = MagicMock() + query.answer = AsyncMock() + query.edit_message_text = AsyncMock() + query.from_user.id = 12345 + return query + + +# --------------------------------------------------------------------------- +# /model command (keyboard display) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_model_command_shows_keyboard(update, context): + """Verify /model sends an inline keyboard with model choices.""" + await model_command(update, context) + + update.message.reply_text.assert_called_once() + call_kwargs = update.message.reply_text.call_args + assert isinstance(call_kwargs.kwargs["reply_markup"], InlineKeyboardMarkup) + # Should contain the session warning + assert "new session" in call_kwargs.args[0].lower() + + +@pytest.mark.asyncio +async def test_model_command_shows_current_override(update, context): + """When an override is active, /model should show it.""" + context.user_data["model_override"] = "sonnet" + await model_command(update, context) + + text = update.message.reply_text.call_args.args[0] + assert "Sonnet" in text + + +# --------------------------------------------------------------------------- +# Model selection callback +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_select_opus_sets_override(callback_query, context): + """Selecting Opus sets model_override to short alias and force_new_session.""" + await _handle_model_selection(callback_query, "model:opus", context) + + assert context.user_data["model_override"] == "opus" + assert context.user_data["force_new_session"] is True + + +@pytest.mark.asyncio +async def test_select_sonnet_sets_override(callback_query, context): + """Selecting Sonnet sets the correct short alias.""" + await _handle_model_selection(callback_query, "model:sonnet", context) + + assert context.user_data["model_override"] == "sonnet" + + +@pytest.mark.asyncio +async def test_select_haiku_skips_effort(callback_query, context): + """Selecting Haiku should not show effort keyboard (not supported).""" + await _handle_model_selection(callback_query, "model:haiku", context) + + assert context.user_data["model_override"] == "haiku" + # Final message, no reply_markup (no effort keyboard) + call_kwargs = callback_query.edit_message_text.call_args + assert "reply_markup" not in call_kwargs.kwargs or call_kwargs.kwargs.get("reply_markup") is None + assert "ready" in call_kwargs.args[0].lower() + + +@pytest.mark.asyncio +async def test_select_opus_shows_effort_with_max(callback_query, context): + """Opus should show effort keyboard including 'max'.""" + await _handle_model_selection(callback_query, "model:opus", context) + + call_kwargs = callback_query.edit_message_text.call_args + markup = call_kwargs.kwargs.get("reply_markup") + assert markup is not None + # Flatten button labels + labels = [btn.text for row in markup.inline_keyboard for btn in row] + assert "Max" in labels + assert "effort" in call_kwargs.args[0].lower() + + +@pytest.mark.asyncio +async def test_select_sonnet_shows_effort_without_max(callback_query, context): + """Sonnet should show effort keyboard without 'max'.""" + await _handle_model_selection(callback_query, "model:sonnet", context) + + call_kwargs = callback_query.edit_message_text.call_args + markup = call_kwargs.kwargs.get("reply_markup") + assert markup is not None + labels = [btn.text for row in markup.inline_keyboard for btn in row] + assert "Max" not in labels + assert "High" in labels + + +@pytest.mark.asyncio +async def test_default_clears_overrides(callback_query, context): + """Selecting 'default' clears model, effort, and forces new session.""" + context.user_data["model_override"] = "opus" + context.user_data["effort_override"] = "high" + + await _handle_model_selection(callback_query, "model:default", context) + + assert "model_override" not in context.user_data + assert "effort_override" not in context.user_data + assert context.user_data["force_new_session"] is True + + +# --------------------------------------------------------------------------- +# Effort selection callback +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_effort_sets_override(callback_query, context): + """Selecting an effort level stores it in user_data.""" + context.user_data["model_override"] = "opus" + + await _handle_model_selection(callback_query, "effort:high", context) + + assert context.user_data["effort_override"] == "high" + + +@pytest.mark.asyncio +async def test_effort_skip_keeps_existing(callback_query, context): + """Selecting 'skip' should not set effort_override.""" + context.user_data["model_override"] = "sonnet" + + await _handle_model_selection(callback_query, "effort:skip", context) + + assert "effort_override" not in context.user_data + + +@pytest.mark.asyncio +async def test_model_switch_clears_stale_effort(callback_query, context): + """Switching models should clear any previous effort override.""" + context.user_data["effort_override"] = "high" + + await _handle_model_selection(callback_query, "model:haiku", context) + + assert "effort_override" not in context.user_data + + +# --------------------------------------------------------------------------- +# Regression: callback data prefix integrity (closure-bug guard) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_model_callback_sets_model_not_effort(callback_query, context): + """model: callback must set model_override, not effort_override. + + Regression guard: a shared closure capturing 'action' from outer scope + could silently rewrite 'model:opus' as 'effort:opus'. This test would + have caught that. + """ + await _handle_model_selection(callback_query, "model:sonnet", context) + + assert context.user_data.get("model_override") == "sonnet" + assert "effort_override" not in context.user_data + + +@pytest.mark.asyncio +async def test_effort_callback_sets_effort_not_model(callback_query, context): + """effort: callback must set effort_override and not overwrite model_override.""" + context.user_data["model_override"] = "sonnet" + + await _handle_model_selection(callback_query, "effort:high", context) + + assert context.user_data.get("effort_override") == "high" + assert context.user_data.get("model_override") == "sonnet" # unchanged + + +@pytest.mark.asyncio +async def test_force_new_session_set_on_model_switch(callback_query, context): + """force_new_session must be True after any model switch.""" + await _handle_model_selection(callback_query, "model:opus", context) + + assert context.user_data.get("force_new_session") is True + + +# --------------------------------------------------------------------------- +# Label helper +# --------------------------------------------------------------------------- + + +def test_label_default_no_settings(): + ctx = MagicMock() + ctx.user_data = {} + ctx.bot_data = {} + assert _current_model_label(ctx) == "Default (CLI default)" + + +def test_label_default_with_server_model(): + ctx = MagicMock() + ctx.user_data = {} + ctx.bot_data = {"settings": MagicMock(claude_model="claude-sonnet-4-6")} + assert _current_model_label(ctx) == "Default (claude-sonnet-4-6)" + + +def test_label_with_model_and_effort(): + ctx = MagicMock() + ctx.user_data = {"model_override": "sonnet", "effort_override": "medium"} + assert _current_model_label(ctx) == "Sonnet | effort=medium" + + +def test_label_model_only(): + ctx = MagicMock() + ctx.user_data = {"model_override": "opus"} + assert _current_model_label(ctx) == "Opus" + + +# --------------------------------------------------------------------------- +# Effort level configuration +# --------------------------------------------------------------------------- + + +def test_haiku_has_no_effort_levels(): + assert _EFFORT_BY_MODEL["haiku"] == [] + + +def test_sonnet_has_no_max(): + assert "max" not in _EFFORT_BY_MODEL["sonnet"] + assert "high" in _EFFORT_BY_MODEL["sonnet"] + + +def test_opus_has_max(): + assert "max" in _EFFORT_BY_MODEL["opus"] + + +def test_model_families_contains_expected(): + assert "opus" in _MODEL_FAMILIES + assert "sonnet" in _MODEL_FAMILIES + assert "haiku" in _MODEL_FAMILIES diff --git a/tests/unit/test_orchestrator.py b/tests/unit/test_orchestrator.py index ce5e419e9..30507182c 100644 --- a/tests/unit/test_orchestrator.py +++ b/tests/unit/test_orchestrator.py @@ -82,8 +82,13 @@ def deps(): } -def test_agentic_registers_6_commands(agentic_settings, deps): - """Agentic mode registers start, new, status, verbose, repo, restart commands.""" +def test_agentic_registers_15_commands(agentic_settings, deps): + """Agentic mode registers core commands + Dex decision-queue handlers. + + Core: start, new, status, verbose, repo, model, restart (7). + Dex: yes, no, revive, archive, fold, delete, pause, resume (8). + Total: 15. + """ orchestrator = MessageOrchestrator(agentic_settings, deps) app = MagicMock() app.add_handler = MagicMock() @@ -100,17 +105,22 @@ def test_agentic_registers_6_commands(agentic_settings, deps): ] commands = [h[0][0].commands for h in cmd_handlers] - assert len(cmd_handlers) == 6 + assert len(cmd_handlers) == 15 assert frozenset({"start"}) in commands assert frozenset({"new"}) in commands assert frozenset({"status"}) in commands assert frozenset({"verbose"}) in commands assert frozenset({"repo"}) in commands + assert frozenset({"model"}) in commands assert frozenset({"restart"}) in commands + assert frozenset({"yes"}) in commands + assert frozenset({"no"}) in commands + for verb in ("revive", "archive", "fold", "delete", "pause", "resume"): + assert frozenset({verb}) in commands -def test_classic_registers_14_commands(classic_settings, deps): - """Classic mode registers all 14 commands.""" +def test_classic_registers_23_commands(classic_settings, deps): + """Classic mode registers all 15 classic commands + 8 Dex handlers = 23.""" orchestrator = MessageOrchestrator(classic_settings, deps) app = MagicMock() app.add_handler = MagicMock() @@ -125,7 +135,7 @@ def test_classic_registers_14_commands(classic_settings, deps): if isinstance(call[0][0], CommandHandler) ] - assert len(cmd_handlers) == 14 + assert len(cmd_handlers) == 23 def test_agentic_registers_text_document_photo_handlers(agentic_settings, deps): @@ -151,31 +161,44 @@ def test_agentic_registers_text_document_photo_handlers(agentic_settings, deps): # 5 message handlers (text, document, photo, voice, unknown commands passthrough) assert len(msg_handlers) == 5 - # 2 callback handlers (stop: + cd:) - assert len(cb_handlers) == 2 + # 3 callback handlers (stop: + model/effort: + cd:) + assert len(cb_handlers) == 3 async def test_agentic_bot_commands(agentic_settings, deps): - """Agentic mode returns 6 bot commands.""" + """Agentic mode returns 9 bot commands (7 core + /yes + /no).""" orchestrator = MessageOrchestrator(agentic_settings, deps) commands = await orchestrator.get_bot_commands() - assert len(commands) == 6 + assert len(commands) == 9 cmd_names = [c.command for c in commands] - assert cmd_names == ["start", "new", "status", "verbose", "repo", "restart"] + assert cmd_names == [ + "start", + "new", + "status", + "verbose", + "repo", + "model", + "restart", + "yes", + "no", + ] async def test_classic_bot_commands(classic_settings, deps): - """Classic mode returns 14 bot commands.""" + """Classic mode returns 17 bot commands (15 classic + /yes + /no).""" orchestrator = MessageOrchestrator(classic_settings, deps) commands = await orchestrator.get_bot_commands() - assert len(commands) == 14 + assert len(commands) == 17 cmd_names = [c.command for c in commands] assert "start" in cmd_names assert "help" in cmd_names assert "git" in cmd_names + assert "model" in cmd_names assert "restart" in cmd_names + assert "yes" in cmd_names + assert "no" in cmd_names async def test_restart_command_sends_sigterm(deps): @@ -338,7 +361,7 @@ async def test_agentic_callback_scoped_to_cd_pattern(agentic_settings, deps): if isinstance(call[0][0], CallbackQueryHandler) ] - assert len(cb_handlers) == 2 + assert len(cb_handlers) == 3 # Find the cd: handler by pattern cd_handler = [h for h in cb_handlers if h.pattern and h.pattern.match("cd:x")] assert len(cd_handler) == 1 diff --git a/tests/unit/test_security/test_audit.py b/tests/unit/test_security/test_audit.py index bc719bdff..de3e75140 100644 --- a/tests/unit/test_security/test_audit.py +++ b/tests/unit/test_security/test_audit.py @@ -331,17 +331,18 @@ async def test_log_command_risk_assessment(self, audit_logger, storage): user_id=123, command="rm", args=["-rf", "/tmp/test"], success=True ) - events = await storage.get_events() - high_risk_event = events[0] - assert high_risk_event.risk_level == "high" - # Test low-risk command await audit_logger.log_command( user_id=123, command="echo", args=["hello"], success=True ) events = await storage.get_events() - low_risk_event = events[0] # Most recent + # Filter by command content rather than list position: two events can + # share the same microsecond timestamp, making sort order ambiguous. + high_risk_event = next(e for e in events if e.details["command"] == "rm") + assert high_risk_event.risk_level == "high" + + low_risk_event = next(e for e in events if e.details["command"] == "echo") assert low_risk_event.risk_level == "low" async def test_log_file_access(self, audit_logger, storage): @@ -369,17 +370,22 @@ async def test_log_file_access_risk_assessment(self, audit_logger, storage): user_id=123, file_path="/etc/passwd", action="delete", success=True ) - events = await storage.get_events() - high_risk_event = events[0] - assert high_risk_event.risk_level == "high" - # Low-risk: read normal file await audit_logger.log_file_access( user_id=123, file_path="/projects/readme.txt", action="read", success=True ) events = await storage.get_events() - low_risk_event = events[0] # Most recent + # Filter by file path rather than list position: two events can share + # the same microsecond timestamp, making sort order ambiguous. + high_risk_event = next( + e for e in events if e.details["file_path"] == "/etc/passwd" + ) + assert high_risk_event.risk_level == "high" + + low_risk_event = next( + e for e in events if e.details["file_path"] == "/projects/readme.txt" + ) assert low_risk_event.risk_level == "low" async def test_log_security_violation(self, audit_logger, storage): diff --git a/tests/unit/test_security/test_auth.py b/tests/unit/test_security/test_auth.py index ba28ceb98..26b17908a 100644 --- a/tests/unit/test_security/test_auth.py +++ b/tests/unit/test_security/test_auth.py @@ -249,6 +249,12 @@ async def test_session_management(self, auth_manager): assert session is not None assert session.user_id == user_id + # Backdate the session so refresh produces a strictly later + # last_activity: without this, refresh can land in the same microsecond + # as creation, making last_activity == created_at and the comparison flaky. + session.created_at = datetime.now(UTC) - timedelta(seconds=1) + session.last_activity = session.created_at + # Refresh session old_activity = session.last_activity result = auth_manager.refresh_session(user_id)