diff --git a/memory-bank/activeContext.md b/memory-bank/activeContext.md index fefb0b6..15e18e7 100644 --- a/memory-bank/activeContext.md +++ b/memory-bank/activeContext.md @@ -1,37 +1,72 @@ # Active Context -**Last Updated**: 2026-03-07 -**Current Phase**: Post v0.6.0 — z-index fix + GPT-5.4 + .env improvements -**Next Action**: PR ready for review +**Last Updated**: 2026-03-08 +**Current Phase**: `question-refinement` branch — pre-consensus question refinement, native web search, citations, tools-by-default +**Next Action**: Branch in progress, uncommitted changes staged -## Latest Work (2026-03-07) +## Latest Work (2026-03-08) -### Z-index stacking context fix -- **Problem**: Nested stacking contexts (`z-10` on main content, `z-20` on TopBar header) trapped dropdowns inside containers. Account menu's `fixed inset-0 z-40` backdrop was meaningless outside its container. -- **Fix**: Removed unnecessary z-index values creating stacking contexts, added `isolate` to Shell root, defined z-index tokens in CSS (`--z-background`, `--z-dropdown`, `--z-overlay`, `--z-modal`), replaced backdrop hack with `useRef` + `mousedown` click-outside pattern (matching ExportMenu). -- Files: `duh-theme.css`, `Shell.tsx`, `TopBar.tsx`, `GridOverlay.tsx`, `ParticleField.tsx`, `ExportMenu.tsx`, `ConsensusComplete.tsx`, `ThreadDetail.tsx` +### Question Refinement +- Pre-consensus clarification step: analyze question → ask clarifying questions → enrich with answers → proceed to consensus +- `src/duh/consensus/refine.py` — `analyze_question()` + `enrich_question()`, uses MOST EXPENSIVE model (not cheapest) +- API: `POST /api/refine` → `RefineResponse{needs_refinement, questions[]}`, `POST /api/enrich` → `EnrichResponse{enriched_question}` +- CLI: `duh ask --refine "question"` — interactive `click.prompt()` loop, default `--no-refine` +- Frontend: consensus store `'refining'` status, `submitQuestion` → refine → clarify → enrich → `startConsensus` +- `RefinementPanel.tsx` — tabbed UI inside GlassPanel, checkmarks on answered tabs, Skip + Start Consensus buttons +- Graceful fallback: any failure → proceed to consensus with original question -### GPT-5.4 added to model catalog -- `gpt-5.4`: 1M context, 128K output, $2.50/$15.00 per MTok, no temperature (uses reasoning.effort) -- Added to `NO_TEMPERATURE_MODELS` set -- File: `src/duh/providers/catalog.py` +### Native Provider Web Search +- Providers use server-side search instead of DDG proxy when `config.tools.web_search.native` is true +- `web_search: bool` param added to `ModelProvider.send()` protocol +- Anthropic: `web_search_20250305` server tool in tools[] +- Google: `GoogleSearch()` grounding (replaces function tools — can't coexist) +- Mistral: `{"type": "web_search"}` appended to tools +- OpenAI: `web_search_options={}` only for `_SEARCH_MODELS` set; others fall back to DDG +- Perplexity: no-op (always searches natively) +- `tool_augmented_send`: filters DDG `web_search` tool when native=True, passes flag to provider -### .env improvements -- Added provider API key placeholders to `.env.example` (ANTHROPIC, OPENAI, GOOGLE, PERPLEXITY, MISTRAL) -- Updated README quick start with all provider env vars + `.env` reference -- Note: Google key env var is `GOOGLE_API_KEY` (not `GEMINI_API_KEY`) +### Citations — Persisted + Domain-Grouped +- `Citation` dataclass (url, title, snippet) on `ModelResponse.citations` +- Extraction per provider: Anthropic (`web_search_tool_result`), Google (grounding metadata), Perplexity (`response.citations`) +- **Persistence**: `citations_json` TEXT column on `Contribution` model, SQLite auto-migration via `ensure_schema()` +- `proposal_citations` tracked on `ConsensusContext` → archived to `RoundResult` → persisted via `_persist_consensus` +- Thread detail API returns `citations` on `ContributionResponse` +- **Domain-grouped Sources nav**: ConsensusNav (live) + ThreadNav (stored) group citations by hostname + - Nested Disclosure: outer "Sources (17)" → inner "wikipedia.org (3)" → P/C/R role badges per citation + - P (green) = propose, C (amber) = challenge, R (blue) = revise +- `CitationList` shared component for inline display below content + +### Anthropic Streaming + max_tokens +- `AnthropicProvider.send()` now uses streaming internally via `_collect_stream()` — avoids 10-minute timeout +- `max_tokens` bumped from 16384 → 32768 across all 6 handler defaults (propose, challenge, revise, commit, voting, decomposition) +- Citations are part of the value — truncating them undermines trust + +### Parallel Challenge Streaming +- `_stream_challenges()` in `ws.py` uses `asyncio.as_completed()` to send each challenge result to the frontend as it finishes +- Previously: all challengers ran in parallel but results were batched after all completed +- Now: first challenger to respond appears immediately in the UI + +### Tools Enabled by Default +- `web_search` tool wired through CLI, REST, and WebSocket paths by default +- Provider tool format fix: `tool_augmented_send` builds generic `{name, description, parameters}` — each provider transforms to native format in `send()` + +### Sidebar UX +- New-question button (Heroicons pencil-square) + collapsible sidebar toggle +- Shell manages `desktopSidebarOpen` (default true) + `mobileSidebarOpen` separately +- TopBar shows sidebar toggle when desktop sidebar collapsed or always on mobile ### Test Results -- 1603 Python tests + 185 Vitest tests (1788 total) +- 1641 Python tests + 194 Vitest tests (1835 total) - Build clean, all tests pass --- ## Current State -- **Branch `ux-cleanup`** — z-index fix, GPT-5.4, .env docs -- **1603 Python tests + 185 Vitest tests** (1788 total) +- **Branch `question-refinement`** — in progress, not yet merged +- **1641 Python tests + 194 Vitest tests** (1835 total) - All previous features intact (v0.1–v0.6) +- Prior work merged: z-index fix, GPT-5.4, .env docs, password reset ## Open Questions (Still Unresolved) diff --git a/memory-bank/decisions.md b/memory-bank/decisions.md index b2eaef2..7160c1e 100644 --- a/memory-bank/decisions.md +++ b/memory-bank/decisions.md @@ -1,6 +1,6 @@ # Architectural Decisions -**Last Updated**: 2026-02-18 +**Last Updated**: 2026-03-08 --- @@ -354,3 +354,87 @@ - Manual migration instructions in docs (user friction) **Consequences**: File-based SQLite databases auto-migrate on startup. Zero friction for local users. PostgreSQL still requires `alembic upgrade head`. Lightweight and self-contained. **References**: `src/duh/memory/migrations.py`, `src/duh/cli/app.py:107-110` + +--- + +## 2026-03-08: Native Provider Web Search Over DDG Proxy + +**Status**: Approved +**Context**: The original web search tool used DuckDuckGo as a proxy — every provider's tool calls went through DDG, which returned index pages rather than real content. Most major providers now offer server-side web search that returns higher-quality results with citations. +**Decision**: Add `web_search: bool` parameter to the `ModelProvider.send()` protocol. When `config.tools.web_search.native` is true, each provider uses its native search capability: Anthropic (`web_search_20250305` server tool), Google (`GoogleSearch()` grounding), Mistral (`{"type": "web_search"}`), OpenAI (`web_search_options`), Perplexity (always native). DDG proxy remains as fallback for providers/models that don't support native search. +**Alternatives**: +- DDG-only (simpler, but returns low-quality index pages instead of real content) +- Single search provider for all (e.g., Bing API — adds external dependency and API key) +- Remove web search entirely (loses grounding capability) +**Consequences**: Higher quality search results with real content. Citations extractable from provider responses. Each provider has different native search API shape — increases per-provider complexity. Google grounding and function declarations can't coexist (grounding replaces function tools). +**References**: `src/duh/providers/anthropic.py`, `src/duh/providers/google.py`, `src/duh/providers/mistral.py`, `src/duh/providers/openai.py`, `src/duh/tools/augmented_send.py` + +--- + +## 2026-03-08: Question Refinement Uses Most Expensive Model + +**Status**: Approved +**Context**: Question refinement analyzes user questions before consensus to determine if clarification is needed. The analysis quality directly impacts downstream consensus quality — a poorly refined question wastes all subsequent model calls. +**Decision**: `analyze_question()` and `enrich_question()` in `src/duh/consensus/refine.py` use the most expensive configured model (sorted by cost), not the cheapest. The refinement step is a single model call, so the cost difference is minimal compared to the full consensus round it precedes. +**Alternatives**: +- Cheapest model (saves tokens, but poor analysis leads to poor consensus) +- User-configurable refinement model (adds UX complexity) +- Multi-model refinement (overkill — single strong model is sufficient for question analysis) +**Consequences**: Better question analysis quality. Marginal cost increase (one extra expensive model call). Graceful fallback on failure — original question proceeds to consensus unchanged. +**References**: `src/duh/consensus/refine.py`, `src/duh/api/routes/ask.py`, `src/duh/cli/app.py` + +--- + +## 2026-03-08: Tools Enabled by Default + +**Status**: Approved +**Context**: Web search was originally opt-in. Users who didn't know about the `--tools` flag got ungrounded responses. Most queries benefit from web search grounding. +**Decision**: `web_search` tool is enabled by default across CLI, REST API, and WebSocket paths. The `config.tools.web_search` section controls behavior. Native provider search is preferred when available. +**Alternatives**: +- Opt-in only (simpler, but most users miss it) +- Always-on with no config (inflexible for cost-sensitive users) +- Per-question tool selection (too much UX friction) +**Consequences**: Better default experience — responses are grounded in current information. Slightly higher cost per query (search tool calls). Users can disable via config if needed. +**References**: `src/duh/config/schema.py`, `src/duh/cli/app.py`, `src/duh/api/routes/ws.py` + +--- + +## 2026-03-08: Citation Persistence on Contributions + +**Status**: Approved +**Context**: Citations were emitted over WebSocket during live consensus but never persisted. Viewing a thread later from the Threads section showed no sources — undermining the trust value of native web search. +**Decision**: Add `citations_json` TEXT column to the `Contribution` model (nullable, JSON-encoded list of `{url, title}`). Track `proposal_citations` on `ConsensusContext` and archive to `RoundResult`. Serialize and persist during `_persist_consensus`. Thread detail API returns parsed citations on `ContributionResponse`. ThreadNav shows domain-grouped sources matching ConsensusNav. +**Alternatives**: +- Separate Citation table with FK to Contribution (more normalized, but adds query complexity for marginal benefit) +- Store citations only on Decision (loses per-role attribution) +- Don't persist (simpler, but citations are essential to trust) +**Consequences**: Citations survive beyond the WebSocket session. Thread detail view shows sources grouped by domain with role attribution (P/C/R). SQLite auto-migration handles existing databases. Slightly larger DB rows due to JSON text. +**References**: `src/duh/memory/models.py:146`, `src/duh/api/routes/threads.py`, `src/duh/api/routes/ws.py`, `web/src/components/threads/ThreadNav.tsx` + +--- + +## 2026-03-08: Anthropic Streaming Internally in send() + +**Status**: Approved +**Context**: Increasing `max_tokens` to 32768 triggered Anthropic SDK's 10-minute timeout error: "Streaming is required for operations that may take longer than 10 minutes." The `send()` method used non-streaming `messages.create()`. +**Decision**: `send()` now calls `_collect_stream()` which uses `messages.stream()` as a context manager and collects the final `Message` via `get_final_message()`. The returned object is identical to `messages.create()` output, so all downstream parsing (citations, tool calls, text concatenation) works unchanged. +**Alternatives**: +- Keep non-streaming and lower max_tokens (loses citation content to truncation) +- Full streaming to frontend (larger change, separate concern) +- Increase Anthropic client timeout (fragile, doesn't scale) +**Consequences**: No more timeout errors at any max_tokens value. Test mocks must mock `messages.stream` context manager instead of `messages.create`. Marginal latency increase from stream overhead (negligible vs network time). +**References**: `src/duh/providers/anthropic.py:222-229` + +--- + +## 2026-03-08: Parallel Challenge Streaming via as_completed + +**Status**: Approved +**Context**: Challengers were already running in parallel via `asyncio.gather` in `handle_challenge`, but the WebSocket handler sent all results after ALL challengers finished. Users saw nothing until the slowest challenger responded. +**Decision**: New `_stream_challenges()` function in `ws.py` uses `asyncio.as_completed()` to send each challenge result to the frontend immediately as each completes. Builds `ChallengeResult` objects and updates `ctx.challenges` directly, bypassing `handle_challenge`. +**Alternatives**: +- Keep batched approach (simpler, but poor UX — users wait for slowest model) +- Token-level streaming per challenger (much more complex, requires protocol changes) +- Sequential challengers (defeats the purpose of multi-model) +**Consequences**: First challenger to respond appears immediately. More engaging real-time experience. WS test mocks now patch `_stream_challenges` instead of `handle_challenge`. Challenge order in UI reflects completion speed, not configuration order. +**References**: `src/duh/api/routes/ws.py:253-347`, `tests/unit/test_api_ws.py` diff --git a/memory-bank/progress.md b/memory-bank/progress.md index 8063e38..dd0e044 100644 --- a/memory-bank/progress.md +++ b/memory-bank/progress.md @@ -4,7 +4,36 @@ --- -## Current State: v0.6.0 — "It's Honest" COMPLETE +## Current State: Post v0.6.0 — `question-refinement` Branch In Progress + +### Question Refinement + Native Web Search + Citations (2026-03-08) + +- **Question refinement**: pre-consensus clarification step (analyze → clarify → enrich → consensus) + - `src/duh/consensus/refine.py`, API routes (`/api/refine`, `/api/enrich`), CLI `--refine` flag + - Frontend: `RefinementPanel.tsx` tabbed UI, consensus store `'refining'` status + - Graceful fallback on failure → original question proceeds to consensus +- **Native provider web search**: Anthropic/Google/Mistral/OpenAI/Perplexity use server-side search + - `web_search: bool` param on `ModelProvider.send()` protocol + - `config.tools.web_search.native` flag controls behavior + - DDG proxy still available as fallback for non-native providers +- **Citations**: `Citation` dataclass on `ModelResponse`, extracted per provider, displayed in frontend + - `CitationList` shared component, `ConsensusNav` collapsible Sources sidebar section + - WS events include `citations` array for PROPOSE and CHALLENGE phases +- **Tools enabled by default**: `web_search` (DuckDuckGo) wired through all paths (CLI, REST, WS) +- **Provider tool format fix**: each provider transforms generic tool defs to native API format +- **Sidebar UX**: new-question button + collapsible sidebar toggle +- **Citation persistence**: `citations_json` on Contribution model, SQLite migration, thread detail API returns citations +- **Domain-grouped Sources**: ConsensusNav + ThreadNav group citations by hostname with Disclosure, P/C/R role badges +- **Anthropic streaming**: `send()` uses `_collect_stream()` internally to avoid 10-min timeout on large max_tokens +- **Parallel challenge streaming**: `_stream_challenges()` sends each result to frontend as it completes via `asyncio.as_completed` +- **max_tokens 32768**: bumped from 16384 across all handlers — citations are essential to trust +- 1641 Python tests + 194 Vitest tests (1835 total), build clean + +### Z-index Fix + GPT-5.4 + .env Docs (2026-03-07) + +- Z-index stacking context fix, GPT-5.4 model catalog entry, .env.example provider keys +- Password reset flow, SMTP mail module, JWT-scoped tokens +- 1603 Python tests + 185 Vitest tests (1788 total) ### Consensus Navigation & Collapsible Sections @@ -195,3 +224,9 @@ Phase 0 benchmark framework — fully functional, pilot-tested on 5 questions. | 2026-03-07 | GPT-5.4 added to model catalog (1M ctx, $2.50/$15.00, no-temperature) | Done | | 2026-03-07 | .env.example updated with provider API key placeholders | Done | | 2026-03-07 | README updated with all provider env vars | Done | +| 2026-03-08 | Question refinement (analyze → clarify → enrich → consensus) | In Progress | +| 2026-03-08 | Native provider web search (Anthropic/Google/Mistral/OpenAI/Perplexity) | In Progress | +| 2026-03-08 | Citations extraction + frontend CitationList + ConsensusNav Sources | In Progress | +| 2026-03-08 | Tools enabled by default (web_search wired through CLI/REST/WS) | In Progress | +| 2026-03-08 | Provider tool format fix (generic → native transform per provider) | In Progress | +| 2026-03-08 | Sidebar UX (new-question button, collapsible toggle) | In Progress | diff --git a/memory-bank/techContext.md b/memory-bank/techContext.md index a88aa0d..4799ad1 100644 --- a/memory-bank/techContext.md +++ b/memory-bank/techContext.md @@ -1,6 +1,6 @@ # Technical Context -**Last Updated**: 2026-02-17 +**Last Updated**: 2026-03-08 --- @@ -124,6 +124,61 @@ Most local providers already speak the OpenAI-compatible API, so the Ollama/LM S - FastAPI mounts `web/dist/` as static files with SPA fallback route - Docker: Node.js 22 build stage copies dist/ to runtime image +## Native Provider Web Search (2026-03-08) + +**Rationale**: DuckDuckGo proxy returned index pages, not real content. Major providers now offer server-side search with higher quality results and extractable citations. + +- `web_search: bool` param on `ModelProvider.send()` protocol +- `config.tools.web_search.native` flag (default: true) +- Per-provider implementation: + - **Anthropic**: `web_search_20250305` server tool in tools[] — skip generic→native transform for entries with `type` key + - **Google**: `GoogleSearch()` grounding replaces function tools (can't coexist with `function_declarations`) + - **Mistral**: `{"type": "web_search"}` appended to tools list + - **OpenAI**: `web_search_options={}` only for `_SEARCH_MODELS` set; standard models fall back to DDG + - **Perplexity**: no-op (always searches natively) +- `tool_augmented_send`: filters DDG `web_search` tool when native=True, passes `web_search` flag to provider + +## Citations (2026-03-08) + +**Rationale**: Native web search returns structured citation data. Surfacing sources improves trust and verifiability. + +- `Citation` dataclass (url, title, snippet) on `ModelResponse.citations` +- Extraction per provider: + - **Anthropic**: `web_search_tool_result` blocks → `web_search_result` entries + - **Google**: `candidate.grounding_metadata.grounding_chunks[].web` → uri/title + - **Perplexity**: `response.citations` list (strings or objects) +- WebSocket: `phase_complete` (PROPOSE) and `challenge` events include `citations` array +- Frontend: `CitationList` shared component (numbered deduped links, hostname fallback), `ConsensusNav` collapsible Sources + +## Question Refinement (2026-03-08) + +**Rationale**: Ambiguous or underspecified questions produce low-quality consensus. A pre-consensus refinement step catches missing context. + +- `src/duh/consensus/refine.py` — `analyze_question()` + `enrich_question()` +- Uses most expensive configured model (quality matters more than cost for a single call) +- API: `POST /api/refine`, `POST /api/enrich` +- CLI: `duh ask --refine "question"` (default `--no-refine`) +- Frontend: consensus store `'refining'` status → `RefinementPanel.tsx` (tabbed UI, Skip button) +- Graceful fallback: any failure → proceed with original question + +## Anthropic Streaming Internals (2026-03-08) + +**Rationale**: Anthropic SDK requires streaming for requests that may exceed 10 minutes. With `max_tokens=32768`, non-streaming `messages.create()` hit this limit. + +- `AnthropicProvider.send()` now calls `_collect_stream()` internally +- Uses `messages.stream()` context manager → `get_final_message()` returns identical `Message` object +- All downstream parsing (citations, tool calls, text blocks) unchanged +- Test mocks must mock `messages.stream` (async context manager), not `messages.create` + +## Parallel Challenge Streaming (2026-03-08) + +**Rationale**: Challengers ran in parallel but results were batched. Users saw nothing until the slowest challenger responded. + +- `_stream_challenges()` in `ws.py` uses `asyncio.as_completed()` +- Each challenge result sent to WebSocket immediately as it finishes +- Builds `ChallengeResult` objects and updates `ctx.challenges` directly +- WS test mocks now patch `_stream_challenges` instead of `handle_challenge` + ## Key Technical Patterns ### Async-First @@ -153,10 +208,10 @@ React app built by Vite to `web/dist/`, served by FastAPI as static files. SPA f ## Decided (v0.4) - **Project structure**: `src/duh/` with cli/, consensus/, providers/, memory/, config/, core/, tools/, api/, mcp/ + `web/` with src/, theme, api, stores, components, pages -- **Testing**: pytest + pytest-asyncio + pytest-cov (1318 tests), Vitest + @testing-library/react (117 tests), asyncio_mode=auto +- **Testing**: pytest + pytest-asyncio + pytest-cov (1641 tests), Vitest + @testing-library/react (194 tests), asyncio_mode=auto - **CI/CD**: GitHub Actions (lint, typecheck, test) + docs deployment to GitHub Pages - **Provider interface**: `typing.Protocol` (structural typing), stateless adapters -- **5 providers shipping**: Anthropic (3 models), OpenAI (3 models), Google (4 models), Mistral (4 models) — 14 total +- **6 providers shipping**: Anthropic, OpenAI, Google, Mistral, Perplexity — with native web search per provider - **Memory schema**: SQLAlchemy ORM — Thread, Turn, Contribution, TurnSummary, ThreadSummary, Decision, Outcome, Subtask, Vote, APIKey - **Configuration**: TOML with Pydantic validation, layered merge (defaults < user < project < env < CLI) - **Error handling**: DuhError hierarchy with ProviderError, ConsensusError, ConfigError, StorageError @@ -175,7 +230,11 @@ React app built by Vite to `web/dist/`, served by FastAPI as static files. SPA f - **Markdown rendering**: react-markdown + remark-gfm + rehype-highlight in `Markdown` shared component - **Light/dark mode**: `prefers-color-scheme` auto-detection + `.theme-dark`/`.theme-light` manual override classes - **Documentation**: MkDocs Material, deployed to GitHub Pages -- **50 Python source files + 66 frontend source files**, mypy strict clean, ruff clean, 0 TS errors +- **~63 Python source files + ~81 frontend source files**, mypy strict clean, ruff clean, 0 TS errors +- **Question refinement**: pre-consensus clarification step (analyze → clarify → enrich → consensus) +- **Native web search**: per-provider server-side search with citation extraction +- **Citations**: `Citation` dataclass, `CitationList` component, ConsensusNav Sources sidebar +- **Tools by default**: web_search enabled across CLI/REST/WS paths ## Dependencies diff --git a/memory-bank/toc.md b/memory-bank/toc.md index cc4a0b7..66f86ef 100644 --- a/memory-bank/toc.md +++ b/memory-bank/toc.md @@ -3,8 +3,8 @@ ## Core Files - [projectbrief.md](./projectbrief.md) — Vision, tenets, architecture, build sequence - [techContext.md](./techContext.md) — Tech stack decisions with rationale (Python, Docker, SQLAlchemy, frontend, tools, etc.) -- [decisions.md](./decisions.md) — Architectural decisions with context, alternatives, and consequences (20 ADRs) -- [activeContext.md](./activeContext.md) — Current state, epistemic confidence Phase A complete +- [decisions.md](./decisions.md) — Architectural decisions with context, alternatives, and consequences (26 ADRs) +- [activeContext.md](./activeContext.md) — Current state, question-refinement branch in progress - [progress.md](./progress.md) — Milestone tracking, what's built, what's next - [competitive-landscape.md](./competitive-landscape.md) — Research on existing tools, frameworks, and academic work - [quick-start.md](./quick-start.md) — Session entry point, v0.5 complete, key file references diff --git a/src/duh/api/app.py b/src/duh/api/app.py index c111c84..2bdf52e 100644 --- a/src/duh/api/app.py +++ b/src/duh/api/app.py @@ -20,7 +20,7 @@ @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncIterator[None]: """Lifespan handler: set up DB + providers on startup, tear down on shutdown.""" - from duh.cli.app import _create_db, _setup_providers + from duh.cli.app import _create_db, _setup_providers, _setup_tools config: DuhConfig = app.state.config factory, engine = await _create_db(config) @@ -29,6 +29,7 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]: app.state.db_factory = factory app.state.engine = engine app.state.provider_manager = pm + app.state.tool_registry = _setup_tools(config) extra = getattr(app.state, "extra_lifespan", None) if extra is not None: diff --git a/src/duh/api/routes/ask.py b/src/duh/api/routes/ask.py index 5e2c98f..4768669 100644 --- a/src/duh/api/routes/ask.py +++ b/src/duh/api/routes/ask.py @@ -15,6 +15,25 @@ router = APIRouter(prefix="/api", tags=["consensus"]) +class RefineRequest(BaseModel): + question: str + max_questions: int = 4 + + +class RefineResponse(BaseModel): + needs_refinement: bool + questions: list[dict[str, str | None]] = [] + + +class EnrichRequest(BaseModel): + original_question: str + clarifications: list[dict[str, str]] + + +class EnrichResponse(BaseModel): + enriched_question: str + + class AskRequest(BaseModel): question: str protocol: str = "consensus" # consensus, voting, auto @@ -46,6 +65,7 @@ async def ask(body: AskRequest, request: Request) -> AskResponse | JSONResponse: config.general.max_rounds = body.rounds db_factory = getattr(request.app.state, "db_factory", None) + tool_registry = getattr(request.app.state, "tool_registry", None) try: if body.decompose: @@ -55,7 +75,7 @@ async def ask(body: AskRequest, request: Request) -> AskResponse | JSONResponse: return await _handle_voting(body, config, pm) # Default: consensus - return await _handle_consensus(body, config, pm, db_factory) + return await _handle_consensus(body, config, pm, db_factory, tool_registry) except ProviderError as exc: logger.exception("Provider error during /api/ask") @@ -78,18 +98,21 @@ async def ask(body: AskRequest, request: Request) -> AskResponse | JSONResponse: async def _handle_consensus( # type: ignore[no-untyped-def] - body: AskRequest, config, pm, db_factory=None + body: AskRequest, config, pm, db_factory=None, tool_registry=None ) -> AskResponse: """Run the consensus protocol.""" from duh.cli.app import _run_consensus + use_native_search = config.tools.enabled and config.tools.web_search.native decision, confidence, rigor, dissent, cost, _overview = await _run_consensus( body.question, config, pm, + tool_registry=tool_registry, panel=body.panel, proposer_override=body.proposer, challengers_override=body.challengers, + web_search=use_native_search, ) thread_id: str | None = None @@ -203,3 +226,26 @@ async def _persist_result( ) await session.commit() return str(thread.id) + + +@router.post("/refine", response_model=RefineResponse) +async def refine(body: RefineRequest, request: Request) -> RefineResponse: + """Analyze a question for ambiguity and suggest clarifications.""" + from duh.consensus.refine import analyze_question + + pm = request.app.state.provider_manager + result = await analyze_question(body.question, pm, max_questions=body.max_questions) + return RefineResponse( + needs_refinement=result.get("needs_refinement", False), + questions=result.get("questions", []), + ) + + +@router.post("/enrich", response_model=EnrichResponse) +async def enrich(body: EnrichRequest, request: Request) -> EnrichResponse: + """Rewrite a question incorporating clarification answers.""" + from duh.consensus.refine import enrich_question + + pm = request.app.state.provider_manager + enriched = await enrich_question(body.original_question, body.clarifications, pm) + return EnrichResponse(enriched_question=enriched) diff --git a/src/duh/api/routes/threads.py b/src/duh/api/routes/threads.py index ff47275..1fcb537 100644 --- a/src/duh/api/routes/threads.py +++ b/src/duh/api/routes/threads.py @@ -3,6 +3,7 @@ from __future__ import annotations import io +import json from fastapi import APIRouter, HTTPException, Query, Request from fastapi.responses import StreamingResponse @@ -11,6 +12,11 @@ router = APIRouter(prefix="/api", tags=["threads"]) +class CitationResponse(BaseModel): + url: str + title: str | None = None + + class ContributionResponse(BaseModel): model_ref: str role: str @@ -18,6 +24,7 @@ class ContributionResponse(BaseModel): input_tokens: int = 0 output_tokens: int = 0 cost_usd: float = 0.0 + citations: list[CitationResponse] | None = None class DecisionResponse(BaseModel): @@ -116,43 +123,7 @@ async def get_thread(thread_id: str, request: Request) -> ThreadDetailResponse: if thread is None: raise HTTPException(status_code=404, detail=f"Thread not found: {thread_id}") - turns = [] - for turn in thread.turns: - contribs = [ - ContributionResponse( - model_ref=c.model_ref, - role=c.role, - content=c.content, - input_tokens=c.input_tokens, - output_tokens=c.output_tokens, - cost_usd=c.cost_usd, - ) - for c in turn.contributions - ] - dec = None - if turn.decision: - dec = DecisionResponse( - content=turn.decision.content, - confidence=turn.decision.confidence, - rigor=turn.decision.rigor, - dissent=turn.decision.dissent, - ) - turns.append( - TurnResponse( - round_number=turn.round_number, - state=turn.state, - contributions=contribs, - decision=dec, - ) - ) - - return ThreadDetailResponse( - thread_id=thread.id, - question=thread.question, - status=thread.status, - created_at=thread.created_at.isoformat(), - turns=turns, - ) + return _build_thread_detail(thread) @router.get("/share/{share_token}", response_model=ThreadDetailResponse) @@ -171,8 +142,24 @@ async def get_shared_thread(share_token: str, request: Request) -> ThreadDetailR detail=f"Shared thread not found: {share_token}", ) + return _build_thread_detail(thread) + + +def _parse_citations(raw: str | None) -> list[CitationResponse] | None: + """Parse JSON-encoded citations from a contribution.""" + if not raw: + return None + try: + items = json.loads(raw) + return [CitationResponse(url=c["url"], title=c.get("title")) for c in items] + except (json.JSONDecodeError, KeyError, TypeError): + return None + + +def _build_thread_detail(thread: object) -> ThreadDetailResponse: + """Build a ThreadDetailResponse from a Thread ORM object.""" turns = [] - for turn in thread.turns: + for turn in thread.turns: # type: ignore[attr-defined] contribs = [ ContributionResponse( model_ref=c.model_ref, @@ -181,6 +168,7 @@ async def get_shared_thread(share_token: str, request: Request) -> ThreadDetailR input_tokens=c.input_tokens, output_tokens=c.output_tokens, cost_usd=c.cost_usd, + citations=_parse_citations(getattr(c, "citations_json", None)), ) for c in turn.contributions ] @@ -202,10 +190,10 @@ async def get_shared_thread(share_token: str, request: Request) -> ThreadDetailR ) return ThreadDetailResponse( - thread_id=thread.id, - question=thread.question, - status=thread.status, - created_at=thread.created_at.isoformat(), + thread_id=thread.id, # type: ignore[attr-defined] + question=thread.question, # type: ignore[attr-defined] + status=thread.status, # type: ignore[attr-defined] + created_at=thread.created_at.isoformat(), # type: ignore[attr-defined] turns=turns, ) diff --git a/src/duh/api/routes/ws.py b/src/duh/api/routes/ws.py index 2f094f1..eb27d38 100644 --- a/src/duh/api/routes/ws.py +++ b/src/duh/api/routes/ws.py @@ -2,8 +2,9 @@ from __future__ import annotations +import json import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from fastapi import APIRouter, WebSocket, WebSocketDisconnect @@ -11,6 +12,7 @@ from duh.config.schema import DuhConfig from duh.consensus.machine import RoundResult from duh.providers.manager import ProviderManager + from duh.tools.registry import ToolRegistry logger = logging.getLogger(__name__) @@ -63,6 +65,8 @@ async def ws_ask(websocket: WebSocket) -> None: pm: ProviderManager = websocket.app.state.provider_manager config.general.max_rounds = rounds + tool_registry = getattr(websocket.app.state, "tool_registry", None) + await _stream_consensus( websocket, question, @@ -71,6 +75,7 @@ async def ws_ask(websocket: WebSocket) -> None: panel=panel, proposer_override=proposer_override, challengers_override=challengers_raw, + tool_registry=tool_registry, ) except WebSocketDisconnect: @@ -93,12 +98,12 @@ async def _stream_consensus( panel: list[str] | None = None, proposer_override: str | None = None, challengers_override: list[str] | None = None, + tool_registry: ToolRegistry | None = None, ) -> None: """Run consensus loop and stream events to WebSocket.""" from duh.consensus.convergence import check_convergence from duh.consensus.handlers import ( generate_overview, - handle_challenge, handle_commit, handle_propose, handle_revise, @@ -119,6 +124,7 @@ async def _stream_consensus( sm = ConsensusStateMachine(ctx) effective_panel = panel or config.consensus.panel or None + use_native_search = config.tools.enabled and config.tools.web_search.native for _round in range(config.general.max_rounds): # PROPOSE @@ -132,17 +138,28 @@ async def _stream_consensus( "round": ctx.current_round, } ) - propose_resp = await handle_propose(ctx, pm, proposer) + propose_resp = await handle_propose( + ctx, + pm, + proposer, + tool_registry=tool_registry, + web_search=use_native_search, + ) + propose_citations = [ + {"url": c.url, "title": c.title} for c in (propose_resp.citations or []) + ] + ctx.proposal_citations = propose_citations await ws.send_json( { "type": "phase_complete", "phase": "PROPOSE", "content": ctx.proposal or "", "truncated": propose_resp.finish_reason != "stop", + "citations": propose_citations if propose_citations else None, } ) - # CHALLENGE + # CHALLENGE — fan out in parallel, stream each result as it arrives sm.transition(ConsensusState.CHALLENGE) challengers = challengers_override or select_challengers( pm, proposer, panel=effective_panel @@ -155,29 +172,14 @@ async def _stream_consensus( "round": ctx.current_round, } ) - challenge_resps = await handle_challenge(ctx, pm, challengers) - succeeded = {ch.model_ref for ch in ctx.challenges} - for i, ch in enumerate(ctx.challenges): - resp_truncated = ( - i < len(challenge_resps) and challenge_resps[i].finish_reason != "stop" - ) - await ws.send_json( - { - "type": "challenge", - "model": ch.model_ref, - "content": ch.content, - "truncated": resp_truncated, - } - ) - # Notify about challengers that failed - for ref in challengers: - if ref not in succeeded: - await ws.send_json( - { - "type": "challenge_error", - "model": ref, - } - ) + await _stream_challenges( + ws, + ctx, + pm, + challengers, + tool_registry=tool_registry, + web_search=use_native_search, + ) await ws.send_json({"type": "phase_complete", "phase": "CHALLENGE"}) # REVISE @@ -248,6 +250,98 @@ async def _stream_consensus( await ws.close() +async def _stream_challenges( + ws: WebSocket, + ctx: object, + pm: object, + challengers: list[str], + *, + tool_registry: object | None = None, + web_search: bool = False, +) -> None: + """Run challengers in parallel, streaming each result to WS as it arrives. + + Updates ``ctx.challenges`` with results. + """ + import asyncio + + from duh.consensus.handlers import ( + _FRAMING_ORDER, + _call_challenger, + detect_sycophancy, + ) + from duh.consensus.machine import ChallengeResult + + async def _run(idx: int, ref: str) -> tuple[int, tuple[str, str, Any]]: + result = await _call_challenger( + ctx, # type: ignore[arg-type] + pm, # type: ignore[arg-type] + ref, + _FRAMING_ORDER[idx % len(_FRAMING_ORDER)], + temperature=0.7, + max_tokens=32768, + tool_registry=tool_registry, # type: ignore[arg-type] + web_search=web_search, + ) + return idx, result + + tasks = [asyncio.create_task(_run(i, ref)) for i, ref in enumerate(challengers)] + + challenges: list[ChallengeResult] = [] + + for coro in asyncio.as_completed(tasks): + try: + _idx, (model_ref, framing, response) = await coro + citation_dicts = tuple( + { + "url": c.url, + "title": c.title, + "snippet": c.snippet, + } + for c in (response.citations or []) + ) + ch = ChallengeResult( + model_ref=model_ref, + content=response.content, + sycophantic=detect_sycophancy(response.content), + framing=framing, + citations=citation_dicts, + ) + challenges.append(ch) + + # Stream to client immediately + ch_citations = ( + [{"url": c["url"], "title": c.get("title")} for c in ch.citations] + if ch.citations + else None + ) + await ws.send_json( + { + "type": "challenge", + "model": ch.model_ref, + "content": ch.content, + "truncated": response.finish_reason != "stop", + "citations": ch_citations, + } + ) + except Exception: + logger.warning("Challenger failed", exc_info=True) + + # Report failures + succeeded = {ch.model_ref for ch in challenges} + for ref in challengers: + if ref not in succeeded: + await ws.send_json({"type": "challenge_error", "model": ref}) + + if not challenges: + from duh.core.errors import ConsensusError + + msg = "All challengers failed" + raise ConsensusError(msg) + + ctx.challenges = challenges # type: ignore[attr-defined] + + async def _persist_consensus( db_factory: object, question: str, @@ -267,12 +361,36 @@ async def _persist_consensus( for rr in round_history: turn = await repo.create_turn(thread.id, rr.round_number, "COMMIT") + proposal_cit = None + if rr.proposal_citations: + proposal_cit = json.dumps( + [ + {"url": c["url"], "title": c.get("title")} + for c in rr.proposal_citations + ] + ) await repo.add_contribution( - turn.id, rr.proposal_model, "proposer", rr.proposal + turn.id, + rr.proposal_model, + "proposer", + rr.proposal, + citations_json=proposal_cit, ) for ch in rr.challenges: + ch_cit = None + if ch.citations: + ch_cit = json.dumps( + [ + {"url": c["url"], "title": c.get("title")} + for c in ch.citations + ] + ) await repo.add_contribution( - turn.id, ch.model_ref, "challenger", ch.content + turn.id, + ch.model_ref, + "challenger", + ch.content, + citations_json=ch_cit, ) await repo.add_contribution( turn.id, rr.proposal_model, "reviser", rr.revision diff --git a/src/duh/cli/app.py b/src/duh/cli/app.py index c222e2b..1decb4d 100644 --- a/src/duh/cli/app.py +++ b/src/duh/cli/app.py @@ -209,6 +209,7 @@ async def _run_consensus( panel: list[str] | None = None, proposer_override: str | None = None, challengers_override: list[str] | None = None, + web_search: bool = False, ) -> tuple[str, float, float, str | None, float, str | None]: """Run the full consensus loop. @@ -249,10 +250,22 @@ async def _run_consensus( proposer = proposer_override or select_proposer(pm, panel=effective_panel) if display: with display.phase_status("PROPOSE", proposer): - await handle_propose(ctx, pm, proposer, tool_registry=tool_registry) + await handle_propose( + ctx, + pm, + proposer, + tool_registry=tool_registry, + web_search=web_search, + ) display.show_propose(proposer, ctx.proposal or "") else: - await handle_propose(ctx, pm, proposer, tool_registry=tool_registry) + await handle_propose( + ctx, + pm, + proposer, + tool_registry=tool_registry, + web_search=web_search, + ) # CHALLENGE sm.transition(ConsensusState.CHALLENGE) @@ -263,11 +276,21 @@ async def _run_consensus( detail = f"{len(challengers)} models" with display.phase_status("CHALLENGE", detail): await handle_challenge( - ctx, pm, challengers, tool_registry=tool_registry + ctx, + pm, + challengers, + tool_registry=tool_registry, + web_search=web_search, ) display.show_challenges(ctx.challenges) else: - await handle_challenge(ctx, pm, challengers, tool_registry=tool_registry) + await handle_challenge( + ctx, + pm, + challengers, + tool_registry=tool_registry, + web_search=web_search, + ) # REVISE sm.transition(ConsensusState.REVISE) @@ -386,6 +409,11 @@ def cli(ctx: click.Context, config_path: str | None) -> None: default=None, help="Restrict to these models only (comma-separated model refs).", ) +@click.option( + "--refine/--no-refine", + default=False, + help="Pre-consensus question refinement (ask clarifying questions).", +) @click.pass_context def ask( ctx: click.Context, @@ -397,6 +425,7 @@ def ask( proposer: str | None, challengers: str | None, panel: str | None, + refine: bool, ) -> None: """Run a consensus query. @@ -415,6 +444,14 @@ def ask( panel_list = panel.split(",") if panel else None challengers_list = challengers.split(",") if challengers else None + # Question refinement (pre-consensus clarification) + if refine: + try: + question = asyncio.run(_refine_question(question, config)) + except DuhError as e: + _error(str(e)) + return + # Determine effective protocol effective_protocol = protocol or config.general.protocol @@ -463,6 +500,31 @@ def ask( ) +async def _refine_question(question: str, config: DuhConfig) -> str: + """Run question refinement interactively on the CLI.""" + from duh.consensus.refine import analyze_question, enrich_question + + pm = await _setup_providers(config) + if not pm.list_all_models(): + return question + + result = await analyze_question(question, pm) + if not result.get("needs_refinement"): + return question + + questions = result.get("questions", []) + click.echo("\nClarifying questions:") + clarifications = [] + for q in questions: + hint = f" ({q['hint']})" if q.get("hint") else "" + answer = click.prompt(f" {q['question']}{hint}") + clarifications.append({"question": q["question"], "answer": answer}) + + enriched = await enrich_question(question, clarifications, pm) + click.echo(f"\nRefined question: {enriched}\n") + return enriched + + async def _ask_async( question: str, config: DuhConfig, @@ -483,6 +545,7 @@ async def _ask_async( ) tool_registry = _setup_tools(config) + use_native_search = config.tools.enabled and config.tools.web_search.native display = ConsensusDisplay() display.start() return await _run_consensus( @@ -494,6 +557,7 @@ async def _ask_async( panel=panel, proposer_override=proposer_override, challengers_override=challengers_override, + web_search=use_native_search, ) diff --git a/src/duh/config/schema.py b/src/duh/config/schema.py index 4c5294c..1c3f8d5 100644 --- a/src/duh/config/schema.py +++ b/src/duh/config/schema.py @@ -61,6 +61,7 @@ class WebSearchConfig(BaseModel): backend: str = "duckduckgo" api_key: str | None = None max_results: int = 5 + native: bool = True class CodeExecutionConfig(BaseModel): @@ -74,7 +75,7 @@ class CodeExecutionConfig(BaseModel): class ToolsConfig(BaseModel): """Tool framework configuration.""" - enabled: bool = False + enabled: bool = True max_rounds: int = 5 web_search: WebSearchConfig = Field(default_factory=WebSearchConfig) code_execution: CodeExecutionConfig = Field(default_factory=CodeExecutionConfig) diff --git a/src/duh/consensus/handlers.py b/src/duh/consensus/handlers.py index c3f6d4d..848845a 100644 --- a/src/duh/consensus/handlers.py +++ b/src/duh/consensus/handlers.py @@ -166,7 +166,7 @@ def _token_budget_note(max_tokens: int) -> str: def build_propose_prompt( - ctx: ConsensusContext, *, max_tokens: int = 16384 + ctx: ConsensusContext, *, max_tokens: int = 32768 ) -> list[PromptMessage]: """Build prompt messages for the PROPOSE phase. @@ -259,8 +259,9 @@ async def handle_propose( model_ref: str, *, temperature: float = 0.7, - max_tokens: int = 16384, + max_tokens: int = 32768, tool_registry: ToolRegistry | None = None, + web_search: bool = False, ) -> ModelResponse: """Execute the PROPOSE phase of consensus. @@ -308,11 +309,16 @@ async def handle_propose( tool_registry, max_tokens=max_tokens, temperature=temperature, + web_search=web_search, ) _log_tool_calls(ctx, response, "propose") else: response = await provider.send( - messages, model_id, max_tokens=max_tokens, temperature=temperature + messages, + model_id, + max_tokens=max_tokens, + temperature=temperature, + web_search=web_search, ) # Record cost @@ -333,7 +339,7 @@ def build_challenge_prompt( ctx: ConsensusContext, framing: str = "flaw", *, - max_tokens: int = 16384, + max_tokens: int = 32768, ) -> list[PromptMessage]: """Build prompt messages for the CHALLENGE phase. @@ -459,6 +465,7 @@ async def _call_challenger( temperature: float, max_tokens: int, tool_registry: ToolRegistry | None = None, + web_search: bool = False, ) -> tuple[str, str, ModelResponse]: """Call a single challenger model. @@ -477,11 +484,16 @@ async def _call_challenger( tool_registry, max_tokens=max_tokens, temperature=temperature, + web_search=web_search, ) _log_tool_calls(ctx, response, "challenge") else: response = await provider.send( - messages, model_id, max_tokens=max_tokens, temperature=temperature + messages, + model_id, + max_tokens=max_tokens, + temperature=temperature, + web_search=web_search, ) model_info = provider_manager.get_model_info(model_ref) @@ -495,8 +507,9 @@ async def handle_challenge( challenger_models: list[str], *, temperature: float = 0.7, - max_tokens: int = 16384, + max_tokens: int = 32768, tool_registry: ToolRegistry | None = None, + web_search: bool = False, ) -> list[ModelResponse]: """Execute the CHALLENGE phase of consensus. @@ -543,6 +556,7 @@ async def handle_challenge( temperature=temperature, max_tokens=max_tokens, tool_registry=tool_registry, + web_search=web_search, ) for i, ref in enumerate(challenger_models) ] @@ -557,12 +571,17 @@ async def handle_challenge( logger.warning("Challenger %s failed: %s", failed_ref, result) continue model_ref, framing, response = result + citation_dicts = tuple( + {"url": c.url, "title": c.title, "snippet": c.snippet} + for c in (response.citations or []) + ) challenges.append( ChallengeResult( model_ref=model_ref, content=response.content, sycophantic=detect_sycophancy(response.content), framing=framing, + citations=citation_dicts, ) ) responses.append(response) @@ -579,7 +598,7 @@ async def handle_challenge( def build_revise_prompt( - ctx: ConsensusContext, *, max_tokens: int = 16384 + ctx: ConsensusContext, *, max_tokens: int = 32768 ) -> list[PromptMessage]: """Build prompt messages for the REVISE phase. @@ -611,7 +630,7 @@ async def handle_revise( model_ref: str | None = None, *, temperature: float = 0.7, - max_tokens: int = 16384, + max_tokens: int = 32768, ) -> ModelResponse: """Execute the REVISE phase of consensus. diff --git a/src/duh/consensus/machine.py b/src/duh/consensus/machine.py index a5bb5f6..941dbf5 100644 --- a/src/duh/consensus/machine.py +++ b/src/duh/consensus/machine.py @@ -41,6 +41,7 @@ class ChallengeResult: content: str sycophantic: bool = False framing: str = "" + citations: tuple[dict[str, str | None], ...] = () @dataclass(frozen=True, slots=True) @@ -56,6 +57,7 @@ class RoundResult: confidence: float rigor: float = 0.0 dissent: str | None = None + proposal_citations: tuple[dict[str, str | None], ...] = () @dataclass(frozen=True, slots=True) @@ -86,6 +88,7 @@ class ConsensusContext: # Current round working data (cleared between rounds) proposal: str | None = None proposal_model: str | None = None + proposal_citations: list[dict[str, str | None]] = field(default_factory=list) challenges: list[ChallengeResult] = field(default_factory=list) revision: str | None = None revision_model: str | None = None @@ -115,6 +118,7 @@ def _clear_round_data(self) -> None: """Reset working data for a new round.""" self.proposal = None self.proposal_model = None + self.proposal_citations = [] self.challenges = [] self.revision = None self.revision_model = None @@ -137,6 +141,7 @@ def _archive_round(self) -> None: confidence=self.confidence, rigor=self.rigor, dissent=self.dissent, + proposal_citations=tuple(self.proposal_citations), ) ) diff --git a/src/duh/consensus/refine.py b/src/duh/consensus/refine.py new file mode 100644 index 0000000..5a96325 --- /dev/null +++ b/src/duh/consensus/refine.py @@ -0,0 +1,156 @@ +"""Question refinement: pre-consensus clarification step. + +Uses the most capable (most expensive) model to evaluate whether a +question is ambiguous and, if so, generate clarifying questions. A +second call rewrites the original question incorporating the user's +answers. This is the user's first impression — it must be exceptional. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any + +from duh.consensus.json_extract import JSONExtractionError, extract_json +from duh.providers.base import PromptMessage + +if TYPE_CHECKING: + from duh.providers.manager import ProviderManager + +logger = logging.getLogger(__name__) + + +async def analyze_question( + question: str, + provider_manager: ProviderManager, + *, + max_questions: int = 4, +) -> dict[str, Any]: + """Evaluate whether *question* needs clarification. + + Returns ``{"needs_refinement": false}`` when the question is specific + enough, or ``{"needs_refinement": true, "questions": [...]}`` with up + to *max_questions* clarifying questions otherwise. + + On any failure (no models, JSON parse error, provider error) the + function returns ``{"needs_refinement": false}`` so consensus can + proceed uninterrupted. + """ + models = provider_manager.list_all_models() + if not models: + return {"needs_refinement": False} + + best = max(models, key=lambda m: m.input_cost_per_mtok) + provider, model_id = provider_manager.get_provider(best.model_ref) + + prompt = ( + "You are an expert question analyst and strategic thinker. Your job " + "is to determine whether a question contains enough context for a " + "panel of experts to give a truly excellent, specific, actionable " + "answer — or whether critical context is missing.\n\n" + "Think deeply: what assumptions would an expert panel have to make? " + "If those assumptions could lead to fundamentally different answers, " + "the question needs refinement.\n\n" + "Consider missing context such as: scale, budget, team size/expertise, " + "timeline, technical constraints, use-case, existing infrastructure, " + "success criteria, risk tolerance, regulatory requirements, or " + "geographic/market context.\n\n" + f"Question: {question}\n\n" + "Return ONLY a JSON object. If the question is already specific " + "enough for expert-quality advice:\n" + '{"needs_refinement": false}\n\n' + "If clarification would meaningfully improve the answer, return:\n" + '{"needs_refinement": true, "questions": [\n' + ' {"question": "...", "hint": "brief guidance on what kind of answer helps"}\n' + "]}\n\n" + f"Include at most {max_questions} questions. Each should be concise, " + "focused on one critical missing dimension, and phrased in a way that " + "feels natural and respectful — like a senior consultant clarifying " + "scope before giving advice. Only ask questions whose answers would " + "materially change the recommendation." + ) + + try: + response = await provider.send( + [PromptMessage(role="user", content=prompt)], + model_id, + max_tokens=500, + temperature=0.3, + response_format="json", + ) + data = extract_json(response.content) + provider_manager.record_usage(best, response.usage) + + if not data.get("needs_refinement"): + return {"needs_refinement": False} + + questions = data.get("questions", []) + if not isinstance(questions, list) or not questions: + return {"needs_refinement": False} + + # Normalise and cap + clean: list[dict[str, str | None]] = [] + for q in questions[:max_questions]: + if isinstance(q, dict) and q.get("question"): + clean.append( + { + "question": str(q["question"]), + "hint": str(q["hint"]) if q.get("hint") else None, + } + ) + + if not clean: + return {"needs_refinement": False} + + return {"needs_refinement": True, "questions": clean} + + except (JSONExtractionError, Exception): + logger.debug("Question refinement analysis failed, skipping", exc_info=True) + return {"needs_refinement": False} + + +async def enrich_question( + original: str, + clarifications: list[dict[str, str]], + provider_manager: ProviderManager, +) -> str: + """Rewrite *original* incorporating clarification answers. + + Each entry in *clarifications* has ``question`` and ``answer`` keys. + Returns the enriched question string, or the original on failure. + """ + models = provider_manager.list_all_models() + if not models: + return original + + best = max(models, key=lambda m: m.input_cost_per_mtok) + provider, model_id = provider_manager.get_provider(best.model_ref) + + qa_block = "\n".join( + f"Q: {c['question']}\nA: {c['answer']}" for c in clarifications + ) + + prompt = ( + "Rewrite the following question into a single, specific, " + "self-contained question that incorporates all the additional " + "context provided below. Keep the rewritten question natural and " + "concise — do not repeat the clarifications verbatim, just weave " + "the context in.\n\n" + f"Original question: {original}\n\n" + f"Additional context:\n{qa_block}\n\n" + "Return ONLY the rewritten question, nothing else." + ) + + try: + response = await provider.send( + [PromptMessage(role="user", content=prompt)], + model_id, + max_tokens=500, + temperature=0.3, + ) + provider_manager.record_usage(best, response.usage) + enriched = response.content.strip() + return enriched if enriched else original + except Exception: + logger.debug("Question enrichment failed, using original", exc_info=True) + return original diff --git a/src/duh/memory/migrations.py b/src/duh/memory/migrations.py index 296e2e1..be2f8f6 100644 --- a/src/duh/memory/migrations.py +++ b/src/duh/memory/migrations.py @@ -49,6 +49,14 @@ async def ensure_schema(engine: AsyncEngine) -> None: "ALTER TABLE users ADD COLUMN is_guest BOOLEAN DEFAULT 0" ) + # ── contributions table ── + contrib_cols = await _get_columns(conn, "contributions") + if "citations_json" not in contrib_cols: + logger.info("Adding 'citations_json' column to contributions table") + await conn.exec_driver_sql( + "ALTER TABLE contributions ADD COLUMN citations_json TEXT DEFAULT NULL" + ) + # ── threads table ── thread_cols = await _get_columns(conn, "threads") if "is_public" not in thread_cols: diff --git a/src/duh/memory/models.py b/src/duh/memory/models.py index c93ca4d..56286d0 100644 --- a/src/duh/memory/models.py +++ b/src/duh/memory/models.py @@ -143,6 +143,9 @@ class Contribution(Base): output_tokens: Mapped[int] = mapped_column(Integer, default=0) cost_usd: Mapped[float] = mapped_column(Float, default=0.0) latency_ms: Mapped[float] = mapped_column(Float, default=0.0) + citations_json: Mapped[str | None] = mapped_column( + Text, nullable=True, default=None + ) created_at: Mapped[datetime] = mapped_column(DateTime, default=_utcnow) turn: Mapped[Turn] = relationship(back_populates="contributions") diff --git a/src/duh/memory/repository.py b/src/duh/memory/repository.py index 58b1ab4..9f60aef 100644 --- a/src/duh/memory/repository.py +++ b/src/duh/memory/repository.py @@ -125,6 +125,7 @@ async def add_contribution( output_tokens: int = 0, cost_usd: float = 0.0, latency_ms: float = 0.0, + citations_json: str | None = None, ) -> Contribution: """Record a model's contribution to a turn.""" contrib = Contribution( @@ -136,6 +137,7 @@ async def add_contribution( output_tokens=output_tokens, cost_usd=cost_usd, latency_ms=latency_ms, + citations_json=citations_json, ) self._session.add(contrib) await self._session.flush() diff --git a/src/duh/providers/anthropic.py b/src/duh/providers/anthropic.py index 8a79e62..42f6d1b 100644 --- a/src/duh/providers/anthropic.py +++ b/src/duh/providers/anthropic.py @@ -16,6 +16,7 @@ ProviderTimeoutError, ) from duh.providers.base import ( + Citation, ModelInfo, ModelResponse, StreamChunk, @@ -112,6 +113,7 @@ async def send( stop_sequences: list[str] | None = None, response_format: str | None = None, tools: list[dict[str, object]] | None = None, + web_search: bool = False, ) -> ModelResponse: system, api_messages = _build_messages(messages) @@ -125,22 +127,44 @@ async def send( if stop_sequences: kwargs["stop_sequences"] = stop_sequences if tools: - kwargs["tools"] = tools + kwargs["tools"] = [ + { + "name": t["name"], + "description": t.get("description", ""), + "input_schema": t.get("input_schema") or t.get("parameters", {}), + } + for t in tools + ] + if web_search: + native_tool: dict[str, object] = { + "type": "web_search_20250305", + "name": "web_search", + "max_uses": 5, + } + if "tools" in kwargs: + kwargs["tools"].insert(0, native_tool) + else: + kwargs["tools"] = [native_tool] start = time.monotonic() try: - response = await self._client.messages.create(**kwargs) + # Use streaming internally to avoid Anthropic's 10-minute + # timeout on non-streaming requests with large max_tokens. + response = await self._collect_stream(kwargs) except anthropic.APIError as e: raise _map_error(e) from e latency_ms = (time.monotonic() - start) * 1000 - # Extract text content and tool use blocks - content = "" + # Extract text content and tool use blocks. + # With server tools (e.g. web_search), there may be multiple text + # blocks interleaved with tool-use/result blocks — concatenate them. + text_parts: list[str] = [] tool_calls_data: list[ToolCallData] = [] + citations_data: list[Citation] = [] for block in response.content: if hasattr(block, "text"): - content = block.text + text_parts.append(block.text) elif hasattr(block, "type") and block.type == "tool_use": import json @@ -151,6 +175,26 @@ async def send( arguments=json.dumps(block.input), ) ) + elif hasattr(block, "type") and block.type == "web_search_tool_result": + # Extract citations from server-side web search results + search_content = getattr(block, "content", None) + if isinstance(search_content, list): + for entry in search_content: + entry_type = getattr(entry, "type", None) + if entry_type == "web_search_result": + url = getattr(entry, "url", None) + if url: + citations_data.append( + Citation( + url=url, + title=getattr(entry, "title", None), + snippet=getattr( + entry, "encrypted_content", None + ), + ) + ) + + content = "\n\n".join(text_parts) usage = TokenUsage( input_tokens=response.usage.input_tokens, @@ -172,8 +216,19 @@ async def send( latency_ms=latency_ms, raw_response=response, tool_calls=tool_calls_data if tool_calls_data else None, + citations=citations_data if citations_data else None, ) + async def _collect_stream(self, kwargs: dict[str, Any]) -> anthropic.types.Message: + """Stream a request and return the final Message. + + Streaming avoids Anthropic's 10-minute timeout for large + max_tokens values while still returning a complete Message + object compatible with non-streaming response parsing. + """ + async with self._client.messages.stream(**kwargs) as s: + return await s.get_final_message() + async def stream( self, messages: list[PromptMessage], diff --git a/src/duh/providers/base.py b/src/duh/providers/base.py index e9f00d2..5932f2c 100644 --- a/src/duh/providers/base.py +++ b/src/duh/providers/base.py @@ -23,6 +23,7 @@ class ModelCapability(enum.Flag): VISION = enum.auto() JSON_MODE = enum.auto() SYSTEM_PROMPT = enum.auto() + WEB_SEARCH = enum.auto() @dataclass(frozen=True, slots=True) @@ -61,6 +62,15 @@ def total_tokens(self) -> int: return self.input_tokens + self.output_tokens +@dataclass(frozen=True, slots=True) +class Citation: + """A single citation from a web search result.""" + + url: str + title: str | None = None + snippet: str | None = None + + @dataclass(frozen=True, slots=True) class ToolCallData: """A tool call from a model response.""" @@ -81,6 +91,7 @@ class ModelResponse: latency_ms: float # Wall-clock time for the call raw_response: object = field(default=None, repr=False) tool_calls: list[ToolCallData] | None = None + citations: list[Citation] | None = None @dataclass(frozen=True, slots=True) @@ -130,6 +141,7 @@ async def send( stop_sequences: list[str] | None = None, response_format: str | None = None, tools: list[dict[str, object]] | None = None, + web_search: bool = False, ) -> ModelResponse: """Send a prompt and wait for complete response. @@ -141,6 +153,7 @@ async def send( stop_sequences: Sequences that stop generation. response_format: If ``"json"``, request JSON output mode. tools: Tool definitions for function calling. + web_search: Enable provider-native web search. Raises ProviderError on failure. """ diff --git a/src/duh/providers/catalog.py b/src/duh/providers/catalog.py index fb0e23b..a6cbd32 100644 --- a/src/duh/providers/catalog.py +++ b/src/duh/providers/catalog.py @@ -18,6 +18,7 @@ | ModelCapability.STREAMING | ModelCapability.SYSTEM_PROMPT | ModelCapability.JSON_MODE + | ModelCapability.WEB_SEARCH ), "openai": ( ModelCapability.TEXT @@ -30,18 +31,21 @@ | ModelCapability.STREAMING | ModelCapability.SYSTEM_PROMPT | ModelCapability.JSON_MODE + | ModelCapability.WEB_SEARCH ), "mistral": ( ModelCapability.TEXT | ModelCapability.STREAMING | ModelCapability.SYSTEM_PROMPT | ModelCapability.JSON_MODE + | ModelCapability.WEB_SEARCH ), "perplexity": ( ModelCapability.TEXT | ModelCapability.STREAMING | ModelCapability.SYSTEM_PROMPT | ModelCapability.JSON_MODE + | ModelCapability.WEB_SEARCH ), } diff --git a/src/duh/providers/google.py b/src/duh/providers/google.py index 30a400e..7936758 100644 --- a/src/duh/providers/google.py +++ b/src/duh/providers/google.py @@ -16,6 +16,7 @@ ProviderTimeoutError, ) from duh.providers.base import ( + Citation, ModelInfo, ModelResponse, StreamChunk, @@ -108,6 +109,7 @@ async def send( stop_sequences: list[str] | None = None, response_format: str | None = None, tools: list[dict[str, object]] | None = None, + web_search: bool = False, ) -> ModelResponse: system, contents = _build_contents(messages) @@ -120,7 +122,23 @@ async def send( if response_format == "json": config_kwargs["response_mime_type"] = "application/json" if tools: - config_kwargs["tools"] = tools + func_decls = [ + genai.types.FunctionDeclaration( + name=str(t["name"]), + description=str(t.get("description", "")), + parameters=t.get("parameters") or t.get("input_schema", {}), # type: ignore[arg-type] + ) + for t in tools + ] + config_kwargs["tools"] = [ + genai.types.Tool(function_declarations=func_decls) + ] + if web_search: + # GoogleSearch grounding cannot coexist with function_declarations + # in the same request — replace function tools entirely. + config_kwargs["tools"] = [ + genai.types.Tool(google_search=genai.types.GoogleSearch()) + ] config = genai.types.GenerateContentConfig(**config_kwargs) @@ -136,14 +154,19 @@ async def send( latency_ms = (time.monotonic() - start) * 1000 - # Extract text and function calls - content = response.text or "" + # Extract text and function calls. + # response.text can raise ValueError when grounding metadata is + # present, so we iterate parts directly. + text_parts: list[str] = [] tool_calls_data: list[ToolCallData] = [] if response.candidates: cand_content = response.candidates[0].content parts = cand_content.parts if cand_content else None if parts: for part in parts: + part_text = getattr(part, "text", None) + if isinstance(part_text, str) and part_text: + text_parts.append(part_text) fc = getattr(part, "function_call", None) if fc and fc.name: import json @@ -156,6 +179,26 @@ async def send( arguments=json.dumps(args), ) ) + content = "\n\n".join(text_parts) + + # Extract citations from grounding metadata (GoogleSearch) + citations_data: list[Citation] = [] + if response.candidates: + cand = response.candidates[0] + grounding = getattr(cand, "grounding_metadata", None) + if grounding: + chunks = getattr(grounding, "grounding_chunks", None) or [] + for chunk in chunks: + web = getattr(chunk, "web", None) + if web: + uri = getattr(web, "uri", None) + if uri: + citations_data.append( + Citation( + url=uri, + title=getattr(web, "title", None), + ) + ) input_tokens = 0 output_tokens = 0 @@ -177,6 +220,7 @@ async def send( latency_ms=latency_ms, raw_response=response, tool_calls=tool_calls_data if tool_calls_data else None, + citations=citations_data if citations_data else None, ) async def stream( diff --git a/src/duh/providers/mistral.py b/src/duh/providers/mistral.py index 7aae04b..b47a95b 100644 --- a/src/duh/providers/mistral.py +++ b/src/duh/providers/mistral.py @@ -115,6 +115,7 @@ async def send( stop_sequences: list[str] | None = None, response_format: str | None = None, tools: list[dict[str, object]] | None = None, + web_search: bool = False, ) -> ModelResponse: api_messages = _build_messages(messages) @@ -129,7 +130,23 @@ async def send( if response_format == "json": kwargs["response_format"] = {"type": "json_object"} if tools: - kwargs["tools"] = tools + kwargs["tools"] = [ + { + "type": "function", + "function": { + "name": t["name"], + "description": t.get("description", ""), + "parameters": t.get("parameters") or t.get("input_schema", {}), + }, + } + for t in tools + ] + if web_search: + ws: dict[str, str] = {"type": "web_search"} + if "tools" in kwargs: + kwargs["tools"].append(ws) + else: + kwargs["tools"] = [ws] start = time.monotonic() try: diff --git a/src/duh/providers/openai.py b/src/duh/providers/openai.py index 00b6bcc..82b70a7 100644 --- a/src/duh/providers/openai.py +++ b/src/duh/providers/openai.py @@ -37,6 +37,11 @@ _KNOWN_MODELS = MODEL_CATALOG[PROVIDER_ID] _DEFAULT_CAPS = PROVIDER_CAPS[PROVIDER_ID] _NO_TEMPERATURE_MODELS = NO_TEMPERATURE_MODELS +_SEARCH_MODELS: set[str] = { + "gpt-5-search-api", + "gpt-4o-search-preview", + "gpt-4o-mini-search-preview", +} def _map_error(e: openai.APIError) -> Exception: @@ -128,6 +133,7 @@ async def send( stop_sequences: list[str] | None = None, response_format: str | None = None, tools: list[dict[str, object]] | None = None, + web_search: bool = False, ) -> ModelResponse: api_messages = _build_messages(messages) @@ -143,7 +149,19 @@ async def send( if response_format == "json": kwargs["response_format"] = {"type": "json_object"} if tools: - kwargs["tools"] = tools + kwargs["tools"] = [ + { + "type": "function", + "function": { + "name": t["name"], + "description": t.get("description", ""), + "parameters": t.get("parameters") or t.get("input_schema", {}), + }, + } + for t in tools + ] + if web_search and model_id in _SEARCH_MODELS: + kwargs["web_search_options"] = {} start = time.monotonic() try: diff --git a/src/duh/providers/perplexity.py b/src/duh/providers/perplexity.py index 6cd7567..9aea342 100644 --- a/src/duh/providers/perplexity.py +++ b/src/duh/providers/perplexity.py @@ -16,6 +16,7 @@ ProviderTimeoutError, ) from duh.providers.base import ( + Citation, ModelInfo, ModelResponse, StreamChunk, @@ -123,6 +124,7 @@ async def send( stop_sequences: list[str] | None = None, response_format: str | None = None, tools: list[dict[str, object]] | None = None, + web_search: bool = False, ) -> ModelResponse: api_messages = _build_messages(messages) clamped = min(max_tokens, self._max_output_for(model_id)) @@ -138,7 +140,17 @@ async def send( if response_format == "json": kwargs["response_format"] = {"type": "json_object"} if tools: - kwargs["tools"] = tools + kwargs["tools"] = [ + { + "type": "function", + "function": { + "name": t["name"], + "description": t.get("description", ""), + "parameters": t.get("parameters") or t.get("input_schema", {}), + }, + } + for t in tools + ] start = time.monotonic() try: @@ -178,10 +190,20 @@ async def send( model_info = self._resolve_model_info(model_id) # Capture citations from Perplexity response if present - citations = getattr(response, "citations", None) + raw_citations = getattr(response, "citations", None) raw = response - if citations is not None: - raw = {"response": response, "citations": citations} + if raw_citations is not None: + raw = {"response": response, "citations": raw_citations} + + citations_data: list[Citation] = [] + if isinstance(raw_citations, list): + for c in raw_citations: + if isinstance(c, str): + citations_data.append(Citation(url=c)) + elif hasattr(c, "url"): + citations_data.append( + Citation(url=c.url, title=getattr(c, "title", None)) + ) return ModelResponse( content=content, @@ -191,6 +213,7 @@ async def send( latency_ms=latency_ms, raw_response=raw, tool_calls=tool_calls_data, + citations=citations_data if citations_data else None, ) async def stream( diff --git a/src/duh/tools/augmented_send.py b/src/duh/tools/augmented_send.py index a3c4517..743203a 100644 --- a/src/duh/tools/augmented_send.py +++ b/src/duh/tools/augmented_send.py @@ -26,6 +26,7 @@ async def tool_augmented_send( max_tool_rounds: int = 5, temperature: float = 0.7, max_tokens: int = 4096, + web_search: bool = False, ) -> ModelResponse: """Send a prompt with tool-use loop. @@ -34,6 +35,9 @@ async def tool_augmented_send( 3. Feed tool results back as messages 4. Repeat until text response or max_tool_rounds reached + When ``web_search=True``, the ``web_search`` tool definition is + filtered from the tools list (the provider handles search natively). + Args: provider: The model provider to use. model_id: Model to call. @@ -42,6 +46,7 @@ async def tool_augmented_send( max_tool_rounds: Maximum tool-use iterations. temperature: Sampling temperature. max_tokens: Max output tokens. + web_search: Pass native web search flag to the provider. Returns: Final ModelResponse (text content or last tool round). @@ -58,6 +63,11 @@ async def tool_augmented_send( for td in tool_defs ] + # When native web search is active, remove the DDG web_search tool + # so the model doesn't see both native and local search. + if web_search: + tools_param = [t for t in tools_param if t["name"] != "web_search"] + current_messages = list(messages) for _round in range(max_tool_rounds): @@ -67,6 +77,7 @@ async def tool_augmented_send( max_tokens=max_tokens, temperature=temperature, tools=tools_param if tools_param else None, + web_search=web_search, ) # If no tool calls, return the response as-is diff --git a/tests/fixtures/providers.py b/tests/fixtures/providers.py index 7282e65..3e99e69 100644 --- a/tests/fixtures/providers.py +++ b/tests/fixtures/providers.py @@ -71,6 +71,7 @@ async def send( stop_sequences: list[str] | None = None, response_format: str | None = None, tools: list[dict[str, object]] | None = None, + web_search: bool = False, ) -> ModelResponse: self.call_log.append( { diff --git a/tests/unit/test_api_refine.py b/tests/unit/test_api_refine.py new file mode 100644 index 0000000..bc31d0a --- /dev/null +++ b/tests/unit/test_api_refine.py @@ -0,0 +1,119 @@ +"""Tests for POST /api/refine and POST /api/enrich endpoints.""" + +from __future__ import annotations + +import json +from unittest.mock import AsyncMock, patch + +from fastapi.testclient import TestClient +from sqlalchemy import event +from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine + +from duh.api.app import create_app +from duh.config.schema import DuhConfig +from duh.memory.models import Base +from duh.providers.manager import ProviderManager +from tests.fixtures.providers import MockProvider + + +async def _make_app() -> TestClient: + """Create a test app with mocked providers and in-memory DB.""" + config = DuhConfig() + config.database.url = "sqlite+aiosqlite:///:memory:" + + engine = create_async_engine("sqlite+aiosqlite://") + + @event.listens_for(engine.sync_engine, "connect") + def _enable_fks(dbapi_conn, connection_record): # type: ignore[no-untyped-def] + cursor = dbapi_conn.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + factory = async_sessionmaker(engine, expire_on_commit=False) + + mock_prov = MockProvider( + provider_id="mock", + responses={"model-a": json.dumps({"needs_refinement": False})}, + input_cost=1.0, + output_cost=5.0, + ) + pm = ProviderManager(cost_hard_limit=100.0) + await pm.register(mock_prov) # type: ignore[arg-type] + + app = create_app(config) + app.state.db_factory = factory + app.state.engine = engine + app.state.provider_manager = pm + return TestClient(app, raise_server_exceptions=False) + + +class TestRefineEndpoint: + async def test_refine_no_refinement(self) -> None: + client = await _make_app() + with patch( + "duh.consensus.refine.analyze_question", + new_callable=AsyncMock, + return_value={"needs_refinement": False}, + ): + resp = client.post("/api/refine", json={"question": "What is 2+2?"}) + assert resp.status_code == 200 + data = resp.json() + assert data["needs_refinement"] is False + assert data["questions"] == [] + + async def test_refine_with_questions(self) -> None: + client = await _make_app() + questions = [ + {"question": "What scale?", "hint": "users/day"}, + {"question": "Budget?", "hint": None}, + ] + with patch( + "duh.consensus.refine.analyze_question", + new_callable=AsyncMock, + return_value={"needs_refinement": True, "questions": questions}, + ): + resp = client.post("/api/refine", json={"question": "What DB?"}) + assert resp.status_code == 200 + data = resp.json() + assert data["needs_refinement"] is True + assert len(data["questions"]) == 2 + + async def test_refine_custom_max_questions(self) -> None: + client = await _make_app() + with patch( + "duh.consensus.refine.analyze_question", + new_callable=AsyncMock, + return_value={"needs_refinement": False}, + ) as mock_analyze: + client.post( + "/api/refine", + json={"question": "Test?", "max_questions": 2}, + ) + mock_analyze.assert_called_once() + _, kwargs = mock_analyze.call_args + assert kwargs["max_questions"] == 2 + + +class TestEnrichEndpoint: + async def test_enrich(self) -> None: + client = await _make_app() + with patch( + "duh.consensus.refine.enrich_question", + new_callable=AsyncMock, + return_value="What DB for a 10k-user SaaS?", + ): + resp = client.post( + "/api/enrich", + json={ + "original_question": "What DB?", + "clarifications": [ + {"question": "Scale?", "answer": "10k users"}, + ], + }, + ) + assert resp.status_code == 200 + data = resp.json() + assert "10k" in data["enriched_question"] diff --git a/tests/unit/test_api_ws.py b/tests/unit/test_api_ws.py index 212f0da..eeeea65 100644 --- a/tests/unit/test_api_ws.py +++ b/tests/unit/test_api_ws.py @@ -113,10 +113,33 @@ def mock_convergence(ctx, **kwargs): propose_mock = AsyncMock(side_effect=mock_propose) stack.enter_context(patch(f"{_HANDLERS}.handle_propose", propose_mock)) + + # _stream_challenges in ws.py replaces the old handle_challenge call + async def mock_stream_challenges(ws, ctx, pm, challengers, **kwargs): + ctx.challenges = [ + ChallengeResult( + model_ref=ref, + content=challenge_content, + sycophantic=False, + framing="flaw", + ) + for ref in challengers + ] + for ch in ctx.challenges: + await ws.send_json( + { + "type": "challenge", + "model": ch.model_ref, + "content": ch.content, + "truncated": False, + "citations": None, + } + ) + stack.enter_context( patch( - f"{_HANDLERS}.handle_challenge", - AsyncMock(side_effect=mock_challenge), + "duh.api.routes.ws._stream_challenges", + AsyncMock(side_effect=mock_stream_challenges), ) ) stack.enter_context( diff --git a/tests/unit/test_config_v02.py b/tests/unit/test_config_v02.py index edc20d6..9a4227b 100644 --- a/tests/unit/test_config_v02.py +++ b/tests/unit/test_config_v02.py @@ -34,7 +34,7 @@ def test_defaults(self) -> None: class TestToolsConfig: def test_defaults(self) -> None: cfg = ToolsConfig() - assert cfg.enabled is False + assert cfg.enabled is True assert cfg.max_rounds == 5 assert cfg.web_search.backend == "duckduckgo" assert cfg.code_execution.enabled is False @@ -93,6 +93,6 @@ def test_backward_compatible(self) -> None: assert cfg.cost.hard_limit == 10.00 assert "anthropic" in cfg.providers # v0.2 fields have safe defaults - assert cfg.tools.enabled is False + assert cfg.tools.enabled is True assert cfg.voting.enabled is False assert cfg.taxonomy.enabled is False diff --git a/tests/unit/test_config_v03.py b/tests/unit/test_config_v03.py index c1319e9..7c43f8d 100644 --- a/tests/unit/test_config_v03.py +++ b/tests/unit/test_config_v03.py @@ -127,7 +127,7 @@ def test_backward_compatible(self) -> None: assert cfg.general.max_rounds == 3 assert cfg.cost.hard_limit == 10.00 assert "anthropic" in cfg.providers - assert cfg.tools.enabled is False + assert cfg.tools.enabled is True assert cfg.api.host == "127.0.0.1" diff --git a/tests/unit/test_native_web_search.py b/tests/unit/test_native_web_search.py new file mode 100644 index 0000000..419008d --- /dev/null +++ b/tests/unit/test_native_web_search.py @@ -0,0 +1,592 @@ +"""Tests for native provider web search support.""" + +from __future__ import annotations + +from typing import Any, ClassVar +from unittest.mock import AsyncMock, MagicMock + +from duh.config.schema import WebSearchConfig +from duh.providers.base import ( + ModelCapability, + ModelInfo, + ModelResponse, + PromptMessage, + TokenUsage, +) + +# ── Helpers ──────────────────────────────────────────────────────── + + +def _model_info(provider: str = "test", model: str = "m1") -> ModelInfo: + return ModelInfo( + provider_id=provider, + model_id=model, + display_name="Test", + capabilities=ModelCapability.TEXT, + context_window=128_000, + max_output_tokens=4096, + input_cost_per_mtok=0.0, + output_cost_per_mtok=0.0, + ) + + +def _text_response(provider: str = "test") -> ModelResponse: + return ModelResponse( + content="ok", + model_info=_model_info(provider), + usage=TokenUsage(input_tokens=10, output_tokens=5), + finish_reason="stop", + latency_ms=1.0, + ) + + +_MESSAGES = [PromptMessage(role="user", content="What happened today?")] + + +def _make_anthropic_client(mock_msg: object) -> AsyncMock: + """Create a mocked Anthropic client with stream support.""" + mock_client = AsyncMock() + stream_cm = MagicMock() + stream_cm.get_final_message = AsyncMock(return_value=mock_msg) + stream_cm.__aenter__ = AsyncMock(return_value=stream_cm) + stream_cm.__aexit__ = AsyncMock(return_value=False) + mock_client.messages.stream = MagicMock(return_value=stream_cm) + return mock_client + + +def _make_anthropic_msg() -> MagicMock: + """Create a minimal Anthropic response message.""" + mock_msg = MagicMock() + mock_msg.content = [MagicMock(text="result", type="text")] + mock_msg.stop_reason = "stop" + mock_msg.usage = MagicMock( + input_tokens=10, + output_tokens=5, + cache_read_input_tokens=0, + cache_creation_input_tokens=0, + ) + return mock_msg + + +# ── Anthropic ────────────────────────────────────────────────────── + + +class TestAnthropicNativeSearch: + async def test_web_search_injects_server_tool(self) -> None: + """web_search=True adds Anthropic server tool to kwargs.""" + from duh.providers.anthropic import AnthropicProvider + + mock_msg = _make_anthropic_msg() + mock_client = _make_anthropic_client(mock_msg) + + provider = AnthropicProvider(client=mock_client) + await provider.send( + _MESSAGES, + "claude-sonnet-4-6", + web_search=True, + ) + + call_kwargs = mock_client.messages.stream.call_args[1] + tools = call_kwargs["tools"] + assert any(t.get("type", "").startswith("web_search") for t in tools), ( + "Server tool not found in tools" + ) + + async def test_web_search_false_no_server_tool(self) -> None: + """web_search=False does not add server tool.""" + from duh.providers.anthropic import AnthropicProvider + + mock_msg = _make_anthropic_msg() + mock_client = _make_anthropic_client(mock_msg) + + provider = AnthropicProvider(client=mock_client) + await provider.send( + _MESSAGES, + "claude-sonnet-4-6", + web_search=False, + ) + + call_kwargs = mock_client.messages.stream.call_args[1] + assert "tools" not in call_kwargs + + async def test_web_search_with_function_tools(self) -> None: + """web_search=True alongside function tools keeps both.""" + from duh.providers.anthropic import AnthropicProvider + + mock_msg = _make_anthropic_msg() + mock_client = _make_anthropic_client(mock_msg) + + func_tool: dict[str, object] = { + "name": "calculator", + "description": "Math", + "parameters": {"type": "object", "properties": {}}, + } + + provider = AnthropicProvider(client=mock_client) + await provider.send( + _MESSAGES, + "claude-sonnet-4-6", + tools=[func_tool], + web_search=True, + ) + + call_kwargs = mock_client.messages.stream.call_args[1] + tools = call_kwargs["tools"] + # Server tool should be first + assert tools[0].get("type", "").startswith("web_search") + # Function tool should follow + assert tools[1]["name"] == "calculator" + + +# ── Google ───────────────────────────────────────────────────────── + + +class TestGoogleNativeSearch: + async def test_web_search_injects_grounding(self) -> None: + """web_search=True adds GoogleSearch grounding tool.""" + from duh.providers.google import GoogleProvider + + mock_response = MagicMock() + mock_response.candidates = [] + mock_response.usage_metadata = MagicMock( + prompt_token_count=10, + candidates_token_count=5, + ) + + mock_client = MagicMock() + mock_client.aio.models.generate_content = AsyncMock( + return_value=mock_response, + ) + + provider = GoogleProvider(client=mock_client) + await provider.send( + _MESSAGES, + "gemini-2.5-flash", + web_search=True, + ) + + call_kwargs = mock_client.aio.models.generate_content.call_args[1] + config = call_kwargs["config"] + tools = config.tools + assert tools is not None + # Should have ONLY the GoogleSearch grounding tool + assert len(tools) == 1 + assert getattr(tools[0], "google_search", None) is not None + + async def test_web_search_replaces_function_tools(self) -> None: + """web_search=True replaces function tools (they can't coexist).""" + from duh.providers.google import GoogleProvider + + mock_response = MagicMock() + mock_response.candidates = [] + mock_response.usage_metadata = MagicMock( + prompt_token_count=10, + candidates_token_count=5, + ) + + mock_client = MagicMock() + mock_client.aio.models.generate_content = AsyncMock( + return_value=mock_response, + ) + + func_tool: dict[str, object] = { + "name": "calculator", + "description": "Math", + "parameters": {"type": "object", "properties": {}}, + } + + provider = GoogleProvider(client=mock_client) + await provider.send( + _MESSAGES, + "gemini-2.5-flash", + tools=[func_tool], + web_search=True, + ) + + call_kwargs = mock_client.aio.models.generate_content.call_args[1] + config = call_kwargs["config"] + tools = config.tools + # Grounding replaces function tools — only 1 tool, the grounding one + assert len(tools) == 1 + assert getattr(tools[0], "google_search", None) is not None + + async def test_web_search_false_no_grounding(self) -> None: + """web_search=False does not add grounding tool.""" + from duh.providers.google import GoogleProvider + + mock_response = MagicMock() + mock_response.candidates = [] + mock_response.usage_metadata = MagicMock( + prompt_token_count=10, + candidates_token_count=5, + ) + + mock_client = MagicMock() + mock_client.aio.models.generate_content = AsyncMock( + return_value=mock_response, + ) + + provider = GoogleProvider(client=mock_client) + await provider.send( + _MESSAGES, + "gemini-2.5-flash", + web_search=False, + ) + + call_kwargs = mock_client.aio.models.generate_content.call_args[1] + config = call_kwargs["config"] + assert config.tools is None or len(config.tools) == 0 + + +# ── Mistral ──────────────────────────────────────────────────────── + + +class TestMistralNativeSearch: + async def test_web_search_injects_tool(self) -> None: + """web_search=True adds web_search tool to kwargs.""" + from duh.providers.mistral import MistralProvider + + mock_response = MagicMock() + mock_response.choices = [ + MagicMock( + message=MagicMock(content="result", tool_calls=None), + finish_reason="stop", + ) + ] + mock_response.usage = MagicMock( + prompt_tokens=10, + completion_tokens=5, + ) + + mock_client = MagicMock() + mock_client.chat.complete_async = AsyncMock(return_value=mock_response) + + provider = MistralProvider(client=mock_client) + await provider.send( + _MESSAGES, + "mistral-large-latest", + web_search=True, + ) + + call_kwargs = mock_client.chat.complete_async.call_args[1] + tools = call_kwargs["tools"] + assert any(t.get("type") == "web_search" for t in tools) + + async def test_web_search_false_no_tool(self) -> None: + """web_search=False does not add web_search tool.""" + from duh.providers.mistral import MistralProvider + + mock_response = MagicMock() + mock_response.choices = [ + MagicMock( + message=MagicMock(content="result", tool_calls=None), + finish_reason="stop", + ) + ] + mock_response.usage = MagicMock( + prompt_tokens=10, + completion_tokens=5, + ) + + mock_client = MagicMock() + mock_client.chat.complete_async = AsyncMock(return_value=mock_response) + + provider = MistralProvider(client=mock_client) + await provider.send( + _MESSAGES, + "mistral-large-latest", + web_search=False, + ) + + call_kwargs = mock_client.chat.complete_async.call_args[1] + assert "tools" not in call_kwargs + + +# ── OpenAI ───────────────────────────────────────────────────────── + + +class TestOpenAINativeSearch: + async def test_search_model_gets_web_search_options(self) -> None: + """web_search=True + search model adds web_search_options.""" + from duh.providers.openai import OpenAIProvider + + mock_response = MagicMock() + mock_response.choices = [ + MagicMock( + message=MagicMock(content="result", tool_calls=None), + finish_reason="stop", + ) + ] + mock_response.usage = MagicMock( + prompt_tokens=10, + completion_tokens=5, + ) + + mock_client = AsyncMock() + mock_client.chat.completions.create = AsyncMock( + return_value=mock_response, + ) + + provider = OpenAIProvider(client=mock_client) + await provider.send( + _MESSAGES, + "gpt-4o-search-preview", + web_search=True, + ) + + call_kwargs = mock_client.chat.completions.create.call_args[1] + assert "web_search_options" in call_kwargs + + async def test_standard_model_ignores_web_search(self) -> None: + """web_search=True + non-search model does NOT add web_search_options.""" + from duh.providers.openai import OpenAIProvider + + mock_response = MagicMock() + mock_response.choices = [ + MagicMock( + message=MagicMock(content="result", tool_calls=None), + finish_reason="stop", + ) + ] + mock_response.usage = MagicMock( + prompt_tokens=10, + completion_tokens=5, + ) + + mock_client = AsyncMock() + mock_client.chat.completions.create = AsyncMock( + return_value=mock_response, + ) + + provider = OpenAIProvider(client=mock_client) + await provider.send( + _MESSAGES, + "gpt-5.4", + web_search=True, + ) + + call_kwargs = mock_client.chat.completions.create.call_args[1] + assert "web_search_options" not in call_kwargs + + async def test_web_search_false_no_options(self) -> None: + """web_search=False never adds web_search_options.""" + from duh.providers.openai import OpenAIProvider + + mock_response = MagicMock() + mock_response.choices = [ + MagicMock( + message=MagicMock(content="result", tool_calls=None), + finish_reason="stop", + ) + ] + mock_response.usage = MagicMock( + prompt_tokens=10, + completion_tokens=5, + ) + + mock_client = AsyncMock() + mock_client.chat.completions.create = AsyncMock( + return_value=mock_response, + ) + + provider = OpenAIProvider(client=mock_client) + await provider.send( + _MESSAGES, + "gpt-4o-search-preview", + web_search=False, + ) + + call_kwargs = mock_client.chat.completions.create.call_args[1] + assert "web_search_options" not in call_kwargs + + +# ── Perplexity ───────────────────────────────────────────────────── + + +class TestPerplexityNativeSearch: + async def test_web_search_is_noop(self) -> None: + """web_search=True changes nothing for Perplexity (always-on).""" + from duh.providers.perplexity import PerplexityProvider + + mock_response = MagicMock() + mock_response.choices = [ + MagicMock( + message=MagicMock(content="result", tool_calls=None), + finish_reason="stop", + ) + ] + mock_response.usage = MagicMock( + prompt_tokens=10, + completion_tokens=5, + ) + + mock_client = AsyncMock() + mock_client.chat.completions.create = AsyncMock( + return_value=mock_response, + ) + + provider = PerplexityProvider(client=mock_client) + await provider.send( + _MESSAGES, + "sonar", + web_search=True, + ) + + call_kwargs = mock_client.chat.completions.create.call_args[1] + # No web_search_options or special tools added + assert "web_search_options" not in call_kwargs + assert "tools" not in call_kwargs + + +# ── Config ───────────────────────────────────────────────────────── + + +class TestWebSearchConfig: + def test_native_defaults_true(self) -> None: + """WebSearchConfig.native defaults to True.""" + cfg = WebSearchConfig() + assert cfg.native is True + + def test_native_explicit_false(self) -> None: + """WebSearchConfig.native can be set to False.""" + cfg = WebSearchConfig(native=False) + assert cfg.native is False + + +# ── Catalog capabilities ────────────────────────────────────────── + + +class TestWebSearchCapability: + def test_web_search_in_flag(self) -> None: + """WEB_SEARCH exists as a ModelCapability flag.""" + assert hasattr(ModelCapability, "WEB_SEARCH") + + def test_anthropic_has_web_search(self) -> None: + from duh.providers.catalog import PROVIDER_CAPS + + assert ModelCapability.WEB_SEARCH in PROVIDER_CAPS["anthropic"] + + def test_google_has_web_search(self) -> None: + from duh.providers.catalog import PROVIDER_CAPS + + assert ModelCapability.WEB_SEARCH in PROVIDER_CAPS["google"] + + def test_mistral_has_web_search(self) -> None: + from duh.providers.catalog import PROVIDER_CAPS + + assert ModelCapability.WEB_SEARCH in PROVIDER_CAPS["mistral"] + + def test_perplexity_has_web_search(self) -> None: + from duh.providers.catalog import PROVIDER_CAPS + + assert ModelCapability.WEB_SEARCH in PROVIDER_CAPS["perplexity"] + + def test_openai_no_web_search(self) -> None: + from duh.providers.catalog import PROVIDER_CAPS + + assert ModelCapability.WEB_SEARCH not in PROVIDER_CAPS["openai"] + + +# ── tool_augmented_send integration ────────────────────────────── + + +class TestAugmentedSendWebSearch: + async def test_web_search_filters_ddg_tool(self) -> None: + """web_search=True removes 'web_search' from tools list.""" + from duh.tools.augmented_send import tool_augmented_send + from duh.tools.registry import ToolRegistry + + call_log: list[dict[str, Any]] = [] + + class _MockProvider: + provider_id = "test" + + async def send( + self, messages: Any, model_id: str, **kwargs: Any + ) -> ModelResponse: + call_log.append(kwargs) + return _text_response() + + class _WebSearchTool: + name = "web_search" + description = "Search the web" + parameters_schema: ClassVar[dict[str, Any]] = { + "type": "object", + "properties": {"query": {"type": "string"}}, + } + + async def execute(self, **kwargs: Any) -> str: + return "results" + + class _OtherTool: + name = "calculator" + description = "Math" + parameters_schema: ClassVar[dict[str, Any]] = { + "type": "object", + "properties": {}, + } + + async def execute(self, **kwargs: Any) -> str: + return "42" + + registry = ToolRegistry() + registry.register(_WebSearchTool()) + registry.register(_OtherTool()) + + await tool_augmented_send( + _MockProvider(), # type: ignore[arg-type] + "m1", + _MESSAGES, + registry, + web_search=True, + ) + + # web_search tool should be filtered out, calculator remains + tools = call_log[0]["tools"] + assert len(tools) == 1 + assert tools[0]["name"] == "calculator" + # web_search flag should be passed through + assert call_log[0]["web_search"] is True + + async def test_no_web_search_keeps_ddg_tool(self) -> None: + """web_search=False keeps 'web_search' in tools list.""" + from duh.tools.augmented_send import tool_augmented_send + from duh.tools.registry import ToolRegistry + + call_log: list[dict[str, Any]] = [] + + class _MockProvider: + provider_id = "test" + + async def send( + self, messages: Any, model_id: str, **kwargs: Any + ) -> ModelResponse: + call_log.append(kwargs) + return _text_response() + + class _WebSearchTool: + name = "web_search" + description = "Search the web" + parameters_schema: ClassVar[dict[str, Any]] = { + "type": "object", + "properties": {"query": {"type": "string"}}, + } + + async def execute(self, **kwargs: Any) -> str: + return "results" + + registry = ToolRegistry() + registry.register(_WebSearchTool()) + + await tool_augmented_send( + _MockProvider(), # type: ignore[arg-type] + "m1", + _MESSAGES, + registry, + web_search=False, + ) + + tools = call_log[0]["tools"] + assert len(tools) == 1 + assert tools[0]["name"] == "web_search" + assert call_log[0]["web_search"] is False diff --git a/tests/unit/test_provider_tools.py b/tests/unit/test_provider_tools.py index 14c4025..9c8ebc6 100644 --- a/tests/unit/test_provider_tools.py +++ b/tests/unit/test_provider_tools.py @@ -18,17 +18,15 @@ # ── Shared fixtures ────────────────────────────────────────────── +# Generic tool format (as produced by tool_augmented_send) SAMPLE_TOOLS: list[dict[str, object]] = [ { - "type": "function", - "function": { - "name": "web_search", - "description": "Search the web", - "parameters": { - "type": "object", - "properties": {"query": {"type": "string"}}, - "required": ["query"], - }, + "name": "web_search", + "description": "Search the web", + "parameters": { + "type": "object", + "properties": {"query": {"type": "string"}}, + "required": ["query"], }, } ] @@ -90,7 +88,11 @@ async def test_tools_param_forwarded(self) -> None: provider = OpenAIProvider(client=client) await provider.send(USER_MSG, "gpt-5.2", tools=SAMPLE_TOOLS) call_kwargs = client.chat.completions.create.call_args.kwargs - assert call_kwargs["tools"] is SAMPLE_TOOLS + tools = call_kwargs["tools"] + assert len(tools) == 1 + assert tools[0]["type"] == "function" + assert tools[0]["function"]["name"] == "web_search" + assert tools[0]["function"]["parameters"]["type"] == "object" async def test_no_tools_param_omitted(self) -> None: client = _oai_make_client() @@ -242,11 +244,16 @@ def _anth_make_response_with_tool_use( def _anth_make_client(response: Any = None) -> MagicMock: import anthropic + resp = response or _anth_make_response_with_tool_use(text="Hello") client = MagicMock(spec=anthropic.AsyncAnthropic) client.messages = MagicMock() - if response is None: - response = _anth_make_response_with_tool_use(text="Hello") - client.messages.create = AsyncMock(return_value=response) + client.messages.create = AsyncMock(return_value=resp) + + stream_cm = MagicMock() + stream_cm.get_final_message = AsyncMock(return_value=resp) + stream_cm.__aenter__ = AsyncMock(return_value=stream_cm) + stream_cm.__aexit__ = AsyncMock(return_value=False) + client.messages.stream = MagicMock(return_value=stream_cm) return client @@ -257,8 +264,11 @@ async def test_tools_param_forwarded(self) -> None: client = _anth_make_client() provider = AnthropicProvider(client=client) await provider.send(USER_MSG, "claude-opus-4-6", tools=SAMPLE_TOOLS) - call_kwargs = client.messages.create.call_args.kwargs - assert call_kwargs["tools"] is SAMPLE_TOOLS + call_kwargs = client.messages.stream.call_args.kwargs + tools = call_kwargs["tools"] + assert len(tools) == 1 + assert tools[0]["name"] == "web_search" + assert tools[0]["input_schema"]["type"] == "object" async def test_no_tools_param_omitted(self) -> None: from duh.providers.anthropic import AnthropicProvider @@ -266,7 +276,7 @@ async def test_no_tools_param_omitted(self) -> None: client = _anth_make_client() provider = AnthropicProvider(client=client) await provider.send(USER_MSG, "claude-opus-4-6") - call_kwargs = client.messages.create.call_args.kwargs + call_kwargs = client.messages.stream.call_args.kwargs assert "tools" not in call_kwargs @@ -454,10 +464,19 @@ def _mock_genai_config(**kwargs: Any) -> MagicMock: side_effect=_mock_genai_config, ) +_PATCH_FUNC_DECL = patch( + "duh.providers.google.genai.types.FunctionDeclaration", + side_effect=lambda **kwargs: MagicMock(**kwargs), +) + +_PATCH_TOOL = patch( + "duh.providers.google.genai.types.Tool", + side_effect=lambda **kwargs: MagicMock(**kwargs), +) + class TestGoogleToolForwarding: - @_PATCH_CONFIG - async def test_tools_param_forwarded(self, _mock_cfg: Any) -> None: + async def test_tools_param_forwarded(self) -> None: from duh.providers.google import GoogleProvider client = _google_make_client() @@ -465,7 +484,9 @@ async def test_tools_param_forwarded(self, _mock_cfg: Any) -> None: await provider.send(USER_MSG, "gemini-2.5-flash", tools=SAMPLE_TOOLS) call_kwargs = client.aio.models.generate_content.call_args config = call_kwargs.kwargs["config"] - assert config.tools is SAMPLE_TOOLS + # Google wraps tools in genai.types.Tool with function_declarations + assert config.tools is not None + assert len(config.tools) == 1 async def test_no_tools_param_not_in_config(self) -> None: from duh.providers.google import GoogleProvider @@ -490,8 +511,10 @@ async def test_response_format_json(self) -> None: class TestGoogleToolCallParsing: + @_PATCH_TOOL + @_PATCH_FUNC_DECL @_PATCH_CONFIG - async def test_single_function_call_parsed(self, _mock_cfg: Any) -> None: + async def test_single_function_call_parsed(self, *_mocks: Any) -> None: from duh.providers.google import GoogleProvider response = _google_make_response_with_function_calls( @@ -508,8 +531,10 @@ async def test_single_function_call_parsed(self, _mock_cfg: Any) -> None: assert tc.id == "google-web_search" assert json.loads(tc.arguments) == {"query": "cats"} + @_PATCH_TOOL + @_PATCH_FUNC_DECL @_PATCH_CONFIG - async def test_multiple_function_calls_parsed(self, _mock_cfg: Any) -> None: + async def test_multiple_function_calls_parsed(self, *_mocks: Any) -> None: from duh.providers.google import GoogleProvider response = _google_make_response_with_function_calls( @@ -524,8 +549,10 @@ async def test_multiple_function_calls_parsed(self, _mock_cfg: Any) -> None: assert resp.tool_calls is not None assert len(resp.tool_calls) == 2 + @_PATCH_TOOL + @_PATCH_FUNC_DECL @_PATCH_CONFIG - async def test_no_function_calls_returns_none(self, _mock_cfg: Any) -> None: + async def test_no_function_calls_returns_none(self, *_mocks: Any) -> None: from duh.providers.google import GoogleProvider response = _google_make_response_with_function_calls(text="No tools needed") @@ -534,8 +561,10 @@ async def test_no_function_calls_returns_none(self, _mock_cfg: Any) -> None: resp = await provider.send(USER_MSG, "gemini-2.5-flash", tools=SAMPLE_TOOLS) assert resp.tool_calls is None + @_PATCH_TOOL + @_PATCH_FUNC_DECL @_PATCH_CONFIG - async def test_function_call_with_no_args(self, _mock_cfg: Any) -> None: + async def test_function_call_with_no_args(self, *_mocks: Any) -> None: from duh.providers.google import GoogleProvider response = _google_make_response_with_function_calls( diff --git a/tests/unit/test_providers_anthropic.py b/tests/unit/test_providers_anthropic.py index 88356fb..35059f4 100644 --- a/tests/unit/test_providers_anthropic.py +++ b/tests/unit/test_providers_anthropic.py @@ -58,10 +58,23 @@ def _make_response( def _make_client(response: Any = None) -> MagicMock: - """Create a mocked AsyncAnthropic client.""" + """Create a mocked AsyncAnthropic client. + + Mocks both ``messages.create`` (for direct calls) and + ``messages.stream`` (used by ``_collect_stream`` in ``send()``). + """ + resp = response or _make_response() client = MagicMock(spec=anthropic.AsyncAnthropic) client.messages = MagicMock() - client.messages.create = AsyncMock(return_value=response or _make_response()) + client.messages.create = AsyncMock(return_value=resp) + + # _collect_stream uses: async with client.messages.stream(**kw) as s: + # return await s.get_final_message() + stream_cm = MagicMock() + stream_cm.get_final_message = AsyncMock(return_value=resp) + stream_cm.__aenter__ = AsyncMock(return_value=stream_cm) + stream_cm.__aexit__ = AsyncMock(return_value=False) + client.messages.stream = MagicMock(return_value=stream_cm) return client @@ -206,7 +219,7 @@ async def test_passes_params_to_sdk(self): temperature=0.3, stop_sequences=["STOP"], ) - call_kwargs = client.messages.create.call_args.kwargs + call_kwargs = client.messages.stream.call_args.kwargs assert call_kwargs["model"] == "claude-opus-4-6" assert call_kwargs["max_tokens"] == 1000 assert call_kwargs["temperature"] == 0.3 @@ -275,7 +288,7 @@ def test_unknown_api_error_maps_to_overloaded(self): async def test_send_raises_mapped_error(self): client = _make_client() - client.messages.create.side_effect = anthropic.AuthenticationError( + client.messages.stream.side_effect = anthropic.AuthenticationError( message="bad key", response=MagicMock(status_code=401, headers={}), body=None, diff --git a/tests/unit/test_providers_base.py b/tests/unit/test_providers_base.py index 42f357a..e542267 100644 --- a/tests/unit/test_providers_base.py +++ b/tests/unit/test_providers_base.py @@ -283,6 +283,7 @@ async def send( max_tokens: int = 4096, temperature: float = 0.7, stop_sequences: list[str] | None = None, + web_search: bool = False, ) -> ModelResponse: raise NotImplementedError diff --git a/tests/unit/test_providers_google.py b/tests/unit/test_providers_google.py index 12497df..9c31b04 100644 --- a/tests/unit/test_providers_google.py +++ b/tests/unit/test_providers_google.py @@ -44,6 +44,16 @@ def _make_response( ) -> MagicMock: resp = MagicMock() resp.text = text + # Build candidates with text parts so the provider can extract + # content from parts (needed for grounding-safe parsing). + text_part = MagicMock() + text_part.text = text + text_part.function_call = None + cand_content = MagicMock() + cand_content.parts = [text_part] + candidate = MagicMock() + candidate.content = cand_content + resp.candidates = [candidate] resp.usage_metadata = MagicMock() resp.usage_metadata.prompt_token_count = prompt_tokens resp.usage_metadata.candidates_token_count = candidate_tokens @@ -179,6 +189,7 @@ async def test_send_no_usage_metadata(): async def test_send_empty_text(): resp = _make_response() resp.text = None + resp.candidates[0].content.parts[0].text = None client = _make_client(resp) prov = GoogleProvider(client=client) result = await prov.send( diff --git a/tests/unit/test_refine.py b/tests/unit/test_refine.py new file mode 100644 index 0000000..453d651 --- /dev/null +++ b/tests/unit/test_refine.py @@ -0,0 +1,129 @@ +"""Tests for question refinement (analyze_question / enrich_question).""" + +from __future__ import annotations + +import json +from unittest.mock import AsyncMock, MagicMock + +from duh.consensus.refine import analyze_question, enrich_question + + +def _mock_pm(response_content: str) -> MagicMock: + """Create a mock ProviderManager that returns *response_content*.""" + model = MagicMock() + model.input_cost_per_mtok = 0.5 + model.model_ref = "mock:cheap" + + provider = AsyncMock() + provider.send = AsyncMock( + return_value=MagicMock( + content=response_content, + usage=MagicMock(input_tokens=10, output_tokens=20), + ) + ) + + pm = MagicMock() + pm.list_all_models.return_value = [model] + pm.get_provider.return_value = (provider, "cheap") + pm.record_usage = MagicMock() + return pm + + +# ── analyze_question ────────────────────────────────────────── + + +class TestAnalyzeQuestion: + async def test_no_refinement_needed(self) -> None: + pm = _mock_pm(json.dumps({"needs_refinement": False})) + result = await analyze_question("What is 2+2?", pm) + assert result["needs_refinement"] is False + + async def test_refinement_needed(self) -> None: + payload = { + "needs_refinement": True, + "questions": [ + {"question": "What scale?", "hint": "users/requests"}, + {"question": "Budget?", "hint": None}, + ], + } + pm = _mock_pm(json.dumps(payload)) + result = await analyze_question("What database should I use?", pm) + assert result["needs_refinement"] is True + assert len(result["questions"]) == 2 + assert result["questions"][0]["question"] == "What scale?" + assert result["questions"][0]["hint"] == "users/requests" + assert result["questions"][1]["hint"] is None + + async def test_max_questions_capped(self) -> None: + payload = { + "needs_refinement": True, + "questions": [{"question": f"Q{i}?"} for i in range(10)], + } + pm = _mock_pm(json.dumps(payload)) + result = await analyze_question("Vague?", pm, max_questions=3) + assert len(result["questions"]) == 3 + + async def test_no_models_returns_no_refinement(self) -> None: + pm = MagicMock() + pm.list_all_models.return_value = [] + result = await analyze_question("anything", pm) + assert result["needs_refinement"] is False + + async def test_json_parse_error_returns_no_refinement(self) -> None: + pm = _mock_pm("This is not JSON at all") + result = await analyze_question("anything", pm) + assert result["needs_refinement"] is False + + async def test_provider_error_returns_no_refinement(self) -> None: + pm = _mock_pm("") + provider, _ = pm.get_provider("mock:cheap") + provider.send.side_effect = RuntimeError("API down") + result = await analyze_question("anything", pm) + assert result["needs_refinement"] is False + + async def test_empty_questions_returns_no_refinement(self) -> None: + payload = {"needs_refinement": True, "questions": []} + pm = _mock_pm(json.dumps(payload)) + result = await analyze_question("anything", pm) + assert result["needs_refinement"] is False + + async def test_json_in_code_fence(self) -> None: + fenced = '```json\n{"needs_refinement": false}\n```' + pm = _mock_pm(fenced) + result = await analyze_question("specific question", pm) + assert result["needs_refinement"] is False + + +# ── enrich_question ─────────────────────────────────────────── + + +class TestEnrichQuestion: + async def test_enrichment(self) -> None: + pm = _mock_pm("What database for a 10k-user SaaS on AWS with $500/mo budget?") + result = await enrich_question( + "What database should I use?", + [ + {"question": "Scale?", "answer": "10k users"}, + {"question": "Budget?", "answer": "$500/mo"}, + ], + pm, + ) + assert "10k" in result or "database" in result + + async def test_no_models_returns_original(self) -> None: + pm = MagicMock() + pm.list_all_models.return_value = [] + result = await enrich_question("original?", [], pm) + assert result == "original?" + + async def test_provider_error_returns_original(self) -> None: + pm = _mock_pm("") + provider, _ = pm.get_provider("mock:cheap") + provider.send.side_effect = RuntimeError("boom") + result = await enrich_question("original?", [], pm) + assert result == "original?" + + async def test_empty_response_returns_original(self) -> None: + pm = _mock_pm(" ") + result = await enrich_question("original?", [], pm) + assert result == "original?" diff --git a/tests/unit/test_tool_augmented_send.py b/tests/unit/test_tool_augmented_send.py index 378b420..3f116db 100644 --- a/tests/unit/test_tool_augmented_send.py +++ b/tests/unit/test_tool_augmented_send.py @@ -43,6 +43,7 @@ async def send( stop_sequences: list[str] | None = None, response_format: str | None = None, tools: list[dict[str, object]] | None = None, + web_search: bool = False, ) -> ModelResponse: self.call_log.append( {"messages": messages, "tools": tools, "model_id": model_id} diff --git a/tests/unit/test_tool_integration.py b/tests/unit/test_tool_integration.py index 1fd4e29..e637e99 100644 --- a/tests/unit/test_tool_integration.py +++ b/tests/unit/test_tool_integration.py @@ -146,6 +146,7 @@ async def send( stop_sequences: list[str] | None = None, response_format: str | None = None, tools: list[dict[str, object]] | None = None, + web_search: bool = False, ) -> ModelResponse: self.call_log.append( { diff --git a/web/src/__tests__/refinement.test.tsx b/web/src/__tests__/refinement.test.tsx new file mode 100644 index 0000000..a7a8896 --- /dev/null +++ b/web/src/__tests__/refinement.test.tsx @@ -0,0 +1,144 @@ +import { describe, it, expect, vi } from 'vitest' +import { render, screen, fireEvent } from '@testing-library/react' +import { RefinementPanel } from '@/components/consensus/RefinementPanel' +import type { ClarifyingQuestion } from '@/api/types' + +const questions: ClarifyingQuestion[] = [ + { question: 'What is the expected scale?', hint: 'users per day' }, + { question: 'What is your budget?', hint: null }, + { question: 'Any existing infrastructure?', hint: 'cloud provider' }, +] + +describe('RefinementPanel', () => { + it('renders all tabs', () => { + render( + , + ) + expect(screen.getByText('Q1')).toBeInTheDocument() + expect(screen.getByText('Q2')).toBeInTheDocument() + expect(screen.getByText('Q3')).toBeInTheDocument() + }) + + it('shows first question by default', () => { + render( + , + ) + expect(screen.getByText('What is the expected scale?')).toBeInTheDocument() + expect(screen.getByText('users per day')).toBeInTheDocument() + }) + + it('switches tab on click', () => { + render( + , + ) + fireEvent.click(screen.getByText('Q2')) + expect(screen.getByText('What is your budget?')).toBeInTheDocument() + }) + + it('submit disabled when not all answered', () => { + render( + , + ) + const submitBtn = screen.getByText('Start Consensus') + expect(submitBtn).toBeDisabled() + }) + + it('submit enabled when all answered', () => { + render( + , + ) + const submitBtn = screen.getByText('Start Consensus') + expect(submitBtn).not.toBeDisabled() + }) + + it('calls onSubmit when submit clicked', () => { + const onSubmit = vi.fn() + render( + , + ) + fireEvent.click(screen.getByText('Start Consensus')) + expect(onSubmit).toHaveBeenCalledOnce() + }) + + it('calls onSkip when skip clicked', () => { + const onSkip = vi.fn() + render( + , + ) + fireEvent.click(screen.getByText('Skip')) + expect(onSkip).toHaveBeenCalledOnce() + }) + + it('calls onAnswer when typing', () => { + const onAnswer = vi.fn() + render( + , + ) + const textarea = screen.getByPlaceholderText('Your answer...') + fireEvent.change(textarea, { target: { value: 'test answer' } }) + expect(onAnswer).toHaveBeenCalledWith(0, 'test answer') + }) + + it('shows checkmark on answered tabs', () => { + const { container } = render( + , + ) + // Tabs with answers should have SVG checkmarks + const svgs = container.querySelectorAll('svg') + expect(svgs).toHaveLength(2) + }) +}) diff --git a/web/src/api/client.ts b/web/src/api/client.ts index f276725..c2e0197 100644 --- a/web/src/api/client.ts +++ b/web/src/api/client.ts @@ -5,6 +5,7 @@ import type { CalibrationResponse, CostResponse, DecisionSpaceResponse, + EnrichResponse, FeedbackRequest, FeedbackResponse, ForgotPasswordRequest, @@ -13,6 +14,7 @@ import type { LoginRequest, ModelsResponse, RecallResponse, + RefineResponse, RegisterRequest, ResetPasswordRequest, ResetPasswordResponse, @@ -107,6 +109,24 @@ export const api = { return request('/health') }, + // Refinement + refine(question: string, maxQuestions?: number): Promise { + return request('/refine', { + method: 'POST', + body: JSON.stringify({ question, max_questions: maxQuestions }), + }) + }, + + enrich( + originalQuestion: string, + clarifications: { question: string; answer: string }[], + ): Promise { + return request('/enrich', { + method: 'POST', + body: JSON.stringify({ original_question: originalQuestion, clarifications }), + }) + }, + // Consensus ask(body: AskRequest): Promise { return request('/ask', { diff --git a/web/src/api/types.ts b/web/src/api/types.ts index 88c259b..ffcd29c 100644 --- a/web/src/api/types.ts +++ b/web/src/api/types.ts @@ -47,6 +47,22 @@ export interface ResetPasswordResponse { message: string } +// ── Refinement types ────────────────────────────────────── + +export interface ClarifyingQuestion { + question: string + hint?: string | null +} + +export interface RefineResponse { + needs_refinement: boolean + questions: ClarifyingQuestion[] +} + +export interface EnrichResponse { + enriched_question: string +} + // ── Request types ───────────────────────────────────────── export interface AskRequest { @@ -96,6 +112,7 @@ export interface Contribution { input_tokens: number output_tokens: number cost_usd: number + citations?: Citation[] | null } export interface Decision { @@ -224,6 +241,13 @@ export interface CalibrationResponse { ece: number } +// ── Citation types ──────────────────────────────────────── + +export interface Citation { + url: string + title?: string | null +} + // ── WebSocket event types ───────────────────────────────── export type WSEventType = @@ -250,6 +274,7 @@ export interface WSPhaseComplete { phase: ConsensusPhase content?: string truncated?: boolean + citations?: Citation[] | null } export interface WSChallenge { @@ -257,6 +282,7 @@ export interface WSChallenge { model: string content: string truncated?: boolean + citations?: Citation[] | null } export interface WSCommit { diff --git a/web/src/components/consensus/ConsensusNav.tsx b/web/src/components/consensus/ConsensusNav.tsx index fcf44e7..c1247f8 100644 --- a/web/src/components/consensus/ConsensusNav.tsx +++ b/web/src/components/consensus/ConsensusNav.tsx @@ -1,5 +1,6 @@ -import { GlassPanel } from '@/components/shared' +import { GlassPanel, Disclosure } from '@/components/shared' import { useConsensusStore } from '@/stores/consensus' +import type { Citation } from '@/api/types' type PhaseStatus = 'complete' | 'active' | 'pending' @@ -40,6 +41,14 @@ function StatusDot({ status }: { status: PhaseStatus }) { return } +function displayHost(url: string): string { + try { + return new URL(url).hostname.replace(/^www\./, '') + } catch { + return url + } +} + function shortModel(model: string): string { const parts = model.split(':') return parts.length > 1 ? parts[1]! : model @@ -53,6 +62,33 @@ export function ConsensusNav() { const isStreaming = status === 'connecting' || status === 'streaming' const isComplete = status === 'complete' + // Collect all citations with their phase role, grouped by domain + type TaggedCitation = Citation & { role: 'propose' | 'challenge' | 'revise' } + const domainGroups = (() => { + const seen = new Set() + const tagged: TaggedCitation[] = [] + for (const round of rounds) { + for (const c of round.proposalCitations ?? []) { + if (!seen.has(c.url)) { seen.add(c.url); tagged.push({ ...c, role: 'propose' }) } + } + for (const ch of round.challenges) { + for (const c of ch.citations ?? []) { + if (!seen.has(c.url)) { seen.add(c.url); tagged.push({ ...c, role: 'challenge' }) } + } + } + } + // Group by hostname + const groups = new Map() + for (const c of tagged) { + const domain = displayHost(c.url) + const list = groups.get(domain) ?? [] + list.push(c) + groups.set(domain, list) + } + // Sort groups by citation count descending + return [...groups.entries()].sort((a, b) => b[1].length - a[1].length) + })() + const scrollTo = (id: string) => { document.getElementById(id)?.scrollIntoView({ behavior: 'smooth', block: 'start' }) } @@ -131,6 +167,55 @@ export function ConsensusNav() { ) })} + {domainGroups.length > 0 && ( +
+ + Sources ({domainGroups.reduce((sum, [, cs]) => sum + cs.length, 0)}) + + } + defaultOpen={false} + > +
+ {domainGroups.map(([domain, citations]) => ( + + {domain} ({citations.length}) + + } + defaultOpen={false} + > +
    + {citations.map((c, i) => ( +
  • + + {c.role === 'propose' ? 'P' : c.role === 'challenge' ? 'C' : 'R'} + + + {c.title || c.url} + +
  • + ))} +
+
+ ))} +
+
+
+ )} + ) diff --git a/web/src/components/consensus/ConsensusPanel.tsx b/web/src/components/consensus/ConsensusPanel.tsx index 45d990c..e2e2897 100644 --- a/web/src/components/consensus/ConsensusPanel.tsx +++ b/web/src/components/consensus/ConsensusPanel.tsx @@ -1,25 +1,29 @@ import { useConsensusStore } from '@/stores' -import { GlassPanel, GlowButton } from '@/components/shared' +import { GlassPanel, GlowButton, Skeleton } from '@/components/shared' import { QuestionInput } from './QuestionInput' import { PhaseCard } from './PhaseCard' import { ConsensusComplete } from './ConsensusComplete' import { CostTicker } from './CostTicker' +import { RefinementPanel } from './RefinementPanel' export function ConsensusPanel() { const { status, error, currentPhase, currentRound, rounds, decision, confidence, rigor, dissent, cost, overview, - startConsensus, reset, + clarifyingQuestions, clarificationAnswers, + submitQuestion, answerClarification, submitClarifications, skipRefinement, + reset, } = useConsensusStore() const isActive = status === 'connecting' || status === 'streaming' + const isRefining = status === 'refining' const isComplete = status === 'complete' return (
startConsensus(q, r, p, ms)} - disabled={isActive} + onSubmit={(q, r, p, ms) => submitQuestion(q, r, p, ms)} + disabled={isActive || isRefining} /> {status === 'error' && error && ( @@ -31,6 +35,27 @@ export function ConsensusPanel() { )} + {isRefining && clarifyingQuestions.length === 0 && ( + +
+ + + Analyzing question... + +
+
+ )} + + {isRefining && clarifyingQuestions.length > 0 && ( + + )} + {isComplete && decision && confidence !== null && (
)} diff --git a/web/src/components/consensus/PhaseCard.tsx b/web/src/components/consensus/PhaseCard.tsx index c2d54fb..03bc182 100644 --- a/web/src/components/consensus/PhaseCard.tsx +++ b/web/src/components/consensus/PhaseCard.tsx @@ -1,6 +1,7 @@ -import { GlassPanel, Markdown, Disclosure } from '@/components/shared' +import { GlassPanel, Markdown, Disclosure, CitationList } from '@/components/shared' import { ModelBadge } from './ModelBadge' import { StreamingText } from './StreamingText' +import type { Citation } from '@/api/types' interface PhaseCardProps { phase: string @@ -8,13 +9,14 @@ interface PhaseCardProps { models?: string[] content?: string | null isActive?: boolean - challenges?: Array<{ model: string; content: string; truncated?: boolean; error?: boolean }> + challenges?: Array<{ model: string; content: string; truncated?: boolean; error?: boolean; citations?: Citation[] | null }> collapsible?: boolean defaultOpen?: boolean truncated?: boolean + citations?: Citation[] | null } -export function PhaseCard({ phase, model, models, content, isActive, challenges, collapsible, defaultOpen = true, truncated }: PhaseCardProps) { +export function PhaseCard({ phase, model, models, content, isActive, challenges, collapsible, defaultOpen = true, truncated, citations }: PhaseCardProps) { const header = ( <> {phase} @@ -35,6 +37,9 @@ export function PhaseCard({ phase, model, models, content, isActive, challenges, ) : ( {content} )} + {!isActive && citations && citations.length > 0 && ( + + )}
)} @@ -54,6 +59,9 @@ export function PhaseCard({ phase, model, models, content, isActive, challenges, >
{ch.error ? {ch.content} : {ch.content}} + {!ch.error && ch.citations && ch.citations.length > 0 && ( + + )}
))} diff --git a/web/src/components/consensus/RefinementPanel.tsx b/web/src/components/consensus/RefinementPanel.tsx new file mode 100644 index 0000000..14e9881 --- /dev/null +++ b/web/src/components/consensus/RefinementPanel.tsx @@ -0,0 +1,114 @@ +import { useState } from 'react' +import { GlassPanel, GlowButton } from '@/components/shared' +import type { ClarifyingQuestion } from '@/api/types' + +interface RefinementPanelProps { + questions: ClarifyingQuestion[] + answers: Record + onAnswer: (index: number, answer: string) => void + onSubmit: () => void + onSkip: () => void +} + +export function RefinementPanel({ + questions, + answers, + onAnswer, + onSubmit, + onSkip, +}: RefinementPanelProps) { + const [activeTab, setActiveTab] = useState(0) + const allAnswered = questions.every((_, i) => (answers[i] ?? '').trim().length > 0) + + const handleTextChange = (value: string) => { + onAnswer(activeTab, value) + } + + const handleKeyDown = (e: React.KeyboardEvent) => { + if (e.key === 'Tab' && !e.shiftKey && (answers[activeTab] ?? '').trim()) { + const nextUnanswered = questions.findIndex( + (_, i) => i > activeTab && !(answers[i] ?? '').trim(), + ) + if (nextUnanswered >= 0) { + e.preventDefault() + setActiveTab(nextUnanswered) + } + } + } + + return ( + +
+
+ + Clarifying Questions + +
+ + {/* Tab bar */} +
+ {questions.map((_, i) => { + const answered = (answers[i] ?? '').trim().length > 0 + const isActive = i === activeTab + return ( + + ) + })} +
+ + {/* Active question */} + {questions[activeTab] && ( +
+

+ {questions[activeTab].question} +

+ {questions[activeTab].hint && ( +

+ {questions[activeTab].hint} +

+ )} +