From df66bc558d0413fc75f0a60808cd40d1ad2c2866 Mon Sep 17 00:00:00 2001 From: Michael Sitarzewski Date: Sun, 8 Mar 2026 07:30:09 -0500 Subject: [PATCH 1/2] Question refinement, tools-by-default, sidebar UX, provider tool format fix - Add pre-consensus question refinement: analyze_question() uses most expensive model to detect ambiguity, enrich_question() rewrites with user's clarifications. REST endpoints POST /api/refine + /api/enrich. Frontend: RefinementPanel with tabbed clarification UI, consensus store 'refining' status, CLI --refine flag. - Enable tools by default (ToolsConfig.enabled=True). Wire tool_registry through REST /api/ask and WebSocket /ws/ask paths (was CLI-only). - Fix tool format: each provider now transforms generic {name, description, parameters} to its native API format (Anthropic: input_schema, OpenAI/ Mistral/Perplexity: function wrapper, Google: FunctionDeclaration). - Sidebar: add new-question button (Heroicons pencil-square) + collapsible sidebar toggle. Brand left, icons right. Desktop sidebar hide/show. 1619 Python tests + 194 Vitest tests passing. Build clean. Co-Authored-By: Claude Opus 4.6 --- src/duh/api/app.py | 3 +- src/duh/api/routes/ask.py | 48 +++++- src/duh/api/routes/ws.py | 13 +- src/duh/cli/app.py | 39 +++++ src/duh/config/schema.py | 2 +- src/duh/consensus/refine.py | 156 ++++++++++++++++++ src/duh/providers/anthropic.py | 9 +- src/duh/providers/google.py | 12 +- src/duh/providers/mistral.py | 12 +- src/duh/providers/openai.py | 12 +- src/duh/providers/perplexity.py | 12 +- tests/unit/test_api_refine.py | 119 +++++++++++++ tests/unit/test_config_v02.py | 4 +- tests/unit/test_config_v03.py | 2 +- tests/unit/test_provider_tools.py | 60 +++++-- tests/unit/test_refine.py | 129 +++++++++++++++ web/src/__tests__/refinement.test.tsx | 144 ++++++++++++++++ web/src/api/client.ts | 20 +++ web/src/api/types.ts | 16 ++ .../components/consensus/ConsensusPanel.tsx | 33 +++- .../components/consensus/RefinementPanel.tsx | 114 +++++++++++++ web/src/components/layout/Shell.tsx | 29 +++- web/src/components/layout/Sidebar.tsx | 55 +++++- web/src/components/layout/TopBar.tsx | 11 +- web/src/stores/consensus.ts | 138 ++++++++++++---- web/tsconfig.tsbuildinfo | 2 +- 26 files changed, 1107 insertions(+), 87 deletions(-) create mode 100644 src/duh/consensus/refine.py create mode 100644 tests/unit/test_api_refine.py create mode 100644 tests/unit/test_refine.py create mode 100644 web/src/__tests__/refinement.test.tsx create mode 100644 web/src/components/consensus/RefinementPanel.tsx diff --git a/src/duh/api/app.py b/src/duh/api/app.py index c111c84..2bdf52e 100644 --- a/src/duh/api/app.py +++ b/src/duh/api/app.py @@ -20,7 +20,7 @@ @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncIterator[None]: """Lifespan handler: set up DB + providers on startup, tear down on shutdown.""" - from duh.cli.app import _create_db, _setup_providers + from duh.cli.app import _create_db, _setup_providers, _setup_tools config: DuhConfig = app.state.config factory, engine = await _create_db(config) @@ -29,6 +29,7 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]: app.state.db_factory = factory app.state.engine = engine app.state.provider_manager = pm + app.state.tool_registry = _setup_tools(config) extra = getattr(app.state, "extra_lifespan", None) if extra is not None: diff --git a/src/duh/api/routes/ask.py b/src/duh/api/routes/ask.py index 5e2c98f..4b1e10b 100644 --- a/src/duh/api/routes/ask.py +++ b/src/duh/api/routes/ask.py @@ -15,6 +15,25 @@ router = APIRouter(prefix="/api", tags=["consensus"]) +class RefineRequest(BaseModel): + question: str + max_questions: int = 4 + + +class RefineResponse(BaseModel): + needs_refinement: bool + questions: list[dict[str, str | None]] = [] + + +class EnrichRequest(BaseModel): + original_question: str + clarifications: list[dict[str, str]] + + +class EnrichResponse(BaseModel): + enriched_question: str + + class AskRequest(BaseModel): question: str protocol: str = "consensus" # consensus, voting, auto @@ -46,6 +65,7 @@ async def ask(body: AskRequest, request: Request) -> AskResponse | JSONResponse: config.general.max_rounds = body.rounds db_factory = getattr(request.app.state, "db_factory", None) + tool_registry = getattr(request.app.state, "tool_registry", None) try: if body.decompose: @@ -55,7 +75,7 @@ async def ask(body: AskRequest, request: Request) -> AskResponse | JSONResponse: return await _handle_voting(body, config, pm) # Default: consensus - return await _handle_consensus(body, config, pm, db_factory) + return await _handle_consensus(body, config, pm, db_factory, tool_registry) except ProviderError as exc: logger.exception("Provider error during /api/ask") @@ -78,7 +98,7 @@ async def ask(body: AskRequest, request: Request) -> AskResponse | JSONResponse: async def _handle_consensus( # type: ignore[no-untyped-def] - body: AskRequest, config, pm, db_factory=None + body: AskRequest, config, pm, db_factory=None, tool_registry=None ) -> AskResponse: """Run the consensus protocol.""" from duh.cli.app import _run_consensus @@ -87,6 +107,7 @@ async def _handle_consensus( # type: ignore[no-untyped-def] body.question, config, pm, + tool_registry=tool_registry, panel=body.panel, proposer_override=body.proposer, challengers_override=body.challengers, @@ -203,3 +224,26 @@ async def _persist_result( ) await session.commit() return str(thread.id) + + +@router.post("/refine", response_model=RefineResponse) +async def refine(body: RefineRequest, request: Request) -> RefineResponse: + """Analyze a question for ambiguity and suggest clarifications.""" + from duh.consensus.refine import analyze_question + + pm = request.app.state.provider_manager + result = await analyze_question(body.question, pm, max_questions=body.max_questions) + return RefineResponse( + needs_refinement=result.get("needs_refinement", False), + questions=result.get("questions", []), + ) + + +@router.post("/enrich", response_model=EnrichResponse) +async def enrich(body: EnrichRequest, request: Request) -> EnrichResponse: + """Rewrite a question incorporating clarification answers.""" + from duh.consensus.refine import enrich_question + + pm = request.app.state.provider_manager + enriched = await enrich_question(body.original_question, body.clarifications, pm) + return EnrichResponse(enriched_question=enriched) diff --git a/src/duh/api/routes/ws.py b/src/duh/api/routes/ws.py index 2f094f1..0a54a3b 100644 --- a/src/duh/api/routes/ws.py +++ b/src/duh/api/routes/ws.py @@ -11,6 +11,7 @@ from duh.config.schema import DuhConfig from duh.consensus.machine import RoundResult from duh.providers.manager import ProviderManager + from duh.tools.registry import ToolRegistry logger = logging.getLogger(__name__) @@ -63,6 +64,8 @@ async def ws_ask(websocket: WebSocket) -> None: pm: ProviderManager = websocket.app.state.provider_manager config.general.max_rounds = rounds + tool_registry = getattr(websocket.app.state, "tool_registry", None) + await _stream_consensus( websocket, question, @@ -71,6 +74,7 @@ async def ws_ask(websocket: WebSocket) -> None: panel=panel, proposer_override=proposer_override, challengers_override=challengers_raw, + tool_registry=tool_registry, ) except WebSocketDisconnect: @@ -93,6 +97,7 @@ async def _stream_consensus( panel: list[str] | None = None, proposer_override: str | None = None, challengers_override: list[str] | None = None, + tool_registry: ToolRegistry | None = None, ) -> None: """Run consensus loop and stream events to WebSocket.""" from duh.consensus.convergence import check_convergence @@ -132,7 +137,9 @@ async def _stream_consensus( "round": ctx.current_round, } ) - propose_resp = await handle_propose(ctx, pm, proposer) + propose_resp = await handle_propose( + ctx, pm, proposer, tool_registry=tool_registry + ) await ws.send_json( { "type": "phase_complete", @@ -155,7 +162,9 @@ async def _stream_consensus( "round": ctx.current_round, } ) - challenge_resps = await handle_challenge(ctx, pm, challengers) + challenge_resps = await handle_challenge( + ctx, pm, challengers, tool_registry=tool_registry + ) succeeded = {ch.model_ref for ch in ctx.challenges} for i, ch in enumerate(ctx.challenges): resp_truncated = ( diff --git a/src/duh/cli/app.py b/src/duh/cli/app.py index c222e2b..745c30f 100644 --- a/src/duh/cli/app.py +++ b/src/duh/cli/app.py @@ -386,6 +386,11 @@ def cli(ctx: click.Context, config_path: str | None) -> None: default=None, help="Restrict to these models only (comma-separated model refs).", ) +@click.option( + "--refine/--no-refine", + default=False, + help="Pre-consensus question refinement (ask clarifying questions).", +) @click.pass_context def ask( ctx: click.Context, @@ -397,6 +402,7 @@ def ask( proposer: str | None, challengers: str | None, panel: str | None, + refine: bool, ) -> None: """Run a consensus query. @@ -415,6 +421,14 @@ def ask( panel_list = panel.split(",") if panel else None challengers_list = challengers.split(",") if challengers else None + # Question refinement (pre-consensus clarification) + if refine: + try: + question = asyncio.run(_refine_question(question, config)) + except DuhError as e: + _error(str(e)) + return + # Determine effective protocol effective_protocol = protocol or config.general.protocol @@ -463,6 +477,31 @@ def ask( ) +async def _refine_question(question: str, config: DuhConfig) -> str: + """Run question refinement interactively on the CLI.""" + from duh.consensus.refine import analyze_question, enrich_question + + pm = await _setup_providers(config) + if not pm.list_all_models(): + return question + + result = await analyze_question(question, pm) + if not result.get("needs_refinement"): + return question + + questions = result.get("questions", []) + click.echo("\nClarifying questions:") + clarifications = [] + for q in questions: + hint = f" ({q['hint']})" if q.get("hint") else "" + answer = click.prompt(f" {q['question']}{hint}") + clarifications.append({"question": q["question"], "answer": answer}) + + enriched = await enrich_question(question, clarifications, pm) + click.echo(f"\nRefined question: {enriched}\n") + return enriched + + async def _ask_async( question: str, config: DuhConfig, diff --git a/src/duh/config/schema.py b/src/duh/config/schema.py index 4c5294c..6c1a603 100644 --- a/src/duh/config/schema.py +++ b/src/duh/config/schema.py @@ -74,7 +74,7 @@ class CodeExecutionConfig(BaseModel): class ToolsConfig(BaseModel): """Tool framework configuration.""" - enabled: bool = False + enabled: bool = True max_rounds: int = 5 web_search: WebSearchConfig = Field(default_factory=WebSearchConfig) code_execution: CodeExecutionConfig = Field(default_factory=CodeExecutionConfig) diff --git a/src/duh/consensus/refine.py b/src/duh/consensus/refine.py new file mode 100644 index 0000000..5a96325 --- /dev/null +++ b/src/duh/consensus/refine.py @@ -0,0 +1,156 @@ +"""Question refinement: pre-consensus clarification step. + +Uses the most capable (most expensive) model to evaluate whether a +question is ambiguous and, if so, generate clarifying questions. A +second call rewrites the original question incorporating the user's +answers. This is the user's first impression — it must be exceptional. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any + +from duh.consensus.json_extract import JSONExtractionError, extract_json +from duh.providers.base import PromptMessage + +if TYPE_CHECKING: + from duh.providers.manager import ProviderManager + +logger = logging.getLogger(__name__) + + +async def analyze_question( + question: str, + provider_manager: ProviderManager, + *, + max_questions: int = 4, +) -> dict[str, Any]: + """Evaluate whether *question* needs clarification. + + Returns ``{"needs_refinement": false}`` when the question is specific + enough, or ``{"needs_refinement": true, "questions": [...]}`` with up + to *max_questions* clarifying questions otherwise. + + On any failure (no models, JSON parse error, provider error) the + function returns ``{"needs_refinement": false}`` so consensus can + proceed uninterrupted. + """ + models = provider_manager.list_all_models() + if not models: + return {"needs_refinement": False} + + best = max(models, key=lambda m: m.input_cost_per_mtok) + provider, model_id = provider_manager.get_provider(best.model_ref) + + prompt = ( + "You are an expert question analyst and strategic thinker. Your job " + "is to determine whether a question contains enough context for a " + "panel of experts to give a truly excellent, specific, actionable " + "answer — or whether critical context is missing.\n\n" + "Think deeply: what assumptions would an expert panel have to make? " + "If those assumptions could lead to fundamentally different answers, " + "the question needs refinement.\n\n" + "Consider missing context such as: scale, budget, team size/expertise, " + "timeline, technical constraints, use-case, existing infrastructure, " + "success criteria, risk tolerance, regulatory requirements, or " + "geographic/market context.\n\n" + f"Question: {question}\n\n" + "Return ONLY a JSON object. If the question is already specific " + "enough for expert-quality advice:\n" + '{"needs_refinement": false}\n\n' + "If clarification would meaningfully improve the answer, return:\n" + '{"needs_refinement": true, "questions": [\n' + ' {"question": "...", "hint": "brief guidance on what kind of answer helps"}\n' + "]}\n\n" + f"Include at most {max_questions} questions. Each should be concise, " + "focused on one critical missing dimension, and phrased in a way that " + "feels natural and respectful — like a senior consultant clarifying " + "scope before giving advice. Only ask questions whose answers would " + "materially change the recommendation." + ) + + try: + response = await provider.send( + [PromptMessage(role="user", content=prompt)], + model_id, + max_tokens=500, + temperature=0.3, + response_format="json", + ) + data = extract_json(response.content) + provider_manager.record_usage(best, response.usage) + + if not data.get("needs_refinement"): + return {"needs_refinement": False} + + questions = data.get("questions", []) + if not isinstance(questions, list) or not questions: + return {"needs_refinement": False} + + # Normalise and cap + clean: list[dict[str, str | None]] = [] + for q in questions[:max_questions]: + if isinstance(q, dict) and q.get("question"): + clean.append( + { + "question": str(q["question"]), + "hint": str(q["hint"]) if q.get("hint") else None, + } + ) + + if not clean: + return {"needs_refinement": False} + + return {"needs_refinement": True, "questions": clean} + + except (JSONExtractionError, Exception): + logger.debug("Question refinement analysis failed, skipping", exc_info=True) + return {"needs_refinement": False} + + +async def enrich_question( + original: str, + clarifications: list[dict[str, str]], + provider_manager: ProviderManager, +) -> str: + """Rewrite *original* incorporating clarification answers. + + Each entry in *clarifications* has ``question`` and ``answer`` keys. + Returns the enriched question string, or the original on failure. + """ + models = provider_manager.list_all_models() + if not models: + return original + + best = max(models, key=lambda m: m.input_cost_per_mtok) + provider, model_id = provider_manager.get_provider(best.model_ref) + + qa_block = "\n".join( + f"Q: {c['question']}\nA: {c['answer']}" for c in clarifications + ) + + prompt = ( + "Rewrite the following question into a single, specific, " + "self-contained question that incorporates all the additional " + "context provided below. Keep the rewritten question natural and " + "concise — do not repeat the clarifications verbatim, just weave " + "the context in.\n\n" + f"Original question: {original}\n\n" + f"Additional context:\n{qa_block}\n\n" + "Return ONLY the rewritten question, nothing else." + ) + + try: + response = await provider.send( + [PromptMessage(role="user", content=prompt)], + model_id, + max_tokens=500, + temperature=0.3, + ) + provider_manager.record_usage(best, response.usage) + enriched = response.content.strip() + return enriched if enriched else original + except Exception: + logger.debug("Question enrichment failed, using original", exc_info=True) + return original diff --git a/src/duh/providers/anthropic.py b/src/duh/providers/anthropic.py index 8a79e62..1226868 100644 --- a/src/duh/providers/anthropic.py +++ b/src/duh/providers/anthropic.py @@ -125,7 +125,14 @@ async def send( if stop_sequences: kwargs["stop_sequences"] = stop_sequences if tools: - kwargs["tools"] = tools + kwargs["tools"] = [ + { + "name": t["name"], + "description": t.get("description", ""), + "input_schema": t.get("input_schema") or t.get("parameters", {}), + } + for t in tools + ] start = time.monotonic() try: diff --git a/src/duh/providers/google.py b/src/duh/providers/google.py index 30a400e..a81a324 100644 --- a/src/duh/providers/google.py +++ b/src/duh/providers/google.py @@ -120,7 +120,17 @@ async def send( if response_format == "json": config_kwargs["response_mime_type"] = "application/json" if tools: - config_kwargs["tools"] = tools + func_decls = [ + genai.types.FunctionDeclaration( + name=str(t["name"]), + description=str(t.get("description", "")), + parameters=t.get("parameters") or t.get("input_schema", {}), # type: ignore[arg-type] + ) + for t in tools + ] + config_kwargs["tools"] = [ + genai.types.Tool(function_declarations=func_decls) + ] config = genai.types.GenerateContentConfig(**config_kwargs) diff --git a/src/duh/providers/mistral.py b/src/duh/providers/mistral.py index 7aae04b..5284cea 100644 --- a/src/duh/providers/mistral.py +++ b/src/duh/providers/mistral.py @@ -129,7 +129,17 @@ async def send( if response_format == "json": kwargs["response_format"] = {"type": "json_object"} if tools: - kwargs["tools"] = tools + kwargs["tools"] = [ + { + "type": "function", + "function": { + "name": t["name"], + "description": t.get("description", ""), + "parameters": t.get("parameters") or t.get("input_schema", {}), + }, + } + for t in tools + ] start = time.monotonic() try: diff --git a/src/duh/providers/openai.py b/src/duh/providers/openai.py index 00b6bcc..72ed5f9 100644 --- a/src/duh/providers/openai.py +++ b/src/duh/providers/openai.py @@ -143,7 +143,17 @@ async def send( if response_format == "json": kwargs["response_format"] = {"type": "json_object"} if tools: - kwargs["tools"] = tools + kwargs["tools"] = [ + { + "type": "function", + "function": { + "name": t["name"], + "description": t.get("description", ""), + "parameters": t.get("parameters") or t.get("input_schema", {}), + }, + } + for t in tools + ] start = time.monotonic() try: diff --git a/src/duh/providers/perplexity.py b/src/duh/providers/perplexity.py index 6cd7567..f450d39 100644 --- a/src/duh/providers/perplexity.py +++ b/src/duh/providers/perplexity.py @@ -138,7 +138,17 @@ async def send( if response_format == "json": kwargs["response_format"] = {"type": "json_object"} if tools: - kwargs["tools"] = tools + kwargs["tools"] = [ + { + "type": "function", + "function": { + "name": t["name"], + "description": t.get("description", ""), + "parameters": t.get("parameters") or t.get("input_schema", {}), + }, + } + for t in tools + ] start = time.monotonic() try: diff --git a/tests/unit/test_api_refine.py b/tests/unit/test_api_refine.py new file mode 100644 index 0000000..bc31d0a --- /dev/null +++ b/tests/unit/test_api_refine.py @@ -0,0 +1,119 @@ +"""Tests for POST /api/refine and POST /api/enrich endpoints.""" + +from __future__ import annotations + +import json +from unittest.mock import AsyncMock, patch + +from fastapi.testclient import TestClient +from sqlalchemy import event +from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine + +from duh.api.app import create_app +from duh.config.schema import DuhConfig +from duh.memory.models import Base +from duh.providers.manager import ProviderManager +from tests.fixtures.providers import MockProvider + + +async def _make_app() -> TestClient: + """Create a test app with mocked providers and in-memory DB.""" + config = DuhConfig() + config.database.url = "sqlite+aiosqlite:///:memory:" + + engine = create_async_engine("sqlite+aiosqlite://") + + @event.listens_for(engine.sync_engine, "connect") + def _enable_fks(dbapi_conn, connection_record): # type: ignore[no-untyped-def] + cursor = dbapi_conn.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + factory = async_sessionmaker(engine, expire_on_commit=False) + + mock_prov = MockProvider( + provider_id="mock", + responses={"model-a": json.dumps({"needs_refinement": False})}, + input_cost=1.0, + output_cost=5.0, + ) + pm = ProviderManager(cost_hard_limit=100.0) + await pm.register(mock_prov) # type: ignore[arg-type] + + app = create_app(config) + app.state.db_factory = factory + app.state.engine = engine + app.state.provider_manager = pm + return TestClient(app, raise_server_exceptions=False) + + +class TestRefineEndpoint: + async def test_refine_no_refinement(self) -> None: + client = await _make_app() + with patch( + "duh.consensus.refine.analyze_question", + new_callable=AsyncMock, + return_value={"needs_refinement": False}, + ): + resp = client.post("/api/refine", json={"question": "What is 2+2?"}) + assert resp.status_code == 200 + data = resp.json() + assert data["needs_refinement"] is False + assert data["questions"] == [] + + async def test_refine_with_questions(self) -> None: + client = await _make_app() + questions = [ + {"question": "What scale?", "hint": "users/day"}, + {"question": "Budget?", "hint": None}, + ] + with patch( + "duh.consensus.refine.analyze_question", + new_callable=AsyncMock, + return_value={"needs_refinement": True, "questions": questions}, + ): + resp = client.post("/api/refine", json={"question": "What DB?"}) + assert resp.status_code == 200 + data = resp.json() + assert data["needs_refinement"] is True + assert len(data["questions"]) == 2 + + async def test_refine_custom_max_questions(self) -> None: + client = await _make_app() + with patch( + "duh.consensus.refine.analyze_question", + new_callable=AsyncMock, + return_value={"needs_refinement": False}, + ) as mock_analyze: + client.post( + "/api/refine", + json={"question": "Test?", "max_questions": 2}, + ) + mock_analyze.assert_called_once() + _, kwargs = mock_analyze.call_args + assert kwargs["max_questions"] == 2 + + +class TestEnrichEndpoint: + async def test_enrich(self) -> None: + client = await _make_app() + with patch( + "duh.consensus.refine.enrich_question", + new_callable=AsyncMock, + return_value="What DB for a 10k-user SaaS?", + ): + resp = client.post( + "/api/enrich", + json={ + "original_question": "What DB?", + "clarifications": [ + {"question": "Scale?", "answer": "10k users"}, + ], + }, + ) + assert resp.status_code == 200 + data = resp.json() + assert "10k" in data["enriched_question"] diff --git a/tests/unit/test_config_v02.py b/tests/unit/test_config_v02.py index edc20d6..9a4227b 100644 --- a/tests/unit/test_config_v02.py +++ b/tests/unit/test_config_v02.py @@ -34,7 +34,7 @@ def test_defaults(self) -> None: class TestToolsConfig: def test_defaults(self) -> None: cfg = ToolsConfig() - assert cfg.enabled is False + assert cfg.enabled is True assert cfg.max_rounds == 5 assert cfg.web_search.backend == "duckduckgo" assert cfg.code_execution.enabled is False @@ -93,6 +93,6 @@ def test_backward_compatible(self) -> None: assert cfg.cost.hard_limit == 10.00 assert "anthropic" in cfg.providers # v0.2 fields have safe defaults - assert cfg.tools.enabled is False + assert cfg.tools.enabled is True assert cfg.voting.enabled is False assert cfg.taxonomy.enabled is False diff --git a/tests/unit/test_config_v03.py b/tests/unit/test_config_v03.py index c1319e9..7c43f8d 100644 --- a/tests/unit/test_config_v03.py +++ b/tests/unit/test_config_v03.py @@ -127,7 +127,7 @@ def test_backward_compatible(self) -> None: assert cfg.general.max_rounds == 3 assert cfg.cost.hard_limit == 10.00 assert "anthropic" in cfg.providers - assert cfg.tools.enabled is False + assert cfg.tools.enabled is True assert cfg.api.host == "127.0.0.1" diff --git a/tests/unit/test_provider_tools.py b/tests/unit/test_provider_tools.py index 14c4025..d1c6c33 100644 --- a/tests/unit/test_provider_tools.py +++ b/tests/unit/test_provider_tools.py @@ -18,17 +18,15 @@ # ── Shared fixtures ────────────────────────────────────────────── +# Generic tool format (as produced by tool_augmented_send) SAMPLE_TOOLS: list[dict[str, object]] = [ { - "type": "function", - "function": { - "name": "web_search", - "description": "Search the web", - "parameters": { - "type": "object", - "properties": {"query": {"type": "string"}}, - "required": ["query"], - }, + "name": "web_search", + "description": "Search the web", + "parameters": { + "type": "object", + "properties": {"query": {"type": "string"}}, + "required": ["query"], }, } ] @@ -90,7 +88,11 @@ async def test_tools_param_forwarded(self) -> None: provider = OpenAIProvider(client=client) await provider.send(USER_MSG, "gpt-5.2", tools=SAMPLE_TOOLS) call_kwargs = client.chat.completions.create.call_args.kwargs - assert call_kwargs["tools"] is SAMPLE_TOOLS + tools = call_kwargs["tools"] + assert len(tools) == 1 + assert tools[0]["type"] == "function" + assert tools[0]["function"]["name"] == "web_search" + assert tools[0]["function"]["parameters"]["type"] == "object" async def test_no_tools_param_omitted(self) -> None: client = _oai_make_client() @@ -258,7 +260,10 @@ async def test_tools_param_forwarded(self) -> None: provider = AnthropicProvider(client=client) await provider.send(USER_MSG, "claude-opus-4-6", tools=SAMPLE_TOOLS) call_kwargs = client.messages.create.call_args.kwargs - assert call_kwargs["tools"] is SAMPLE_TOOLS + tools = call_kwargs["tools"] + assert len(tools) == 1 + assert tools[0]["name"] == "web_search" + assert tools[0]["input_schema"]["type"] == "object" async def test_no_tools_param_omitted(self) -> None: from duh.providers.anthropic import AnthropicProvider @@ -454,10 +459,19 @@ def _mock_genai_config(**kwargs: Any) -> MagicMock: side_effect=_mock_genai_config, ) +_PATCH_FUNC_DECL = patch( + "duh.providers.google.genai.types.FunctionDeclaration", + side_effect=lambda **kwargs: MagicMock(**kwargs), +) + +_PATCH_TOOL = patch( + "duh.providers.google.genai.types.Tool", + side_effect=lambda **kwargs: MagicMock(**kwargs), +) + class TestGoogleToolForwarding: - @_PATCH_CONFIG - async def test_tools_param_forwarded(self, _mock_cfg: Any) -> None: + async def test_tools_param_forwarded(self) -> None: from duh.providers.google import GoogleProvider client = _google_make_client() @@ -465,7 +479,9 @@ async def test_tools_param_forwarded(self, _mock_cfg: Any) -> None: await provider.send(USER_MSG, "gemini-2.5-flash", tools=SAMPLE_TOOLS) call_kwargs = client.aio.models.generate_content.call_args config = call_kwargs.kwargs["config"] - assert config.tools is SAMPLE_TOOLS + # Google wraps tools in genai.types.Tool with function_declarations + assert config.tools is not None + assert len(config.tools) == 1 async def test_no_tools_param_not_in_config(self) -> None: from duh.providers.google import GoogleProvider @@ -490,8 +506,10 @@ async def test_response_format_json(self) -> None: class TestGoogleToolCallParsing: + @_PATCH_TOOL + @_PATCH_FUNC_DECL @_PATCH_CONFIG - async def test_single_function_call_parsed(self, _mock_cfg: Any) -> None: + async def test_single_function_call_parsed(self, *_mocks: Any) -> None: from duh.providers.google import GoogleProvider response = _google_make_response_with_function_calls( @@ -508,8 +526,10 @@ async def test_single_function_call_parsed(self, _mock_cfg: Any) -> None: assert tc.id == "google-web_search" assert json.loads(tc.arguments) == {"query": "cats"} + @_PATCH_TOOL + @_PATCH_FUNC_DECL @_PATCH_CONFIG - async def test_multiple_function_calls_parsed(self, _mock_cfg: Any) -> None: + async def test_multiple_function_calls_parsed(self, *_mocks: Any) -> None: from duh.providers.google import GoogleProvider response = _google_make_response_with_function_calls( @@ -524,8 +544,10 @@ async def test_multiple_function_calls_parsed(self, _mock_cfg: Any) -> None: assert resp.tool_calls is not None assert len(resp.tool_calls) == 2 + @_PATCH_TOOL + @_PATCH_FUNC_DECL @_PATCH_CONFIG - async def test_no_function_calls_returns_none(self, _mock_cfg: Any) -> None: + async def test_no_function_calls_returns_none(self, *_mocks: Any) -> None: from duh.providers.google import GoogleProvider response = _google_make_response_with_function_calls(text="No tools needed") @@ -534,8 +556,10 @@ async def test_no_function_calls_returns_none(self, _mock_cfg: Any) -> None: resp = await provider.send(USER_MSG, "gemini-2.5-flash", tools=SAMPLE_TOOLS) assert resp.tool_calls is None + @_PATCH_TOOL + @_PATCH_FUNC_DECL @_PATCH_CONFIG - async def test_function_call_with_no_args(self, _mock_cfg: Any) -> None: + async def test_function_call_with_no_args(self, *_mocks: Any) -> None: from duh.providers.google import GoogleProvider response = _google_make_response_with_function_calls( diff --git a/tests/unit/test_refine.py b/tests/unit/test_refine.py new file mode 100644 index 0000000..453d651 --- /dev/null +++ b/tests/unit/test_refine.py @@ -0,0 +1,129 @@ +"""Tests for question refinement (analyze_question / enrich_question).""" + +from __future__ import annotations + +import json +from unittest.mock import AsyncMock, MagicMock + +from duh.consensus.refine import analyze_question, enrich_question + + +def _mock_pm(response_content: str) -> MagicMock: + """Create a mock ProviderManager that returns *response_content*.""" + model = MagicMock() + model.input_cost_per_mtok = 0.5 + model.model_ref = "mock:cheap" + + provider = AsyncMock() + provider.send = AsyncMock( + return_value=MagicMock( + content=response_content, + usage=MagicMock(input_tokens=10, output_tokens=20), + ) + ) + + pm = MagicMock() + pm.list_all_models.return_value = [model] + pm.get_provider.return_value = (provider, "cheap") + pm.record_usage = MagicMock() + return pm + + +# ── analyze_question ────────────────────────────────────────── + + +class TestAnalyzeQuestion: + async def test_no_refinement_needed(self) -> None: + pm = _mock_pm(json.dumps({"needs_refinement": False})) + result = await analyze_question("What is 2+2?", pm) + assert result["needs_refinement"] is False + + async def test_refinement_needed(self) -> None: + payload = { + "needs_refinement": True, + "questions": [ + {"question": "What scale?", "hint": "users/requests"}, + {"question": "Budget?", "hint": None}, + ], + } + pm = _mock_pm(json.dumps(payload)) + result = await analyze_question("What database should I use?", pm) + assert result["needs_refinement"] is True + assert len(result["questions"]) == 2 + assert result["questions"][0]["question"] == "What scale?" + assert result["questions"][0]["hint"] == "users/requests" + assert result["questions"][1]["hint"] is None + + async def test_max_questions_capped(self) -> None: + payload = { + "needs_refinement": True, + "questions": [{"question": f"Q{i}?"} for i in range(10)], + } + pm = _mock_pm(json.dumps(payload)) + result = await analyze_question("Vague?", pm, max_questions=3) + assert len(result["questions"]) == 3 + + async def test_no_models_returns_no_refinement(self) -> None: + pm = MagicMock() + pm.list_all_models.return_value = [] + result = await analyze_question("anything", pm) + assert result["needs_refinement"] is False + + async def test_json_parse_error_returns_no_refinement(self) -> None: + pm = _mock_pm("This is not JSON at all") + result = await analyze_question("anything", pm) + assert result["needs_refinement"] is False + + async def test_provider_error_returns_no_refinement(self) -> None: + pm = _mock_pm("") + provider, _ = pm.get_provider("mock:cheap") + provider.send.side_effect = RuntimeError("API down") + result = await analyze_question("anything", pm) + assert result["needs_refinement"] is False + + async def test_empty_questions_returns_no_refinement(self) -> None: + payload = {"needs_refinement": True, "questions": []} + pm = _mock_pm(json.dumps(payload)) + result = await analyze_question("anything", pm) + assert result["needs_refinement"] is False + + async def test_json_in_code_fence(self) -> None: + fenced = '```json\n{"needs_refinement": false}\n```' + pm = _mock_pm(fenced) + result = await analyze_question("specific question", pm) + assert result["needs_refinement"] is False + + +# ── enrich_question ─────────────────────────────────────────── + + +class TestEnrichQuestion: + async def test_enrichment(self) -> None: + pm = _mock_pm("What database for a 10k-user SaaS on AWS with $500/mo budget?") + result = await enrich_question( + "What database should I use?", + [ + {"question": "Scale?", "answer": "10k users"}, + {"question": "Budget?", "answer": "$500/mo"}, + ], + pm, + ) + assert "10k" in result or "database" in result + + async def test_no_models_returns_original(self) -> None: + pm = MagicMock() + pm.list_all_models.return_value = [] + result = await enrich_question("original?", [], pm) + assert result == "original?" + + async def test_provider_error_returns_original(self) -> None: + pm = _mock_pm("") + provider, _ = pm.get_provider("mock:cheap") + provider.send.side_effect = RuntimeError("boom") + result = await enrich_question("original?", [], pm) + assert result == "original?" + + async def test_empty_response_returns_original(self) -> None: + pm = _mock_pm(" ") + result = await enrich_question("original?", [], pm) + assert result == "original?" diff --git a/web/src/__tests__/refinement.test.tsx b/web/src/__tests__/refinement.test.tsx new file mode 100644 index 0000000..a7a8896 --- /dev/null +++ b/web/src/__tests__/refinement.test.tsx @@ -0,0 +1,144 @@ +import { describe, it, expect, vi } from 'vitest' +import { render, screen, fireEvent } from '@testing-library/react' +import { RefinementPanel } from '@/components/consensus/RefinementPanel' +import type { ClarifyingQuestion } from '@/api/types' + +const questions: ClarifyingQuestion[] = [ + { question: 'What is the expected scale?', hint: 'users per day' }, + { question: 'What is your budget?', hint: null }, + { question: 'Any existing infrastructure?', hint: 'cloud provider' }, +] + +describe('RefinementPanel', () => { + it('renders all tabs', () => { + render( + , + ) + expect(screen.getByText('Q1')).toBeInTheDocument() + expect(screen.getByText('Q2')).toBeInTheDocument() + expect(screen.getByText('Q3')).toBeInTheDocument() + }) + + it('shows first question by default', () => { + render( + , + ) + expect(screen.getByText('What is the expected scale?')).toBeInTheDocument() + expect(screen.getByText('users per day')).toBeInTheDocument() + }) + + it('switches tab on click', () => { + render( + , + ) + fireEvent.click(screen.getByText('Q2')) + expect(screen.getByText('What is your budget?')).toBeInTheDocument() + }) + + it('submit disabled when not all answered', () => { + render( + , + ) + const submitBtn = screen.getByText('Start Consensus') + expect(submitBtn).toBeDisabled() + }) + + it('submit enabled when all answered', () => { + render( + , + ) + const submitBtn = screen.getByText('Start Consensus') + expect(submitBtn).not.toBeDisabled() + }) + + it('calls onSubmit when submit clicked', () => { + const onSubmit = vi.fn() + render( + , + ) + fireEvent.click(screen.getByText('Start Consensus')) + expect(onSubmit).toHaveBeenCalledOnce() + }) + + it('calls onSkip when skip clicked', () => { + const onSkip = vi.fn() + render( + , + ) + fireEvent.click(screen.getByText('Skip')) + expect(onSkip).toHaveBeenCalledOnce() + }) + + it('calls onAnswer when typing', () => { + const onAnswer = vi.fn() + render( + , + ) + const textarea = screen.getByPlaceholderText('Your answer...') + fireEvent.change(textarea, { target: { value: 'test answer' } }) + expect(onAnswer).toHaveBeenCalledWith(0, 'test answer') + }) + + it('shows checkmark on answered tabs', () => { + const { container } = render( + , + ) + // Tabs with answers should have SVG checkmarks + const svgs = container.querySelectorAll('svg') + expect(svgs).toHaveLength(2) + }) +}) diff --git a/web/src/api/client.ts b/web/src/api/client.ts index f276725..c2e0197 100644 --- a/web/src/api/client.ts +++ b/web/src/api/client.ts @@ -5,6 +5,7 @@ import type { CalibrationResponse, CostResponse, DecisionSpaceResponse, + EnrichResponse, FeedbackRequest, FeedbackResponse, ForgotPasswordRequest, @@ -13,6 +14,7 @@ import type { LoginRequest, ModelsResponse, RecallResponse, + RefineResponse, RegisterRequest, ResetPasswordRequest, ResetPasswordResponse, @@ -107,6 +109,24 @@ export const api = { return request('/health') }, + // Refinement + refine(question: string, maxQuestions?: number): Promise { + return request('/refine', { + method: 'POST', + body: JSON.stringify({ question, max_questions: maxQuestions }), + }) + }, + + enrich( + originalQuestion: string, + clarifications: { question: string; answer: string }[], + ): Promise { + return request('/enrich', { + method: 'POST', + body: JSON.stringify({ original_question: originalQuestion, clarifications }), + }) + }, + // Consensus ask(body: AskRequest): Promise { return request('/ask', { diff --git a/web/src/api/types.ts b/web/src/api/types.ts index 88c259b..3fd0fda 100644 --- a/web/src/api/types.ts +++ b/web/src/api/types.ts @@ -47,6 +47,22 @@ export interface ResetPasswordResponse { message: string } +// ── Refinement types ────────────────────────────────────── + +export interface ClarifyingQuestion { + question: string + hint?: string | null +} + +export interface RefineResponse { + needs_refinement: boolean + questions: ClarifyingQuestion[] +} + +export interface EnrichResponse { + enriched_question: string +} + // ── Request types ───────────────────────────────────────── export interface AskRequest { diff --git a/web/src/components/consensus/ConsensusPanel.tsx b/web/src/components/consensus/ConsensusPanel.tsx index 45d990c..9978a6a 100644 --- a/web/src/components/consensus/ConsensusPanel.tsx +++ b/web/src/components/consensus/ConsensusPanel.tsx @@ -1,25 +1,29 @@ import { useConsensusStore } from '@/stores' -import { GlassPanel, GlowButton } from '@/components/shared' +import { GlassPanel, GlowButton, Skeleton } from '@/components/shared' import { QuestionInput } from './QuestionInput' import { PhaseCard } from './PhaseCard' import { ConsensusComplete } from './ConsensusComplete' import { CostTicker } from './CostTicker' +import { RefinementPanel } from './RefinementPanel' export function ConsensusPanel() { const { status, error, currentPhase, currentRound, rounds, decision, confidence, rigor, dissent, cost, overview, - startConsensus, reset, + clarifyingQuestions, clarificationAnswers, + submitQuestion, answerClarification, submitClarifications, skipRefinement, + reset, } = useConsensusStore() const isActive = status === 'connecting' || status === 'streaming' + const isRefining = status === 'refining' const isComplete = status === 'complete' return (
startConsensus(q, r, p, ms)} - disabled={isActive} + onSubmit={(q, r, p, ms) => submitQuestion(q, r, p, ms)} + disabled={isActive || isRefining} /> {status === 'error' && error && ( @@ -31,6 +35,27 @@ export function ConsensusPanel() { )} + {isRefining && clarifyingQuestions.length === 0 && ( + +
+ + + Analyzing question... + +
+
+ )} + + {isRefining && clarifyingQuestions.length > 0 && ( + + )} + {isComplete && decision && confidence !== null && (
+ onAnswer: (index: number, answer: string) => void + onSubmit: () => void + onSkip: () => void +} + +export function RefinementPanel({ + questions, + answers, + onAnswer, + onSubmit, + onSkip, +}: RefinementPanelProps) { + const [activeTab, setActiveTab] = useState(0) + const allAnswered = questions.every((_, i) => (answers[i] ?? '').trim().length > 0) + + const handleTextChange = (value: string) => { + onAnswer(activeTab, value) + } + + const handleKeyDown = (e: React.KeyboardEvent) => { + if (e.key === 'Tab' && !e.shiftKey && (answers[activeTab] ?? '').trim()) { + const nextUnanswered = questions.findIndex( + (_, i) => i > activeTab && !(answers[i] ?? '').trim(), + ) + if (nextUnanswered >= 0) { + e.preventDefault() + setActiveTab(nextUnanswered) + } + } + } + + return ( + +
+
+ + Clarifying Questions + +
+ + {/* Tab bar */} +
+ {questions.map((_, i) => { + const answered = (answers[i] ?? '').trim().length > 0 + const isActive = i === activeTab + return ( + + ) + })} +
+ + {/* Active question */} + {questions[activeTab] && ( +
+

+ {questions[activeTab].question} +

+ {questions[activeTab].hint && ( +

+ {questions[activeTab].hint} +

+ )} +