diff --git a/anton/chat.py b/anton/chat.py index 7f28a39..0164cd0 100644 --- a/anton/chat.py +++ b/anton/chat.py @@ -35,7 +35,7 @@ handle_setup, handle_setup_models, ) -from anton.commands.ui import handle_theme, print_slash_help +from anton.commands.ui import handle_explain, handle_theme, print_slash_help from anton.utils.clipboard import ( ensure_clipboard, format_clipboard_image_message, @@ -1274,6 +1274,9 @@ def _bottom_toolbar(): elif cmd == "/unpublish": await _handle_unpublish(console, settings, workspace) continue + elif cmd == "/explain": + handle_explain(console, settings.workspace_path) + continue elif cmd == "/help": print_slash_help(console) continue diff --git a/anton/commands/ui.py b/anton/commands/ui.py index f5bcabf..c07a62b 100644 --- a/anton/commands/ui.py +++ b/anton/commands/ui.py @@ -1,9 +1,11 @@ -"""Slash-command handlers for /theme and /help.""" +"""Slash-command handlers for /theme, /explain, and /help.""" from __future__ import annotations from rich.console import Console +from anton.explainability import ExplainabilityStore + def handle_theme(console: Console, arg: str) -> None: """Switch the color theme (light/dark).""" @@ -17,7 +19,9 @@ def handle_theme(console: Console, arg: str) -> None: elif arg in ("light", "dark"): new_mode = arg else: - console.print(f"[anton.warning]Unknown theme '{arg}'. Use: /theme light | /theme dark[/]") + console.print( + f"[anton.warning]Unknown theme '{arg}'. Use: /theme light | /theme dark[/]" + ) console.print() return @@ -37,7 +41,9 @@ def print_slash_help(console: Console) -> None: console.print(" [bold]/llm[/] — Change LLM provider or API key") console.print("\n[bold]Data Connections[/]") - console.print(" [bold]/connect[/] — Connect a database or API to your Local Vault") + console.print( + " [bold]/connect[/] — Connect a database or API to your Local Vault" + ) console.print(" [bold]/list[/] — List all saved connections") console.print(" [bold]/edit[/] — Edit credentials for an existing connection") console.print(" [bold]/remove[/] — Remove a saved connection") @@ -53,9 +59,62 @@ def print_slash_help(console: Console) -> None: console.print(" [bold]/resume[/] — Continue a previous session") console.print(" [bold]/publish[/] — Publish an HTML report to the web") console.print(" [bold]/unpublish[/] — Remove a published report") + console.print( + " [bold]/explain[/] — Show explainability details for the latest answer" + ) console.print("\n[bold]General[/]") console.print(" [bold]/help[/] — Show this help menu") console.print(" [bold]exit[/] — Exit the chat") console.print() + + +def handle_explain(console: Console, workspace_path) -> None: + """Print explainability details for the latest answer in the workspace.""" + store = ExplainabilityStore(workspace_path) + record = store.load_latest() + if record is None: + console.print( + "[anton.warning]No explainability record found yet for this workspace.[/]" + ) + console.print() + return + + console.print() + console.print("[anton.cyan]Explain This Answer[/]") + console.print(f"[anton.muted]Turn {record.turn} • {record.created_at}[/]") + console.print() + + console.print("[bold]Summary[/]") + console.print(record.summary or "No summary available.") + console.print() + + console.print("[bold]Data Sources Used[/]") + if record.data_sources: + for source in record.data_sources: + engine = source.get("engine") + if engine: + console.print(f" - {source.get('name', 'Unknown')} ({engine})") + else: + console.print(f" - {source.get('name', 'Unknown')}") + else: + console.print(" - None captured") + console.print() + + console.print("[bold]Generated SQL[/]") + if record.sql_queries: + for i, query in enumerate(record.sql_queries, 1): + header = f" Query {i}: {query.get('datasource', 'Unknown datasource')}" + if query.get("engine"): + header += f" ({query['engine']})" + console.print(header) + console.print("```sql") + console.print(query.get("sql", "")) + console.print("```") + if query.get("status") == "error" and query.get("error_message"): + console.print(f"[anton.warning]{query['error_message']}[/]") + console.print() + else: + console.print(" - No SQL generated") + console.print() diff --git a/anton/core/backends/scratchpad_boot.py b/anton/core/backends/scratchpad_boot.py index 814a550..adf8647 100644 --- a/anton/core/backends/scratchpad_boot.py +++ b/anton/core/backends/scratchpad_boot.py @@ -12,6 +12,7 @@ # Persistent namespace across cells namespace = {"__builtins__": __builtins__} +namespace["_anton_explainability_queries"] = [] # --- Inject get_llm() for LLM access from scratchpad code --- _scratchpad_model = os.environ.get("ANTON_SCRATCHPAD_MODEL", "") @@ -243,6 +244,7 @@ def agentic_loop(*, system, user_message, tools, handle_tool, max_turns=10, max_ _minds_datasource = os.environ.get("ANTON_MINDS_DATASOURCE", "") _minds_api_key = os.environ.get("ANTON_MINDS_API_KEY", "") _minds_url = os.environ.get("ANTON_MINDS_URL", "") +_minds_engine = os.environ.get("ANTON_MINDS_DATASOURCE_ENGINE", "") if _minds_datasource and _minds_api_key and _minds_url: try: import ssl as _minds_ssl @@ -273,13 +275,28 @@ def query_minds_data(query, datasource=None): try: with _minds_urllib.urlopen(req, context=ctx, timeout=60) as resp: - return json.loads(resp.read().decode()) + parsed = json.loads(resp.read().decode()) + namespace.setdefault("_anton_explainability_queries", []).append({ + "datasource": ds, + "sql": query, + "engine": _minds_engine or None, + "status": "ok", + "error_message": None, + }) + return parsed except _minds_urllib.HTTPError as e: body = "" try: body = e.read().decode() except Exception: pass + namespace.setdefault("_anton_explainability_queries", []).append({ + "datasource": ds, + "sql": query, + "engine": _minds_engine or None, + "status": "error", + "error_message": f"HTTP {e.code}: {body or e.reason}", + }) return { "type": "error", "data": None, @@ -287,6 +304,13 @@ def query_minds_data(query, datasource=None): "error_message": f"HTTP {e.code}: {body or e.reason}", } except Exception as e: + namespace.setdefault("_anton_explainability_queries", []).append({ + "datasource": ds, + "sql": query, + "engine": _minds_engine or None, + "status": "error", + "error_message": str(e), + }) return { "type": "error", "data": None, @@ -561,6 +585,7 @@ def emit(self, record): err_buf = io.StringIO() log_buf = io.StringIO() error = None + namespace["_anton_explainability_queries"] = [] _cell_log_handler.buf = log_buf sys.stdout = out_buf @@ -625,6 +650,7 @@ def emit(self, record): "stderr": err_buf.getvalue(), "logs": log_buf.getvalue(), "error": error, + "explainability_queries": list(namespace.get("_anton_explainability_queries", [])), } if _auto_installed: result["auto_installed"] = _auto_installed diff --git a/anton/core/session.py b/anton/core/session.py index a920346..cdb8ca5 100644 --- a/anton/core/session.py +++ b/anton/core/session.py @@ -21,6 +21,8 @@ from anton.core.tools.tool_defs import SCRATCHPAD_TOOL, MEMORIZE_TOOL, RECALL_TOOL, ToolDef from anton.core.utils.scratchpad import prepare_scratchpad_exec, format_cell_result +from anton.explainability import ExplainabilityCollector, ExplainabilityStore + from anton.utils.datasources import ( build_datasource_context, scrub_credentials, @@ -102,6 +104,10 @@ def __init__( workspace_path=workspace.base if workspace else None, ) self.tool_registry = ToolRegistry() + self._explainability_store = ( + ExplainabilityStore(workspace.base) if workspace is not None else None + ) + self._active_explainability: ExplainabilityCollector | None = None @property def history(self) -> list[dict]: @@ -182,6 +188,44 @@ def _persist_history(self) -> None: if self._history_store and self._session_id: self._history_store.save(self._session_id, self._history) + def _record_cell_explainability( + self, *, pad_name: str, description: str, cell + ) -> None: + if self._active_explainability is None: + return + if description: + self._active_explainability.add_scratchpad_step(description) + elif pad_name: + self._active_explainability.add_scratchpad_step( + f"work in scratchpad {pad_name}" + ) + for query in getattr(cell, "explainability_queries", []) or []: + if not isinstance(query, dict): + continue + self._active_explainability.add_query( + datasource=str(query.get("datasource", "")), + sql=str(query.get("sql", "")), + engine=( + str(query.get("engine")) + if query.get("engine") is not None + else None + ), + status=str(query.get("status", "ok")), + error_message=( + str(query.get("error_message")) + if query.get("error_message") is not None + else None + ), + ) + self._active_explainability.add_sources_from_text( + getattr(cell, "code", ""), + getattr(cell, "stdout", ""), + getattr(cell, "logs", ""), + ) + self._active_explainability.add_inferred_queries_from_code( + getattr(cell, "code", "") + ) + async def _build_system_prompt(self, user_message: str = "") -> str: import datetime as _dt _now = _dt.datetime.now() @@ -572,64 +616,75 @@ async def turn_stream( assistant_text_parts: list[str] = [] _max_auto_retries = 2 _retry_count = 0 + self._active_explainability = ExplainabilityCollector( + self._explainability_store, + turn=self._turn_count + 1, + user_message=user_msg_str, + ) - while True: - try: - async for event in self._stream_and_handle_tools(user_msg_str): - if isinstance(event, StreamTextDelta): - assistant_text_parts.append(event.text) - yield event - break # completed successfully - except Exception as _agent_exc: - # Token/billing limit — don't retry, let the chat loop handle it - if isinstance(_agent_exc, TokenLimitExceeded): - raise - _retry_count += 1 - if _retry_count <= _max_auto_retries: - # Inject the error into history and let the LLM try to recover - self._history.append( - { - "role": "user", - "content": ( - f"SYSTEM: An error interrupted execution: {_agent_exc}\n\n" - "If you can diagnose and fix the issue, continue working on the task. " - "Adjust your approach to avoid the same error. " - "If this is unrecoverable, summarize what you accomplished and suggest next steps." - ), - } - ) - # Continue the while loop — _stream_and_handle_tools will be called - # again with the error context now in history - continue - else: - # Exhausted retries — stop and summarize for the user - self._history.append( - { - "role": "user", - "content": ( - f"SYSTEM: The task has failed {_retry_count} times. Latest error: {_agent_exc}\n\n" - "Stop retrying. Please:\n" - "1. Summarize what you accomplished so far.\n" - "2. Explain what went wrong in plain language.\n" - "3. Suggest next steps — what the user can try (e.g. rephrase, " - "simplify the request, or ask you to continue from where you left off).\n" - "Be concise and helpful." - ), - } - ) - try: - async for event in self._llm.plan_stream( - system=await self._build_system_prompt(user_msg_str), - messages=self._history, - ): - if isinstance(event, StreamTextDelta): - assistant_text_parts.append(event.text) - yield event - except Exception: - fallback = f"An unexpected error occurred: {_agent_exc}. Please try again or rephrase your request." - assistant_text_parts.append(fallback) - yield StreamTextDelta(text=fallback) - break + try: + while True: + try: + async for event in self._stream_and_handle_tools(user_msg_str): + if isinstance(event, StreamTextDelta): + assistant_text_parts.append(event.text) + yield event + break # completed successfully + except Exception as _agent_exc: + # Token/billing limit — don't retry, let the chat loop handle it + if isinstance(_agent_exc, TokenLimitExceeded): + raise + _retry_count += 1 + if _retry_count <= _max_auto_retries: + # Inject the error into history and let the LLM try to recover + self._history.append( + { + "role": "user", + "content": ( + f"SYSTEM: An error interrupted execution: {_agent_exc}\n\n" + "If you can diagnose and fix the issue, continue working on the task. " + "Adjust your approach to avoid the same error. " + "If this is unrecoverable, summarize what you accomplished and suggest next steps." + ), + } + ) + # Continue the while loop — _stream_and_handle_tools will be called + # again with the error context now in history + continue + else: + # Exhausted retries — stop and summarize for the user + self._history.append( + { + "role": "user", + "content": ( + f"SYSTEM: The task has failed {_retry_count} times. Latest error: {_agent_exc}\n\n" + "Stop retrying. Please:\n" + "1. Summarize what you accomplished so far.\n" + "2. Explain what went wrong in plain language.\n" + "3. Suggest next steps — what the user can try (e.g. rephrase, " + "simplify the request, or ask you to continue from where you left off).\n" + "Be concise and helpful." + ), + } + ) + try: + async for event in self._llm.plan_stream( + system=await self._build_system_prompt(user_msg_str), + messages=self._history, + ): + if isinstance(event, StreamTextDelta): + assistant_text_parts.append(event.text) + yield event + except Exception: + fallback = f"An unexpected error occurred: {_agent_exc}. Please try again or rephrase your request." + assistant_text_parts.append(fallback) + yield StreamTextDelta(text=fallback) + break + finally: + if self._active_explainability is not None: + self._active_explainability.finalize( + "".join(assistant_text_parts)[:2000] + ) # Log assistant response to episodic memory if self._episodic is not None and assistant_text_parts: @@ -869,6 +924,12 @@ async def _stream_and_handle_tools( if cell else "No result produced." ) + if cell is not None: + self._record_cell_explainability( + pad_name=tc.input.get("name", ""), + description=description, + cell=cell, + ) if self._episodic is not None and cell is not None: self._episodic.log_turn( self._turn_count + 1, diff --git a/anton/core/tools/tool_handlers.py b/anton/core/tools/tool_handlers.py index d94d724..c9ba531 100644 --- a/anton/core/tools/tool_handlers.py +++ b/anton/core/tools/tool_handlers.py @@ -106,6 +106,10 @@ async def handle_scratchpad(session: ChatSession, tc_input: dict) -> str: estimated_time=estimated_time, estimated_seconds=estimated_seconds, ) + if cell is not None: + session._record_cell_explainability( + pad_name=name, description=description, cell=cell, + ) return format_cell_result(cell) elif action == "view": diff --git a/anton/explainability.py b/anton/explainability.py new file mode 100644 index 0000000..627e9b8 --- /dev/null +++ b/anton/explainability.py @@ -0,0 +1,271 @@ +from __future__ import annotations + +import json +import re +from dataclasses import dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from urllib.parse import urlparse + + +def _utc_now() -> str: + return datetime.now(timezone.utc).isoformat() + + +@dataclass +class ExplainabilityQuery: + datasource: str + sql: str + engine: str | None = None + status: str = "ok" + error_message: str | None = None + + def to_dict(self) -> dict: + return { + "datasource": self.datasource, + "sql": self.sql, + "engine": self.engine, + "status": self.status, + "error_message": self.error_message, + } + + +@dataclass +class ExplainabilityRecord: + turn: int + created_at: str + user_message: str + answer_text: str + summary: str + data_sources: list[dict] = field(default_factory=list) + sql_queries: list[dict] = field(default_factory=list) + scratchpad_steps: list[str] = field(default_factory=list) + + def to_dict(self) -> dict: + return { + "turn": self.turn, + "created_at": self.created_at, + "user_message": self.user_message, + "answer_text": self.answer_text, + "summary": self.summary, + "data_sources": self.data_sources, + "sql_queries": self.sql_queries, + "scratchpad_steps": self.scratchpad_steps, + } + + +class ExplainabilityStore: + def __init__(self, workspace_path: Path) -> None: + self._dir = workspace_path / ".anton" / "explainability" + + def save(self, record: ExplainabilityRecord) -> None: + self._dir.mkdir(parents=True, exist_ok=True) + payload = json.dumps(record.to_dict(), ensure_ascii=False, indent=2) + "\n" + latest = self._dir / "latest.json" + latest.write_text(payload, encoding="utf-8") + turn_file = self._dir / f"turn-{record.turn:04d}.json" + turn_file.write_text(payload, encoding="utf-8") + + def load_latest(self) -> ExplainabilityRecord | None: + latest = self._dir / "latest.json" + if not latest.is_file(): + return None + try: + payload = json.loads(latest.read_text(encoding="utf-8")) + except Exception: + return None + try: + return ExplainabilityRecord( + turn=int(payload.get("turn", 0)), + created_at=str(payload.get("created_at", "")), + user_message=str(payload.get("user_message", "")), + answer_text=str(payload.get("answer_text", "")), + summary=str(payload.get("summary", "")), + data_sources=list(payload.get("data_sources", [])), + sql_queries=list(payload.get("sql_queries", [])), + scratchpad_steps=list(payload.get("scratchpad_steps", [])), + ) + except Exception: + return None + + +class ExplainabilityCollector: + def __init__(self, store: ExplainabilityStore, *, turn: int, user_message: str) -> None: + self._store = store + self._turn = turn + self._user_message = user_message + self._created_at = _utc_now() + self._scratchpad_steps: list[str] = [] + self._queries: list[ExplainabilityQuery] = [] + self._sources: list[dict[str, str | None]] = [] + + def add_scratchpad_step(self, description: str) -> None: + cleaned = (description or "").strip() + if cleaned and cleaned not in self._scratchpad_steps: + self._scratchpad_steps.append(cleaned) + + def add_query( + self, + *, + datasource: str, + sql: str, + engine: str | None = None, + status: str = "ok", + error_message: str | None = None, + ) -> None: + cleaned_sql = (sql or "").strip() + cleaned_ds = (datasource or "").strip() or "Unknown datasource" + if not cleaned_sql: + return + entry = ExplainabilityQuery( + datasource=cleaned_ds, + sql=cleaned_sql, + engine=(engine or "").strip() or None, + status=status, + error_message=(error_message or "").strip() or None, + ) + if any( + existing.datasource == entry.datasource + and existing.sql == entry.sql + and existing.status == entry.status + for existing in self._queries + ): + return + self._queries.append(entry) + self.add_source(name=cleaned_ds, engine=(engine or "").strip() or None) + + def add_source(self, *, name: str, engine: str | None = None) -> None: + cleaned_name = (name or "").strip() + if not cleaned_name: + return + entry = {"name": cleaned_name, "engine": (engine or "").strip() or None} + if entry not in self._sources: + self._sources.append(entry) + + def add_sources_from_text(self, *texts: str) -> None: + for text in texts: + if not text: + continue + for source in _extract_sources_from_text(text): + self.add_source(name=source, engine=None) + + def add_inferred_queries_from_code(self, code: str) -> None: + if self._queries: + return + sql_statements = _extract_sql_from_code(code) + datasource_names = _extract_datasource_names_from_code(code) + datasource = datasource_names[0] if datasource_names else "connected datasource" + for sql in sql_statements: + self.add_query( + datasource=datasource, + sql=sql, + engine=None, + status="ok", + error_message=None, + ) + + def finalize(self, answer_text: str) -> ExplainabilityRecord: + data_sources: list[dict] = [] + seen_sources: set[tuple[str, str | None]] = set() + for source in self._sources: + key = (str(source.get("name", "")), source.get("engine")) + if key in seen_sources: + continue + seen_sources.add(key) + data_sources.append({"name": key[0], "engine": key[1]}) + + summary = self._build_summary(answer_text, data_sources) + record = ExplainabilityRecord( + turn=self._turn, + created_at=self._created_at, + user_message=self._user_message, + answer_text=answer_text.strip(), + summary=summary, + data_sources=data_sources, + sql_queries=[query.to_dict() for query in self._queries], + scratchpad_steps=list(self._scratchpad_steps), + ) + if self._store is not None: + self._store.save(record) + return record + + def _build_summary(self, answer_text: str, data_sources: list[dict]) -> str: + if self._queries: + source_names = ", ".join(source["name"] for source in data_sources[:3]) + query_count = len(self._queries) + step_text = "" + if self._scratchpad_steps: + lead = self._scratchpad_steps[0].rstrip(".") + step_text = f" I used the scratchpad to {lead.lower()}." + return ( + f"I queried {source_names} with {query_count} SQL " + f"{'statement' if query_count == 1 else 'statements'} to gather the data behind this answer." + f"{step_text}" + ) + if data_sources: + source_names = ", ".join(source["name"] for source in data_sources[:3]) + if self._scratchpad_steps: + lead = self._scratchpad_steps[0].rstrip(".").lower() + return ( + f"I gathered information from {source_names} and used the scratchpad to " + f"{lead} before drafting the answer." + ) + return f"I gathered information from {source_names} before drafting the answer." + if self._scratchpad_steps: + primary_step = self._scratchpad_steps[0].rstrip(".").lower() + return f"I used the scratchpad to {primary_step} before drafting the answer." + if answer_text.strip(): + return "I answered directly from the conversation context without querying a datasource or generating SQL." + return "No explainability details were captured for this answer." + + +_URL_RE = re.compile(r"https?://[^\s)\"'>]+") +_SQL_LITERAL_RE = re.compile( + r"(?P'''|\"\"\"|'|\")(?P.*?)(?P=quote)", + re.DOTALL, +) +_DS_PREFIX_RE = re.compile(r"\b(DS_[A-Z0-9_]+)__[A-Z0-9_]+\b") + + +def _extract_sources_from_text(text: str) -> list[str]: + sources: list[str] = [] + for match in _URL_RE.findall(text): + parsed = urlparse(match) + host = (parsed.hostname or "").lower() + host = host.removeprefix("www.") + if not host: + continue + if host not in sources: + sources.append(host) + return sources + + +def _looks_like_sql(text: str) -> bool: + normalized = " ".join(text.strip().split()).upper() + if len(normalized) < 12: + return False + starters = ("SELECT ", "WITH ", "INSERT ", "UPDATE ", "DELETE ", "SHOW ", "DESCRIBE ") + if not normalized.startswith(starters): + return False + return any(keyword in normalized for keyword in (" FROM ", " JOIN ", " INTO ", " TABLE ", "SELECT ")) + + +def _extract_sql_from_code(code: str) -> list[str]: + sql_statements: list[str] = [] + for match in _SQL_LITERAL_RE.finditer(code or ""): + body = match.group("body").strip() + if not _looks_like_sql(body): + continue + cleaned = "\n".join(line.rstrip() for line in body.splitlines()).strip() + if cleaned and cleaned not in sql_statements: + sql_statements.append(cleaned) + return sql_statements + + +def _extract_datasource_names_from_code(code: str) -> list[str]: + names: list[str] = [] + for prefix in _DS_PREFIX_RE.findall(code or ""): + slug = prefix.removeprefix("DS_").lower().replace("_", "-") + if slug not in names: + names.append(slug) + return names diff --git a/tests/test_explainability.py b/tests/test_explainability.py new file mode 100644 index 0000000..1b835cc --- /dev/null +++ b/tests/test_explainability.py @@ -0,0 +1,137 @@ +from __future__ import annotations + +import json +from unittest.mock import MagicMock + +from anton.commands.ui import handle_explain +from anton.explainability import ExplainabilityCollector, ExplainabilityStore + + +def test_explainability_store_persists_latest_and_turn_file(tmp_path): + store = ExplainabilityStore(tmp_path) + collector = ExplainabilityCollector(store, turn=3, user_message="How did revenue change?") + collector.add_scratchpad_step("Query monthly revenue") + collector.add_query( + datasource="warehouse.orders", + sql="SELECT month, revenue FROM revenue_by_month", + engine="postgres", + ) + + record = collector.finalize("Revenue increased 12% month over month.") + + latest = tmp_path / ".anton" / "explainability" / "latest.json" + turn_file = tmp_path / ".anton" / "explainability" / "turn-0003.json" + + assert latest.is_file() + assert turn_file.is_file() + + latest_payload = json.loads(latest.read_text()) + assert latest_payload["turn"] == 3 + assert latest_payload["sql_queries"][0]["datasource"] == "warehouse.orders" + assert "queried warehouse.orders" in latest_payload["summary"].lower() + assert record.summary == latest_payload["summary"] + + +def test_explainability_sql_shape_includes_datasources_and_queries(tmp_path): + store = ExplainabilityStore(tmp_path) + collector = ExplainabilityCollector(store, turn=4, user_message="What was monthly revenue?") + collector.add_scratchpad_step("Query monthly revenue") + collector.add_query( + datasource="finance.monthly_revenue", + sql="SELECT month, revenue FROM monthly_revenue ORDER BY month DESC", + engine="snowflake", + ) + + record = collector.finalize("Revenue rose in March.") + + assert record.data_sources == [{"name": "finance.monthly_revenue", "engine": "snowflake"}] + assert record.sql_queries == [ + { + "datasource": "finance.monthly_revenue", + "sql": "SELECT month, revenue FROM monthly_revenue ORDER BY month DESC", + "engine": "snowflake", + "status": "ok", + "error_message": None, + } + ] + assert "sql statement" in record.summary.lower() + + +def test_explainability_summary_without_queries_is_direct_answer(tmp_path): + store = ExplainabilityStore(tmp_path) + collector = ExplainabilityCollector(store, turn=1, user_message="What is Anton?") + + record = collector.finalize("Anton is MindsDB's autonomous AI coworker.") + + assert record.sql_queries == [] + assert ( + record.summary + == "I answered directly from the conversation context without querying a datasource or generating SQL." + ) + + +def test_explainability_extracts_non_sql_sources_from_text(tmp_path): + store = ExplainabilityStore(tmp_path) + collector = ExplainabilityCollector(store, turn=2, user_message="Compare green coffee prices") + collector.add_scratchpad_step("Fetch green coffee bean prices and compute roasting cost comparison") + collector.add_sources_from_text( + 'See https://www.happymugcoffee.com/collections/green-coffee and https://burmancoffee.com/' + ) + + record = collector.finalize("Home roasting is much cheaper.") + + source_names = [source["name"] for source in record.data_sources] + assert "happymugcoffee.com" in source_names + assert "burmancoffee.com" in source_names + assert "gathered information from" in record.summary.lower() + + +def test_handle_explain_prints_sections_for_latest_record(tmp_path): + store = ExplainabilityStore(tmp_path) + collector = ExplainabilityCollector(store, turn=5, user_message="What was revenue?") + collector.add_scratchpad_step("Query monthly revenue") + collector.add_query( + datasource="finance.monthly_revenue", + sql="SELECT month, revenue FROM monthly_revenue", + engine="postgres", + ) + collector.finalize("Revenue rose.") + + console = MagicMock() + handle_explain(console, tmp_path) + + rendered = "\n".join( + str(call.args[0]) for call in console.print.call_args_list if call.args + ) + assert "Explain This Answer" in rendered + assert "Summary" in rendered + assert "Data Sources Used" in rendered + assert "Generated SQL" in rendered + assert "finance.monthly_revenue" in rendered + assert "SELECT month, revenue FROM monthly_revenue" in rendered + + +def test_explainability_infers_sql_and_datasource_from_scratchpad_code(tmp_path): + store = ExplainabilityStore(tmp_path) + collector = ExplainabilityCollector(store, turn=6, user_message="Average revenue") + collector.add_scratchpad_step("Average annual revenue over last 10 years in the dataset") + collector.add_inferred_queries_from_code( + """ +import os +sql = \"\"\" +SELECT EXTRACT(YEAR FROM sale_date) AS year, AVG(revenue) AS avg_revenue +FROM sales +GROUP BY 1 +ORDER BY 1 +\"\"\" +host = os.environ["DS_POSTGRES_PROD_DB__HOST"] +cur.execute(sql) +""" + ) + + record = collector.finalize("Average revenue is stable.") + + assert record.data_sources == [{"name": "postgres-prod-db", "engine": None}] + assert len(record.sql_queries) == 1 + assert "SELECT EXTRACT(YEAR FROM sale_date)" in record.sql_queries[0]["sql"] + assert record.sql_queries[0]["datasource"] == "postgres-prod-db"