diff --git a/apps/api/app/api/v1/routes/retrieval.py b/apps/api/app/api/v1/routes/retrieval.py index 5629e733..eb6a5846 100644 --- a/apps/api/app/api/v1/routes/retrieval.py +++ b/apps/api/app/api/v1/routes/retrieval.py @@ -1,4 +1,4 @@ -"""Retrieval API routes for lexical + graph-routing baseline.""" +"""Retrieval API routes for evidence-only agentic retrieval.""" from __future__ import annotations @@ -59,7 +59,7 @@ class RetrievalQueryRequest(BaseModel): ) use_agentic: bool | None = Field( None, - description="Per-request agentic mode toggle. true=force agentic, false=force legacy, null=use server default.", + description="Deprecated mode hint retained for cache/request compatibility; retrieval always uses the agentic workflow.", ) @field_validator("channels") @@ -99,10 +99,8 @@ class RetrievalQueryResponse(BaseModel): decision_trace: list[dict] | None = Field( default=None, description=( - "Per-step navigation decisions from agentic retrieval. " - "Each entry has phase, document, action, reason, collected_paths, " - "and drill_into. Use this to understand " - "why KNOWHERE stopped or made specific navigation choices." + "Per-step agentic retrieval trace. Each entry follows the " + "observation/decision/result schema." ), ) diff --git a/apps/api/tests/contract/test_agentic_discovery_selection_contract.py b/apps/api/tests/contract/test_agentic_discovery_selection_contract.py index 7d906918..a27dc654 100644 --- a/apps/api/tests/contract/test_agentic_discovery_selection_contract.py +++ b/apps/api/tests/contract/test_agentic_discovery_selection_contract.py @@ -1,96 +1,79 @@ -from shared.services.retrieval.agentic.core.types import DocTreeNode -from shared.services.retrieval.agentic.discovery.selection import ( - _build_discovery_path_selections, - _project_discovery_hints, -) +from shared.services.retrieval.agentic.navigation.actions import build_legal_actions -def test_root_discovery_hint_is_projected_for_llm_selection() -> None: - hint_lines, hint_by_path, excluded_hints = _project_discovery_hints( - [ +def test_discovery_hint_is_projected_as_collect_action() -> None: + action_set = build_legal_actions( + items=[], + current_scope=None, + collected_paths=[], + expanded_scopes=set(), + discovery_hints=[ { - "section_path": "Root", - "chunk_id": "chunk_root_relevant", - "summary": "document-level market chart", + "section_path": "2 阶段性调整还是牛熊切换? / 2.1 牛熊切换缘何开启?", + "discovery_score": 0.82, + "chunk_type": "text", } ], - exclude_paths=None, + rejected_paths=set(), + rejected_collect_paths=set(), + total_images=0, + total_tables=0, + budget_snapshot=None, ) - assert hint_lines == [ - '▸ path="Root"', - " document-level market chart", - ] - assert hint_by_path["Root"]["chunk_id"] == "chunk_root_relevant" + assert len(action_set.collect) == 1 + action = action_set.collect[0] + assert action.id == "D1" + assert action.action == "COLLECT" + assert action.source == "discovery" + assert action.path == "2 阶段性调整还是牛熊切换? / 2.1 牛熊切换缘何开启?" + assert action.score == 0.82 -def test_root_discovery_hint_without_llm_selection_does_not_hydrate() -> None: - node = DocTreeNode() - - path_selections, chunk_refs = _build_discovery_path_selections( - selections=[], - hint_by_path={ - "Root": { - "section_path": "Root", - "chunk_id": "chunk_root_relevant", +def test_discovery_hint_under_collected_path_is_not_repeated() -> None: + action_set = build_legal_actions( + items=[], + current_scope=None, + collected_paths=[ + { + "path": "2 阶段性调整还是牛熊切换?", + "hydrate_mode": "chunks", } - }, - document_id="doc_root", - node=node, - ) - - assert path_selections == [] - assert chunk_refs == [] - assert node.confidence == {} - - -def test_explicit_root_discovery_selection_with_chunk_id_uses_exact_chunk_ref() -> None: - node = DocTreeNode() - - path_selections, chunk_refs = _build_discovery_path_selections( - selections=[{"path": "Root", "confidence": 0.91}], - hint_by_path={ - "Root": { - "section_path": "Root", - "chunk_id": "chunk_root_relevant", + ], + expanded_scopes=set(), + discovery_hints=[ + { + "section_path": "2 阶段性调整还是牛熊切换? / 2.1 牛熊切换缘何开启?", + "discovery_score": 0.82, } - }, - document_id="doc_root", - node=node, + ], + rejected_paths=set(), + rejected_collect_paths=set(), + total_images=0, + total_tables=0, + budget_snapshot=None, ) - assert path_selections == [] - assert chunk_refs == [ - { - "document_id": "doc_root", - "chunk_id": "chunk_root_relevant", - "section_path": "Root", - } - ] - assert node.confidence["Root"] == 0.91 + assert action_set.collect == [] -def test_explicit_root_discovery_selection_without_chunk_id_keeps_path_fallback() -> None: - node = DocTreeNode() - - path_selections, chunk_refs = _build_discovery_path_selections( - selections=[{"path": "Root", "confidence": 0.7}], - hint_by_path={ - "Root": { - "section_path": "Root", - "chunk_id": "", +def test_discovery_hint_under_rejected_collect_path_is_not_repeated() -> None: + action_set = build_legal_actions( + items=[], + current_scope=None, + collected_paths=[], + expanded_scopes=set(), + discovery_hints=[ + { + "section_path": "1、2016:机构行为助推行情演绎 / 二是英国“脱欧”影响下", + "discovery_score": 0.7, } - }, - document_id="doc_root", - node=node, + ], + rejected_paths=set(), + rejected_collect_paths={"1、2016:机构行为助推行情演绎"}, + total_images=0, + total_tables=0, + budget_snapshot=None, ) - assert path_selections == [ - { - "path": "Root", - "confidence": 0.7, - "hydrate_mode": "self_only", - } - ] - assert chunk_refs == [] - assert node.confidence["Root"] == 0.7 + assert action_set.collect == [] diff --git a/apps/api/tests/contract/test_retrieval_contract.py b/apps/api/tests/contract/test_retrieval_contract.py index 852cb745..ac5d7876 100644 --- a/apps/api/tests/contract/test_retrieval_contract.py +++ b/apps/api/tests/contract/test_retrieval_contract.py @@ -1,6 +1,5 @@ from collections.abc import Callable, Coroutine, Sequence from contextlib import AbstractAsyncContextManager -from datetime import datetime, timezone from typing import Any, cast from uuid import uuid4 @@ -357,128 +356,55 @@ async def test_should_return_empty_results_for_an_empty_query( @pytest.mark.asyncio -async def test_legacy_retrieval_should_rank_hot_chunk_before_cold_chunk_when_discovery_scores_tie( +async def test_retrieval_should_ignore_false_agentic_hint_and_use_workflow( developer_api_client_factory: Callable[ [], AbstractAsyncContextManager[AsyncClient] ], monkeypatch: MonkeyPatch, ) -> None: - async with developer_api_client_factory() as api_client: - cold_document = await _seed_retrieval_document( - user_id="local-dev-user", - namespace="contract-hot-ranking", - source_file_name="cold.pdf", - section_path="ranking/cold", - content="same ranking marker cold", + async def fake_run_request( + self: object, + db: AsyncSession, + *, + request: WorkflowRunRequest, + llm_fn: object | None = None, + ) -> WorkflowResult: + return WorkflowResult( + namespace=request.namespace, + query=request.query, + router_used="workflow_single_step", + answer_text="", + plan=QueryPlan.single_step(request.query), + referenced_chunks=[], + results=[], ) - hot_document = await _seed_retrieval_document( + + monkeypatch.setattr( + "shared.services.retrieval.workflow.orchestrator.WorkflowOrchestrator.run_request", + fake_run_request, + ) + + async with developer_api_client_factory() as api_client: + await _seed_retrieval_document( user_id="local-dev-user", - namespace="contract-hot-ranking", - source_file_name="hot.pdf", - section_path="ranking/hot", - content="same ranking marker hot", + namespace="contract-agentic-only", + source_file_name="a.pdf", + section_path="agentic/a", + content="same ranking marker a", ) await _seed_retrieval_document( user_id="local-dev-user", - namespace="contract-hot-ranking", - source_file_name="filler.pdf", - section_path="ranking/filler", - content="same ranking marker filler", - ) - - now = datetime.now(timezone.utc).replace(tzinfo=None) - await ContractDatabase.execute( - """ - INSERT INTO retrieval_hit_stats ( - id, - user_id, - namespace, - hit_kind, - document_id, - chunk_id, - hit_count, - last_hit_at, - created_at, - updated_at - ) VALUES ( - :id, - :user_id, - :namespace, - 'chunk', - :document_id, - :chunk_id, - :hit_count, - :now, - :now, - :now - ) - """, - { - "id": f"rhs_{uuid4().hex[:12]}", - "user_id": "local-dev-user", - "namespace": "contract-hot-ranking", - "document_id": hot_document["document_id"], - "chunk_id": hot_document["chunk_id"], - "hit_count": 100, - "now": now, - }, + namespace="contract-agentic-only", + source_file_name="b.pdf", + section_path="agentic/b", + content="same ranking marker b", ) - - def to_channel_row(document: dict[str, str]) -> dict[str, object]: - return { - "document_id": document["document_id"], - "chunk_id": document["chunk_id"], - "section_id": document["section_id"], - "section_path": document["section_path"], - "source_file_name": "cold.pdf" - if document["document_id"] == cold_document["document_id"] - else "hot.pdf", - "chunk_type": "text", - "content": "same ranking marker", - "score": 1.0, - "file_path": None, - "chunk_metadata": {}, - "job_result_id": document["job_result_id"], - "job_id": document["job_id"], - "sort_order": 0, - } - - async def fake_content_channel(*_args: object, **_kwargs: object) -> list[dict[str, object]]: - return [ - to_channel_row(hot_document), - to_channel_row(cold_document), - ] - - async def fake_path_channel(*_args: object, **_kwargs: object) -> list[dict[str, object]]: - return [ - to_channel_row(cold_document), - to_channel_row(hot_document), - ] - - async def fake_graph_routing(*_args: object, **_kwargs: object) -> list[dict[str, object]]: - return [] - - monkeypatch.setattr( - "shared.services.retrieval.execution.legacy_route.path_channel", - fake_path_channel, - ) - monkeypatch.setattr( - "shared.services.retrieval.execution.legacy_route.content_channel", - fake_content_channel, - ) - monkeypatch.setattr( - "shared.services.retrieval.execution.legacy_route.list_graph_routed_chunks", - fake_graph_routing, - ) - response = await api_client.post( "/api/v1/retrieval/query", json={ - "namespace": "contract-hot-ranking", + "namespace": "contract-agentic-only", "query": "same ranking marker", "top_k": 1, - "channels": ["path", "content"], - "channel_weights": {"path": 1.0, "content": 1.0}, "use_agentic": False, }, ) @@ -486,10 +412,8 @@ async def fake_graph_routing(*_args: object, **_kwargs: object) -> list[dict[str assert response.status_code == 200 response_json = cast(dict[str, object], response.json()) - results = cast(list[dict[str, object]], response_json["results"]) - - assert len(results) == 1 - assert _result_source(results[0])["document_id"] == hot_document["document_id"] + assert response_json["router_used"] == "workflow_single_step" + assert response_json["results"] == [] @pytest.mark.asyncio diff --git a/apps/worker/app/services/document_agent/manifest.py b/apps/worker/app/services/document_agent/manifest.py index 03866523..3fd747ca 100644 --- a/apps/worker/app/services/document_agent/manifest.py +++ b/apps/worker/app/services/document_agent/manifest.py @@ -125,7 +125,7 @@ class H1Candidate: page: int confidence: float matched_line: str - source: Literal["toc_exact_top", "toc_fuzzy_top", "heading_grep", "none"] + source: Literal["toc_exact_top", "toc_fuzzy_top", "heading_grep", "toc_grep", "h2_refine", "none"] evidence: dict[str, Any] = field(default_factory=dict) def to_dict(self) -> dict[str, Any]: @@ -163,6 +163,8 @@ class Shard: anchor_type: Literal["h1_boundary", "blank_separator", "forced_max_size"] anchor_evidence: str confidence: float + split_depth: int = 1 # 1=H1 cut, 2=H2 cut, etc. + is_continuation: bool = False # True for continuation shards that don't contain parent heading def to_dict(self) -> dict[str, Any]: return asdict(self) diff --git a/apps/worker/app/services/document_agent/planner/planner.py b/apps/worker/app/services/document_agent/planner/planner.py index 61d63f9f..bada42b0 100644 --- a/apps/worker/app/services/document_agent/planner/planner.py +++ b/apps/worker/app/services/document_agent/planner/planner.py @@ -77,21 +77,35 @@ def _segment_sample(candidates: list[int], count: int) -> list[int]: return [candidates[round(index * step)] for index in range(count)] -def _sample_pages(page_count: int, extrema_pages: list[int]) -> list[int]: +def _sample_pages( + page_count: int, + extrema_pages: list[int], + exclude_pages: set[int] | None = None, +) -> list[int]: + """Select representative pages for VLM profiling. + + Args: + page_count: Total number of pages. + extrema_pages: Pages with statistical extrema (min/max text, tables, etc.). + exclude_pages: Pages to skip entirely (e.g. TOC pages already detected + by the TOC pipeline). These inflate text-density metrics without + adding profiling value. + """ if page_count <= 0: return [] - extrema = [page for page in extrema_pages if 1 <= page <= page_count] - remaining = [page for page in range(1, page_count + 1) if page not in set(extrema)] - if not remaining: + skip = exclude_pages or set() + extrema = [page for page in extrema_pages if 1 <= page <= page_count and page not in skip] + pool = [page for page in range(1, page_count + 1) if page not in set(extrema) and page not in skip] + if not pool: return sorted(set(extrema)) - third = max(len(remaining) // 3, 1) - front = remaining[:third] - middle = remaining[third : third * 2] - back = remaining[third * 2 :] + third = max(len(pool) // 3, 1) + front = pool[:third] + middle = pool[third : third * 2] + back = pool[third * 2 :] sampled = ( _segment_sample(front, 4) - + _segment_sample(middle or remaining, 3) - + _segment_sample(back or remaining, 3) + + _segment_sample(middle or pool, 3) + + _segment_sample(back or pool, 3) ) ordered = [] for page in extrema + sampled: @@ -168,7 +182,16 @@ def propose(self) -> tuple[DocumentProfile, ReflexionDecision, ToolResult]: or self.ctx.settings.get("vlm_model") or os.environ.get("IMAGE_MODEL") ) - pages = _sample_pages(self.ctx.blackboard.page_count, self.ctx.blackboard.extrema_pages) + toc_pages = set( + self.ctx.blackboard.toc_result.toc_pages + if self.ctx.blackboard.toc_result + else [] + ) + pages = _sample_pages( + self.ctx.blackboard.page_count, + self.ctx.blackboard.extrema_pages, + exclude_pages=toc_pages, + ) if not model: profile = DocumentProfile( is_scanned=False, diff --git a/apps/worker/app/services/document_agent/tools/match_h1_pages.py b/apps/worker/app/services/document_agent/tools/match_h1_pages.py index a60ad269..1a6d0d07 100644 --- a/apps/worker/app/services/document_agent/tools/match_h1_pages.py +++ b/apps/worker/app/services/document_agent/tools/match_h1_pages.py @@ -2,10 +2,13 @@ from __future__ import annotations +import base64 +import json +import os import re import time import unicodedata -from typing import Any +from typing import Any, cast from app.services.document_agent.manifest import ( H1BoundaryResult, @@ -15,6 +18,7 @@ ) from app.services.document_agent.pdf_text import read_page_texts from app.services.document_agent.registry import has_toc_result, register_tool +from app.services.document_agent.visual import render_pages from loguru import logger @@ -53,6 +57,89 @@ def _clean_toc_title(title: str) -> str: return cleaned +# ── C1: Unified grep matching ──────────────────────────────────────────── + + +def grep_titles_in_pages( + titles: list[str], + search_pages: list[int], + page_texts: dict[int, str], + *, + source: str = "toc_grep", + confidence: float = 0.88, +) -> tuple[list[H1Candidate], list[str]]: + """Grep a list of titles across specified pages, returning match results. + + H1/H2 share this function. Callers control scope via *titles* and + *search_pages*. + + Returns: + (matched_candidates, unmatched_titles) + """ + candidates: list[H1Candidate] = [] + unmatched: list[str] = [] + + for title in titles: + normalized_title = _normalize(title) + found = False + for page in search_pages: + text = page_texts.get(page, "") + if normalized_title in _normalize(text): + matched_line = "" + for line in text.splitlines(): + if normalized_title in _normalize(line): + matched_line = line.strip()[:100] + break + candidates.append( + H1Candidate( + title=title, + page=page, + confidence=confidence, + matched_line=matched_line, + source=source, # type: ignore[arg-type] + evidence={ + "normalized_needle": normalized_title, + "page_text_length": len(text), + }, + ) + ) + found = True + break # First match per title + if not found: + unmatched.append(title) + + # Deduplicate by page – keep first hit + seen: set[int] = set() + deduped: list[H1Candidate] = [] + for c in candidates: + if c.page not in seen: + seen.add(c.page) + deduped.append(c) + + return deduped, unmatched + + +def extract_children_titles( + toc_hierarchies: list[dict[str, Any]], + parent_title: str, +) -> list[str]: + """Extract level-2 titles under a given H1 parent from toc_with_level.""" + titles: list[str] = [] + for hier in toc_hierarchies or []: + entries = hier.get("toc_with_level", []) + in_scope = False + for entry in entries: + if entry.get("level") == 1: + cleaned = _clean_toc_title(entry.get("heading", "")) + in_scope = _normalize(cleaned) == _normalize(parent_title) + continue + if in_scope and entry.get("level") == 2: + cleaned = _clean_toc_title(entry.get("heading", "")) + if cleaned and len(cleaned) >= 2: + titles.append(cleaned) + return titles + + def _extract_level1_titles(toc_hierarchies: list[dict[str, Any]]) -> list[str]: """Extract level-1 titles from toc_hierarchies. @@ -69,6 +156,87 @@ def _extract_level1_titles(toc_hierarchies: list[dict[str, Any]]) -> list[str]: return titles +# ── C2: Lazy VLM verification ──────────────────────────────────────────── + + +def verify_section_start( + *, + page: int, + title: str, + ctx: ToolContext, +) -> bool: + """VLM-confirm whether *page* is the start of a section titled *title*. + + Used for lazy verification before committing a shard cut. + If VLM is unavailable (no model / budget exhausted / render fails), + returns ``True`` (trust GREP). + """ + model = ctx.settings.get("vlm_model") or os.environ.get("IMAGE_MODEL") + if not model: + return True # No VLM → trust GREP + + # Render 1 page PNG + png_items = render_pages( + ctx, [page], folder_name="verify_pages", prefix="verify", timeout=60, + ) + if not png_items: + return True # Render failed → trust GREP + + prompt = ( + f"This is page {page} of a PDF document.\n" + f"Question: Is this page the START of a section titled '{title}'?\n" + "Criteria: The title appears as a prominent heading/title on this page, " + "not merely mentioned in body text.\n" + 'Return JSON: {"is_section_start": true/false, "reason": "brief"}' + ) + est = 800 # ~800 tokens for 1 image + if not ctx.budget.try_reserve("visual", est): + return True # Budget exhausted → trust GREP + + try: + png_path = str(png_items[0]["png_path"]) + with open(png_path, "rb") as f: + img_b64 = base64.b64encode(f.read()).decode() + content_parts: list[dict[str, Any]] = [ + {"type": "text", "text": prompt}, + {"type": "text", "text": f"\n--- Page {page} ---"}, + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{img_b64}"}, + }, + ] + + from shared.services.ai.openai_compatible_client_sync import ( + get_openai_client, + ) + + client = get_openai_client(model=model) + raw, usage = client.chat_completion_with_usage( + messages=cast(Any, [{"role": "user", "content": content_parts}]), + model=model, + temperature=0.0, + max_tokens=256, + response_format={"type": "json_object"}, + ) + ctx.budget.commit( + "visual", actual=usage.get("total_tokens", est), est=est, + ) + data = json.loads(raw) + result = bool(data.get("is_section_start", True)) + logger.info( + "[verify_section_start] page={} title='{}' → {} reason={}", + page, title[:30], result, data.get("reason", ""), + ) + return result + except Exception as exc: + ctx.budget.refund("visual", est=est) + logger.warning("[verify_section_start] VLM failed for page {}: {}", page, exc) + return True # VLM failure → trust GREP + + +# ── Tool registration ──────────────────────────────────────────────────── + + @register_tool( name="match.h1_pages", description=( @@ -119,53 +287,12 @@ def match_h1_pages(ctx: ToolContext, _args: dict[str, Any]) -> ToolResult: ) page_texts = read_page_texts(ctx.pdf_path, search_pages, timeout=300) - # Strict substring matching: for each level-1 title, find the first body page - h1_candidates: list[H1Candidate] = [] - matched_titles: list[str] = [] - unmatched_titles: list[str] = [] - - for title in level1_titles: - normalized_title = _normalize(title) - found = False - for page in search_pages: - text = page_texts.get(page, "") - normalized_text = _normalize(text) - if normalized_title in normalized_text: - # Find the matched line for evidence - matched_line = "" - for line in text.splitlines(): - if normalized_title in _normalize(line): - matched_line = line.strip()[:100] - break - - h1_candidates.append( - H1Candidate( - title=title, - page=page, - confidence=0.88, - matched_line=matched_line, - source="toc_exact_top", - evidence={ - "normalized_needle": normalized_title, - "page_text_length": len(text), - }, - ) - ) - matched_titles.append(title) - found = True - break # Only first match per title - - if not found: - unmatched_titles.append(title) + # Delegate to unified grep function + h1_candidates, unmatched_titles = grep_titles_in_pages( + level1_titles, search_pages, page_texts, source="toc_exact_top", + ) - # Deduplicate: if multiple titles map to the same page, keep the first - seen_pages: set[int] = set() - deduped: list[H1Candidate] = [] - for candidate in h1_candidates: - if candidate.page not in seen_pages: - seen_pages.add(candidate.page) - deduped.append(candidate) - h1_candidates = deduped + matched_titles = [c.title for c in h1_candidates] ctx.blackboard.h1_result = H1BoundaryResult( h1_candidates=h1_candidates, diff --git a/apps/worker/app/services/document_agent/tools/propose_shard_plan.py b/apps/worker/app/services/document_agent/tools/propose_shard_plan.py index 82be670f..a52eeb05 100644 --- a/apps/worker/app/services/document_agent/tools/propose_shard_plan.py +++ b/apps/worker/app/services/document_agent/tools/propose_shard_plan.py @@ -8,13 +8,21 @@ from typing import Any from app.services.document_agent.manifest import ( + H1Candidate, Shard, ShardPlan, ToolContext, ToolResult, ) +from app.services.document_agent.pdf_text import read_page_texts from app.services.document_agent.registry import has_doc_stats, has_h1_result, has_toc_result, register_tool +from app.services.document_agent.tools.match_h1_pages import ( + extract_children_titles, + grep_titles_in_pages, + verify_section_start, +) from app.services.document_agent.validators import single_shard_plan, validate_shard_plan +from loguru import logger from shared.utils.token_estimate import estimate_tokens @@ -37,9 +45,18 @@ def _thresholds(ctx: ToolContext) -> tuple[int, int, int]: def _cuts_to_shards(cuts: list[tuple[int, str, str, float]], page_count: int) -> list[Shard]: shards: list[Shard] = [] previous = 0 + # Track which cuts came from H2 refinement to mark continuation shards + h2_cut_pages: set[int] = set() + for cut_page, _anchor_type, evidence, _confidence in cuts: + if evidence.startswith("H2 refine:"): + h2_cut_pages.add(cut_page) + for cut_page, anchor_type, evidence, confidence in cuts: if cut_page <= previous: continue + # A shard is continuation if it starts AFTER an H2 cut (previous cut was H2) + _is_continuation = previous in h2_cut_pages + _split_depth = 2 if (evidence.startswith("H2 refine:") or _is_continuation) else 1 shards.append( Shard( shard_index=len(shards), @@ -49,10 +66,13 @@ def _cuts_to_shards(cuts: list[tuple[int, str, str, float]], page_count: int) -> anchor_type=anchor_type, # type: ignore[arg-type] anchor_evidence=evidence, confidence=confidence, + split_depth=_split_depth, + is_continuation=_is_continuation, ) ) previous = cut_page if previous < page_count: + _is_continuation = previous in h2_cut_pages shards.append( Shard( shard_index=len(shards), @@ -62,6 +82,8 @@ def _cuts_to_shards(cuts: list[tuple[int, str, str, float]], page_count: int) -> anchor_type="forced_max_size", anchor_evidence="final shard", confidence=1.0, + split_depth=2 if _is_continuation else 1, + is_continuation=_is_continuation, ) ) return shards @@ -212,14 +234,169 @@ def _deterministic_guardrail_plan( cuts.append((cut_page, "h1_boundary", f"guardrail H1 start page {chosen}", 0.35)) previous = cut_page else: - cuts.append((target, "forced_max_size", "guardrail max shard size", 0.25)) - previous = target - # Merge final shard into previous if it's smaller than min_pages - if cuts and (page_count - cuts[-1][0]) < min_pages: - cuts.pop() + break # No more H1 in range → leave oversized shard for H2 refinement return cuts, "too_large" +# ── C3: H2-aware shard refinement ──────────────────────────────────────── + + +def _find_h1_for_range( + h1_candidates: list[H1Candidate], + range_start: int, + range_end: int, +) -> str | None: + """Find the H1 title whose start page falls in [range_start+1, range_end].""" + for c in h1_candidates: + if range_start < c.page <= range_end: + return c.title + # Fallback: the H1 whose page is closest to and <= range_start+1 + best: H1Candidate | None = None + for c in h1_candidates: + if c.page <= range_start + 1: + if best is None or c.page > best.page: + best = c + return best.title if best else None + + +def _pick_and_verify_best_cut( + h2_candidates: list[H1Candidate], + shard_start: int, + shard_end: int, + min_pages: int, + max_pages: int, + ctx: ToolContext, +) -> tuple[int, str, str, float] | None: + """Pick the H2 candidate that produces the most balanced sub-shards. + + Candidates are ranked by how close they split the shard to the midpoint. + Each candidate is VLM-verified before acceptance. + """ + if not h2_candidates: + return None + + shard_length = shard_end - shard_start + midpoint = shard_start + shard_length // 2 + + # Sort by distance to midpoint (most balanced first) + ranked = sorted(h2_candidates, key=lambda c: abs(c.page - midpoint)) + + for candidate in ranked: + cut_page = candidate.page - 1 # Cut *before* the H2 start page + left_len = cut_page - shard_start + right_len = shard_end - cut_page + if left_len < min_pages or right_len < min_pages: + continue + if left_len > max_pages or right_len > max_pages: + continue + # VLM verification + if not verify_section_start(page=candidate.page, title=candidate.title, ctx=ctx): + logger.info( + "[h2_refine] VLM rejected H2 cut at page {} ('{}')", + candidate.page, candidate.title[:30], + ) + continue + logger.info( + "[h2_refine] accepted H2 cut at page {} ('{}'), left={} right={}", + candidate.page, candidate.title[:30], left_len, right_len, + ) + return ( + cut_page, + "h1_boundary", + f"H2 refine: '{candidate.title[:60]}' at page {candidate.page}", + candidate.confidence * 0.9, # Slightly lower confidence than H1 + ) + + return None + + +def _refine_with_h2( + cuts: list[tuple[int, str, str, float]], + page_count: int, + min_pages: int, + max_pages: int, + ctx: ToolContext, + h1_candidates: list[H1Candidate], +) -> list[tuple[int, str, str, float]]: + """Post-process cuts: split any shard that exceeds max_pages using H2 boundaries.""" + if not ctx.blackboard.toc_hierarchies: + return cuts + + refined: list[tuple[int, str, str, float]] = [] + previous = 0 + + # Build endpoints: each cut + the implicit final boundary + endpoints = [(cp, at, ev, cf) for cp, at, ev, cf in cuts] + [ + (page_count, "final", "", 1.0) + ] + + for cut_page, anchor_type, evidence, confidence in endpoints: + shard_length = cut_page - previous + if shard_length > max_pages: + # Try H2 refinement for this oversized shard + h1_title = _find_h1_for_range(h1_candidates, previous, cut_page) + h2_cut_found = False + if h1_title: + h2_titles = extract_children_titles( + ctx.blackboard.toc_hierarchies, h1_title, + ) + if h2_titles: + search_pages = list(range(previous + 1, cut_page + 1)) + page_texts = read_page_texts( + ctx.pdf_path, search_pages, timeout=120, + ) + h2_candidates, _ = grep_titles_in_pages( + h2_titles, search_pages, page_texts, + source="h2_refine", + ) + best = _pick_and_verify_best_cut( + h2_candidates, previous, cut_page, + min_pages, max_pages, ctx, + ) + if best: + refined.append(best) + h2_cut_found = True + logger.info( + "[h2_refine] split oversized shard [{}-{}] at page {}", + previous + 1, cut_page, best[0], + ) + else: + logger.warning( + "[h2_refine] no valid H2 cut for shard [{}-{}]", + previous + 1, cut_page, + ) + else: + logger.info( + "[h2_refine] no H2 titles found under H1 '{}' for shard [{}-{}]", + h1_title[:30], previous + 1, cut_page, + ) + else: + logger.info( + "[h2_refine] no H1 found for oversized shard [{}-{}]", + previous + 1, cut_page, + ) + + # Ultimate fallback: forced_max_size + if not h2_cut_found: + fallback_page = previous + max_pages + if fallback_page < cut_page: + refined.append(( + fallback_page, "forced_max_size", + "H2 refine fallback: forced max size", 0.2, + )) + logger.warning( + "[h2_refine] forced_max_size fallback at page {} for shard [{}-{}]", + fallback_page, previous + 1, cut_page, + ) + + # Append the original cut (skip the synthetic "final" endpoint) + if anchor_type != "final": + refined.append((cut_page, anchor_type, evidence, confidence)) + previous = cut_page + + return refined + + @register_tool( name="propose.shard_plan", description="Ask the LLM to decide whether and where to split using profile, TOC, and H1 evidence.", @@ -319,6 +496,12 @@ def propose_shard_plan(ctx: ToolContext, _args: dict[str, Any]) -> ToolResult: }, ) + # C3: H2 refinement – split any shard that still exceeds max_pages + if cuts: + cuts = _refine_with_h2( + cuts, page_count, min_pages, max_pages, ctx, h1_candidates, + ) + shards = _cuts_to_shards(cuts, page_count) enabled = len(shards) > 1 if not enabled: diff --git a/apps/worker/app/services/document_agent/validators.py b/apps/worker/app/services/document_agent/validators.py index 2f06371b..6f6f01f5 100644 --- a/apps/worker/app/services/document_agent/validators.py +++ b/apps/worker/app/services/document_agent/validators.py @@ -22,8 +22,10 @@ def validate_shard_plan( if not plan.shards: errors.append("shard_plan has no shards") return ValidationReport(valid=False, errors=errors, warnings=warnings) + sorted_shards = sorted(plan.shards, key=lambda item: item.shard_index) expected_start = 1 - for shard in sorted(plan.shards, key=lambda item: item.shard_index): + for idx, shard in enumerate(sorted_shards): + is_last = idx == len(sorted_shards) - 1 if shard.page_start != expected_start: errors.append( f"shard {shard.shard_index} starts at {shard.page_start}, expected {expected_start}" @@ -36,7 +38,13 @@ def validate_shard_plan( if plan.enabled and length > max_pages: errors.append(f"shard {shard.shard_index} exceeds max_pages={max_pages}") if plan.enabled and length < min_pages: - errors.append(f"shard {shard.shard_index} shorter than min_pages={min_pages}") + if is_last: + warnings.append( + f"shard {shard.shard_index} (final) shorter than min_pages={min_pages} " + f"({length} pages)" + ) + else: + errors.append(f"shard {shard.shard_index} shorter than min_pages={min_pages}") expected_start = shard.page_end + 1 if expected_start != page_count + 1: errors.append("shard_plan does not cover full document") diff --git a/apps/worker/app/services/document_ingestion/parse_execution.py b/apps/worker/app/services/document_ingestion/parse_execution.py index a65856d3..665194b6 100644 --- a/apps/worker/app/services/document_ingestion/parse_execution.py +++ b/apps/worker/app/services/document_ingestion/parse_execution.py @@ -4,10 +4,18 @@ from app.services.document_ingestion.source_preparation import PreparedSourceFile from app.services.document_parser import parse_service from app.services.document_parser.orchestration.parse_output import ParseOutput -from app.services.document_parser.support.stage_profiler import stage_timer +from app.services.document_parser.support.stage_profiler import ( + stage_timer, + init_stage_tracker, + cleanup_stage_tracker, +) from loguru import logger from shared.models.schemas.job_metadata import JobMetadataHelper +from shared.services.ai.token_tracking import ( + init_token_tracker, + cleanup_token_tracker, +) def execute_document_parse( @@ -29,50 +37,63 @@ def execute_document_parse( f"internal_filename={prepared_source.internal_parse_name}, type={doc_type}" ) - with stage_timer( - "worker.parse.document", - job_id=job_id, - filename=prepared_source.source_file_name, - doc_type=doc_type, - ): - parse_output = parse_service.checkerboard_parse_output( - file_full_path=prepared_source.local_file_path, - filename=prepared_source.source_file_name, - output_dir=output_dir, + token_usage_dict = init_token_tracker() + stage_timing_dict = init_stage_tracker() + + try: + with stage_timer( + "worker.parse.document", job_id=job_id, - internal_output_filename=prepared_source.internal_parse_name, + filename=prepared_source.source_file_name, doc_type=doc_type, - smart_title_parse=JobMetadataHelper.get_parsing_param( - job_context.job_metadata, - "smart_title_parse", - True, - ), - summary_image=JobMetadataHelper.get_parsing_param( - job_context.job_metadata, - "summary_image", - True, - ), - summary_table=JobMetadataHelper.get_parsing_param( - job_context.job_metadata, - "summary_table", - True, - ), - summary_txt=JobMetadataHelper.get_parsing_param( - job_context.job_metadata, - "summary_txt", - True, - ), - add_frag_desc=JobMetadataHelper.get_parsing_param( - job_context.job_metadata, - "add_frag_desc", - "", - ), - s3_key=job_context.s3_key, + ): + parse_output = parse_service.checkerboard_parse_output( + file_full_path=prepared_source.local_file_path, + filename=prepared_source.source_file_name, + output_dir=output_dir, + job_id=job_id, + internal_output_filename=prepared_source.internal_parse_name, + doc_type=doc_type, + smart_title_parse=JobMetadataHelper.get_parsing_param( + job_context.job_metadata, + "smart_title_parse", + True, + ), + summary_image=JobMetadataHelper.get_parsing_param( + job_context.job_metadata, + "summary_image", + True, + ), + summary_table=JobMetadataHelper.get_parsing_param( + job_context.job_metadata, + "summary_table", + True, + ), + summary_txt=JobMetadataHelper.get_parsing_param( + job_context.job_metadata, + "summary_txt", + True, + ), + add_frag_desc=JobMetadataHelper.get_parsing_param( + job_context.job_metadata, + "add_frag_desc", + "", + ), + s3_key=job_context.s3_key, + ) + + logger.info( + "File parsing completed: " + f"job_id={job_id}, output_dir={parse_output.output_dir}, " + f"chunks={parse_output.rows_count}" ) - logger.info( - "File parsing completed: " - f"job_id={job_id}, output_dir={parse_output.output_dir}, " - f"chunks={parse_output.rows_count}" - ) + job_context.job_metadata["stages"] = { + "timing_ms": dict(stage_timing_dict), + "token_usage": dict(token_usage_dict), + } + finally: + cleanup_token_tracker() + cleanup_stage_tracker() + return parse_output diff --git a/apps/worker/app/services/document_parser/formats/image/parser.py b/apps/worker/app/services/document_parser/formats/image/parser.py index 9dd3808a..251e2e28 100755 --- a/apps/worker/app/services/document_parser/formats/image/parser.py +++ b/apps/worker/app/services/document_parser/formats/image/parser.py @@ -142,7 +142,7 @@ def ask_image( if task in ("summary-images", "atlas-page-info"): image_model = settings.IMAGE_MODEL or "gpt-4-vision-preview" - else: # Image Q&A and OCR use better models + else: # OCR and image type classification use higher-capability models image_model = settings.IMAGE_MODEL_MAX or "gpt-4-vision-preview" if len(urls_) > 0: diff --git a/apps/worker/app/services/document_parser/formats/pdf/parser.py b/apps/worker/app/services/document_parser/formats/pdf/parser.py index ab1de83a..217c89f6 100755 --- a/apps/worker/app/services/document_parser/formats/pdf/parser.py +++ b/apps/worker/app/services/document_parser/formats/pdf/parser.py @@ -1,5 +1,7 @@ # pyright: reportArgumentType=false import os +import re +import shutil from app.services.document_parser.formats.markdown.parser import parse_md from app.services.document_parser.orchestration.oversized_pdf_policy import ( @@ -10,6 +12,7 @@ from loguru import logger from shared.core.config import settings +from shared.services.storage.job_file_storage import JobFileStorage def parse_pdfs( @@ -20,6 +23,7 @@ def parse_pdfs( profile=None, relative_root=None, s3_key=None, + job_id=None, ): route = profile.route if profile else "standard" base_llm_paras.update({"doc_name": filename}) @@ -43,6 +47,7 @@ def parse_pdfs( return _parse_oversized_pdf( pdf_path, filename, output_dir, base_llm_paras, profile=profile, relative_root=relative_root, s3_key=s3_key, + job_id=job_id, ) except Exception as exc: logger.exception( @@ -74,7 +79,7 @@ def parse_pdfs( def _parse_oversized_pdf( pdf_path, filename, output_dir, base_llm_paras, - profile=None, relative_root=None, s3_key=None, + profile=None, relative_root=None, s3_key=None, job_id=None, ): """Handle PDFs exceeding MinerU's page limit via shard-first hierarchy. @@ -101,176 +106,267 @@ def _parse_oversized_pdf( split_pdf, ) - job_id = base_llm_paras.get("doc_name", filename) - - # 1. Run doc_agent to get full anatomy map (shard plan + TOC info) - with stage_timer("pdf.doc_agent", filename=filename): - anatomy = run_doc_agent(pdf_path, job_id=job_id, output_dir=output_dir) + doc_agent_job_id = job_id or base_llm_paras.get("doc_name", filename) + work_dir: str | None = None + temp_shard_s3_keys: list[str] = [] + + try: + # 1. Run doc_agent to get full anatomy map (shard plan + TOC info) + with stage_timer("pdf.doc_agent", filename=filename): + anatomy = run_doc_agent( + pdf_path, + job_id=doc_agent_job_id, + output_dir=output_dir, + ) - agent_shards = anatomy.shard_plan.shards + agent_shards = anatomy.shard_plan.shards + + # 2. Extract TOC info from anatomy for page exclusion and heading constraint + toc_pages: set[int] = set() + toc_hierarchies = None + if anatomy.toc_result and anatomy.toc_result.toc_pages: + toc_pages = set(anatomy.toc_result.toc_pages) + toc_hierarchies = anatomy.toc_hierarchies + logger.info( + f"📌 DOC_AGENT TOC detected: {len(toc_pages)} pages to exclude " + f"({sorted(toc_pages)}), " + f"{len(toc_hierarchies) if toc_hierarchies else 0} hierarchy regions" + ) - # 2. Extract TOC info from anatomy for page exclusion and heading constraint - toc_pages: set[int] = set() - toc_hierarchies = None - if anatomy.toc_result and anatomy.toc_result.toc_pages: - toc_pages = set(anatomy.toc_result.toc_pages) - toc_hierarchies = anatomy.toc_hierarchies - logger.info( - f"📌 DOC_AGENT TOC detected: {len(toc_pages)} pages to exclude " - f"({sorted(toc_pages)}), " - f"{len(toc_hierarchies) if toc_hierarchies else 0} hierarchy regions" + # 3. Bin-pack agent shards to maximize MinerU page limit + merged_shards = bin_pack_shards( + agent_shards, + max_pages=settings.MAX_PDF_PAGE_LIMIT, ) - - # 3. Bin-pack agent shards to maximize MinerU page limit - merged_shards = bin_pack_shards(agent_shards, max_pages=settings.MAX_PDF_PAGE_LIMIT) - logger.info( - f"📦 Bin-packed {len(agent_shards)} agent shards → " - f"{len(merged_shards)} MinerU shards" - ) - for ms in merged_shards: logger.info( - f" shard_{ms.shard_index}: pages {ms.page_start}-{ms.page_end} " - f"({ms.page_count} pages)" - ) - - # 4. Physically split PDF (exclude TOC pages if detected) - work_dir = os.path.join(output_dir, "_shards") - os.makedirs(work_dir, exist_ok=True) - with stage_timer("pdf.split", filename=filename): - shard_pdf_paths, _page_remap = split_pdf( - pdf_path, merged_shards, work_dir, - exclude_pages=toc_pages if toc_pages else None, + f"📦 Bin-packed {len(agent_shards)} agent shards → " + f"{len(merged_shards)} MinerU shards" ) + for ms in merged_shards: + logger.info( + f" shard_{ms.shard_index}: pages {ms.page_start}-{ms.page_end} " + f"({ms.page_count} pages)" + ) - # 5. Parse each shard via MinerU (parallel) - shard_output_dirs: list[str | None] = [None] * len(shard_pdf_paths) - concurrency = settings.MINERU_SHARD_CONCURRENCY + # 4. Physically split PDF (exclude TOC pages if detected) + work_dir = os.path.join(output_dir, "_shards") + os.makedirs(work_dir, exist_ok=True) + with stage_timer("pdf.split", filename=filename): + shard_pdf_paths, _page_remap = split_pdf( + pdf_path, merged_shards, work_dir, + exclude_pages=toc_pages if toc_pages else None, + ) - def _parse_single_shard(shard_idx, shard_pdf): - shard_out = os.path.join(work_dir, f"shard_{shard_idx}_output") - os.makedirs(shard_out, exist_ok=True) - shard_filename = ( - f"{os.path.splitext(filename)[0]}_shard{shard_idx}.pdf" - ) - logger.info( - f" 🔄 MinerU shard_{shard_idx}: parsing" + temp_shard_s3_keys = [ + _build_temp_shard_s3_key( + source_s3_key=s3_key, + job_id=job_id, + filename=filename, + shard_index=shard_index, + ) + for shard_index, _shard_pdf_path in enumerate(shard_pdf_paths) + ] + + # 5. Parse each shard via MinerU (parallel) + shard_output_dirs: list[str | None] = [None] * len(shard_pdf_paths) + concurrency = settings.MINERU_SHARD_CONCURRENCY + + def _parse_single_shard(shard_idx, shard_pdf): + assert work_dir is not None + shard_out = os.path.join(work_dir, f"shard_{shard_idx}_output") + os.makedirs(shard_out, exist_ok=True) + shard_filename = ( + f"{os.path.splitext(filename)[0]}_shard{shard_idx}.pdf" + ) + shard_s3_key = temp_shard_s3_keys[shard_idx] + logger.info( + f" 🔄 MinerU shard_{shard_idx}: parsing via S3 URL " + f"({shard_s3_key})" + ) + parse_via_full(shard_pdf, shard_filename, shard_out, s3_key=shard_s3_key) + return shard_out + + with stage_timer( + "pdf.mineru_parallel", filename=filename, shard_count=len(shard_pdf_paths) + ): + with ThreadPoolExecutor(max_workers=concurrency) as executor: + futures = { + executor.submit(_parse_single_shard, i, shard_pdf_path): i + for i, shard_pdf_path in enumerate(shard_pdf_paths) + } + for future in as_completed(futures): + idx = futures[future] + shard_output_dirs[idx] = future.result() + + # 6. Per-shard heading prediction (parallel) + @dataclass + class ShardHeadingResult: + shard_index: int + lines_with_heading: list[str] + heading_count: int + + smart_parse = base_llm_paras.get("smart_title_parse", True) + hierarchy_model_name = ( + base_llm_paras.get("hierarchy_model_name") + or base_llm_paras.get("model_name", settings.NORMOL_MODEL) ) - parse_via_full(shard_pdf, shard_filename, shard_out, s3_key=None) - return shard_out - - with stage_timer( - "pdf.mineru_parallel", filename=filename, shard_count=len(shard_pdf_paths) - ): - with ThreadPoolExecutor(max_workers=concurrency) as executor: - futures = { - executor.submit(_parse_single_shard, i, shard_pdf_path): i - for i, shard_pdf_path in enumerate(shard_pdf_paths) - } - for future in as_completed(futures): - idx = futures[future] - shard_output_dirs[idx] = future.result() - - # 6. Per-shard heading prediction (parallel) - @dataclass - class ShardHeadingResult: - shard_index: int - lines_with_heading: list[str] - heading_count: int - - smart_parse = base_llm_paras.get("smart_title_parse", True) - hierarchy_model_name = ( - base_llm_paras.get("hierarchy_model_name") - or base_llm_paras.get("model_name", settings.NORMOL_MODEL) - ) - def _predict_shard_headings(shard_idx: int, shard_out_dir: str) -> ShardHeadingResult: - """Run full heading prediction pipeline on a single shard's full.md.""" - md_path = os.path.join(shard_out_dir, "full.md") - if not os.path.exists(md_path): - raise FileNotFoundError(f"shard_{shard_idx}: full.md not found") - - with open(md_path, "r", encoding="utf-8") as f: - md_lines = f.readlines() - md_lines = [line.strip() for line in md_lines if line.strip() != ""] - md_lines = merge_html_tables(md_lines) + def _predict_shard_headings(shard_idx: int, shard_out_dir: str) -> ShardHeadingResult: + """Run full heading prediction pipeline on a single shard's full.md.""" + md_path = os.path.join(shard_out_dir, "full.md") + if not os.path.exists(md_path): + raise FileNotFoundError(f"shard_{shard_idx}: full.md not found") + + with open(md_path, "r", encoding="utf-8") as f: + md_lines = f.readlines() + md_lines = [line.strip() for line in md_lines if line.strip() != ""] + md_lines = merge_html_tables(md_lines) + + # TOC context: first TOC shared by all shards; subsequent TOCs assigned + # by page boundary. For simplicity, all TOCs are passed since pred_titles + # only matches headings actually present in this shard's content. + shard_toc = toc_hierarchies + + lines_with_heading = eval_md_headings( + md_lines, + source_type="md", + toc_hierarchies=shard_toc, + smart_parse=smart_parse, + model_name=hierarchy_model_name, + output_dir=shard_out_dir, + layout_json_path=( + os.path.join(shard_out_dir, "layout.json") + if os.path.exists(os.path.join(shard_out_dir, "layout.json")) + else None + ), + ) - # TOC context: first TOC shared by all shards; subsequent TOCs assigned - # by page boundary. For simplicity, all TOCs are passed since pred_titles - # only matches headings actually present in this shard's content. - shard_toc = toc_hierarchies + heading_count = sum(1 for line in lines_with_heading if line.startswith("#")) + logger.info( + f" ✅ shard_{shard_idx}: {heading_count} headings identified " + f"from {len(lines_with_heading)} lines" + ) + return ShardHeadingResult( + shard_index=shard_idx, + lines_with_heading=lines_with_heading, + heading_count=heading_count, + ) - lines_with_heading = eval_md_headings( - md_lines, - source_type="md", - toc_hierarchies=shard_toc, - smart_parse=smart_parse, - model_name=hierarchy_model_name, - output_dir=shard_out_dir, - layout_json_path=( - os.path.join(shard_out_dir, "layout.json") - if os.path.exists(os.path.join(shard_out_dir, "layout.json")) - else None - ), + shard_heading_results: list[ShardHeadingResult | None] = [None] * len(shard_output_dirs) + + with stage_timer( + "pdf.shard_headings", filename=filename, shard_count=len(shard_output_dirs) + ): + with ThreadPoolExecutor(max_workers=concurrency) as executor: + futures = { + executor.submit(_predict_shard_headings, i, shard_dir): i + for i, shard_dir in enumerate(shard_output_dirs) + if shard_dir is not None + } + for future in as_completed(futures): + idx = futures[future] + shard_heading_results[idx] = future.result() + + # 7. Merge: concatenate lines_with_heading (in shard order) + merge images + complete_heading_results: list[ShardHeadingResult] = [] + for index, result in enumerate(shard_heading_results): + if result is None: + raise RuntimeError(f"Missing heading result for shard_{index}") + complete_heading_results.append(result) + + # Compute level offsets: continuation shards get shifted deeper. + shard_offsets: list[int] = [] + for shard in agent_shards: + if shard.is_continuation: + shard_offsets.append(max(shard.split_depth - 1, 0)) + else: + shard_offsets.append(0) + if any(offset > 0 for offset in shard_offsets): + logger.info(f"📐 Shard level offsets: {shard_offsets}") + + all_lines_with_heading: list[str] = merge_shard_lines( + [result.lines_with_heading for result in complete_heading_results], + shard_offsets=shard_offsets, + ) + total_headings = sum( + 1 for line in all_lines_with_heading if line.startswith("#") ) - heading_count = sum(1 for line in lines_with_heading if line.startswith("#")) logger.info( - f" ✅ shard_{shard_idx}: {heading_count} headings identified " - f"from {len(lines_with_heading)} lines" - ) - return ShardHeadingResult( - shard_index=shard_idx, - lines_with_heading=lines_with_heading, - heading_count=heading_count, + f"📎 Merged {len(complete_heading_results)} shards: " + f"{len(all_lines_with_heading)} lines, {total_headings} headings" ) - shard_heading_results: list[ShardHeadingResult | None] = [None] * len(shard_output_dirs) - - with stage_timer( - "pdf.shard_headings", filename=filename, shard_count=len(shard_output_dirs) - ): - with ThreadPoolExecutor(max_workers=concurrency) as executor: - futures = { - executor.submit(_predict_shard_headings, i, shard_dir): i - for i, shard_dir in enumerate(shard_output_dirs) - if shard_dir is not None - } - for future in as_completed(futures): - idx = futures[future] - shard_heading_results[idx] = future.result() - - # 7. Merge: concatenate lines_with_heading (in shard order) + merge images - complete_heading_results: list[ShardHeadingResult] = [] - for index, result in enumerate(shard_heading_results): - if result is None: - raise RuntimeError(f"Missing heading result for shard_{index}") - complete_heading_results.append(result) - - all_lines_with_heading: list[str] = merge_shard_lines( - [result.lines_with_heading for result in complete_heading_results] - ) - total_headings = sum( - 1 for line in all_lines_with_heading if line.startswith("#") - ) + with stage_timer("pdf.merge_images", filename=filename): + merge_images(shard_output_dirs, output_dir) + + logger.info("✅ Shard-first hierarchy complete, entering parse_md Phase B") - logger.info( - f"📎 Merged {len(complete_heading_results)} shards: " - f"{len(all_lines_with_heading)} lines, {total_headings} headings" + # 8. parse_md Phase B only (skip TOC detection + heading prediction) + with stage_timer("pdf.parse_md", filename=filename): + return parse_md( + output_dir, + source_type="md", + base_llm_paras=base_llm_paras, + relative_root=relative_root, + lines_with_heading=all_lines_with_heading, + ) + finally: + _cleanup_temp_shard_s3_assets(temp_shard_s3_keys) + _cleanup_local_shard_workspace(work_dir) + + +def _build_temp_shard_s3_key( + *, + source_s3_key: str | None, + job_id: str | None, + filename: str, + shard_index: int, +) -> str: + owner_segment = _sanitize_temp_storage_segment( + job_id or _source_key_stem(source_s3_key) or os.path.splitext(filename)[0] ) + return f"tmp/mineru-shards/{owner_segment}/shard_{shard_index}.pdf" - with stage_timer("pdf.merge_images", filename=filename): - merge_images(shard_output_dirs, output_dir) - logger.info("✅ Shard-first hierarchy complete, entering parse_md Phase B") +def _source_key_stem(source_s3_key: str | None) -> str | None: + if not source_s3_key: + return None + key_name = os.path.basename(source_s3_key.rstrip("/")) + stem, _extension = os.path.splitext(key_name) + return stem or None - # 8. parse_md Phase B only (skip TOC detection + heading prediction) - with stage_timer("pdf.parse_md", filename=filename): - return parse_md( - output_dir, - source_type="md", - base_llm_paras=base_llm_paras, - relative_root=relative_root, - lines_with_heading=all_lines_with_heading, - ) +def _sanitize_temp_storage_segment(value: object) -> str: + normalized = re.sub(r"[^A-Za-z0-9_.-]+", "-", str(value)).strip(".-") + return normalized or "document" + + +def _cleanup_temp_shard_s3_assets(s3_keys: list[str]) -> None: + if not s3_keys: + return + storage = JobFileStorage() + for s3_key in s3_keys: + try: + deleted = storage.delete_upload_file(s3_key) + if deleted: + logger.info(f"Deleted temporary MinerU shard S3 object: {s3_key}") + else: + logger.debug(f"Temporary MinerU shard S3 object was absent: {s3_key}") + except Exception as exc: + logger.warning( + f"Failed to delete temporary MinerU shard S3 object " + f"{s3_key}: {exc}" + ) + +def _cleanup_local_shard_workspace(work_dir: str | None) -> None: + if not work_dir or not os.path.exists(work_dir): + return + try: + shutil.rmtree(work_dir) + logger.info(f"Deleted temporary MinerU shard workspace: {work_dir}") + except Exception as exc: + logger.warning( + f"Failed to delete temporary MinerU shard workspace {work_dir}: {exc}" + ) diff --git a/apps/worker/app/services/document_parser/formats/pdf/shard_merger.py b/apps/worker/app/services/document_parser/formats/pdf/shard_merger.py index 143cfabb..748413e5 100644 --- a/apps/worker/app/services/document_parser/formats/pdf/shard_merger.py +++ b/apps/worker/app/services/document_parser/formats/pdf/shard_merger.py @@ -17,9 +17,33 @@ def _extract_heading_key(line: str) -> tuple[int, str] | None: return len(m.group(1)), m.group(2).strip() -def merge_shard_lines(shard_lines_list: list[list[str]]) -> list[str]: +def _apply_level_offset(lines: list[str], offset: int) -> list[str]: + """Shift all markdown heading levels by *offset* (e.g. ## → ### when offset=1).""" + if offset <= 0: + return lines + result: list[str] = [] + for line in lines: + key = _extract_heading_key(line) + if key is not None: + level, text = key + new_level = level + offset + result.append(f"{'#' * new_level} {text}") + else: + result.append(line) + return result + + +def merge_shard_lines( + shard_lines_list: list[list[str]], + shard_offsets: list[int] | None = None, +) -> list[str]: """Concatenate per-shard lines_with_heading in order, removing boundary duplicates. + When ``shard_offsets`` is provided, each shard's heading levels are shifted + by the corresponding offset. This is used for continuation shards (from + H2+ splitting) whose heading predictor starts from L1 but should be deeper + in the global hierarchy. + When a PDF section-divider page falls at the end of shard N and the same heading opens shard N+1, each shard independently identifies it as a heading, resulting in two consecutive identical headings after naïve concatenation. @@ -37,13 +61,20 @@ def merge_shard_lines(shard_lines_list: list[list[str]]) -> list[str]: if not lines: continue + # Apply level offset for continuation shards + offset = shard_offsets[shard_idx] if shard_offsets else 0 + lines = _apply_level_offset(lines, offset) + # Determine next shard's first heading (if any) next_first_heading: tuple[int, str] | None = None for future_idx in range(shard_idx + 1, len(shard_lines_list)): - for next_line in shard_lines_list[future_idx]: + future_lines = shard_lines_list[future_idx] + future_offset = shard_offsets[future_idx] if shard_offsets else 0 + for next_line in future_lines: key = _extract_heading_key(next_line) if key is not None: - next_first_heading = key + # Compare with the offset-adjusted level + next_first_heading = (key[0] + future_offset, key[1]) break if next_first_heading is not None: break diff --git a/apps/worker/app/services/document_parser/formats/pdf/shard_splitter.py b/apps/worker/app/services/document_parser/formats/pdf/shard_splitter.py index ebf59914..99413e75 100644 --- a/apps/worker/app/services/document_parser/formats/pdf/shard_splitter.py +++ b/apps/worker/app/services/document_parser/formats/pdf/shard_splitter.py @@ -69,28 +69,16 @@ def bin_pack_shards( agent_shards: list["Shard"], max_pages: int, ) -> list[MergedShard]: - """Greedy left-to-right bin-packing: merge adjacent agent shards up to max_pages.""" - if not agent_shards: - return [] - - merged: list[MergedShard] = [] - cur_start = agent_shards[0].page_start - cur_end = agent_shards[0].page_end - - for shard in agent_shards[1:]: - if shard.page_end - cur_start + 1 <= max_pages: - cur_end = shard.page_end - else: - merged.append( - MergedShard(len(merged), page_start=cur_start, page_end=cur_end) - ) - cur_start = shard.page_start - cur_end = shard.page_end + """1:1 mapping: each agent shard becomes its own MinerU shard. - merged.append( - MergedShard(len(merged), page_start=cur_start, page_end=cur_end) - ) - return merged + Agent shards are cut at semantic boundaries (H1/H2) by the document + agent. Merging them would cross those boundaries and degrade heading + prediction quality, so we preserve them as-is. + """ + return [ + MergedShard(idx, page_start=s.page_start, page_end=s.page_end) + for idx, s in enumerate(agent_shards) + ] def split_pdf( diff --git a/apps/worker/app/services/document_parser/orchestration/format_adapters.py b/apps/worker/app/services/document_parser/orchestration/format_adapters.py index 6dc3409c..e6af2f8e 100644 --- a/apps/worker/app/services/document_parser/orchestration/format_adapters.py +++ b/apps/worker/app/services/document_parser/orchestration/format_adapters.py @@ -85,6 +85,7 @@ def parse(self, session: ParseSession) -> ParseOutput: profile=session.profile, relative_root=session.relative_root, s3_key=session.s3_key, + job_id=session.job_id, ) return ParseOutput(output_dir=session.full_output_dir, parsed_df=parsed_df) diff --git a/apps/worker/app/services/document_parser/structure/heading_candidates.py b/apps/worker/app/services/document_parser/structure/heading_candidates.py index 976d8127..ea9812f4 100644 --- a/apps/worker/app/services/document_parser/structure/heading_candidates.py +++ b/apps/worker/app/services/document_parser/structure/heading_candidates.py @@ -10,7 +10,7 @@ from loguru import logger from pandas import Index -from app.services.document_parser.support.text_helpers import count_cn_en +from app.services.document_parser.support.text_helpers import count_cn_en, detect_primary_lang HEADING_COLUMNS = Index(["id", "heading", "level", "reason"]) @@ -136,8 +136,10 @@ def remove_by_conditions(text, *, include_punc: bool = False): else: neg_triggered_code.append(0) - MAX_HEADING_TOKENS = 10 - neg_triggered_code.append(1 if count_cn_en(text) > MAX_HEADING_TOKENS else 0) + MAX_HEADING_TOKENS_ZH = 30 + MAX_HEADING_TOKENS_EN = 10 + _limit = MAX_HEADING_TOKENS_ZH if detect_primary_lang(text) == "zh" else MAX_HEADING_TOKENS_EN + neg_triggered_code.append(1 if count_cn_en(text) > _limit else 0) return neg_triggered_code diff --git a/apps/worker/app/services/document_parser/support/stage_profiler.py b/apps/worker/app/services/document_parser/support/stage_profiler.py index 3d9a9177..7ec11c34 100644 --- a/apps/worker/app/services/document_parser/support/stage_profiler.py +++ b/apps/worker/app/services/document_parser/support/stage_profiler.py @@ -1,12 +1,77 @@ -"""Structured stage timing helper for the document parsing pipeline.""" +"""Structured stage timing helper for the document parsing pipeline. + +Stage-timing accumulation uses the same greenlet-parent-chain approach +as token_tracking.py — see that module's docstring for rationale. +""" + +from __future__ import annotations from contextlib import contextmanager from time import perf_counter from typing import Any, Iterator +import threading from loguru import logger +# ── Greenlet-safe stage timing accumulator ── + +_trackers: dict[int, dict[str, int]] = {} +_lock = threading.Lock() +_root_ids: dict[int, int] = {} + + +def _current_greenlet_id() -> int: + try: + import gevent + return id(gevent.getcurrent()) + except ImportError: + import threading as _threading + return _threading.get_ident() + + +def _find_root_id() -> int | None: + """Walk up the greenlet parent chain to find a registered root id.""" + gid = _current_greenlet_id() + if gid in _trackers: + return gid + if gid in _root_ids: + return _root_ids[gid] + try: + import gevent + g = gevent.getcurrent() + while g is not None: + pid = id(g) + if pid in _trackers: + _root_ids[gid] = pid + return pid + g = getattr(g, 'parent', None) + except ImportError: + pass + return None + + +def init_stage_tracker() -> dict[str, int]: + """Create a new stage timing accumulator for the current parse task.""" + gid = _current_greenlet_id() + tracker: dict[str, int] = {} + with _lock: + _trackers[gid] = tracker + return tracker + + +def cleanup_stage_tracker() -> None: + """Remove the stage tracker for the current greenlet.""" + gid = _current_greenlet_id() + with _lock: + _trackers.pop(gid, None) + stale = [k for k, v in _root_ids.items() if v == gid] + for k in stale: + del _root_ids[k] + + +# ── Public timer context manager (unchanged API) ── + def _compact_fields(fields: dict[str, Any]) -> dict[str, Any]: """Drop empty values so timing logs stay compact and readable.""" compacted_fields: dict[str, Any] = {} @@ -37,6 +102,14 @@ def stage_timer(stage: str, **fields: Any) -> Iterator[None]: raise elapsed_ms = int((perf_counter() - start_time) * 1000) + + root = _find_root_id() + if root is not None: + tracker = _trackers.get(root) + if tracker is not None: + with _lock: + tracker[stage] = tracker.get(stage, 0) + elapsed_ms + logger.bind( event="document_parser.stage", stage=stage, diff --git a/apps/worker/tests/contract/test_parse_task_contract.py b/apps/worker/tests/contract/test_parse_task_contract.py index 843825bc..355966cd 100644 --- a/apps/worker/tests/contract/test_parse_task_contract.py +++ b/apps/worker/tests/contract/test_parse_task_contract.py @@ -558,6 +558,13 @@ class _Profile: page_count = 3 calls: dict[str, object] = {} + parse_s3_keys: list[str | None] = [] + deleted_s3_keys: list[str] = [] + + class _FakeJobFileStorage: + def delete_upload_file(self, storage_key: str) -> bool: + deleted_s3_keys.append(storage_key) + return True def _fake_run_doc_agent(pdf_path_arg: str, job_id: str, output_dir: str): calls["doc_agent"] = { @@ -610,6 +617,7 @@ def _fake_split_pdf(pdf_path_arg, shards, work_dir, exclude_pages=None): return paths, None def _fake_parse_via_full(shard_pdf, shard_filename, shard_out, s3_key=None): + parse_s3_keys.append(s3_key) shard_index = 0 if "shard0" in shard_filename else 1 lines_by_shard = { 0: ["# Chapter 1", "Shard one body."], @@ -648,6 +656,7 @@ def _identity_eval_md_headings( "app.services.document_parser.formats.markdown.parser.eval_md_headings", _identity_eval_md_headings, ) + monkeypatch.setattr(pdf_parser, "JobFileStorage", _FakeJobFileStorage) df = pdf_parser.parse_pdfs( str(pdf_path), @@ -664,10 +673,20 @@ def _identity_eval_md_headings( }, profile=_Profile(), relative_root="oversized.pdf", + s3_key="uploads/job-oversized.pdf", + job_id="job-oversized", ) assert calls["exclude_pages"] == {1} + assert calls["doc_agent"]["job_id"] == "job-oversized" assert len(calls["heading_dirs"]) == 2 + expected_s3_keys = [ + "tmp/mineru-shards/job-oversized/shard_0.pdf", + "tmp/mineru-shards/job-oversized/shard_1.pdf", + ] + assert parse_s3_keys == expected_s3_keys + assert deleted_s3_keys == expected_s3_keys + assert not (output_dir / "_shards").exists() assert list(df["type"]) == ["PTXT", "PTXT"] assert list(df["content"]) == ["Shard one body.", "Shard two body."] assert list(df["path"]) == [ diff --git a/packages/shared-python/shared/core/config/ai.py b/packages/shared-python/shared/core/config/ai.py index 315511f6..9d035051 100644 --- a/packages/shared-python/shared/core/config/ai.py +++ b/packages/shared-python/shared/core/config/ai.py @@ -35,7 +35,7 @@ class AIConfig(BaseModel): IMAGE_MODEL_MAX: str = Field( default="qwen3.5-flash", - description="Higher-capability image model for OCR and ask-image Q&A", + description="Higher-capability image model for OCR and image type classification", ) RETRIEVAL_DECOMPOSITION_ENABLED: bool = Field( default=False, diff --git a/packages/shared-python/shared/services/ai/llm_mock.py b/packages/shared-python/shared/services/ai/llm_mock.py index c751e325..ccb84143 100644 --- a/packages/shared-python/shared/services/ai/llm_mock.py +++ b/packages/shared-python/shared/services/ai/llm_mock.py @@ -139,15 +139,11 @@ def _detect_mock_task(prompt_text: str) -> str: if "perform ocr operation" in normalized_prompt: return "ocr-image" if ( - "you will receive an image" in normalized_prompt - and "line 1: output a short title" in normalized_prompt + "you will receive an image from a document" in normalized_prompt + and "identify the image type" in normalized_prompt ): return "summary-images" - if ( - "you will receive one or more images and the user's current question" - in normalized_prompt - ): - return "ask-image" + if "summaries of sub-sections from a document section" in normalized_prompt: return "file-summary" if ( @@ -302,7 +298,7 @@ def _build_mock_response(task_name: str) -> str: "atlas-page-info": "Mock atlas page info", "ocr-image": "Mock OCR text", "summary-images": "Mock Image Title\nMock image summary", - "ask-image": "Mock image answer", + "file-summary": "Mock section summary", "summary-titled": "Mock Title\nMock summary", "summary": "Mock summary", diff --git a/packages/shared-python/shared/services/ai/openai_compatible_client_sync.py b/packages/shared-python/shared/services/ai/openai_compatible_client_sync.py index e8d5f413..9ce586eb 100644 --- a/packages/shared-python/shared/services/ai/openai_compatible_client_sync.py +++ b/packages/shared-python/shared/services/ai/openai_compatible_client_sync.py @@ -20,6 +20,7 @@ from shared.services.http.client_pool import get_sync_client from shared.services.ai.llm_mock import build_mock_chat_completion_response from shared.utils.security_utils import mask_api_key +from shared.services.ai.token_tracking import record_tokens LOCAL_DEBUG = os.getenv("LOCAL_DEBUG", "0") == "1" LLMUsage = dict[str, int] @@ -213,7 +214,9 @@ def _make_ali_pool_raw_call( internal_message="AI returned empty result", provider=self.default_model, ) - return response, _extract_usage(response) + usage = _extract_usage(response) + record_tokens(usage) + return response, usage except openai.RateLimitError as exc: retry_after = _parse_retry_after(exc) quota_manager.mark_rate_limited(lease.token_id, retry_after) @@ -295,18 +298,18 @@ def chat_completion_raw_with_usage( allowed_api_params = { "n", "stop", "presence_penalty", "frequency_penalty", "logit_bias", "user", "seed", "tools", "tool_choice", - "response_format", "logprobs", "top_logprobs", + "response_format", "logprobs", "top_logprobs", "extra_body", } for key, value in kwargs.items(): if key in allowed_api_params: api_kwargs[key] = value - extra_body = api_kwargs.get("extra_body", {}) - if isinstance(extra_body, dict): - extra_body.setdefault("enable_thinking", False) + if "extra_body" not in api_kwargs: + api_kwargs["extra_body"] = {"enable_thinking": False} else: - extra_body = {"enable_thinking": False} - api_kwargs["extra_body"] = extra_body + extra_body = api_kwargs["extra_body"] + if isinstance(extra_body, dict): + extra_body.setdefault("enable_thinking", False) effective_model = model or self.default_model if _should_mock_llm_calls(): @@ -344,7 +347,9 @@ def chat_completion_raw_with_usage( internal_message="AI returned empty result", provider=self.default_model, ) - return response, _extract_usage(response) + usage = _extract_usage(response) + record_tokens(usage) + return response, usage except LLMServiceException: raise except Exception as exc: @@ -385,21 +390,22 @@ def chat_completion_with_usage( allowed_api_params = { "n", "stop", "presence_penalty", "frequency_penalty", "logit_bias", "user", "seed", "tools", "tool_choice", - "response_format", "logprobs", "top_logprobs", + "response_format", "logprobs", "top_logprobs", "extra_body", } for key, value in kwargs.items(): if key in allowed_api_params: api_kwargs[key] = value - # ── disable thinking mode ── + # ── disable thinking mode (Qwen default) ── # Qwen3.5 by default enables thinking mode, which wastes tokens by outputting ... - # Explicitly disable it in all API calls - extra_body = api_kwargs.get("extra_body", {}) - if isinstance(extra_body, dict): - extra_body.setdefault("enable_thinking", False) + # Only inject the disable flag when the caller hasn't already set extra_body + # (e.g. DeepSeek thinking mode explicitly passes its own extra_body). + if "extra_body" not in api_kwargs: + api_kwargs["extra_body"] = {"enable_thinking": False} else: - extra_body = {"enable_thinking": False} - api_kwargs["extra_body"] = extra_body + extra_body = api_kwargs["extra_body"] + if isinstance(extra_body, dict): + extra_body.setdefault("enable_thinking", False) effective_model = model or self.default_model if _should_mock_llm_calls(): @@ -453,7 +459,9 @@ def chat_completion_with_usage( ) content = choices[0].message.content or "" - return content, _extract_usage(response) + usage = _extract_usage(response) + record_tokens(usage) + return content, usage except LLMServiceException: raise except Exception as exc: diff --git a/packages/shared-python/shared/services/ai/prompt_service.py b/packages/shared-python/shared/services/ai/prompt_service.py index 20bc6e3f..451f9b50 100755 --- a/packages/shared-python/shared/services/ai/prompt_service.py +++ b/packages/shared-python/shared/services/ai/prompt_service.py @@ -566,16 +566,65 @@ def build_prompt(task, texts, query, **kwargs): temperature = 0.1 max_tokens = int(kwargs["paras"]["max_tokens"] * 1.2) if texts.strip(): - img_context = f"- Image context is [{texts}], you may reference the title for summarization" + img_context = f"- Image context is [{texts}], you may reference the context for summarization" else: img_context = "" prompt = f""" - You will receive an image, which may be a photo, chart, or an image requiring OCR. - Your task is to extract the main content described in the image. Note: - - Line 1: Output a short title (no more than 15 characters) summarizing the image's core topic - - Line 2 onward: Provide a precise and concise summary, using text descriptions only, avoid extracting specific data from the image - - Your response **MUST BE in the SAME LANGUAGE** as any text visible in the image (if there is no text, English is preferred) + You will receive an image from a document. Your task is to extract the most + USEFUL information from this image based on its type. + + **STEP 1: Identify the image type** (do NOT output this step, use it internally): + - Credential/ID: identity cards, passports, driver licenses, business licenses, certificates, permits + - Data Chart: bar charts, line charts, pie charts, scatter plots, heatmaps, gauge charts + - Table Screenshot: tabular data rendered as an image + - Diagram: flowcharts, org charts, architecture diagrams, mind maps, UML diagrams + - Engineering Drawing: architectural plans, circuit diagrams, CAD drawings, mechanical drawings + - Photo: real-world photographs of people, objects, scenes, products + - Other: anything not fitting the above categories + + **STEP 2: Extract information according to image type**: + + For Credential/ID images: + - Extract ALL visible fields: name, ID number, date of birth, expiry date, + issuing authority, company name, registration number, legal representative, + business scope, qualification level, etc. + - Preserve exact values as shown (numbers, dates, codes) + + For Data Charts: + - Chart title, axis labels and units + - Key data points, trends, and notable patterns + - Time range or categories covered + - Data source if visible + + For Table Screenshots: + - Table title and column headers + - Key data entries and notable values + - Number of rows/columns and what the table represents + + For Diagrams (flow/architecture/org): + - All node names and their relationships + - Flow direction and process steps + - Hierarchy levels and key connections + + For Engineering/Technical Drawings: + - Drawing title, drawing number, scale + - Key dimensions and annotations + - Component/part names, material specifications + + For Photos: + - Primary subject and scene description + - Notable features, text, or signage visible + - Context clues about location or purpose + + For Other: + - Describe the most important visual information + + **Output format**: + - Line 1: A concise title (no more than 20 characters) capturing the core topic + - Line 2 onward: The extracted information following the type-specific guidelines above + - Your response **MUST BE in the SAME LANGUAGE** as any text visible in the image + (if no text, use English) - If the image is blank, unreadable, or contains no meaningful content, return exactly: null {img_context} @@ -595,23 +644,6 @@ def build_prompt(task, texts, query, **kwargs): - Do not add any format wrappers, prefixes, or explanations beyond the text content """ - elif task == "ask-image": - temperature = 0.1 - max_tokens = int(kwargs["paras"]["max_tokens"] * 1.2) - - prompt = f""" - You will receive one or more images and the user's current question: [{query}] - You may also receive context related to the image(s). - - {texts} - - Your task is to answer the user's question based on the image(s) and context (if any). Note: - - Your answer must be in the SAME LANGUAGE as the user's question - - Provide a complete and accurate answer with some explanation, but not exceeding {max_tokens} characters - - If the image content is unrelated to the user's question, return exactly: null - - Do not return any additional explanations or descriptions beyond the answer - """ - elif task == "judge-image-type": temperature = 0.1 prompt = """ diff --git a/packages/shared-python/shared/services/ai/token_tracking.py b/packages/shared-python/shared/services/ai/token_tracking.py new file mode 100644 index 00000000..63b84dda --- /dev/null +++ b/packages/shared-python/shared/services/ai/token_tracking.py @@ -0,0 +1,106 @@ +"""Greenlet-safe token usage tracker for document parsing pipeline. + +In the gevent worker, child greenlets (GeventPool.spawn) do NOT inherit +``ContextVar`` or ``threading.local`` from the parent. We therefore use +a module-level dict keyed by the *root* greenlet id of the current parse +task. ``init_tracker`` sets the greenlet id; ``record_tokens`` looks it +up the chain via ``gevent.getcurrent()`` so that child greenlets spawned +within the same parse task all accumulate into the same dict. + +Thread-safety: multiple parse tasks in the same worker process each have +distinct root greenlet ids, so their accumulators never collide. +""" + +from __future__ import annotations + +import threading + +_trackers: dict[int, dict[str, int]] = {} +_lock = threading.Lock() + +# The root greenlet id for the current parse task. Stored so that child +# greenlets (which cannot inherit ContextVar) can be associated back to +# their root. We walk the greenlet parent chain to find the id. +_root_ids: dict[int, int] = {} + + +def _current_greenlet_id() -> int: + try: + import gevent + return id(gevent.getcurrent()) + except ImportError: + import threading as _threading + return _threading.get_ident() + + +def _find_root_id() -> int | None: + """Walk up the greenlet parent chain to find a registered root id.""" + gid = _current_greenlet_id() + # Check self first (fastest path for the root greenlet) + if gid in _trackers: + return gid + # Check if this greenlet was registered as a child + if gid in _root_ids: + return _root_ids[gid] + # Walk parent chain + try: + import gevent + g = gevent.getcurrent() + while g is not None: + pid = id(g) + if pid in _trackers: + # Cache for future lookups + _root_ids[gid] = pid + return pid + g = getattr(g, 'parent', None) + except ImportError: + pass + return None + + +def init_token_tracker() -> dict[str, int]: + """Create a new token accumulator for the current parse task. + + Must be called from the root greenlet of the task (i.e. from + ``execute_document_parse``). Returns the mutable dict that will + accumulate all token usage for the lifetime of this task. + """ + gid = _current_greenlet_id() + tracker: dict[str, int] = { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + } + with _lock: + _trackers[gid] = tracker + return tracker + + +def cleanup_token_tracker() -> None: + """Remove the tracker for the current greenlet. Call after parsing.""" + gid = _current_greenlet_id() + with _lock: + _trackers.pop(gid, None) + # Also clean any child mappings that pointed to this root + stale = [k for k, v in _root_ids.items() if v == gid] + for k in stale: + del _root_ids[k] + + +def record_tokens(usage: dict[str, int]) -> None: + """Accumulate token usage into the current task's tracker. + + Safe to call from any greenlet (root or child). If no tracker is + active (e.g. called from retrieval or outside a parse task), this + is a silent no-op. + """ + root = _find_root_id() + if root is None: + return + tracker = _trackers.get(root) + if tracker is None: + return + with _lock: + tracker["prompt_tokens"] += usage.get("prompt_tokens", 0) + tracker["completion_tokens"] += usage.get("completion_tokens", 0) + tracker["total_tokens"] += usage.get("total_tokens", 0) diff --git a/packages/shared-python/shared/services/retrieval/agentic/__init__.py b/packages/shared-python/shared/services/retrieval/agentic/__init__.py index 1a814959..961ba1da 100644 --- a/packages/shared-python/shared/services/retrieval/agentic/__init__.py +++ b/packages/shared-python/shared/services/retrieval/agentic/__init__.py @@ -5,7 +5,7 @@ Phase 2: Per-document iterative navigation (navigate_step) Phase 3: Render evidence text for downstream agents -Each navigate_step decides action (NAVIGATE/STOP), optional asset tools, -and section selections in a single LLM call. KNOWHERE does not generate -final answers; downstream agents decide whether the evidence is sufficient. +Each navigate_step chooses one observe-act action plus optional collection +side effects. KNOWHERE does not generate final answers; downstream agents +decide whether the evidence is sufficient. """ diff --git a/packages/shared-python/shared/services/retrieval/agentic/core/budget.py b/packages/shared-python/shared/services/retrieval/agentic/core/budget.py index cc8e58c7..ff7c7f2f 100644 --- a/packages/shared-python/shared/services/retrieval/agentic/core/budget.py +++ b/packages/shared-python/shared/services/retrieval/agentic/core/budget.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio +import copy from dataclasses import dataclass from typing import Any, Literal @@ -10,9 +11,69 @@ BudgetStatus = Literal["HEALTHY", "TIGHT", "CRITICAL", "EXHAUSTED"] +def status_from_usage(*, remaining: int, used_pct: int) -> BudgetStatus: + if remaining <= 0: + return "EXHAUSTED" + if used_pct >= 80: + return "CRITICAL" + if used_pct >= 50: + return "TIGHT" + return "HEALTHY" + + +def project_budget_snapshot( + snapshot: dict | None, + *, + pool: BudgetPoolName, + additional_tokens: int, +) -> dict | None: + """Return a snapshot projected after an estimated upcoming token cost.""" + if not snapshot: + return snapshot + adjusted = copy.deepcopy(snapshot) + pool_data = adjusted.get(pool) + if not isinstance(pool_data, dict): + return adjusted + + capacity = int(pool_data.get("capacity") or 0) + used = int(pool_data.get("used") or 0) + reserved = int(pool_data.get("reserved") or 0) + projected_used_total = max(0, used + reserved + max(int(additional_tokens), 0)) + used_pct = ( + int(round(projected_used_total * 100 / capacity)) + if capacity > 0 else 100 + ) + remaining = max(capacity - projected_used_total, 0) + pool_data["used_pct"] = used_pct + pool_data["remaining"] = remaining + pool_data["status"] = status_from_usage( + remaining=remaining, + used_pct=used_pct, + ) + return adjusted + + +def budget_status_from_snapshot( + snapshot: dict[str, Any] | None, + *, + pool: BudgetPoolName = "planning", +) -> str: + """Read a pool status from a serialized budget snapshot.""" + if not isinstance(snapshot, dict): + return "UNKNOWN" + pool_data = snapshot.get(pool) + if not isinstance(pool_data, dict): + return "UNKNOWN" + return str(pool_data.get("status") or "UNKNOWN") + + class BudgetExceeded(Exception): """Raised when a planned LLM call cannot reserve budget.""" + def __init__(self, message: str, *, details: dict[str, Any] | None = None) -> None: + super().__init__(message) + self.details = details or {} + @dataclass class BudgetPool: @@ -29,7 +90,7 @@ def remaining(self) -> int: def used_pct(self) -> int: if self.capacity <= 0: return 100 - return min(100, int(round((self.used + self.reserved) * 100 / self.capacity))) + return int(round((self.used + self.reserved) * 100 / self.capacity)) class BudgetLedger: @@ -65,20 +126,17 @@ def __init__( self.explored_chunks = 0 self.explored_docs = 0 self.trimmed_paths: list[dict[str, Any]] = [] + self._overdraft_events: list[dict[str, Any]] = [] def remaining(self, pool: BudgetPoolName) -> int: return self._pools[pool].remaining def status(self, pool: BudgetPoolName) -> BudgetStatus: pool_state = self._pools[pool] - if pool_state.remaining <= 0: - return "EXHAUSTED" - used_pct = pool_state.used_pct - if used_pct >= 80: - return "CRITICAL" - if used_pct >= 50: - return "TIGHT" - return "HEALTHY" + return status_from_usage( + remaining=pool_state.remaining, + used_pct=pool_state.used_pct, + ) async def allocate_doc_caps(self, doc_chunks: dict[str, int]) -> None: """Allocate planning soft caps by document chunk counts.""" @@ -107,21 +165,71 @@ async def try_reserve( *, priority: Literal["normal", "low"] = "normal", ) -> bool: + reservation = await self.reserve( + pool, + est, + doc_id=doc_id, + priority=priority, + allow_overdraft=False, + ) + return bool(reservation.get("reserved")) + + async def reserve( + self, + pool: BudgetPoolName, + est: int, + doc_id: str | None = None, + *, + priority: Literal["normal", "low"] = "normal", + allow_overdraft: bool = False, + overdraft_reason: str = "", + ) -> dict[str, Any]: est = max(int(est), 0) if est == 0: - return True + return {"reserved": True, "overdraft": False, "failure": None} async with self._lock: pool_state = self._pools[pool] if priority == "low" and self.status(pool) == "CRITICAL": - return False + failure = self._reserve_failure(pool, est, doc_id, "low_priority_critical") + return {"reserved": False, "overdraft": False, "failure": failure} + + failure_reason = "" if pool_state.remaining < est: - return False + failure_reason = "pool_remaining_lt_est" + + # Per-doc cap enforcement: prevent one document from consuming + # the entire planning pool. + doc_remaining: int | None = None + if pool == "planning" and doc_id and doc_id in self._doc_caps: + doc_remaining = self._doc_caps[doc_id] - ( + self._doc_used.get(doc_id, 0) + + self._doc_reserved.get(doc_id, 0) + ) + if doc_remaining < est: + failure_reason = "doc_remaining_lt_est" + + if failure_reason: + failure = self._reserve_failure(pool, est, doc_id, failure_reason) + if not allow_overdraft or pool != "planning": + return {"reserved": False, "overdraft": False, "failure": failure} + self._record_overdraft( + pool=pool, + est=est, + doc_id=doc_id, + reason=overdraft_reason or failure_reason, + failure=failure, + ) pool_state.reserved += est if pool == "planning" and doc_id: self._doc_reserved[doc_id] = self._doc_reserved.get(doc_id, 0) + est - return True + return { + "reserved": True, + "overdraft": bool(failure_reason), + "failure": self._reserve_failure(pool, est, doc_id, failure_reason) + if failure_reason else None, + } async def commit( self, @@ -137,7 +245,7 @@ async def commit( pool_state = self._pools[pool] reserved_delta = min(est, pool_state.reserved) pool_state.reserved -= reserved_delta - pool_state.used = min(pool_state.capacity, pool_state.used + actual) + pool_state.used = max(0, pool_state.used + actual) if pool == "planning" and doc_id: doc_reserved = min(est, self._doc_reserved.get(doc_id, 0)) @@ -183,10 +291,13 @@ def snapshot(self) -> dict[str, object]: "reserved": pool.reserved, "remaining": pool.remaining, "used_pct": pool.used_pct, + "overdraft": max(pool.used + pool.reserved - pool.capacity, 0), "status": self.status(name), } for name, pool in self._pools.items() } + if self._overdraft_events: + snapshot["overdraft_events"] = list(self._overdraft_events) snapshot.update({ "total_chunks": self.total_chunks, "total_docs": self.total_docs, @@ -197,3 +308,54 @@ def snapshot(self) -> dict[str, object]: "trimmed_paths": list(self.trimmed_paths), }) return snapshot + + def _reserve_failure( + self, + pool: BudgetPoolName, + est: int, + doc_id: str | None, + reason: str, + ) -> dict[str, Any]: + pool_state = self._pools[pool] + details: dict[str, Any] = { + "reason": reason, + "pool": pool, + "prompt_est": max(int(est), 0), + "pool_capacity": pool_state.capacity, + "pool_used": pool_state.used, + "pool_reserved": pool_state.reserved, + "pool_remaining": pool_state.remaining, + } + if pool == "planning" and doc_id and doc_id in self._doc_caps: + doc_used = self._doc_used.get(doc_id, 0) + doc_reserved = self._doc_reserved.get(doc_id, 0) + details.update({ + "doc_id": doc_id, + "doc_cap": self._doc_caps[doc_id], + "doc_used": doc_used, + "doc_reserved": doc_reserved, + "doc_remaining": max(self._doc_caps[doc_id] - doc_used - doc_reserved, 0), + }) + return details + + def _record_overdraft( + self, + *, + pool: BudgetPoolName, + est: int, + doc_id: str | None, + reason: str, + failure: dict[str, Any], + ) -> None: + pool_shortfall = max(int(est) - int(failure.get("pool_remaining") or 0), 0) + doc_shortfall = 0 + if "doc_remaining" in failure: + doc_shortfall = max(int(est) - int(failure.get("doc_remaining") or 0), 0) + self._overdraft_events.append({ + "pool": pool, + "doc_id": doc_id, + "prompt_est": max(int(est), 0), + "shortfall": max(pool_shortfall, doc_shortfall), + "reason": reason, + "failure": failure, + }) diff --git a/packages/shared-python/shared/services/retrieval/agentic/core/runtime.py b/packages/shared-python/shared/services/retrieval/agentic/core/runtime.py index 3f1563b4..91d7a9a8 100644 --- a/packages/shared-python/shared/services/retrieval/agentic/core/runtime.py +++ b/packages/shared-python/shared/services/retrieval/agentic/core/runtime.py @@ -1,6 +1,7 @@ """Runtime setup helpers for agentic retrieval.""" from __future__ import annotations + import json import os from typing import Any @@ -66,6 +67,8 @@ async def call( pool: BudgetPoolName, doc_id: str | None = None, priority: str = "normal", + allow_overdraft: bool = False, + overdraft_reason: str = "", ) -> str: ledger = self._state.ledger if ledger is None: @@ -73,14 +76,19 @@ async def call( prompt_text = _stringify_llm_input(prompt) est = estimate_tokens(prompt_text) - reserved = await ledger.try_reserve( + reservation = await ledger.reserve( pool, est, doc_id=doc_id, priority="low" if priority == "low" else "normal", + allow_overdraft=allow_overdraft, + overdraft_reason=overdraft_reason, ) - if not reserved: - raise BudgetExceeded(f"{pool} budget exhausted") + if not reservation.get("reserved"): + raise BudgetExceeded( + f"{pool} budget exhausted", + details=reservation.get("failure") or {}, + ) try: response = await llm_fn(prompt) @@ -89,7 +97,7 @@ async def call( raise usage = current_llm_usage.get() or {} - actual = int(usage.get("prompt_tokens") or est) + actual = _extract_actual_tokens(usage, est) await ledger.commit(pool, actual=actual, est=est, doc_id=doc_id) return response @@ -105,6 +113,8 @@ def for_document( *, doc_id: str, step: int = 0, + allow_overdraft: bool = False, + overdraft_reason: str = "", ) -> LLMFn: async def _call(prompt: Any) -> str: return await self.call( @@ -112,7 +122,9 @@ async def _call(prompt: Any) -> str: prompt, pool="planning", doc_id=doc_id, - priority="low" if step >= 4 else "normal", + priority="normal", + allow_overdraft=allow_overdraft, + overdraft_reason=overdraft_reason, ) return _call @@ -136,6 +148,20 @@ async def _call(prompt: Any) -> str: return _call +def _extract_actual_tokens(usage: dict, est: int) -> int: + """Derive actual token consumption from LLM usage dict. + + Checks ``total_tokens`` first, then sums ``prompt_tokens`` and + ``completion_tokens``. Falls back to the pre-call estimate. + """ + total = usage.get("total_tokens") + if total: + return int(total) + prompt = int(usage.get("prompt_tokens") or 0) + completion = int(usage.get("completion_tokens") or 0) + return (prompt + completion) or est + + def _stringify_llm_input(prompt: Any) -> str: if isinstance(prompt, str): return prompt diff --git a/packages/shared-python/shared/services/retrieval/agentic/core/trace.py b/packages/shared-python/shared/services/retrieval/agentic/core/trace.py index 7925bcb0..b059a078 100644 --- a/packages/shared-python/shared/services/retrieval/agentic/core/trace.py +++ b/packages/shared-python/shared/services/retrieval/agentic/core/trace.py @@ -15,7 +15,11 @@ from loguru import logger from sqlalchemy.ext.asyncio import AsyncSession -from shared.services.retrieval.agentic.core.types import AgentRunConfig, ToolResult +from shared.services.retrieval.agentic.core.types import ( + AgentRunConfig, + DecisionTraceStep, + ToolResult, +) from shared.services.retrieval.settings import DEFAULT_TOP_K @@ -132,6 +136,27 @@ def record_step( 'created_at': _now_utc(), }) + def record_decision_trace_step(self, step: DecisionTraceStep) -> None: + """Buffer a DB trace row derived from the public decision trace step.""" + result_status = str(step.result.get("status") or "unknown") + self._steps.append({ + "step_index": len(self._steps), + "action_type": f"{step.phase}:{step.agent}:{step.decision.get('action', '')}", + "action_input": { + "public_step_index": step.step_index, + "decision": step.decision, + "scope": step.scope, + "document_id": step.document_id, + "parent_step_index": step.parent_step_index, + }, + "observation_status": result_status, + "observation_payload_keys": list(step.observation.keys()), + "latency_ms": step.elapsed_ms or 0, + "error": step.result.get("error"), + "tokens_used": 0, + "created_at": _now_utc(), + }) + def record_budget_stop(self, reason: str) -> None: """Record that the agent loop stopped due to a budget guard.""" self._steps.append({ diff --git a/packages/shared-python/shared/services/retrieval/agentic/core/types.py b/packages/shared-python/shared/services/retrieval/agentic/core/types.py index f0d6b9f3..8d75618e 100644 --- a/packages/shared-python/shared/services/retrieval/agentic/core/types.py +++ b/packages/shared-python/shared/services/retrieval/agentic/core/types.py @@ -8,7 +8,7 @@ import time from dataclasses import dataclass, field -from typing import Any +from typing import Any, Literal from shared.services.retrieval.agentic.core.budget import BudgetLedger @@ -183,6 +183,53 @@ def merge(self, other: 'DocTreeNode') -> None: self.reparent_leaf_content() +NavAction = Literal[ + "EXPAND", + "BACK", + "FINISH", + "SEARCH_IMAGES", + "SEARCH_TABLES", + "ERROR", +] + + +@dataclass +class DecisionTraceStep: + """Uniform observe-act-result trace entry exposed to downstream agents.""" + + step_index: int + agent: str + phase: str + observation: dict[str, Any] + decision: dict[str, Any] + result: dict[str, Any] + parent_step_index: int | None = None + document_id: str | None = None + document: str | None = None + scope: str | None = None + budget: dict[str, Any] | None = None + elapsed_ms: int | None = None + + def to_dict(self) -> dict[str, Any]: + data: dict[str, Any] = { + "step_index": self.step_index, + "agent": self.agent, + "parent_step_index": self.parent_step_index, + "phase": self.phase, + "document_id": self.document_id, + "document": self.document, + "scope": self.scope, + "observation": self.observation, + "decision": self.decision, + "result": self.result, + } + if self.budget is not None: + data["budget"] = self.budget + if self.elapsed_ms is not None: + data["elapsed_ms"] = self.elapsed_ms + return data + + @dataclass class NavigateStepResult: """Return type for navigate_step — Collector Agent model. @@ -190,17 +237,27 @@ class NavigateStepResult: Each step returns: - ``collect``: paths to add to the evidence collection (full hydration) - ``drill``: paths to explore deeper in subsequent steps - - ``action``: navigation direction — DRILL/BACK/STOP - - ``tools``: optional asset tools (FIND_IMAGES/FIND_TABLES) + - ``action``: one explicit action — EXPAND/BACK/FINISH/SEARCH_*/ERROR - ``node``: outline tree node for rendering context - ``reason``: LLM reasoning for trace + - ``error_reason``: set when action is ERROR — distinguishes system + errors from intentional FINISH so callers can decide retry vs skip. + - ``search_assets_params``: parameters for SEARCH_IMAGES/SEARCH_TABLES + - ``observation``: what the navigator saw before choosing the action + - ``result_status`` / ``result_note``: executor-visible action validation """ - action: str = "STOP" # DRILL | BACK | STOP + action: NavAction = "FINISH" collect: list[dict[str, Any]] = field(default_factory=list) drill: list[dict[str, Any]] = field(default_factory=list) + back_to: str | None = None # BACK target ancestor path (None = root) tools: list[str] = field(default_factory=list) node: DocTreeNode = field(default_factory=DocTreeNode) reason: str = "" + error_reason: str | None = None + search_assets_params: dict[str, Any] | None = None + observation: dict[str, Any] = field(default_factory=dict) + result_status: str = "ok" + result_note: str | None = None @property def drill_into(self) -> str | None: @@ -209,14 +266,27 @@ def drill_into(self) -> str | None: @property def is_terminal(self) -> bool: - """True when navigation should stop (STOP or empty collect+drill).""" - return self.action == "STOP" or (not self.collect and not self.drill) + """True only for explicit terminal actions.""" + return self.action in ("FINISH", "ERROR") + + @staticmethod + def stop(scope_path: str | None = None, *, reason: str = "") -> 'NavigateStepResult': + return NavigateStepResult( + action="FINISH", + node=DocTreeNode.empty(scope_path), + reason=reason, + ) @staticmethod - def stop(scope_path: str | None = None) -> 'NavigateStepResult': + def error(scope_path: str | None = None, *, reason: str = "") -> 'NavigateStepResult': + """Return an ERROR result distinguishable from intentional FINISH.""" return NavigateStepResult( - action="STOP", + action="ERROR", node=DocTreeNode.empty(scope_path), + reason=f"navigation_error: {reason[:200]}" if reason else "navigation_error", + error_reason=reason[:500] if reason else "unknown_error", + result_status="error", + result_note=reason[:500] if reason else "unknown_error", ) @@ -243,7 +313,7 @@ class AgenticResult: - ``router_used``: routing path identifier - ``budget_snapshot``: final budget ledger state at run completion - ``stop_reason``: why the run terminated (evidence_only / - latency_budget / context_budget / no_llm / etc.) + budget / latency / max_steps / error / llm_stop) - ``failure_reason``: fatal retrieval failure reason, if any. - ``decision_trace``: per-step navigation decisions with reasons, exposed to downstream agents for stop/retry/modify-query decisions. diff --git a/packages/shared-python/shared/services/retrieval/agentic/discovery/phase.py b/packages/shared-python/shared/services/retrieval/agentic/discovery/phase.py index e5110a70..4efa8ab4 100644 --- a/packages/shared-python/shared/services/retrieval/agentic/discovery/phase.py +++ b/packages/shared-python/shared/services/retrieval/agentic/discovery/phase.py @@ -4,10 +4,8 @@ from typing import Any from loguru import logger -from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from shared.models.database.document import Document from shared.services.retrieval.agentic import tools from shared.services.retrieval.agentic.core.budget import BudgetExceeded from shared.services.retrieval.agentic.core.trace import TraceRecorder @@ -168,8 +166,6 @@ async def _select_documents( ) _append_selected_docs(state, kg_result) - if not state.selected_docs and state.discovery_top_doc_ids: - await _append_discovery_hints(db, state=state) logger.info( f" agentic step {state.step_count}: kg_document_select " @@ -178,34 +174,6 @@ async def _select_documents( ) -async def _append_discovery_hints(db: AsyncSession, *, state: AgentState) -> None: - hint_ids = [ - doc_id - for doc_id in state.discovery_top_doc_ids - if doc_id not in state.ever_explored_doc_ids - ] - if not hint_ids: - return - doc_stmt = ( - select(Document.document_id, Document.source_file_name, Document.current_job_result_id) - .where(Document.document_id.in_(hint_ids)) - ) - doc_result = await db.execute(doc_stmt) - for doc_id, source_file_name, job_result_id in doc_result.all(): - state.selected_docs.append( - CandidateDoc( - document_id=doc_id, - source_file_name=source_file_name or doc_id, - confidence=0.5, - reason="discovery_hint (KG returned 0)", - source="discovery_hint", - ) - ) - state.doc_id_to_name[doc_id] = source_file_name or doc_id - if job_result_id: - state.doc_job_map[doc_id] = job_result_id - - def _append_selected_docs(state: AgentState, kg_result: ToolResult) -> None: if kg_result.status != "selected_docs": return diff --git a/packages/shared-python/shared/services/retrieval/agentic/discovery/selection.py b/packages/shared-python/shared/services/retrieval/agentic/discovery/selection.py deleted file mode 100644 index 45ae3010..00000000 --- a/packages/shared-python/shared/services/retrieval/agentic/discovery/selection.py +++ /dev/null @@ -1,248 +0,0 @@ -"""Post-navigation discovery selection for agentic retrieval.""" -from __future__ import annotations - -import time -from dataclasses import dataclass, field -from typing import Any - -from loguru import logger -from sqlalchemy.ext.asyncio import AsyncSession - -from shared.services.retrieval.agentic.core.budget import BudgetExceeded -from shared.services.retrieval.agentic.prompts import ( - DISCOVERY_SELECT_PROMPT, - format_budget_block, - parse_action_response, -) -from shared.services.retrieval.agentic.navigation.selection_hydration import ( - hydrate_chunk_refs_into_node, - hydrate_path_selections_into_node, -) -from shared.services.retrieval.agentic.core.types import DocTreeNode -from shared.services.retrieval.search.lexical_text import normalize_section_path -from shared.services.retrieval.llm_adapter import LLMFn - - -_MAX_DISCOVERY_PER_DOC = 10 - - -@dataclass -class DiscoverySelectResult: - """Result of discovery_select_step: node + dedup metadata.""" - - node: DocTreeNode - excluded_hints: list[dict[str, str]] = field(default_factory=list) - candidate_count: int = 0 - - -async def discovery_select_step( - db: AsyncSession, - *, - document_id: str, - query: str, - llm_fn: LLMFn, - user_id: str, - namespace: str, - doc_name: str = "", - discovery_hints: list[dict[str, Any]], - exclude_paths: set[str] | None = None, - budget_snapshot: dict | None = None, -) -> DiscoverySelectResult: - """Select and hydrate discovery-found sections after BFS navigation.""" - node = DocTreeNode(scope_path=None) - if not discovery_hints: - return DiscoverySelectResult(node=node) - - hints = discovery_hints[:_MAX_DISCOVERY_PER_DOC] - - t0 = time.monotonic() - try: - hint_lines, hint_by_path, excluded_hints = _project_discovery_hints( - hints, - exclude_paths=exclude_paths, - ) - if excluded_hints: - logger.info( - f' discovery_select_step doc="{doc_name}": ' - f"{len(excluded_hints)} hints excluded by navigation COLLECT: " - + ", ".join( - f'"{h["path"]}" (covered by "{h["covered_by"]}")' - for h in excluded_hints[:3] - ) - + (f" (+{len(excluded_hints) - 3} more)" if len(excluded_hints) > 3 else "") - ) - if not hint_lines: - return DiscoverySelectResult( - node=node, - excluded_hints=excluded_hints, - candidate_count=len(hints), - ) - - selections: list[dict[str, Any]] = [] - if hint_lines: - prompt = _build_discovery_selection_prompt( - document_id=document_id, - doc_name=doc_name, - query=query, - hint_lines=hint_lines, - budget_snapshot=budget_snapshot, - ) - response = await llm_fn(prompt) - parsed = parse_action_response(response) - selections = parsed.get("selections", []) - - logger.info( - f' discovery_select_step doc="{doc_name}": ' - f"hints={len(hints)} selections={len(selections)}" - ) - - path_selections, chunk_refs = _build_discovery_path_selections( - selections=selections, - hint_by_path=hint_by_path, - document_id=document_id, - node=node, - ) - await hydrate_chunk_refs_into_node( - db, - node=node, - refs=chunk_refs, - user_id=user_id, - namespace=namespace, - document_id=document_id, - ) - await hydrate_path_selections_into_node( - db, - node=node, - path_selections=path_selections, - user_id=user_id, - namespace=namespace, - document_id=document_id, - ) - - latency = int((time.monotonic() - t0) * 1000) - logger.info( - f" discovery_select_step done: hydrated={len(node.leaf_content)} " - f"latency={latency}ms" - ) - return DiscoverySelectResult( - node=node, - excluded_hints=excluded_hints, - candidate_count=len(hints), - ) - - except BudgetExceeded: - raise - except Exception as exc: - logger.error(f" discovery_select_step failed for doc={document_id}: {exc}") - return DiscoverySelectResult(node=node) - - -def _project_discovery_hints( - hints: list[dict[str, Any]], - *, - exclude_paths: set[str] | None, -) -> tuple[list[str], dict[str, dict], list[dict[str, str]]]: - """Project discovery hints into prompt lines, filtering excluded paths. - - Returns ``(hint_lines, hint_by_path, excluded_hints)`` where - *excluded_hints* records each path that was dropped and which - navigation-collected path covered it. - """ - exclude_set = { - normalize_section_path(path) - for path in (exclude_paths or set()) - if path - } - hint_lines: list[str] = [] - hint_by_path: dict[str, dict] = {} - excluded_hints: list[dict[str, str]] = [] - for hint in hints: - section_path = normalize_section_path(hint.get("section_path", "")) - if not section_path: - continue - covered_by = _find_covering_path(section_path, exclude_set) - if covered_by is not None: - excluded_hints.append({"path": section_path, "covered_by": covered_by}) - continue - if section_path in hint_by_path: - continue - - hint_by_path[section_path] = hint - summary = hint.get("summary", "") or "" - hint_lines.append(f'▸ path="{section_path}"') - if summary: - hint_lines.append(f" {summary[:300]}") - - return hint_lines, hint_by_path, excluded_hints - - -def _find_covering_path(path: str, exclude_set: set[str]) -> str | None: - """Return the exclude-set entry that covers *path*, or ``None``. - - A path is covered if it exactly matches an exclude entry, OR if any - exclude entry is a prefix of this path (i.e. the parent path was - already collected by navigation). - """ - if path in exclude_set: - return path - for excluded in exclude_set: - if path.startswith(excluded + " / "): - return excluded - return None - - -def _build_discovery_selection_prompt( - *, - document_id: str, - doc_name: str, - query: str, - hint_lines: list[str], - budget_snapshot: dict | None, -) -> str: - return DISCOVERY_SELECT_PROMPT.format( - doc_name=doc_name or document_id, - budget_block=format_budget_block(budget_snapshot), - items="\n".join(hint_lines), - query=query, - ) - - -def _build_discovery_path_selections( - *, - selections: list[dict[str, Any]], - hint_by_path: dict[str, dict], - document_id: str, - node: DocTreeNode, -) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: - valid_selections = [ - selection for selection in selections if selection["path"] in hint_by_path - ] - path_selections: list[dict[str, Any]] = [] - chunk_refs: list[dict[str, Any]] = [] - for selection in valid_selections: - path = selection["path"] - confidence = selection.get("confidence", 0.7) - node.confidence[path] = confidence - hint = hint_by_path[path] - if path == "Root": - chunk_id = str(hint.get("chunk_id") or "").strip() - if chunk_id: - chunk_refs.append({ - "document_id": document_id, - "chunk_id": chunk_id, - "section_path": path, - }) - continue - path_selections.append({ - "path": path, - "confidence": confidence, - "hydrate_mode": "self_only", - }) - continue - path_selections.append({ - "path": path, - "confidence": confidence, - "hydrate_mode": "self_only", - }) - - return path_selections, chunk_refs diff --git a/packages/shared-python/shared/services/retrieval/agentic/evidence/builder.py b/packages/shared-python/shared/services/retrieval/agentic/evidence/builder.py index c8187401..f2a59c3c 100644 --- a/packages/shared-python/shared/services/retrieval/agentic/evidence/builder.py +++ b/packages/shared-python/shared/services/retrieval/agentic/evidence/builder.py @@ -142,7 +142,7 @@ async def render_evidence( for doc_id, doc_tree in doc_trees.items(): # Only render if there is actual hydrated evidence (chunks collected # via COLLECT or discovery). Outline-only trees (e.g. navigation - # STOP with empty collect) must not leak into evidence_text. + # FINISH with empty collect must not leak into evidence_text. if not doc_tree.has_leaf_content(): continue doc_name = doc_id_to_name.get(doc_id, doc_id) diff --git a/packages/shared-python/shared/services/retrieval/agentic/evidence/renderer.py b/packages/shared-python/shared/services/retrieval/agentic/evidence/renderer.py index 42c543f7..21761b47 100644 --- a/packages/shared-python/shared/services/retrieval/agentic/evidence/renderer.py +++ b/packages/shared-python/shared/services/retrieval/agentic/evidence/renderer.py @@ -64,12 +64,15 @@ def min_sort(path: str) -> float: leaf_tag = " [Leaf]" if is_leaf else "" level_tag = f"[L{level}] " if level else "" + # Indent based on the section's own level relative to the tree depth, + # so L2 items are indented even when they sit in the root node. + item_indent = indent + " " * max(level - 1, 0) if level <= 1: - parts.append(f"{indent}▸ {level_tag}{title}{leaf_tag}") + parts.append(f"{item_indent}▸ {level_tag}{title}{leaf_tag}") else: - parts.append(f"{indent}└ {level_tag}{title}{leaf_tag}") + parts.append(f"{item_indent}└ {level_tag}{title}{leaf_tag}") - sub_indent = indent + " " + sub_indent = item_indent + " " if path in node.children: child = node.children[path] if path in node.leaf_content: diff --git a/packages/shared-python/shared/services/retrieval/agentic/navigation/actions.py b/packages/shared-python/shared/services/retrieval/agentic/navigation/actions.py new file mode 100644 index 00000000..3b151378 --- /dev/null +++ b/packages/shared-python/shared/services/retrieval/agentic/navigation/actions.py @@ -0,0 +1,689 @@ +"""Legal action projection for agentic document navigation.""" +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Literal + +from shared.services.retrieval.agentic.core.budget import budget_status_from_snapshot +from shared.services.retrieval.agentic.navigation.path_ledger import PathLedger +from shared.services.retrieval.search.lexical_text import normalize_section_path +from shared.utils.text_utils import truncate_content_preview + + +ActionKind = Literal[ + "EXPAND", + "COLLECT", + "BACK", + "SEARCH_IMAGES", + "SEARCH_TABLES", + "FINISH", +] + + +@dataclass(frozen=True) +class LegalAction: + id: str + action: ActionKind + path: str | None = None + target_scope: str | None = None + asset_type: str | None = None + note: str | None = None + source: str = "tree" + score: float = 0.0 + critical_expand: bool = False + + +@dataclass +class LegalActionSet: + by_id: dict[str, LegalAction] = field(default_factory=dict) + expand: list[LegalAction] = field(default_factory=list) + collect: list[LegalAction] = field(default_factory=list) + back: list[LegalAction] = field(default_factory=list) + search: list[LegalAction] = field(default_factory=list) + finish: LegalAction | None = None + + def add(self, action: LegalAction) -> None: + self.by_id[action.id] = action + if action.action == "EXPAND": + self.expand.append(action) + elif action.action == "COLLECT": + self.collect.append(action) + elif action.action == "BACK": + self.back.append(action) + elif action.action in ("SEARCH_IMAGES", "SEARCH_TABLES"): + self.search.append(action) + elif action.action == "FINISH": + self.finish = action + + def get(self, action_id: str | None) -> LegalAction | None: + if not action_id: + return None + return self.by_id.get(action_id) + + +def build_legal_actions( + *, + items: list[dict[str, Any]], + current_scope: str | None, + collected_paths: list[dict[str, Any]], + expanded_scopes: set[str], + discovery_hints: list[dict[str, Any]] | None = None, + rejected_paths: set[str] | None = None, + rejected_collect_paths: set[str] | None = None, + total_images: int, + total_tables: int, + disabled_asset_types: set[str] | None = None, + budget_snapshot: dict[str, Any] | None = None, +) -> LegalActionSet: + action_set = LegalActionSet() + covered_paths = _covered_paths(collected_paths) + outline_paths = _outline_paths(collected_paths) + budget_mode = budget_status_from_snapshot(budget_snapshot) + rejected = {PathLedger.normalize(path) for path in rejected_paths or set()} + rejected_collects = { + normalized + for path in rejected_collect_paths or set() + if (normalized := PathLedger.normalize(path)) + } + discovery_scores = _discovery_scores_by_path(discovery_hints or []) + scored_items = _score_items(items, discovery_scores) + ranked_items = _rank_items(scored_items) + expand_allowlist = _expand_allowlist( + ranked_items, + budget_mode=budget_mode, + limit=3, + ) + + expand_index = 1 + collect_index = 1 + for item in scored_items: + path = str(item.get("path") or "").strip() + if not path or path == "Root": + continue + if PathLedger.is_covered(path, covered_paths): + continue + if PathLedger.is_covered(path, rejected_collects): + continue + + action_set.add(LegalAction( + id=f"C{collect_index}", + action="COLLECT", + path=path, + target_scope=path, + note=( + "upgrade outline to full evidence" + if path in outline_paths + else _item_note(item) + ), + score=float(item.get("relevance_score") or 0.0), + )) + collect_index += 1 + + critical_expand = False + if budget_mode == "EXHAUSTED": + continue + if budget_mode == "CRITICAL": + if action_set.collect: + continue + critical_expand = True + if item.get("is_leaf"): + continue + if path == current_scope: + continue + if current_scope and PathLedger.is_ancestor(path, current_scope): + continue + if path in expanded_scopes: + continue + if ( + budget_mode == "TIGHT" + and path not in expand_allowlist + ): + continue + if path in rejected and not _path_has_discovery_signal(path, discovery_scores): + continue + + action_set.add(LegalAction( + id=f"E{expand_index}", + action="EXPAND", + path=path, + target_scope=path, + note=_item_note(item), + score=float(item.get("relevance_score") or 0.0), + critical_expand=critical_expand, + )) + expand_index += 1 + + discovery_index = 1 + seen_discovery_paths: set[str] = set() + for hint in sorted( + discovery_hints or [], + key=lambda item: float(item.get("discovery_score") or 0.0), + reverse=True, + ): + path = normalize_section_path(str(hint.get("section_path") or "")) + if not path or path in seen_discovery_paths: + continue + seen_discovery_paths.add(path) + if PathLedger.is_covered(path, covered_paths): + continue + # TODO: allow tool-specific LLM adjudicators to revive rejected + # collects when validity cannot be determined structurally. + if PathLedger.is_covered(path, rejected_collects): + continue + if any(action.path == path for action in action_set.collect): + continue + action_set.add(LegalAction( + id=f"D{discovery_index}", + action="COLLECT", + path=path, + target_scope=path, + note=_discovery_note(hint), + source="discovery", + score=float(hint.get("discovery_score") or 0.0), + )) + discovery_index += 1 + + search_allowed = budget_mode not in ("CRITICAL", "EXHAUSTED") + disabled_assets = {item.lower() for item in disabled_asset_types or set()} + if ( + search_allowed + and "image" not in disabled_assets + and total_images > 0 + and _asset_search_worthwhile(ranked_items, "image") + ): + action_set.add(LegalAction( + id="S1", + action="SEARCH_IMAGES", + asset_type="image", + note=f"{total_images} images available in current scope", + )) + if ( + search_allowed + and "table" not in disabled_assets + and total_tables > 0 + and _asset_search_worthwhile(ranked_items, "table") + ): + action_set.add(LegalAction( + id="S2", + action="SEARCH_TABLES", + asset_type="table", + note=f"{total_tables} tables available in current scope", + )) + + if current_scope and budget_mode != "EXHAUSTED": + back_index = 1 + for target in PathLedger.back_targets(current_scope): + label = target if target is not None else "root" + action_set.add(LegalAction( + id=f"B{back_index}", + action="BACK", + path=target, + target_scope=target, + note=f"return to {label}", + )) + back_index += 1 + + action_set.add(LegalAction( + id="F1", + action="FINISH", + note="finish this document", + )) + return action_set + + +def format_agent_state_block( + *, + current_scope: str | None, + query_intent: str, + expanded_scopes: set[str], + rejected_paths: set[str], + collected_paths: list[dict[str, Any]], + rejected_collect_paths: set[str] | None = None, + prior_tool_result: dict[str, Any] | None, + search_context: str, + budget_snapshot: dict[str, Any] | None, +) -> str: + lines = [ + "=== Agent State ===", + f"Current scope: {current_scope or 'root'}", + f"Advisory query intent: {query_intent or 'UNKNOWN'}", + _format_budget_state(budget_snapshot), + ] + budget_mode = budget_status_from_snapshot(budget_snapshot) + if budget_mode == "CRITICAL": + lines.append( + "Budget policy: exploration actions are closed; collect the best visible " + "evidence before FINISH." + ) + elif budget_mode == "EXHAUSTED": + lines.append( + "Budget policy: planning budget is exhausted or in overdraft. Do not " + "explore or search again. Use the current observation and tool results " + "to decide FINISH, or collect only indispensable visible evidence." + ) + if expanded_scopes: + lines.append("Expanded scopes:") + for path in sorted(expanded_scopes): + lines.append(f' - "{path}"') + else: + lines.append("Expanded scopes: none") + rejected_collects = set(rejected_collect_paths or set()) + low_value_rejected = set(rejected_paths) - rejected_collects + if low_value_rejected: + lines.append("Low-value scopes avoided unless revived by discovery:") + for path in sorted(low_value_rejected): + lines.append(f' - "{path}"') + if rejected_collect_paths: + lines.append("Collects rejected by tool reconciliation:") + for path in sorted(rejected_collect_paths): + lines.append(f' - "{path}"') + + full_paths, outline_paths = _dedupe_collection_modes(collected_paths) + if full_paths or outline_paths: + if full_paths: + lines.append(f"Full evidence collected: {len(full_paths)} item(s)") + for path in full_paths: + lines.append(f' - "{path}"') + else: + lines.append("Full evidence collected: none") + if outline_paths: + lines.append( + f"Outline-only evidence: {len(outline_paths)} item(s) " + "(structure only; not hydrated as full chunks)" + ) + for path in outline_paths: + lines.append(f' - "{path}"') + else: + lines.append("Collected evidence: none") + + if prior_tool_result: + lines.append(f"Last tool result: {_compact_dict(prior_tool_result)}") + if search_context: + lines.append("Tool observation:") + lines.append(search_context.strip()) + lines.append("=== End Agent State ===") + return "\n".join(lines) + + +def format_actionable_observation( + *, + items: list[dict[str, Any]], + action_set: LegalActionSet, + max_chars: int = 20000, +) -> tuple[str, bool]: + """Render visible document state and legal action affordances once.""" + if not items: + return "(no visible sections)", False + + full_text = _render_actionable_items( + items=items, + action_set=action_set, + include_summary=True, + ) + if len(full_text) <= max_chars: + return full_text, False + + slim_text = _render_actionable_items( + items=items, + action_set=action_set, + include_summary=False, + ) + return slim_text[:max_chars], True + + +def _render_actionable_items( + *, + items: list[dict[str, Any]], + action_set: LegalActionSet, + include_summary: bool, +) -> str: + collect_by_path = { + action.path: action + for action in action_set.collect + if action.path + } + expand_by_path = { + action.path: action + for action in action_set.expand + if action.path + } + lines = [ + "=== Actionable Observation ===", + "Each visible section appears once. Choose action IDs attached to the relevant line.", + ] + for item in items: + lines.extend(_render_actionable_item( + item=item, + collect_action=collect_by_path.get(str(item.get("path") or "")), + expand_action=expand_by_path.get(str(item.get("path") or "")), + include_summary=include_summary, + )) + + discovery_lines = _format_discovery_actions(action_set) + if discovery_lines: + lines.append("") + lines.append("Discovery hints:") + lines.extend(discovery_lines) + + global_actions = _format_global_actions(action_set) + if global_actions: + lines.append("") + lines.append("Global actions:") + lines.extend(global_actions) + lines.append("=== End Actionable Observation ===") + return "\n".join(lines) + + +def _render_actionable_item( + *, + item: dict[str, Any], + collect_action: LegalAction | None, + expand_action: LegalAction | None, + include_summary: bool, +) -> list[str]: + level = int(item.get("level", 1) or 1) + show_summary = bool(item.get("show_summary", True)) + has_actions = collect_action is not None or expand_action is not None + show_details = show_summary or has_actions + path = str(item.get("path") or "") + summary = str(item.get("summary") or "") + is_leaf = bool(item.get("is_leaf", False)) + indent = " " * max(level - 1, 0) + prefix = "▸" if level == 1 else "└" + level_tag = f"depth={level}" + counts = _format_counts(item) if show_details else "" + tokens = _format_token_estimate(item) if show_details else "" + leaf = " [Leaf]" if is_leaf else "" + actions = _format_node_actions( + collect_action=collect_action, + expand_action=expand_action, + ) + + lines = [ + f'{indent}{prefix} {level_tag} path="{path}"{counts}{tokens}{leaf} actions: {actions}' + ] + if include_summary and show_details and summary: + display_summary = _enrich_section_covers_summary(summary) + clipped = truncate_content_preview(display_summary, head=120, tail=0) + lines.append(f"{indent} summary: {clipped}") + return lines + + +def _format_node_actions( + *, + collect_action: LegalAction | None, + expand_action: LegalAction | None, +) -> str: + actions: list[str] = [] + if collect_action: + collect_name = ( + "collect_full" + if collect_action.note == "upgrade outline to full evidence" + else "collect" + ) + actions.append(f"{collect_name}={collect_action.id}") + if expand_action: + actions.append(f"expand={expand_action.id}") + return ", ".join(actions) if actions else "none" + + +def _format_discovery_actions(action_set: LegalActionSet) -> list[str]: + lines: list[str] = [] + for action in action_set.collect: + if action.source != "discovery" or not action.path: + continue + note = f" | {action.note}" if action.note else "" + lines.append(f' {action.id} -> "{action.path}"{note}') + return lines + + +def _format_global_actions(action_set: LegalActionSet) -> list[str]: + lines: list[str] = [] + for action in action_set.search: + if action.action == "SEARCH_IMAGES": + lines.append(f" search_images={action.id} ({action.note})") + elif action.action == "SEARCH_TABLES": + lines.append(f" search_tables={action.id} ({action.note})") + for action in action_set.back: + target = action.target_scope or "root" + lines.append(f" back={action.id} -> {target}") + if action_set.finish: + lines.append(f" finish={action_set.finish.id}") + return lines + + +def _format_counts(item: dict[str, Any]) -> str: + parts: list[str] = [] + chunk_count = int(item.get("chunk_count") or 0) + image_count = int(item.get("image_count") or 0) + table_count = int(item.get("table_count") or 0) + if chunk_count: + parts.append(f"text={chunk_count}") + if image_count: + parts.append(f"image={image_count}") + if table_count: + parts.append(f"table={table_count}") + return f' [{" ".join(parts)}]' if parts else "" + + +def _format_token_estimate(item: dict[str, Any]) -> str: + total_chars = int(item.get("total_chars") or 0) + if total_chars <= 0: + return "" + tokens = total_chars / 2 + if tokens >= 1000: + return f" ~{tokens / 1000:.1f}k tokens" + return f" ~{int(tokens)} tokens" + + +def _enrich_section_covers_summary(summary: str) -> str: + prefix = "This section covers: " + if not summary.startswith(prefix): + return summary + body = summary[len(prefix):] + sub_sections = [s.strip() for s in body.split(", ") if s.strip()] + return f"This section covers {len(sub_sections)} sub-sections: {body}" + + +def _covered_paths(collected_paths: list[dict[str, Any]]) -> set[str]: + return { + PathLedger.normalize(str(item.get("path") or "")) + for item in collected_paths + if item.get("path") and item.get("hydrate_mode") != "outline" + } + + +def _outline_paths(collected_paths: list[dict[str, Any]]) -> set[str]: + return { + str(item.get("path") or "") + for item in collected_paths + if item.get("path") and item.get("hydrate_mode") == "outline" + } + + +def _dedupe_collection_modes( + collected_paths: list[dict[str, Any]], +) -> tuple[list[str], list[str]]: + full: set[str] = set() + outline: set[str] = set() + for item in collected_paths: + path = str(item.get("path") or "") + if not path: + continue + if item.get("hydrate_mode") == "outline": + outline.add(path) + else: + full.add(path) + outline -= full + return sorted(full), sorted(outline) + + +def _discovery_note(hint: dict[str, Any]) -> str | None: + summary = str(hint.get("summary") or "").strip() + score = float(hint.get("discovery_score") or 0.0) + score_note = f"score={score:.2f}" if score > 0 else "" + if summary: + clipped = truncate_content_preview(summary, head=120, tail=0) + return f"{clipped} {score_note}".strip() + chunk_type = str(hint.get("chunk_type") or "").strip() + if chunk_type: + return f"bottom-discovery hit type={chunk_type} {score_note}".strip() + return f"bottom-discovery hit {score_note}".strip() + + +def _format_budget_state(snapshot: dict[str, Any] | None) -> str: + if not isinstance(snapshot, dict): + return "Budget mode: UNKNOWN" + planning = snapshot.get("planning") + if not isinstance(planning, dict): + return "Budget mode: UNKNOWN" + status = str(planning.get("status") or "UNKNOWN") + used_pct = planning.get("used_pct") + remaining = planning.get("remaining") + capacity = planning.get("capacity") + overdraft = int(planning.get("overdraft") or 0) + overdraft_note = f", overdraft={overdraft}" if overdraft > 0 else "" + if used_pct is None: + return f"Budget mode: {status}" + if remaining is not None and capacity: + return ( + f"Budget mode: {status} ({used_pct}% used, " + f"{remaining}/{capacity} tokens remaining{overdraft_note})" + ) + return f"Budget mode: {status} ({used_pct}% used{overdraft_note})" + + +def _item_note(item: dict[str, Any]) -> str | None: + parts: list[str] = [] + chunk_count = int(item.get("chunk_count") or 0) + image_count = int(item.get("image_count") or 0) + table_count = int(item.get("table_count") or 0) + if chunk_count: + parts.append(f"text={chunk_count}") + if image_count: + parts.append(f"image={image_count}") + if table_count: + parts.append(f"table={table_count}") + score = float(item.get("relevance_score") or 0.0) + if score > 0: + parts.append(f"relevance={score:.2f}") + if item.get("is_leaf"): + parts.append("leaf") + return " ".join(parts) if parts else None + + +def _discovery_scores_by_path( + discovery_hints: list[dict[str, Any]], +) -> dict[str, float]: + scores: dict[str, float] = {} + for hint in discovery_hints: + path = normalize_section_path(str(hint.get("section_path") or "")) + if not path: + continue + score = float(hint.get("discovery_score") or 0.0) + scores[path] = max(scores.get(path, 0.0), score) + return scores + + +def _score_items( + items: list[dict[str, Any]], + discovery_scores: dict[str, float], +) -> list[dict[str, Any]]: + scored: list[dict[str, Any]] = [] + for index, item in enumerate(items): + copied = dict(item) + copied["_original_index"] = index + copied["relevance_score"] = _score_item(copied, discovery_scores) + scored.append(copied) + return scored + + +def _rank_items(scored_items: list[dict[str, Any]]) -> list[dict[str, Any]]: + return sorted( + scored_items, + key=lambda item: ( + float(item.get("relevance_score") or 0.0), + int(item.get("chunk_count") or 0), + int(item.get("table_count") or 0) + int(item.get("image_count") or 0), + -int(item.get("_original_index") or 0), + ), + reverse=True, + ) + + +def _score_item( + item: dict[str, Any], + discovery_scores: dict[str, float], +) -> float: + path = normalize_section_path(str(item.get("path") or "")) + if not path: + return 0.0 + score = discovery_scores.get(path, 0.0) + for hint_path, hint_score in discovery_scores.items(): + if PathLedger.is_ancestor(path, hint_path): + score = max(score, float(hint_score) * 0.9) + elif PathLedger.is_ancestor(hint_path, path): + score = max(score, float(hint_score) * 0.65) + return min(score, 1.0) + + +def _expand_allowlist( + ranked_items: list[dict[str, Any]], + *, + budget_mode: str, + limit: int, +) -> set[str]: + if budget_mode != "TIGHT": + return { + normalize_section_path(str(item.get("path") or "")) + for item in ranked_items + if item.get("path") + } + candidates = [ + normalize_section_path(str(item.get("path") or "")) + for item in ranked_items + if item.get("path") + and not item.get("is_leaf") + ] + return set(candidates[:limit]) + + +def _path_has_discovery_signal( + path: str, + discovery_scores: dict[str, float], +) -> bool: + return any( + candidate == path + or PathLedger.is_ancestor(path, candidate) + or PathLedger.is_ancestor(candidate, path) + for candidate in discovery_scores + ) + + +def _asset_search_worthwhile( + ranked_items: list[dict[str, Any]], + asset_kind: Literal["image", "table"], +) -> bool: + count_key = "image_count" if asset_kind == "image" else "table_count" + return any(int(item.get(count_key) or 0) > 0 for item in ranked_items[:5]) + + +def _compact_dict(value: dict[str, Any]) -> str: + bits: list[str] = [] + for key in ("tool", "status", "matched", "candidate_count", "status_detail"): + if key in value: + bits.append(f"{key}={value[key]}") + budget = value.get("budget") + if isinstance(budget, dict): + delta = budget.get("delta") + after = budget.get("after") + if isinstance(delta, dict): + bits.append( + "budget_delta=" + f"used:{delta.get('used', 0)}, " + f"used_pct:{delta.get('used_pct', 0)}, " + f"overdraft:{delta.get('overdraft', 0)}" + ) + if isinstance(after, dict) and int(after.get("overdraft") or 0) > 0: + bits.append(f"budget_overdraft={after.get('overdraft')}") + return ", ".join(bits) if bits else str(value) diff --git a/packages/shared-python/shared/services/retrieval/agentic/navigation/assets.py b/packages/shared-python/shared/services/retrieval/agentic/navigation/assets.py index 48403e3c..ad6367c9 100644 --- a/packages/shared-python/shared/services/retrieval/agentic/navigation/assets.py +++ b/packages/shared-python/shared/services/retrieval/agentic/navigation/assets.py @@ -8,8 +8,12 @@ from sqlalchemy import or_, select from sqlalchemy.ext.asyncio import AsyncSession -from shared.models.database.document import Document, DocumentChunk, DocumentSection +from shared.models.database.document import DocumentChunk, DocumentSection from shared.models.database.job_result import JobResult +from shared.services.retrieval.agentic.core.budget import BudgetExceeded +from shared.services.retrieval.hydration.assets import build_retrieval_asset_url_map +from shared.services.retrieval.llm_adapter import LLMFn +from shared.utils.token_estimate import estimate_tokens def build_connected_owner_map(text_chunks: list[dict[str, Any]]) -> dict[str, str]: @@ -95,26 +99,6 @@ async def count_assets_under_scope( return total_images, total_tables -def build_asset_tools_block(total_images: int, total_tables: int) -> str: - if total_images <= 0 and total_tables <= 0: - return "" - - tools_lines = ["\nOptional asset tools (usable with NAVIGATE or STOP):\n"] - if total_images > 0: - tools_lines.append( - f" FIND_IMAGES — Extract image/chart assets under the current scope ({total_images} available).\n" - ) - if total_tables > 0: - tools_lines.append( - f" FIND_TABLES — Extract table/data assets under the current scope ({total_tables} available).\n" - ) - tools_lines.append( - " Note: with NAVIGATE selections, asset tools are limited to the selected sections; " - "with STOP or no selections, they use the current scope.\n" - ) - return "".join(tools_lines) - - async def resolve_root_asset_owners( db: AsyncSession, *, @@ -248,16 +232,6 @@ async def asset_filter_step( ] owner_by_target_id = build_connected_owner_map(text_row_dicts) - if any(value == "Root" for value in owner_by_target_id.values()): - doc_stmt = select(Document.source_file_name).where( - Document.document_id == document_id - ) - doc_file_name = (await db.execute(doc_stmt)).scalar() or "" - if doc_file_name: - for target_id in list(owner_by_target_id): - if owner_by_target_id[target_id] == "Root": - owner_by_target_id[target_id] = doc_file_name - connected_target_ids: set[str] = set(owner_by_target_id.keys()) if connected_target_ids: connected_rows = ( @@ -340,3 +314,505 @@ async def asset_filter_step( except Exception as exc: logger.error(f" asset_filter_step failed: {exc}") return [] + + +async def search_assets_step( + db: AsyncSession, + *, + document_id: str, + job_result_id: str, + scope_path: str | list[str] | None, + asset_type: str, + query: str, + llm_fn: LLMFn, + vlm_fn: LLMFn | None = None, +) -> dict[str, Any]: + """LLM-filtered asset search. + + For **tables**: uses text LLM with summary descriptions (unchanged). + For **images**: generates presigned S3 URLs via + ``build_retrieval_asset_url_map`` and sends them to the VLM + (``vlm_fn``) for visual relevance judgment. + """ + t0 = time.monotonic() + + all_assets = await asset_filter_step( + db, + document_id=document_id, + job_result_id=job_result_id, + scope_path=scope_path, + asset_type=asset_type, + ) + if not all_assets: + logger.info(f" search_assets_step: no {asset_type} assets under scope={scope_path}") + return { + "status": "empty", + "matched_assets": [], + "verdicts": [], + "candidate_count": 0, + } + + # Build lookup by chunk_id + asset_by_id: dict[str, dict[str, Any]] = {} + for asset in all_assets: + chunk_id = str(asset.get("chunk_id") or "") + if chunk_id: + asset_by_id[chunk_id] = asset + + if not asset_by_id: + return { + "status": "empty", + "matched_assets": [], + "verdicts": [], + "candidate_count": 0, + } + + status_detail = "" + status = "empty" + + # ── Route by asset type ────────────────────────────────────────── + if asset_type == "image": + if vlm_fn is None: + logger.info(" search_assets_step: VLM unavailable for image search") + selected_ids = await _search_assets_via_text_llm( + query=query, + asset_type=asset_type, + assets=list(asset_by_id.values()), + llm_fn=llm_fn, + ) + status = "fallback_matched" if selected_ids else "fallback_empty" + status_detail = "vlm_unavailable_text_fallback" + else: + selected_ids, vlm_error = await _search_images_via_vlm( + query=query, + assets=list(asset_by_id.values()), + vlm_fn=vlm_fn, + ) + if vlm_error: + logger.info( + " search_assets_step: VLM image search fell back to text " + f"filter, reason={vlm_error}" + ) + selected_ids = await _search_assets_via_text_llm( + query=query, + asset_type=asset_type, + assets=list(asset_by_id.values()), + llm_fn=llm_fn, + ) + status = "fallback_matched" if selected_ids else "fallback_empty" + status_detail = "vlm_failed_text_fallback" + else: + status = "matched" if selected_ids else "empty" + else: + selected_ids = await _search_assets_via_text_llm( + query=query, + asset_type=asset_type, + assets=list(asset_by_id.values()), + llm_fn=llm_fn, + ) + status = "matched" if selected_ids else "empty" + + selected_id_set = {str(cid) for cid in selected_ids} + matched_assets = [asset_by_id[cid] for cid in selected_ids if cid in asset_by_id] + verdicts = [ + _asset_verdict( + asset, + relevant=str(asset.get("chunk_id") or "") in selected_id_set, + reason=( + _selected_reason(status) + if str(asset.get("chunk_id") or "") in selected_id_set + else _not_selected_reason(status) + ), + ) + for asset in asset_by_id.values() + ] + + latency = int((time.monotonic() - t0) * 1000) + logger.info( + f" search_assets_step query=\"{query}\" type={asset_type}: " + f"{len(matched_assets)}/{len(all_assets)} assets matched, {latency}ms" + ) + return { + "status": status, + "status_detail": status_detail, + "matched_assets": matched_assets, + "verdicts": verdicts, + "candidate_count": len(asset_by_id), + "latency_ms": latency, + } + + +def _asset_verdict( + asset: dict[str, Any], + *, + relevant: bool, + reason: str, +) -> dict[str, Any]: + metadata = asset.get("chunk_metadata") or {} + summary = metadata.get("summary", "") + return { + "chunk_id": asset.get("chunk_id", ""), + "file_path": asset.get("file_path", ""), + "section_path": asset.get("owner_section_path") or asset.get("section_path", ""), + "summary": summary, + "relevant": relevant, + "reason": reason, + } + + +def _selected_reason(status: str) -> str: + if status.startswith("fallback_"): + return "selected_by_text_fallback" + return "selected_by_asset_inspector" + + +def _not_selected_reason(status: str) -> str: + if status.startswith("fallback_"): + return "not_selected_by_text_fallback" + return "not_selected_by_asset_inspector" + + +async def _search_assets_via_text_llm( + *, + query: str, + asset_type: str, + assets: list[dict[str, Any]], + llm_fn: LLMFn, +) -> list[str]: + """Text-based LLM filtering for table assets.""" + candidates_for_llm, valid_ids, id_to_chunk_id = _project_assets_for_text_filter( + query=query, + asset_type=asset_type, + assets=assets, + ) + + prompt = _format_asset_filter_prompt(query, asset_type, candidates_for_llm) + try: + response = await llm_fn(prompt) + selected_ids = _parse_asset_filter_response(response, valid_ids) + return [ + id_to_chunk_id[row_id] + for row_id in selected_ids + if row_id in id_to_chunk_id + ] + except BudgetExceeded: + raise + except Exception as exc: + logger.warning(f" _search_assets_via_text_llm failed: {exc}") + return [] + + +def _project_assets_for_text_filter( + *, + query: str, + asset_type: str, + assets: list[dict[str, Any]], +) -> tuple[list[dict[str, str]], set[str], dict[str, str]]: + """Project assets into a prompt-sized text view. + + Stable row identifiers are shown to the model. Owner paths stay internal: + reconciliation and hydration use the original asset rows, not prompt text. + Descriptive text is reduced only when the complete prompt would exceed the + navigation planning budget envelope. + """ + projected: list[dict[str, str]] = [] + valid_ids: set[str] = set() + id_to_chunk_id: dict[str, str] = {} + for index, asset in enumerate(assets, start=1): + chunk_id = str(asset.get("chunk_id") or "") + if not chunk_id: + continue + row_id = f"I{index}" if asset_type == "image" else f"T{index}" + metadata = asset.get("chunk_metadata") or {} + summary = str(metadata.get("summary") or "").strip() + file_path = str(asset.get("file_path") or "") + content = str(asset.get("content") or "").strip() + description = summary or (content if asset_type == "table" else "") + projected.append({ + "id": row_id, + "file": file_path, + "desc": description, + }) + valid_ids.add(row_id) + id_to_chunk_id[row_id] = chunk_id + + if not projected: + return projected, valid_ids, id_to_chunk_id + + prompt = _format_asset_filter_prompt(query, asset_type, projected) + prompt_budget = _asset_filter_prompt_budget() + if estimate_tokens(prompt) <= prompt_budget: + return projected, valid_ids, id_to_chunk_id + + structural_prompt = _format_asset_filter_prompt( + query, + asset_type, + [ + { + "id": item["id"], + "file": item["file"], + "desc": "", + } + for item in projected + ], + ) + structural_tokens = estimate_tokens(structural_prompt) + desc_budget = max(prompt_budget - structural_tokens, len(projected)) + per_item_desc_tokens = max(desc_budget // len(projected), 1) + compacted = [ + { + "id": item["id"], + "file": item["file"], + "desc": _fit_text_to_token_budget(item["desc"], per_item_desc_tokens), + } + for item in projected + ] + return compacted, valid_ids, id_to_chunk_id + + +def _asset_filter_prompt_budget() -> int: + from shared.services.retrieval.agentic.core.runtime import build_config_from_env + + config = build_config_from_env() + planning_capacity = int( + max(config.token_budget_total - config.bootstrap_budget, 0) + * config.planning_ratio + ) + return max(planning_capacity, 1) + + +def _fit_text_to_token_budget(text: str, token_budget: int) -> str: + text = text.strip() + if not text or estimate_tokens(text) <= token_budget: + return text + words = text.split() + if len(words) > 1: + kept: list[str] = [] + for word in words: + candidate = " ".join([*kept, word]) + if estimate_tokens(candidate) > token_budget: + break + kept.append(word) + return " ".join(kept).strip() + + lo = 0 + hi = len(text) + best = "" + while lo <= hi: + mid = (lo + hi) // 2 + candidate = text[:mid].strip() + if estimate_tokens(candidate) <= token_budget: + best = candidate + lo = mid + 1 + else: + hi = mid - 1 + return best + + +async def _search_images_via_vlm( + *, + query: str, + assets: list[dict[str, Any]], + vlm_fn: LLMFn, +) -> tuple[list[str], str | None]: + """VLM-based image search with presigned S3 URLs. + + Generates presigned URLs for each image asset, builds a multimodal + prompt with image_url blocks, and asks the VLM to select relevant ones. + """ + url_map = await build_retrieval_asset_url_map( + assets, log_context="search_images_vlm", + ) + + # Only include images that have valid URLs. + candidates: list[tuple[str, str, str]] = [] # (row_id, file_path, url) + valid_ids: set[str] = set() + id_to_chunk_id: dict[str, str] = {} + for index, asset in enumerate(assets, start=1): + chunk_id = str(asset.get("chunk_id") or "") + url = url_map.get(chunk_id) + if not url: + continue + row_id = f"I{index}" + file_path = asset.get("file_path") or "" + candidates.append((row_id, file_path, url)) + valid_ids.add(row_id) + id_to_chunk_id[row_id] = chunk_id + + if not candidates: + logger.info(" _search_images_via_vlm: no presigned URLs available, skipping") + return [], "no_presigned_urls" + + messages = _format_vlm_image_filter_messages(query, candidates) + try: + response = await vlm_fn(messages) + selected_ids = _parse_asset_filter_response(response, valid_ids) + return [ + id_to_chunk_id[row_id] + for row_id in selected_ids + if row_id in id_to_chunk_id + ], None + except BudgetExceeded: + raise + except Exception as exc: + logger.warning(f" _search_images_via_vlm failed: {exc}") + return [], str(exc) + + +def _format_asset_filter_prompt( + query: str, + asset_type: str, + candidates: list[dict[str, str]], +) -> str: + """Build the text LLM prompt for table asset filtering.""" + type_label = "images" if asset_type == "image" else "tables" + items_text = _format_asset_candidates_table(candidates) + example_id = candidates[0]["id"] if candidates else ("I1" if asset_type == "image" else "T1") + return ( + f"You are an asset relevance filter.\n\n" + f"Original user query: {query}\n\n" + f"Below are {len(candidates)} {type_label} from a document. " + f"Select ONLY assets that directly satisfy the user's query.\n\n" + f"Selection policy:\n" + f"- Match the requested asset type and the requested subject. " + f"Being an image/chart/table is not enough.\n" + f"- Do not select assets only because they belong to the same broad " + f"domain as the query.\n" + f"- Do not broaden specific market, instrument, company, metric, or " + f"entity terms. Neighboring topics are not matches unless the candidate " + f"explicitly connects them to the requested subject.\n" + f"- Treat words like \"all\" as all relevant assets, not all visible " + f"candidates.\n" + f"- If the file name, summary, or content signal does not " + f"directly support relevance, leave it out.\n" + f"- If uncertain, do not select the asset.\n\n" + f"=== Candidate {type_label.title()} ===\n{items_text}\n=== End ===\n\n" + f"Return ONLY a JSON array of matching row IDs, e.g.: " + f'["{example_id}"]\n' + f"If none are relevant, return an empty array: []\n" + f"Do not include any explanation." + ) + + +def _format_asset_candidates_table(candidates: list[dict[str, str]]) -> str: + lines = [ + "| ID | File | Summary / content signal |", + "|---|---|---|", + ] + for candidate in candidates: + lines.append( + "| " + + " | ".join([ + _markdown_cell(candidate.get("id", "")), + _markdown_cell(candidate.get("file", "")), + _markdown_cell(candidate.get("desc", "")), + ]) + + " |" + ) + return "\n".join(lines) + + +def _markdown_cell(value: str) -> str: + return ( + str(value or "") + .replace("\n", " ") + .replace("\r", " ") + .replace("|", "\\|") + .strip() + ) + + +def _format_vlm_image_filter_messages( + query: str, + candidates: list[tuple[str, str, str]], +) -> list[dict[str, Any]]: + """Build multimodal VLM messages with inline image URLs. + + Each candidate is (row_id, file_path, presigned_url). + The VLM sees the actual images and decides relevance. + """ + content_parts: list[dict[str, Any]] = [ + { + "type": "text", + "text": ( + f"You are an image relevance filter.\n\n" + f"Original user query: {query}\n\n" + f"Below are {len(candidates)} images from a document. " + f"Look at each image and select ONLY images that directly " + f"satisfy the user's query.\n\n" + f"Selection policy:\n" + f"- Match both the requested visual type and requested subject.\n" + f"- Do not select images only because they are charts or from " + f"the same broad domain.\n" + f"- Do not broaden specific market, instrument, company, metric, " + f"or entity terms. Neighboring topics are not matches unless " + f"the image explicitly connects them to the requested subject.\n" + f"- Treat words like \"all\" as all relevant images, not all " + f"visible candidates.\n" + f"- If uncertain, do not select the image.\n\n" + ), + }, + ] + + for row_id, file_path, url in candidates: + content_parts.append({ + "type": "text", + "text": f'Image {row_id} file="{file_path}":', + }) + content_parts.append({ + "type": "image_url", + "image_url": {"url": url}, + }) + + content_parts.append({ + "type": "text", + "text": ( + f"\n\nReturn ONLY a JSON array of matching image row IDs, e.g.: " + f'["{candidates[0][0]}"]\n' + f"If none are relevant, return an empty array: []\n" + f"Do not include any explanation." + ), + }) + + return [{"role": "user", "content": content_parts}] + + +def _parse_asset_filter_response( + text: str, + valid_ids: set[str], +) -> list[str]: + """Parse LLM response for asset filter and keep valid row IDs.""" + import json + import re + + text = text.strip() + + # Try direct JSON parse + try: + result = json.loads(text) + if isinstance(result, list): + return [str(item) for item in result if str(item) in valid_ids] + except (ValueError, json.JSONDecodeError): + pass + + # Try extracting from code fence + fence_match = re.search(r"```(?:json)?\s*\n?(.*?)\n?```", text, re.DOTALL) + if fence_match: + try: + result = json.loads(fence_match.group(1).strip()) + if isinstance(result, list): + return [str(item) for item in result if str(item) in valid_ids] + except (ValueError, json.JSONDecodeError): + pass + + # Try finding any JSON array + bracket_match = re.search(r"\[.*?\]", text, re.DOTALL) + if bracket_match: + try: + result = json.loads(bracket_match.group()) + if isinstance(result, list): + return [str(item) for item in result if str(item) in valid_ids] + except (ValueError, json.JSONDecodeError): + pass + + return [] diff --git a/packages/shared-python/shared/services/retrieval/agentic/navigation/document.py b/packages/shared-python/shared/services/retrieval/agentic/navigation/document.py index 230cffc2..a09dac70 100644 --- a/packages/shared-python/shared/services/retrieval/agentic/navigation/document.py +++ b/packages/shared-python/shared/services/retrieval/agentic/navigation/document.py @@ -2,11 +2,11 @@ Collector Agent architecture ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -The navigation loop uses a Collector Agent model where each step -independently produces two decisions: +The navigation loop uses an observe-act Collector Agent model where each step +produces one main action and optional collection side effects: 1. **collect**: paths to add to the evidence collection -2. **action**: navigation direction (DRILL/BACK/STOP) +2. **action**: EXPAND/BACK/SEARCH_IMAGES/SEARCH_TABLES/FINISH The ``collected_paths`` list accumulates across all steps. After navigation completes (or is interrupted), a single batch hydration @@ -31,6 +31,7 @@ AgentRunConfig, AgentState, CandidateDoc, + DecisionTraceStep, DocTreeNode, NavigateStepResult, ToolResult, @@ -38,8 +39,11 @@ from shared.services.retrieval.agentic.navigation.selection_hydration import ( hydrate_path_selections_into_node, ) -from shared.services.retrieval.agentic.discovery.selection import ( - DiscoverySelectResult, +from shared.services.retrieval.agentic.navigation.path_ledger import PathLedger +from shared.services.retrieval.agentic.navigation.state import NavigationState +from shared.services.retrieval.agentic.prompts import ( + QUERY_INTENT_PROMPT, + parse_query_intent_response, ) from shared.services.retrieval.llm_adapter import LLMFn @@ -81,15 +85,18 @@ async def navigate_selected_documents(self) -> None: logger.info( f" agentic: Phase 2 — navigating {len(self._state.selected_docs)} documents" ) + query_intent = await self._classify_query_intent() for doc in self._state.selected_docs: if self._state.elapsed_ms >= self._config.latency_budget_ms: logger.info(" agentic: latency budget hit during Phase 2, stopping") break - await self._navigate_document(doc) + await self._navigate_document(doc, query_intent=query_intent) async def _navigate_document( self, doc: CandidateDoc, + *, + query_intent: str, ) -> None: job_result_id = self._state.doc_job_map.get(doc.document_id, "") if not job_result_id: @@ -100,6 +107,14 @@ async def _navigate_document( doc_name = doc.source_file_name or self._state.doc_id_to_name.get(doc.document_id, "") root = DocTreeNode(scope_path=None) doc_pending_assets: list[dict[str, Any]] = [] + from shared.services.retrieval.agentic.navigation.section_tree import ( + load_document_section_rows, + ) + section_rows = await load_document_section_rows( + self._db, + document_id=doc.document_id, + job_result_id=job_result_id, + ) # Phase 2A: Collector Agent navigation (summary-only, no content hydration) doc_pending_assets, collected_paths = await self._navigate_collector( @@ -107,26 +122,21 @@ async def _navigate_document( root=root, doc_name=doc_name, job_result_id=job_result_id, + section_rows=section_rows, + query_intent=query_intent, ) - # Phase 2B: Discovery hints (independent hydration path) - await self._hydrate_discovery_hints( - doc=doc, - root=root, - doc_name=doc_name, - collected_paths=collected_paths, - ) - - # Phase 2C: Batch hydrate all collected paths + # Phase 2B: Batch hydrate all collected paths if collected_paths: await self._hydrate_collected( doc=doc, root=root, job_result_id=job_result_id, collected_paths=collected_paths, + section_rows=section_rows, ) - # Phase 2D: Reconcile assets into hydrated tree + # Phase 2C: Reconcile assets into hydrated tree if doc_pending_assets: self._reconcile_pending_assets( doc=doc, @@ -150,36 +160,58 @@ async def _navigate_collector( root: DocTreeNode, doc_name: str, job_result_id: str, + section_rows: list, + query_intent: str, ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: """Collector Agent navigation loop. Returns (doc_pending_assets, collected_paths). """ doc_exclude: set[str] = set() - nav_trace: list[dict[str, Any]] = [] - collected_paths: list[dict[str, Any]] = [] + doc_discovery_hints = self._discovery_by_doc.get(doc.document_id, []) doc_pending_assets: list[dict[str, Any]] = [] + nav_state = NavigationState( + document_id=doc.document_id, + document_name=doc_name, + job_result_id=job_result_id, + ) + # Context from SEARCH tools — injected into next navigate prompt + search_context: str = "" + prior_tool_result: dict[str, Any] | None = None - # Scope stack for BACK support: each entry is a scope path (None = root) - scope_stack: list[str | None] = [None] - step_count = 0 + exit_reason = "unknown" + budget_failure: dict[str, Any] | None = None - while step_count < self._config.max_nav_steps: + while nav_state.step_count < self._config.max_nav_steps: if self._state.elapsed_ms >= self._config.latency_budget_ms: + exit_reason = "latency" break - if self._llm_fn is None: - break - if self._state.ledger and self._state.ledger.status("planning") in ("CRITICAL", "EXHAUSTED"): - logger.info(" agentic: planning budget critical, ending navigation for current doc") + has_tool_context = bool(prior_tool_result or search_context) + if ( + self._state.ledger + and self._state.ledger.status("planning") == "EXHAUSTED" + and not has_tool_context + ): + logger.info(" agentic: planning budget exhausted, ending navigation for current doc") + exit_reason = "budget" break - current_scope = scope_stack[-1] - step_count += 1 + nav_state.step_count += 1 + before_scope = nav_state.current_scope + expanded_before = set(nav_state.expanded_scopes) + rejected_before = set(nav_state.rejected_paths) + rejected_collect_before = set(nav_state.rejected_collect_paths) + collected_before_count = len(nav_state.collected_paths) doc_llm_fn = self._llm_budget.for_document( cast(LLMFn, self._llm_fn), doc_id=doc.document_id, - step=step_count, + step=nav_state.step_count, + allow_overdraft=has_tool_context, + overdraft_reason=( + "report_tool_result_to_main_agent" + if has_tool_context else "" + ), ) try: nav_result = await tools.navigate_step( @@ -191,83 +223,325 @@ async def _navigate_collector( user_id=self._user_id, namespace=self._namespace, doc_name=doc_name, - scope_path=current_scope, + scope_path=nav_state.current_scope, exclude_paths=doc_exclude, budget_snapshot=self._state.ledger.snapshot() if self._state.ledger else None, - nav_trace=nav_trace if nav_trace else None, - collected_paths=collected_paths, + nav_trace=nav_state.nav_trace if nav_state.nav_trace else None, + collected_paths=nav_state.collected_paths, + expanded_scopes=nav_state.expanded_scopes, + rejected_paths=nav_state.rejected_paths, + rejected_collect_paths=nav_state.rejected_collect_paths, + disabled_asset_types=nav_state.blocked_asset_types_for_scope( + nav_state.current_scope + ), + discovery_hints=doc_discovery_hints, + section_rows=section_rows, + query_intent=query_intent, + search_context=search_context, + prior_tool_result=prior_tool_result, + ) + except BudgetExceeded as exc: + budget_failure = getattr(exc, "details", {}) or {} + logger.info( + " agentic: planning budget exhausted during navigation " + f"details={budget_failure}" ) - except BudgetExceeded: - logger.info(" agentic: planning budget exhausted during navigation") if self._trace_enabled: self._trace.record_budget_stop("planning_exhausted") + exit_reason = "budget" break self._state.step_count += 1 - # Asset collection runs during navigation (images/tables). - await self._collect_assets( + # Clear previous tool context (consumed by this step's prompt) + search_context = "" + prior_tool_result = None + + # ── Execute asset tools (SEARCH) ───────────────────────────── + asset_tool_result = await self._execute_asset_tools( doc=doc, - scope=current_scope, - step_node=nav_result.node, - asset_tools=nav_result.tools, + job_result_id=job_result_id, + scope=nav_state.current_scope, + nav_result=nav_result, pending_assets=doc_pending_assets, - round_scope="nav", + parent_step_index=len(self._decision_steps), ) + search_context = asset_tool_result.get("context", "") + prior_tool_result = asset_tool_result.get("summary") + if ( + prior_tool_result is not None + and self._should_block_asset_search( + prior_tool_result, nav_state.current_scope, + ) + ): + nav_state.block_asset_search( + nav_state.current_scope, + str(prior_tool_result.get("asset_type") or ""), + ) # Merge outline + confidence into root tree _merge_step_node(root, nav_result.node) + requested_collects = list(nav_result.collect) + collect_reconcile = self._reconcile_collects_after_tool( + nav_result=nav_result, + asset_tool_result=asset_tool_result, + ) + rejected_collects = collect_reconcile["rejected_collects"] + if rejected_collects: + nav_result.collect = collect_reconcile["accepted_collects"] + for path in rejected_collects: + nav_state.mark_rejected_collect(path) + doc_exclude.add(path) + logger.info( + " agentic: tool reconciliation rejected collects: " + f"{rejected_collects}" + ) + # ── Process COLLECT ────────────────────────────────────────── collected_in_step: list[str] = [] for coll_item in nav_result.collect: path = coll_item["path"] - coll_item["collected_at_step"] = step_count - coll_item["scope_context"] = current_scope or "root" - collected_paths.append(coll_item) + nav_state.add_collected( + coll_item, + step=nav_state.step_count, + scope_context=nav_state.current_scope, + ) collected_in_step.append(path) - # Collected paths should be excluded from future navigation - doc_exclude.add(path) + # Outline collections should NOT exclude children — the intent + # is "see structure, then drill deeper for full content". + if coll_item.get("hydrate_mode") != "outline": + doc_exclude.add(path) + + # ── Process navigation action ──────────────────────────────── + should_break = False + if ( + nav_result.action == "EXPAND" + and nav_result.result_status == "ok" + and nav_result.drill_into + ): + drill_path = nav_result.drill_into + # Create child node in tree for the drill target + target_parent = _find_target_node(root, drill_path) + target_parent.children.setdefault(drill_path, DocTreeNode(scope_path=drill_path)) + nav_state.mark_expanded(drill_path) + nav_state.current_scope = drill_path + + elif nav_result.action == "BACK" and nav_result.result_status == "ok": + if nav_state.current_scope is None: + logger.info(" agentic: BACK at root scope, staying at root") + nav_result.result_status = "invalid_back" + nav_result.result_note = "already_at_root" + else: + back_target = nav_result.back_to # None = root + if PathLedger.valid_back_target(nav_state.current_scope, back_target): + nav_state.mark_rejected_if_unproductive(nav_state.current_scope) + nav_state.current_scope = back_target + else: + logger.warning( + f" agentic: invalid back_to='{back_target}' " + f"from scope='{nav_state.current_scope}'" + ) + nav_result.result_status = "invalid_back" + nav_result.result_note = f"invalid_back_target: {back_target}" + + elif nav_result.action == "ERROR": + logger.warning( + f" agentic: navigation ERROR for doc={doc.document_id}: " + f"{nav_result.error_reason or nav_result.reason}" + ) + exit_reason = "error" + should_break = True + + elif nav_result.action == "FINISH" and nav_result.result_status == "ok": + exit_reason = "llm_finish" + should_break = True # ── Build trace entry ──────────────────────────────────────── + state_delta = nav_state.snapshot_delta( + before_scope=before_scope, + expanded_before=expanded_before, + rejected_before=rejected_before, + rejected_collect_before=rejected_collect_before, + collected_before_count=collected_before_count, + ) trace_entry: dict[str, Any] = { - "step": step_count, - "scope": current_scope or "root", + "step": nav_state.step_count, + "scope": before_scope or "root", "action": nav_result.action, "drill_into": nav_result.drill_into, + "back_to": nav_result.back_to, "collected": collected_in_step, + "tools_used": nav_result.tools, "reason": nav_result.reason, + "result_status": nav_result.result_status, + "state_delta": state_delta, } - nav_trace.append(trace_entry) + if rejected_collects: + trace_entry["requested_collects"] = [ + item.get("path", "") + for item in requested_collects + if item.get("path") + ] + trace_entry["rejected_collects"] = rejected_collects + trace_entry["tool_reconciliation"] = collect_reconcile["reason"] + # Record tool usage & results so future steps can see search history + if prior_tool_result: + trace_entry["tool_results"] = prior_tool_result + nav_state.nav_trace.append(trace_entry) # ── Record decision step ───────────────────────────────────── - self._record_navigation_step( + main_step_index = self._record_navigation_step( doc=doc, - scope=current_scope, - step_num=step_count, + scope=before_scope, + step_num=nav_state.step_count, nav_result=nav_result, collected_in_step=collected_in_step, + asset_summary=asset_tool_result.get("summary"), + rejected_collects=rejected_collects, + state_delta=state_delta, ) + asset_trace = asset_tool_result.get("asset_trace") + if asset_trace: + asset_trace["parent_step_index"] = main_step_index + self._append_decision_trace_step(DecisionTraceStep(**asset_trace)) - # ── Process navigation action ──────────────────────────────── - if nav_result.action == "DRILL" and nav_result.drill_into: - drill_path = nav_result.drill_into - # Create child node in tree for the drill target - target_parent = _find_target_node(root, drill_path) - target_parent.children.setdefault(drill_path, DocTreeNode(scope_path=drill_path)) - scope_stack.append(drill_path) - - elif nav_result.action == "BACK": - if len(scope_stack) > 1: - scope_stack.pop() - else: - # Already at root, treat as STOP - logger.info(" agentic: BACK at root scope, treating as STOP") - break - - elif nav_result.action == "STOP" or nav_result.is_terminal: + if should_break: break - - return doc_pending_assets, collected_paths + else: + # while loop exhausted — max_nav_steps reached + exit_reason = "max_steps" + + # Hard guard: if navigation is forcefully interrupted and collected + # nothing, collect visible leaf children under the last explored + # scope. Voluntary FINISH/BACK with empty collection is respected. + # budget – planning pool EXHAUSTED (pre-check or exception) + # latency – elapsed time exceeded latency budget + # max_steps – navigation step count limit reached + # error – unexpected exception during navigate_step + forced_exits = ("budget", "latency", "max_steps", "error") + guard_triggered = False + if not nav_state.collected_paths and nav_state.step_count > 0 and exit_reason in forced_exits: + guard_triggered = True + guard_scope = nav_state.current_scope + logger.info( + f" agentic: forced exit ({exit_reason}) with 0 collected paths, " + f"auto-collecting leaves under scope={guard_scope or 'root'}" + ) + from shared.services.retrieval.agentic.navigation.section_tree import ( + load_child_sections, + ) + guard_items = await load_child_sections( + self._db, + doc.document_id, + job_result_id, + guard_scope, + section_rows=section_rows, + ) + for item in guard_items: + if not item.get("show_summary", True): + continue + if item.get("is_leaf"): + path = item["path"] + nav_state.collected_paths.append({ + "path": path, + "confidence": 0.4, + "hydrate_mode": "chunks", + "collected_at_step": nav_state.step_count, + "scope_context": guard_scope or "root", + "guard_reason": f"forced_exit_{exit_reason}", + }) + + self._append_decision_trace_step(DecisionTraceStep( + step_index=len(self._decision_steps), + agent="system", + phase="system_guard", + document_id=doc.document_id, + document=doc.source_file_name or "", + scope=guard_scope or "root", + observation={ + "exit_reason": exit_reason, + "collected_count": 0, + }, + decision={ + "action": "auto_collect_visible_leaves", + "args": {"scope": guard_scope or "root"}, + "reason": "hard navigation constraint stopped the loop before evidence collection", + }, + result={ + "status": "guard_auto_collect", + "collected": [ + { + "path": p["path"], + "confidence": p.get("confidence", 0.0), + } + for p in nav_state.collected_paths + ], + "note": f"forced_exit_{exit_reason}", + }, + budget=self._state.ledger.snapshot() if self._state.ledger else None, + elapsed_ms=self._state.elapsed_ms, + )) + + # ── Navigate summary — record exit reason and final state ───── + doc_name = doc.source_file_name or self._state.doc_id_to_name.get(doc.document_id, "") + self._append_decision_trace_step(DecisionTraceStep( + step_index=len(self._decision_steps), + agent="navigator", + phase="navigate_summary", + document_id=doc.document_id, + document=doc_name, + scope=nav_state.current_scope or "root", + observation={ + "total_steps": nav_state.step_count, + "collected_count": len(nav_state.collected_paths), + "guard_triggered": guard_triggered, + "query_intent": query_intent, + "budget_failure": budget_failure, + }, + decision={ + "action": "summarize_navigation", + "args": {}, + "reason": f"Navigation ended with exit_reason={exit_reason}", + }, + result={ + "status": exit_reason, + "final_scope": nav_state.current_scope or "root", + "collected": [ + { + "path": p.get("path", ""), + "confidence": p.get("confidence", 0.0), + } + for p in nav_state.collected_paths + ], + }, + budget=self._state.ledger.snapshot() if self._state.ledger else None, + elapsed_ms=self._state.elapsed_ms, + )) + + return doc_pending_assets, nav_state.collected_paths + + async def _classify_query_intent( + self, + ) -> str: + """Classify query intent as advisory state. Fail-open to UNKNOWN.""" + if self._llm_fn is None: + return "UNKNOWN" + prompt = QUERY_INTENT_PROMPT.format( + query=self._query, + ) + try: + response = await self._llm_budget.call( + cast(LLMFn, self._llm_fn), + prompt, + pool="planning", + priority="low", + ) + except Exception as exc: + logger.info( + f" agentic: query intent classifier failed-open: {exc}" + ) + return "UNKNOWN" + return parse_query_intent_response(response) async def _hydrate_collected( self, @@ -276,16 +550,19 @@ async def _hydrate_collected( root: DocTreeNode, job_result_id: str, collected_paths: list[dict[str, Any]], + section_rows: list, ) -> None: """Batch-hydrate all collected paths after navigation completes.""" if not collected_paths: return - # Deduplicate: keep highest confidence per path + # Deduplicate: keep the most complete evidence mode per path. A later + # full collect is an upgrade over an earlier outline collect even when + # the outline confidence was higher. deduped: dict[str, dict[str, Any]] = {} for item in collected_paths: path = item["path"] - if path not in deduped or item.get("confidence", 0) > deduped[path].get("confidence", 0): + if path not in deduped or _collect_rank(item) > _collect_rank(deduped[path]): deduped[path] = item unique_selections = list(deduped.values()) @@ -320,6 +597,7 @@ async def _hydrate_collected( section_items = await load_child_sections( self._db, doc.document_id, job_result_id, path, limit_depth=False, + section_rows=section_rows, ) if section_items: # Filter out the scope node itself AND ancestor/sibling @@ -351,135 +629,257 @@ async def _hydrate_collected( f"collected={len(unique_selections)} hydrated_chunks={total_chunks}" ) - async def _collect_assets( + async def _execute_asset_tools( self, *, doc: CandidateDoc, + job_result_id: str, scope: str | None, - step_node: DocTreeNode, - asset_tools: list[str], + nav_result: NavigateStepResult, pending_assets: list[dict[str, Any]], - round_scope: str, - ) -> None: - selected_asset_scopes = list(step_node.confidence.keys()) - asset_scope: str | list[str] | None = selected_asset_scopes or scope - for asset_tool in asset_tools: - if asset_tool not in ("FIND_IMAGES", "FIND_TABLES"): - continue - asset_type = "image" if asset_tool == "FIND_IMAGES" else "table" - asset_chunks = await tools.asset_filter_step( - self._db, - document_id=doc.document_id, - job_result_id=self._state.doc_job_map.get(doc.document_id, ""), - scope_path=asset_scope, - asset_type=asset_type, - ) - if asset_chunks: - pending_assets.extend(asset_chunks) + parent_step_index: int, + ) -> dict[str, Any]: + """Execute SEARCH_* and return the observation for the next loop.""" + empty_result: dict[str, Any] = { + "context": "", + "summary": None, + "asset_trace": None, + } + if not nav_result.tools: + return empty_result + + if not nav_result.search_assets_params: + return empty_result + + params = nav_result.search_assets_params + search_query = params["query"] + asset_type = params["asset_type"] + scope_paths = params.get("scope_paths") + tool_name = "SEARCH_IMAGES" if asset_type == "image" else "SEARCH_TABLES" + budget_before = self._state.ledger.snapshot() if self._state.ledger else None + + search_llm_fn = self._llm_budget.for_document( + cast(LLMFn, self._llm_fn), + doc_id=doc.document_id, + step=self._state.step_count, + allow_overdraft=True, + overdraft_reason=f"{tool_name}_asset_inspector", + ) if self._llm_fn else None + if search_llm_fn is None: + summary = { + "tool": tool_name, + "asset_type": asset_type, + "query": search_query, + "matched": 0, + "status": "unavailable", + "scope_paths": scope_paths if scope_paths is not None else ([scope] if scope else []), + "matched_paths": [], + "sub_agent_assessment": "LLM unavailable for asset inspection", + } + return { + "context": self._format_asset_context( + tool_name, + asset_type, + search_query, + [], + status="unavailable", + ), + "summary": summary, + "asset_trace": self._build_asset_trace_payload( + doc=doc, + scope=scope, + parent_step_index=parent_step_index, + asset_type=asset_type, + query=search_query, + candidates=[], + result={"status": "unavailable", "verdicts": [], "matched_assets": []}, + ), + } - scope_display = ( - asset_scope - if isinstance(asset_scope, list) - else (asset_scope or "root") - ) - if self._trace_enabled: - self._trace.record_step( - "asset_filter_step", - ToolResult( - status="filtered" if asset_chunks else "empty", - payload={ - "document_id": doc.document_id, - "scope": scope_display, - "navigation_scope": scope or "root", - "asset_type": asset_type, - "chunks_found": len(asset_chunks) if asset_chunks else 0, - }, - ), - decision_reason=f"asset_{round_scope}_{doc.source_file_name}", + vlm_fn: LLMFn | None = None + if asset_type == "image": + from shared.services.retrieval.llm_adapter import create_retrieval_vlm_fn + raw_vlm_fn = create_retrieval_vlm_fn() + if raw_vlm_fn is not None: + vlm_fn = self._llm_budget.for_document( + raw_vlm_fn, + doc_id=doc.document_id, + step=self._state.step_count, + allow_overdraft=True, + overdraft_reason=f"{tool_name}_vlm_asset_inspector", ) - logger.info( - f" agentic step {self._state.step_count}: asset_filter_step " - f'doc="{doc.source_file_name}" scope={scope_display} ' - f"type={asset_type} chunks={len(asset_chunks) if asset_chunks else 0}" - ) - async def _hydrate_discovery_hints( - self, - *, - doc: CandidateDoc, - root: DocTreeNode, - doc_name: str, - collected_paths: list[dict[str, Any]] | None = None, - ) -> None: - doc_hints = self._discovery_by_doc.get(doc.document_id, []) - if not doc_hints or self._llm_fn is None: - return - if self._state.elapsed_ms >= self._config.latency_budget_ms: - return - - discovery_exclude_paths = _build_discovery_exclude_set( - root, collected_paths or [] - ) - - doc_discovery_llm_fn = self._llm_budget.for_discovery( - cast(LLMFn, self._llm_fn), - doc_id=doc.document_id, - low_priority=root.has_content(), - ) try: - result = await tools.discovery_select_step( + asset_result = await tools.search_assets_step( self._db, document_id=doc.document_id, - query=self._query, - llm_fn=doc_discovery_llm_fn, - user_id=self._user_id, - namespace=self._namespace, - doc_name=doc_name, - discovery_hints=doc_hints, - exclude_paths=discovery_exclude_paths, - budget_snapshot=self._state.ledger.snapshot() if self._state.ledger else None, + job_result_id=job_result_id, + scope_path=scope_paths if scope_paths is not None else scope, + asset_type=asset_type, + query=search_query, + llm_fn=search_llm_fn, + vlm_fn=vlm_fn, + ) + except BudgetExceeded as exc: + failure = getattr(exc, "details", {}) or {} + logger.info( + f" agentic: {tool_name} skipped — planning budget exhausted " + f"details={failure}" ) - except BudgetExceeded: - logger.info(" agentic: planning budget exhausted during discovery selection") - if self._trace_enabled: - self._trace.record_budget_stop("planning_exhausted") - result = DiscoverySelectResult(node=DocTreeNode(scope_path=None)) - self._state.step_count += 1 + asset_result = { + "status": "budget_exceeded", + "status_detail": "budget_reserve_failed", + "budget_failure": failure, + "matched_assets": [], + "verdicts": [], + "candidate_count": 0, + } - discovery_node = result.node - excluded_hints = result.excluded_hints + budget_after = self._state.ledger.snapshot() if self._state.ledger else None + budget_delta = _budget_delta(budget_before, budget_after) + matched_assets = asset_result.get("matched_assets") or [] + verdicts = asset_result.get("verdicts") or [] + candidate_count = int(asset_result.get("candidate_count") or 0) + if matched_assets: + pending_assets.extend(matched_assets) + + summary = { + "tool": tool_name, + "asset_type": asset_type, + "query": search_query, + "matched": len(matched_assets), + "candidate_count": candidate_count, + "status": asset_result.get("status", "empty"), + "status_detail": asset_result.get("status_detail", ""), + "budget": { + "before": _compact_budget_snapshot(budget_before), + "after": _compact_budget_snapshot(budget_after), + "delta": budget_delta, + }, + "scope_paths": scope_paths if scope_paths is not None else ([scope] if scope else []), + "matched_paths": [ + asset.get("file_path", "") + for asset in matched_assets + if asset.get("file_path") + ], + "matched_owner_paths": [ + asset.get("owner_section_path") or asset.get("section_path") or "" + for asset in matched_assets + if asset.get("owner_section_path") or asset.get("section_path") + ], + "sub_agent_assessment": ( + f"asset inspector matched {len(matched_assets)} " + f"of {candidate_count} {asset_type} candidates " + f"(status={asset_result.get('status', 'empty')})" + ), + } + logger.info( + f" agentic step {self._state.step_count}: {tool_name} " + f'doc="{doc.source_file_name}" scope={scope or "root"} ' + f'search_scope={scope_paths if scope_paths is not None else scope or "root"} ' + f'query="{search_query}" matched={len(matched_assets)}' + ) + return { + "context": self._format_asset_context( + tool_name, + asset_type, + search_query, + matched_assets, + status=str(asset_result.get("status", "empty")), + status_detail=str(asset_result.get("status_detail", "")), + ), + "summary": summary, + "matched_assets": matched_assets, + "candidate_count": candidate_count, + "asset_trace": self._build_asset_trace_payload( + doc=doc, + scope=scope, + parent_step_index=parent_step_index, + asset_type=asset_type, + query=search_query, + candidates=verdicts, + result=asset_result, + budget_before=budget_before, + budget_after=budget_after, + budget_delta=budget_delta, + ), + } - if self._trace_enabled: - self._trace.record_step( - "discovery_select_step", - ToolResult( - status="selected" if discovery_node.has_content() else "empty", - payload={ - "document_id": doc.document_id, - "hints_count": len(doc_hints), - "hydrated_count": len(discovery_node.leaf_content), - "excluded_count": len(excluded_hints), - }, - ), - decision_reason=f"discovery_{doc.source_file_name}", - ) - self._decision_steps.append({ - "phase": "discovery_select", - "document": doc_name, - "document_id": doc.document_id, - "action": "select" if discovery_node.has_content() else "skip", + @staticmethod + def _should_block_asset_search( + summary: dict[str, Any] | None, + current_scope: str | None, + ) -> bool: + if not summary: + return False + if int(summary.get("matched") or 0) > 0: + return False + if not _tool_searched_current_scope(summary, current_scope): + return False + return str(summary.get("status") or "").lower() in { + "empty", + "fallback_empty", + "unavailable", + "error", + "budget_exceeded", + } + + @staticmethod + def _reconcile_collects_after_tool( + *, + nav_result: NavigateStepResult, + asset_tool_result: dict[str, Any], + ) -> dict[str, Any]: + accepted_collects = list(nav_result.collect) + empty = { + "accepted_collects": accepted_collects, + "rejected_collects": [], "reason": "", - "candidate_count": result.candidate_count, - "hydrated_count": len(discovery_node.leaf_content), - "selected_paths": list(discovery_node.leaf_content.keys()), - "excluded_hints": excluded_hints, - "exclude_set": sorted(discovery_exclude_paths), - }) - root.merge(discovery_node) - if self._state.ledger is not None: - self._state.ledger.mark_explored( - chunks=sum(len(chunks) for chunks in discovery_node.leaf_content.values()), + } + if not accepted_collects or not nav_result.tools: + return empty + summary = asset_tool_result.get("summary") + if not isinstance(summary, dict): + return empty + tool_name = str(summary.get("tool") or "") + if tool_name not in {"SEARCH_IMAGES", "SEARCH_TABLES"}: + return empty + status = str(summary.get("status") or "").lower() + if status not in {"empty", "fallback_empty", "matched", "fallback_matched"}: + return empty + + matched_assets = asset_tool_result.get("matched_assets") or [] + matched_owner_paths = [ + str(asset.get("owner_section_path") or asset.get("section_path") or "") + for asset in matched_assets + if asset.get("owner_section_path") or asset.get("section_path") + ] + still_accepted: list[dict[str, Any]] = [] + rejected_paths: list[str] = [] + for item in accepted_collects: + path = str(item.get("path") or "") + if not path: + continue + has_matching_asset = any( + PathLedger.is_same_or_descendant(owner_path, path) + for owner_path in matched_owner_paths ) + if has_matching_asset: + still_accepted.append(item) + else: + rejected_paths.append(path) + + if not rejected_paths: + return empty + reason = ( + f"{tool_name} returned no valid matching assets under rejected " + f"collect paths; status={status}, matched={len(matched_assets)}" + ) + return { + "accepted_collects": still_accepted, + "rejected_collects": rejected_paths, + "reason": reason, + } def _reconcile_pending_assets( self, @@ -514,6 +914,109 @@ def _reconcile_pending_assets( decision_reason=f"deferred_reconcile_{doc.source_file_name}", ) + @staticmethod + def _format_asset_context( + tool_name: str, + asset_type: str, + search_query: str, + matched_assets: list[dict[str, Any]], + *, + status: str = "empty", + status_detail: str = "", + ) -> str: + if not matched_assets: + detail = f" Status detail: {status_detail}." if status_detail else "" + return ( + f"=== {tool_name} Results ===\n" + f"No matching {asset_type}s found for \"{search_query}\" " + f"(status={status}).{detail}\n" + f"=== End {tool_name} Results ===" + ) + lines = [ + f"=== {tool_name} Results ===", + f'Found {len(matched_assets)} matching {asset_type}s for "{search_query}".', + "Matched assets are available as asset evidence.", + ] + for i, asset in enumerate(matched_assets): + file_path = asset.get("file_path", "") + lines.append(f" {i + 1}. {file_path}") + owner_paths = _unique_asset_owner_paths(matched_assets) + if owner_paths: + lines.append("Owner sections with matching assets:") + for owner_path in owner_paths: + lines.append(f' - "{owner_path}"') + lines.append( + "Use these asset results and owner sections to decide collect, " + "finish, back, or further navigation." + ) + lines.append(f"=== End {tool_name} Results ===") + return "\n".join(lines) + + def _build_asset_trace_payload( + self, + *, + doc: CandidateDoc, + scope: str | None, + parent_step_index: int, + asset_type: str, + query: str, + candidates: list[dict[str, Any]], + result: dict[str, Any], + budget_before: dict[str, Any] | None = None, + budget_after: dict[str, Any] | None = None, + budget_delta: dict[str, Any] | None = None, + ) -> dict[str, Any]: + matched_assets = result.get("matched_assets") or [] + matched = [ + { + "chunk_id": asset.get("chunk_id", ""), + "file_path": asset.get("file_path", ""), + "section_path": asset.get("owner_section_path") + or asset.get("section_path", ""), + } + for asset in matched_assets + ] + return { + "step_index": len(self._decision_steps), + "agent": "asset_inspector", + "parent_step_index": parent_step_index, + "phase": "asset_inspect", + "document_id": doc.document_id, + "document": doc.source_file_name or "", + "scope": scope or "root", + "observation": { + "asset_type": asset_type, + "query": query, + "candidates": candidates, + }, + "decision": { + "action": "inspect_assets", + "args": {"asset_type": asset_type, "query": query}, + "reason": "judged each candidate against the requested evidence", + }, + "result": { + "status": result.get("status", "empty"), + "status_detail": result.get("status_detail", ""), + "verdicts": result.get("verdicts") or [], + "matched": matched, + "budget_failure": result.get("budget_failure"), + "budget": { + "before": _compact_budget_snapshot(budget_before), + "after": _compact_budget_snapshot(budget_after), + "delta": budget_delta or {}, + }, + }, + "budget": self._state.ledger.snapshot() if self._state.ledger else None, + "elapsed_ms": self._state.elapsed_ms, + } + + def _append_decision_trace_step(self, step: DecisionTraceStep) -> int: + step.step_index = len(self._decision_steps) + self._decision_steps.append(step.to_dict()) + if self._trace_enabled: + self._trace.record_decision_trace_step(step) + return step.step_index + def _record_navigation_step( self, *, @@ -522,45 +1025,80 @@ def _record_navigation_step( step_num: int, nav_result: NavigateStepResult, collected_in_step: list[str], - ) -> None: + asset_summary: dict[str, Any] | None = None, + rejected_collects: list[str] | None = None, + state_delta: dict[str, Any] | None = None, + ) -> int: action = nav_result.action reason = nav_result.reason drill_into = nav_result.drill_into - - if self._trace_enabled: - self._trace.record_step( - "navigate_step", - ToolResult( - status=f"{action.lower()}", - payload={ - "document_id": doc.document_id, - "scope": scope or "root", - "step": step_num, - "action": action, - "reason": reason, - "drill_into": drill_into, - "collected_count": len(collected_in_step), - "collected_paths": collected_in_step, - "asset_tools": nav_result.tools, - "outline_count": len(nav_result.node.outline_items), - }, - ), - decision_reason=f"nav_s{step_num}_{doc.source_file_name}", + doc_name = doc.source_file_name or self._state.doc_id_to_name.get(doc.document_id, "") + collected = [ + { + "path": item.get("path", ""), + "confidence": item.get("confidence", 0.0), + "hydrate_mode": item.get("hydrate_mode", "chunks"), + } + for item in nav_result.collect + ] + decision_args: dict[str, Any] = {} + if drill_into: + decision_args["target"] = drill_into + if action == "BACK": + decision_args["target"] = nav_result.back_to + if nav_result.search_assets_params: + decision_args["query"] = nav_result.search_assets_params.get("query", "") + decision_args["asset_type"] = nav_result.search_assets_params.get("asset_type", "") + projected_scope = scope or "root" + if action == "EXPAND" and drill_into: + projected_scope = drill_into + elif action == "BACK" and nav_result.result_status == "ok": + projected_scope = nav_result.back_to or "root" + + result_payload: dict[str, Any] = { + "status": nav_result.result_status, + "collected": collected, + "new_scope": projected_scope, + "note": nav_result.result_note, + } + if state_delta is not None: + result_payload["state_delta"] = state_delta + if rejected_collects: + result_payload["rejected_collects"] = rejected_collects + if nav_result.error_reason: + result_payload["error"] = nav_result.error_reason + if asset_summary: + result_payload["matched_assets"] = asset_summary.get("matched", 0) + result_payload["tool_status"] = asset_summary.get("status") + result_payload["tool_budget"] = asset_summary.get("budget") + result_payload["sub_agent_assessment"] = asset_summary.get( + "sub_agent_assessment" ) - doc_name = doc.source_file_name or self._state.doc_id_to_name.get(doc.document_id, "") - self._decision_steps.append({ - "phase": "navigate", - "document": doc_name, - "document_id": doc.document_id, - "action": action, - "reason": reason, - "step": step_num, - "drill_into": drill_into, - "collected_paths": collected_in_step, - "collected_count": len(collected_in_step), - }) + trace_step = DecisionTraceStep( + step_index=len(self._decision_steps), + agent="navigator", + phase="navigate", + document=doc_name, + document_id=doc.document_id, + scope=scope or "root", + observation=nav_result.observation, + decision={ + "action": action, + "args": decision_args, + "reason": reason, + }, + result=result_payload, + budget=self._state.ledger.snapshot() if self._state.ledger else None, + elapsed_ms=self._state.elapsed_ms, + ) + step_index = self._append_decision_trace_step(trace_step) + status_tag = ( + f" status={nav_result.result_status}" + if nav_result.result_status != "ok" + else "" + ) scope_log = scope or "root" logger.info( f" agentic step {self._state.step_count}: navigate_step " @@ -570,7 +1108,82 @@ def _record_navigation_step( f"collected={len(collected_in_step)} " f"drill_into={drill_into} " f"outline={len(nav_result.node.outline_items)}" + f"{status_tag}" ) + return step_index + + +def _tool_searched_current_scope( + summary: dict[str, Any], + current_scope: str | None, +) -> bool: + current = PathLedger.normalize(current_scope) or "root" + scope_paths = summary.get("scope_paths") + if not isinstance(scope_paths, list) or not scope_paths: + return current == "root" + searched = [ + PathLedger.normalize(str(path or "")) + for path in scope_paths + if PathLedger.normalize(str(path or "")) + ] + if current == "root": + return not searched + return searched == [current] + + +def _compact_budget_snapshot(snapshot: dict[str, Any] | None) -> dict[str, Any]: + if not isinstance(snapshot, dict): + return {} + planning = snapshot.get("planning") + if not isinstance(planning, dict): + return {} + compact = { + "status": planning.get("status"), + "used_pct": planning.get("used_pct"), + "remaining": planning.get("remaining"), + "capacity": planning.get("capacity"), + "overdraft": planning.get("overdraft", 0), + } + overdraft_events = snapshot.get("overdraft_events") + if isinstance(overdraft_events, list) and overdraft_events: + compact["overdraft_events"] = overdraft_events[-3:] + return compact + + +def _budget_delta( + before: dict[str, Any] | None, + after: dict[str, Any] | None, +) -> dict[str, Any]: + before_planning = ( + before.get("planning") if isinstance(before, dict) else None + ) + after_planning = ( + after.get("planning") if isinstance(after, dict) else None + ) + if not isinstance(before_planning, dict) or not isinstance(after_planning, dict): + return {} + return { + "used": int(after_planning.get("used") or 0) + - int(before_planning.get("used") or 0), + "used_pct": int(after_planning.get("used_pct") or 0) + - int(before_planning.get("used_pct") or 0), + "remaining": int(after_planning.get("remaining") or 0) + - int(before_planning.get("remaining") or 0), + "overdraft": int(after_planning.get("overdraft") or 0) + - int(before_planning.get("overdraft") or 0), + } + + +def _unique_asset_owner_paths(matched_assets: list[dict[str, Any]]) -> list[str]: + seen: set[str] = set() + owner_paths: list[str] = [] + for asset in matched_assets: + owner_path = str(asset.get("owner_section_path") or asset.get("section_path") or "") + if not owner_path or owner_path in seen: + continue + seen.add(owner_path) + owner_paths.append(owner_path) + return owner_paths def _find_target_node(node: DocTreeNode, path: str) -> DocTreeNode: @@ -582,7 +1195,7 @@ def _find_target_node(node: DocTreeNode, path: str) -> DocTreeNode: where a path appears in both ``children`` and ``leaf_content``. """ for child_path, child in node.children.items(): - if path.startswith(child_path + " / "): + if PathLedger.is_ancestor(child_path, path): return _find_target_node(child, path) return node @@ -601,37 +1214,11 @@ def _merge_step_node(root: DocTreeNode, step_node: DocTreeNode) -> None: target.confidence[path] = max(target.confidence.get(path, 0), conf) -def _collect_leaf_paths(node: DocTreeNode) -> set[str]: - """Collect all paths that have been hydrated (leaf_content).""" - paths = set(node.leaf_content.keys()) - for child in node.children.values(): - paths.update(_collect_leaf_paths(child)) - return paths - - -def _build_discovery_exclude_set( - root: DocTreeNode, - collected_paths: list[dict[str, Any]], -) -> set[str]: - """Build exclude set for discovery using collected navigation paths. - - If navigation COLLECT'd a parent path like "五、施工安全保证措施", - all discovery hints under that path should be excluded because - COLLECT already loads all descendants via prefix matching. - """ - # 1. Already-hydrated leaf paths - exclude = _collect_leaf_paths(root) - - # 2. Collected parent paths from navigation COLLECT decisions. - # These haven't been hydrated yet (hydrate runs after discovery), - # but we know COLLECT will load all their descendants. - for item in collected_paths: - path = item.get("path", "") - if path: - exclude.add(path) - - return exclude - +def _collect_rank(item: dict[str, Any]) -> tuple[int, float]: + mode = str(item.get("hydrate_mode") or "chunks") + mode_rank = 0 if mode == "outline" else 1 + confidence = float(item.get("confidence") or 0.0) + return (mode_rank, confidence) def _ensure_child_node(root: DocTreeNode, path: str) -> None: @@ -652,7 +1239,7 @@ def _ensure_child_node(root: DocTreeNode, path: str) -> None: # Walk to find the deepest existing ancestor node target = root for child_path, child in root.children.items(): - if path.startswith(child_path + " / "): + if PathLedger.is_ancestor(child_path, path): target = child break if path == child_path: @@ -683,7 +1270,7 @@ def _build_outline_subtree(node: DocTreeNode) -> None: parent_paths: set[str] = set() for path in all_paths: for other in all_paths: - if other != path and other.startswith(path + " / "): + if other != path and PathLedger.is_ancestor(path, other): parent_paths.add(path) break @@ -695,7 +1282,7 @@ def _build_outline_subtree(node: DocTreeNode) -> None: # will be created when the function recurses into "A / B". parent_paths = { pp for pp in parent_paths - if not any(pp != other and pp.startswith(other + " / ") for other in parent_paths) + if not any(pp != other and PathLedger.is_ancestor(other, pp) for other in parent_paths) } # Track children that already exist (e.g. from collected hydration). @@ -718,7 +1305,7 @@ def _build_outline_subtree(node: DocTreeNode) -> None: to_move = [ cp for cp in list(node.children.keys()) if cp != parent_path - and cp.startswith(parent_path + " / ") + and PathLedger.is_ancestor(parent_path, cp) ] for cp in to_move: parent_node.children[cp] = node.children.pop(cp) @@ -730,7 +1317,7 @@ def _build_outline_subtree(node: DocTreeNode) -> None: item_path = item["path"] best_parent: str | None = None for pp in parent_paths: - if item_path.startswith(pp + " / "): + if PathLedger.is_ancestor(pp, item_path): if best_parent is None or len(pp) > len(best_parent): best_parent = pp if best_parent and best_parent not in pre_existing: @@ -744,4 +1331,3 @@ def _build_outline_subtree(node: DocTreeNode) -> None: node.reparent_leaf_content() for child in node.children.values(): _build_outline_subtree(child) - diff --git a/packages/shared-python/shared/services/retrieval/agentic/navigation/path_ledger.py b/packages/shared-python/shared/services/retrieval/agentic/navigation/path_ledger.py new file mode 100644 index 00000000..bd051eb5 --- /dev/null +++ b/packages/shared-python/shared/services/retrieval/agentic/navigation/path_ledger.py @@ -0,0 +1,62 @@ +"""Path relationship helpers for document navigation state.""" +from __future__ import annotations + +from collections.abc import Iterable + +from shared.services.retrieval.search.lexical_text import normalize_section_path + + +class PathLedger: + """Small, authoritative wrapper for section path relationships.""" + + @staticmethod + def normalize(path: str | None) -> str: + return normalize_section_path(str(path or "").strip()) + + @classmethod + def is_ancestor(cls, ancestor: str | None, descendant: str | None) -> bool: + ancestor_path = cls.normalize(ancestor) + descendant_path = cls.normalize(descendant) + if not ancestor_path or not descendant_path: + return False + return descendant_path.startswith(ancestor_path + " / ") + + @classmethod + def is_same_or_descendant(cls, path: str | None, scope: str | None) -> bool: + candidate = cls.normalize(path) + scope_path = cls.normalize(scope) + if not candidate or not scope_path: + return False + return candidate == scope_path or candidate.startswith(scope_path + " / ") + + @classmethod + def is_covered(cls, path: str | None, covered_paths: Iterable[str]) -> bool: + candidate = cls.normalize(path) + if not candidate: + return False + return any( + candidate == covered + or candidate.startswith(covered + " / ") + for covered in (cls.normalize(item) for item in covered_paths) + if covered + ) + + @classmethod + def back_targets(cls, current_scope: str | None) -> list[str | None]: + scope = cls.normalize(current_scope) + if not scope: + return [] + parts = [part for part in scope.split(" / ") if part] + targets: list[str | None] = [ + " / ".join(parts[:index]) + for index in range(len(parts) - 1, 0, -1) + ] + targets.append(None) + return targets + + @classmethod + def valid_back_target(cls, current_scope: str | None, target: str | None) -> bool: + scope = cls.normalize(current_scope) + if not scope: + return False + return target is None or cls.is_ancestor(target, scope) diff --git a/packages/shared-python/shared/services/retrieval/agentic/navigation/section_counts.py b/packages/shared-python/shared/services/retrieval/agentic/navigation/section_counts.py index 64e2499f..2e7ae237 100644 --- a/packages/shared-python/shared/services/retrieval/agentic/navigation/section_counts.py +++ b/packages/shared-python/shared/services/retrieval/agentic/navigation/section_counts.py @@ -41,7 +41,7 @@ async def attach_section_counts( for item_path, item in items_by_path.items(): if not item["show_summary"]: continue - if chunk_path == item_path or chunk_path.startswith(item_path + " / "): + if _chunk_belongs_to_item(chunk_path, item_path): item["chunk_count"] += text_count item["image_count"] += image_count item["table_count"] += table_count @@ -55,6 +55,14 @@ async def attach_section_counts( sid_to_path=sid_to_path, ) + # Root is a virtual navigation container. Media availability for the + # whole document is exposed through global SEARCH actions, not as Root + # node-local images/tables. + root_item = items_by_path.get("Root") + if root_item: + root_item["image_count"] = 0 + root_item["table_count"] = 0 + async def _load_direct_chunk_counts( db: AsyncSession, @@ -173,6 +181,12 @@ async def _attach_connected_asset_counts( for item_path, item in items_by_path.items(): if not item["show_summary"]: continue - if ref_path == item_path or ref_path.startswith(item_path + " / "): + if _chunk_belongs_to_item(ref_path, item_path): item["image_count"] += referenced_images item["table_count"] += referenced_tables + + +def _chunk_belongs_to_item(chunk_path: str, item_path: str) -> bool: + if item_path == "Root": + return chunk_path == item_path + return chunk_path == item_path or chunk_path.startswith(item_path + " / ") diff --git a/packages/shared-python/shared/services/retrieval/agentic/navigation/section_prompt_projection.py b/packages/shared-python/shared/services/retrieval/agentic/navigation/section_prompt_projection.py index d7130fa3..83cb39fa 100644 --- a/packages/shared-python/shared/services/retrieval/agentic/navigation/section_prompt_projection.py +++ b/packages/shared-python/shared/services/retrieval/agentic/navigation/section_prompt_projection.py @@ -3,135 +3,16 @@ from typing import Any -from shared.utils.text_utils import truncate_content_preview - - -def format_items_for_llm( - items: list[dict], - max_chars: int = 20000, - collected_paths: set[str] | None = None, -) -> tuple[str, bool]: - """Format section items with hierarchy, token estimates, and collection marks.""" - if not items: - return "(no items available)", False - - coll = collected_paths or set() - full_text = "\n".join(_render_item(item, include_summary=True, collected=coll) for item in items) - if len(full_text) <= max_chars: - return full_text, False - - slim_text = "\n".join(_render_item(item, include_summary=False, collected=coll) for item in items) - return slim_text[:max_chars], True - - -def _render_item(item: dict, include_summary: bool, collected: set[str]) -> str: - level = item.get("level", 1) - show_summary = item.get("show_summary", True) - is_leaf = item.get("is_leaf", False) - path = item.get("path", "") - summary = item.get("summary") or "" - - # Check if this path (or an ancestor) is already collected - is_collected = _is_path_collected(path, collected) - collected_tag = "[✓] " if is_collected else "" - - leaf_tag = " [Leaf]" if is_leaf else "" - - # Counts and token estimate - counts_str = "" - token_str = "" - if show_summary: - count_parts: list[str] = [] - chunk_count = item.get("chunk_count", 0) - if chunk_count > 0: - count_parts.append(f"text={chunk_count}") - image_count = item.get("image_count", 0) - if image_count > 0: - count_parts.append(f"image={image_count}") - table_count = item.get("table_count", 0) - if table_count > 0: - count_parts.append(f"table={table_count}") - counts_str = f' [{" ".join(count_parts)}]' if count_parts else "" - - total_chars = item.get("total_chars", 0) - if total_chars > 0: - # Approximate tokens: Chinese ~2 chars/token, English ~4 chars/token - # Use conservative 2 chars/token for mixed content - tokens = total_chars / 2 - if tokens >= 1000: - token_str = f" ~{tokens / 1000:.1f}k tokens" - else: - token_str = f" ~{int(tokens)} tokens" - - indent = " " * (level - 1) - prefix = "▸" if level == 1 else "└" - level_tag = f"[L{level}]" - - lines = [ - f'{indent}{prefix} {collected_tag}{level_tag} path="{path}"{counts_str}{token_str}{leaf_tag}' - ] - - if include_summary and show_summary and summary: - sub_indent = " " * level - display_summary = _enrich_section_covers_summary(summary) - clipped = truncate_content_preview(display_summary, head=80, tail=0) - lines.append(f"{sub_indent}{clipped}") - - return "\n".join(lines) - - -def _enrich_section_covers_summary(summary: str) -> str: - """Inject sub-section count into 'This section covers:' summaries. - - Transforms: - 'This section covers: A, B, C' - into: - 'This section covers 3 sub-sections: A, B, C' - """ - prefix = "This section covers: " - if not summary.startswith(prefix): - return summary - body = summary[len(prefix):] - sub_sections = [s.strip() for s in body.split(", ") if s.strip()] - count = len(sub_sections) - return f"This section covers {count} sub-sections: {body}" - - -def _is_path_collected(path: str, collected: set[str]) -> bool: - """Check if path itself or any ancestor is in the collected set.""" - if path in collected: - return True - for coll_path in collected: - if path.startswith(coll_path + " / "): - return True - return False - - -def format_collection_status( - collected_paths: list[dict[str, Any]], -) -> str: - """Render the collection status block for the navigation prompt.""" - if not collected_paths: - return "" - - lines = [f"=== Collection Status ({len(collected_paths)} items) ==="] - for item in collected_paths: - path = item.get("path", "") - conf = item.get("confidence", 0) - step = item.get("collected_at_step", "?") - outline = item.get("outline", False) - mode_tag = " [outline]" if outline else "" - lines.append(f'✓ "{path}" (step {step}, conf={conf:.1f}{mode_tag})') - lines.append("=== End Collection ===") - return "\n".join(lines) - def format_nav_trace( nav_trace: list[dict[str, Any]], - collected_paths: list[dict[str, Any]], ) -> str: - """Render the unified navigation trace block (includes scope, actions, and collection).""" - if not nav_trace and not collected_paths: + """Render the unified navigation trace block. + + Includes compact navigation history. Current collection state is rendered + separately by the Agent State block. + """ + if not nav_trace: return "" lines = ["=== Navigation Trace ==="] @@ -140,29 +21,46 @@ def format_nav_trace( scope = entry.get("scope", "root") action = entry.get("action", "?") reason = entry.get("reason", "") - action_display = action drill_into = entry.get("drill_into") - if action == "DRILL" and drill_into: - action_display = f'DRILL "{drill_into}"' + if action == "EXPAND" and drill_into: + action_display = f'EXPAND "{drill_into}"' + elif action == "BACK": + back_to = entry.get("back_to") + target = f'"{back_to}"' if back_to else "root" + action_display = f"BACK to {target}" lines.append(f"Step {step}: scope={scope} → {action_display}") + # Show tool usage and results so LLM can avoid repeating searches + tool_results = entry.get("tool_results", {}) + if tool_results: + tool_name = tool_results.get("tool", "") + tool_query = tool_results.get("query", "") + matched = int(tool_results.get("matched") or 0) + tool_status = str(tool_results.get("status") or "") + status = ( + f"found {matched} match(es)" + if matched + else f"no matches ({tool_status})" + if tool_status + else "no matches" + ) + lines.append(f' 🔧 {tool_name}("{tool_query}") → {status}') + # Show what was collected in this step step_collected = entry.get("collected", []) if step_collected: paths_display = ", ".join(f'"{c}"' for c in step_collected) lines.append(f" collected: {paths_display}") + result_status = entry.get("result_status") + if result_status and result_status != "ok": + lines.append(f" result_status: {result_status}") + if reason: lines.append(f" reason: {reason}") lines.append("") - # Append current collection summary - if collected_paths: - total = len(collected_paths) - lines.append(f"[Current] collection: {total} items") - lines.append("Do NOT re-collect paths marked [✓] below.") - lines.append("=== End Trace ===") return "\n".join(lines) diff --git a/packages/shared-python/shared/services/retrieval/agentic/navigation/section_tree.py b/packages/shared-python/shared/services/retrieval/agentic/navigation/section_tree.py index ece877f2..92592d4f 100644 --- a/packages/shared-python/shared/services/retrieval/agentic/navigation/section_tree.py +++ b/packages/shared-python/shared/services/retrieval/agentic/navigation/section_tree.py @@ -10,15 +10,11 @@ from shared.services.retrieval.search.lexical_text import normalize_section_path, split_section_path -async def load_child_sections( +async def load_document_section_rows( db: AsyncSession, document_id: str, job_result_id: str, - scope_path: str | list[str] | None = None, - exclude_paths: set[str] | None = None, - limit_depth: bool = True, -) -> list[dict]: - """Load the continuous context tree for a navigation scope.""" +) -> list: stmt = ( select( DocumentSection.section_id, @@ -31,7 +27,25 @@ async def load_child_sections( .where(DocumentSection.job_result_id == job_result_id) .order_by(DocumentSection.sort_order) ) - section_rows = (await db.execute(stmt)).all() + return list((await db.execute(stmt)).all()) + + +async def load_child_sections( + db: AsyncSession, + document_id: str, + job_result_id: str, + scope_path: str | list[str] | None = None, + exclude_paths: set[str] | None = None, + limit_depth: bool = True, + section_rows: list | None = None, +) -> list[dict]: + """Load the continuous context tree for a navigation scope.""" + if section_rows is None: + section_rows = await load_document_section_rows( + db, + document_id=document_id, + job_result_id=job_result_id, + ) if not section_rows: return [] @@ -215,5 +229,3 @@ def _resolve_allowed_depths(items_by_path: dict[str, dict], scope_list: list[str if child_depths: allowed_set.update(sorted(child_depths)[:2]) return allowed_set - - diff --git a/packages/shared-python/shared/services/retrieval/agentic/navigation/state.py b/packages/shared-python/shared/services/retrieval/agentic/navigation/state.py new file mode 100644 index 00000000..841059e5 --- /dev/null +++ b/packages/shared-python/shared/services/retrieval/agentic/navigation/state.py @@ -0,0 +1,99 @@ +"""Per-document navigation state for the collector runtime.""" +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +from shared.services.retrieval.agentic.navigation.path_ledger import PathLedger + + +@dataclass +class NavigationState: + """Mutable state for one document navigation loop.""" + + document_id: str + document_name: str + job_result_id: str + current_scope: str | None = None + expanded_scopes: set[str] = field(default_factory=set) + rejected_paths: set[str] = field(default_factory=set) + rejected_collect_paths: set[str] = field(default_factory=set) + collected_paths: list[dict[str, Any]] = field(default_factory=list) + nav_trace: list[dict[str, Any]] = field(default_factory=list) + tool_history: list[dict[str, Any]] = field(default_factory=list) + blocked_asset_searches: set[str] = field(default_factory=set) + step_count: int = 0 + + def snapshot_delta( + self, + *, + before_scope: str | None, + expanded_before: set[str], + rejected_before: set[str], + rejected_collect_before: set[str], + collected_before_count: int, + ) -> dict[str, Any]: + return { + "current_scope_before": before_scope or "root", + "current_scope_after": self.current_scope or "root", + "expanded_added": sorted(self.expanded_scopes - expanded_before), + "rejected_added": sorted(self.rejected_paths - rejected_before), + "rejected_collect_added": sorted( + self.rejected_collect_paths - rejected_collect_before + ), + "collected_added": [ + item.get("path", "") + for item in self.collected_paths[collected_before_count:] + if item.get("path") + ], + } + + def add_collected( + self, + item: dict[str, Any], + *, + step: int, + scope_context: str | None, + ) -> dict[str, Any]: + enriched = dict(item) + enriched["collected_at_step"] = step + enriched["scope_context"] = scope_context or "root" + self.collected_paths.append(enriched) + return enriched + + def mark_expanded(self, path: str | None) -> None: + normalized = PathLedger.normalize(path) + if normalized: + self.expanded_scopes.add(normalized) + + def mark_rejected_collect(self, path: str | None) -> None: + normalized = PathLedger.normalize(path) + if normalized: + self.rejected_paths.add(normalized) + self.rejected_collect_paths.add(normalized) + + def mark_rejected_if_unproductive(self, path: str | None) -> None: + normalized = PathLedger.normalize(path) + if not normalized: + return + has_full_collect = any( + item.get("hydrate_mode") != "outline" + and PathLedger.is_same_or_descendant(item.get("path"), normalized) + for item in self.collected_paths + ) + if not has_full_collect: + self.rejected_paths.add(normalized) + + def blocked_asset_types_for_scope(self, scope: str | None) -> set[str]: + prefix = f"{PathLedger.normalize(scope) or 'root'}:" + return { + key.split(":", 1)[1] + for key in self.blocked_asset_searches + if key.startswith(prefix) + } + + def block_asset_search(self, scope: str | None, asset_type: str) -> None: + normalized_scope = PathLedger.normalize(scope) or "root" + normalized_type = asset_type.strip().lower() + if normalized_type: + self.blocked_asset_searches.add(f"{normalized_scope}:{normalized_type}") diff --git a/packages/shared-python/shared/services/retrieval/agentic/navigation/tools.py b/packages/shared-python/shared/services/retrieval/agentic/navigation/tools.py index 53077776..2873a2e2 100644 --- a/packages/shared-python/shared/services/retrieval/agentic/navigation/tools.py +++ b/packages/shared-python/shared/services/retrieval/agentic/navigation/tools.py @@ -1,17 +1,4 @@ -"""Agentic retrieval navigation tools — Collector Agent model. - -Collector Agent architecture -~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Each ``navigate_step`` returns two independent decisions: - -- **collect**: paths the agent adds to its evidence collection. - Collected paths are hydrated with full content after navigation completes. -- **action + drill_into**: navigation direction (DRILL into a section, - BACK to parent, or STOP). - -Asset collection (images/tables) still runs during navigation so LLM -tool requests are honoured, but assets are reconciled after hydration. -""" +"""Agentic retrieval navigation tools — observe-act collector model.""" from __future__ import annotations from typing import Any @@ -20,24 +7,31 @@ from sqlalchemy.ext.asyncio import AsyncSession from shared.services.retrieval.agentic.navigation.assets import ( - build_asset_tools_block, count_assets_under_scope, ) -from shared.services.retrieval.agentic.core.budget import BudgetExceeded +from shared.services.retrieval.agentic.core.budget import ( + BudgetExceeded, + budget_status_from_snapshot, +) from shared.services.retrieval.agentic.prompts import ( COLLECTOR_PROMPT, - format_budget_block, + adjust_budget_snapshot, parse_collector_response, ) from shared.services.retrieval.agentic.navigation.section_prompt_projection import ( - format_items_for_llm, format_nav_trace, ) +from shared.services.retrieval.agentic.navigation.actions import ( + build_legal_actions, + format_actionable_observation, + format_agent_state_block, +) from shared.services.retrieval.agentic.navigation.section_tree import load_child_sections from shared.services.retrieval.agentic.core.types import DocTreeNode, NavigateStepResult from shared.services.retrieval.llm_adapter import LLMFn + async def navigate_step( db: AsyncSession, *, @@ -53,16 +47,17 @@ async def navigate_step( budget_snapshot: dict | None = None, nav_trace: list[dict[str, Any]] | None = None, collected_paths: list[dict[str, Any]] | None = None, + expanded_scopes: set[str] | None = None, + rejected_paths: set[str] | None = None, + rejected_collect_paths: set[str] | None = None, + disabled_asset_types: set[str] | None = None, + discovery_hints: list[dict[str, Any]] | None = None, + section_rows: list | None = None, + query_intent: str = "UNKNOWN", + search_context: str = "", + prior_tool_result: dict[str, Any] | None = None, ) -> NavigateStepResult: - """Navigate one document scope using the Collector Agent model. - - Returns a ``NavigateStepResult`` with: - - ``collect``: paths to add to the evidence collection - - ``action``: DRILL/BACK/STOP - - ``drill``: the single drill target (if action == DRILL) - - ``tools``: asset tool invocations - - ``node``: outline tree node for rendering context - """ + """Navigate one document scope with a single observe-act decision.""" scope_paths = [scope_path] if scope_path else [] try: @@ -72,122 +67,298 @@ async def navigate_step( job_result_id, scope_path, exclude_paths=exclude_paths, + section_rows=section_rows, ) if not items: - return NavigateStepResult.stop(scope_paths[0] if scope_paths else None) + return NavigateStepResult.stop( + scope_paths[0] if scope_paths else None, + reason="No visible sections in the current scope.", + ) - visible_items = { - item["path"]: item for item in items if item.get("show_summary", True) - } - total_images, total_tables = await count_assets_under_scope( - db, - document_id=document_id, - job_result_id=job_result_id, - scope_paths=scope_paths, - ) - tools_block = build_asset_tools_block(total_images, total_tables) + budget_status = budget_status_from_snapshot(budget_snapshot) + if budget_status in {"CRITICAL", "EXHAUSTED"}: + total_images, total_tables = 0, 0 + else: + total_images, total_tables = await count_assets_under_scope( + db, + document_id=document_id, + job_result_id=job_result_id, + scope_paths=scope_paths, + ) - # Build collected path set for [✓] marking on tree - collected_path_set = { - item.get("path", "") for item in (collected_paths or []) - } - items_text, overflowed = format_items_for_llm( - items, - collected_paths=collected_path_set, + expanded_path_set = set(expanded_scopes or _expanded_paths_from_trace(nav_trace or [])) + if scope_path: + expanded_path_set.add(scope_path) + provisional_action_set = build_legal_actions( + items=items, + current_scope=scope_path, + collected_paths=collected_paths or [], + expanded_scopes=expanded_path_set, + discovery_hints=discovery_hints, + rejected_paths=rejected_paths or set(), + rejected_collect_paths=rejected_collect_paths or set(), + total_images=total_images, + total_tables=total_tables, + disabled_asset_types=disabled_asset_types or set(), + budget_snapshot=budget_snapshot, ) + provisional_observation_text, provisional_overflowed = ( + format_actionable_observation( + items=items, + action_set=provisional_action_set, + ) + ) + + trace_block = format_nav_trace(nav_trace or []) - # Build trace block (unified: scope + actions + collection) - trace_block = format_nav_trace( - nav_trace or [], - collected_paths or [], + # Estimate this call's prompt token cost and adjust the budget + # snapshot so the LLM sees post-call budget, not pre-call. + # This prevents the LLM from seeing misleadingly low percentages + # (e.g. 63% when it will actually be 89% after this call). + prompt_tokens_est = ( + len(provisional_observation_text) + + len(trace_block) + + 800 + ) // 2 # rough chars-to-tokens ratio + adjusted_snapshot = adjust_budget_snapshot( + budget_snapshot, prompt_tokens_est, + ) + if ( + budget_status_from_snapshot(adjusted_snapshot) + == budget_status_from_snapshot(budget_snapshot) + ): + action_set = provisional_action_set + actionable_observation = provisional_observation_text + overflowed = provisional_overflowed + else: + action_set = build_legal_actions( + items=items, + current_scope=scope_path, + collected_paths=collected_paths or [], + expanded_scopes=expanded_path_set, + discovery_hints=discovery_hints, + rejected_paths=rejected_paths or set(), + rejected_collect_paths=rejected_collect_paths or set(), + total_images=total_images, + total_tables=total_tables, + disabled_asset_types=disabled_asset_types or set(), + budget_snapshot=adjusted_snapshot, + ) + actionable_observation, overflowed = format_actionable_observation( + items=items, + action_set=action_set, + ) + observation = { + "visible_sections": [ + item.get("path", "") + for item in items + if item.get("path") + ][:50], + "available_images": total_images, + "available_tables": total_tables, + "prior_tool_result": prior_tool_result, + "current_scope": scope_path or "root", + "query_intent": query_intent, + "legal_actions": { + "expand": [item.id for item in action_set.expand], + "collect": [item.id for item in action_set.collect], + "back": [item.id for item in action_set.back], + "search": [item.id for item in action_set.search], + "finish": [action_set.finish.id] if action_set.finish else [], + }, + "rejected_paths": sorted(rejected_paths or set()), + "rejected_collect_paths": sorted(rejected_collect_paths or set()), + } + agent_state_block = format_agent_state_block( + current_scope=scope_path, + query_intent=query_intent, + expanded_scopes=expanded_path_set, + rejected_paths=rejected_paths or set(), + collected_paths=collected_paths or [], + rejected_collect_paths=rejected_collect_paths or set(), + prior_tool_result=prior_tool_result, + search_context=search_context, + budget_snapshot=adjusted_snapshot, ) prompt = COLLECTOR_PROMPT.format( doc_name=doc_name or document_id, doc_id=document_id, - budget_block=format_budget_block(budget_snapshot), + agent_state_block=agent_state_block, trace_block=trace_block, - items_overview=items_text, query=query, - tools_block=tools_block, + actionable_observation=actionable_observation, ) response = await llm_fn(prompt) parsed = parse_collector_response(response) - action = parsed["action"] + requested_action = parsed["action"] selected_tools = parsed["tools"] + tool_params = parsed.get("tool_params", {}) reason = parsed.get("reason", "") raw_collect = parsed.get("collect", []) - drill_into = parsed.get("drill_into") + action_id = parsed.get("action_id") + legal_main = action_set.get(action_id) + action = ( + legal_main.action + if legal_main and legal_main.action != "COLLECT" + else requested_action + ) + if action != requested_action: + reason = ( + f"Action field '{requested_action}' did not match legal action " + f"ID '{action_id}'; executing ID-defined action '{action}'. " + + reason + ).strip()[:500] + selected_tools = [action] if action in ("SEARCH_IMAGES", "SEARCH_TABLES") else [] scope_label = scope_path or "root" logger.info( f" navigate_step scope={scope_label}: " f"action={action} collect={len(raw_collect)} " - f"drill_into={drill_into} tools={selected_tools} " + f"action_id={action_id} tools={selected_tools} " + f"tool_params={tool_params} " f"overflowed={overflowed}" ) node = DocTreeNode(scope_path=scope_paths[0] if scope_paths else None) node.outline_items = [item for item in items if item.get("show_summary", True)] - # Validate collect paths: must be visible and not already collected + # Resolve COLLECT side effects from legal action IDs. valid_collect: list[dict[str, Any]] = [] + invalid_collect: list[str] = [] + existing_collect_modes = _existing_collect_modes(collected_paths or []) for item in raw_collect: - path = item.get("path", "") - if path in visible_items and path not in collected_path_set: + collect_id = item.get("id") + legal_collect = action_set.get(collect_id) + if legal_collect and legal_collect.action == "COLLECT" and legal_collect.path: confidence = item.get("confidence", 0.7) - outline = item.get("outline", False) - node.confidence[path] = confidence + outline = bool(item.get("outline", False)) and ( + budget_status_from_snapshot(adjusted_snapshot) != "CRITICAL" + or query_intent in {"MACRO_SUMMARY", "STRUCTURE_OVERVIEW"} + ) + if ( + legal_main + and legal_main.action == "EXPAND" + and legal_main.path == legal_collect.path + ): + outline = True + hydrate_mode = "outline" if outline else "chunks" + existing_mode = existing_collect_modes.get(legal_collect.path) + if existing_mode == "chunks": + continue + if existing_mode == "outline" and hydrate_mode == "outline": + continue + node.confidence[legal_collect.path] = confidence valid_collect.append({ - "path": path, + "path": legal_collect.path, "confidence": confidence, - "hydrate_mode": "outline" if outline else "chunks", + "hydrate_mode": hydrate_mode, }) + elif collect_id: + invalid_collect.append(str(collect_id)) - # Validate drill target: must be visible, not collected, not a leaf valid_drill: list[dict[str, Any]] = [] - if action == "DRILL" and drill_into: - if drill_into in visible_items and drill_into not in collected_path_set: - drill_item = visible_items[drill_into] - if drill_item.get("is_leaf"): - # Leaf nodes can't be drilled — auto-collect instead - logger.info( - f" navigate_step: drill target '{drill_into}' is a leaf, " - f"auto-collecting instead" - ) - if not any(c["path"] == drill_into for c in valid_collect): - node.confidence[drill_into] = 0.7 - valid_collect.append({ - "path": drill_into, - "confidence": 0.7, - "hydrate_mode": "chunks", - }) - action = "STOP" # no valid drill target - else: - valid_drill.append({ - "path": drill_into, - "confidence": 0.8, - }) + result_status = "ok" + result_note: str | None = None + drill_into: str | None = None + back_to: str | None = None + if requested_action == "ERROR": + result_status = "invalid_response" + result_note = reason or "invalid model response" + elif action == "EXPAND": + if legal_main and legal_main.action == "EXPAND" and legal_main.path: + drill_into = legal_main.path + valid_drill.append({ + "path": drill_into, + "confidence": 0.8, + }) else: - logger.warning( - f" navigate_step: drill target '{drill_into}' invalid " - f"(not visible or already collected), falling back to STOP" - ) - action = "STOP" + result_status = "invalid_action_id" + result_note = f"invalid_expand_id: {action_id}" + elif action == "BACK": + if legal_main and legal_main.action == "BACK": + back_to = legal_main.target_scope + else: + result_status = "invalid_action_id" + result_note = f"invalid_back_id: {action_id}" + elif action in ("SEARCH_IMAGES", "SEARCH_TABLES"): + if legal_main is None or legal_main.action != action: + result_status = "invalid_action_id" + result_note = f"invalid_search_id: {action_id}" + elif action == "FINISH": + if legal_main is None or legal_main.action != "FINISH": + result_status = "invalid_action_id" + result_note = f"invalid_finish_id: {action_id}" + + if invalid_collect and result_status == "ok": + result_status = "invalid_collect" + result_note = "invalid_collect_ids: " + ", ".join(invalid_collect[:5]) + + # Parse tool parameters for SEARCH + search_assets_params: dict[str, Any] | None = None + + if action in ("SEARCH_IMAGES", "SEARCH_TABLES") and result_status == "ok": + asset_type = "image" if action == "SEARCH_IMAGES" else "table" + collected_scope_paths = [ + str(item.get("path") or "") + for item in valid_collect + if item.get("path") + ] + search_assets_params = { + "query": query.strip(), + "asset_type": asset_type, + "scope_paths": collected_scope_paths or scope_paths, + } + elif action in ("SEARCH_IMAGES", "SEARCH_TABLES"): + selected_tools = [] return NavigateStepResult( action=action, collect=valid_collect, drill=valid_drill, + back_to=back_to, tools=selected_tools, node=node, reason=reason, + search_assets_params=search_assets_params, + observation=observation, + result_status=result_status, + result_note=result_note, ) except BudgetExceeded: raise except Exception as exc: logger.error(f" navigate_step failed for doc={document_id}: {exc}") - return NavigateStepResult.stop(scope_paths[0] if scope_paths else None) + return NavigateStepResult.error( + scope_paths[0] if scope_paths else None, + reason=str(exc), + ) + + +def _expanded_paths_from_trace(nav_trace: list[dict[str, Any]]) -> set[str]: + expanded: set[str] = set() + for entry in nav_trace: + if entry.get("action") != "EXPAND": + continue + if entry.get("result_status", "ok") != "ok": + continue + drill_into = entry.get("drill_into") + if isinstance(drill_into, str) and drill_into: + expanded.add(drill_into) + return expanded + +def _existing_collect_modes(collected_paths: list[dict[str, Any]]) -> dict[str, str]: + modes: dict[str, str] = {} + for item in collected_paths: + path = str(item.get("path") or "") + if not path: + continue + hydrate_mode = str(item.get("hydrate_mode") or "chunks") + if hydrate_mode != "outline": + modes[path] = "chunks" + elif modes.get(path) != "chunks": + modes[path] = "outline" + return modes diff --git a/packages/shared-python/shared/services/retrieval/agentic/orchestrator.py b/packages/shared-python/shared/services/retrieval/agentic/orchestrator.py index 59030460..4b30012e 100644 --- a/packages/shared-python/shared/services/retrieval/agentic/orchestrator.py +++ b/packages/shared-python/shared/services/retrieval/agentic/orchestrator.py @@ -5,10 +5,10 @@ Phase 2: Per-document navigation (iterative BFS via navigate_step) Phase 3: Render evidence text for downstream agents -The orchestrator drives navigation via an iterative BFS queue per document, -calling navigate_step at each level. Each navigate_step is a single LLM call -that decides action (NAVIGATE/STOP), asset tools (FIND_IMAGES/FIND_TABLES), -and section selections. STOP terminates the drill-down for that scope. +The orchestrator drives navigation through a per-document observe-act loop. +Each navigate_step is a single LLM call that chooses one main action +(EXPAND/BACK/SEARCH_IMAGES/SEARCH_TABLES/FINISH) plus optional collection +side effects. FINISH explicitly terminates navigation for that document. KNOWHERE does not generate final answers. Downstream agents decide whether the returned evidence is sufficient for their task and may call retrieval again. @@ -39,6 +39,7 @@ AgentRunConfig, AgentState, AgenticResult, + DecisionTraceStep, ) from shared.services.retrieval.llm_adapter import LLMFn from shared.services.retrieval.settings import DEFAULT_TOP_K @@ -164,6 +165,34 @@ async def run( # If no LLM or no docs selected, return discovery rows directly if not state.selected_docs: logger.info('agentic: no documents selected — returning discovery results') + no_docs_trace: list[dict[str, Any]] = [] + kg_select_step = DecisionTraceStep( + step_index=0, + agent='doc_selector', + parent_step_index=None, + phase='kg_select', + document_id=None, + document=None, + scope='corpus', + observation={ + 'query': query, + 'candidate_documents': len(state.doc_id_to_name), + }, + decision={ + 'action': 'select_documents', + 'args': {}, + 'reason': 'No documents selected for navigation', + }, + result={ + 'status': 'empty', + 'collected': [], + }, + budget=state.ledger.snapshot() if state.ledger else None, + elapsed_ms=state.elapsed_ms, + ) + no_docs_trace.append(kg_select_step.to_dict()) + if trace_enabled: + trace.record_decision_trace_step(kg_select_step) discovery_refs = [ { 'chunk_id': r.get('chunk_id', ''), @@ -175,10 +204,40 @@ async def run( else r.get('section_path', '') ), 'file_path': r.get('file_path', ''), + 'job_id': r.get('job_id', ''), } for r in discovery_rows[:top_k] if r.get('chunk_id') ] + terminal_step = DecisionTraceStep( + step_index=len(no_docs_trace), + agent='retrieval_agent', + parent_step_index=None, + phase='terminal', + document_id=None, + document=None, + scope='retrieve_step', + observation={ + 'router_used': 'agentic_discovery_only', + 'referenced_chunks': len(discovery_refs), + 'evidence_chars': 0, + }, + decision={ + 'action': 'complete', + 'args': {}, + 'reason': 'no_documents_selected', + }, + result={ + 'status': 'ok', + 'stop_reason': 'no_documents_selected', + 'failure_reason': '', + }, + budget=state.ledger.snapshot() if state.ledger else None, + elapsed_ms=state.elapsed_ms, + ) + no_docs_trace.append(terminal_step.to_dict()) + if trace_enabled: + trace.record_decision_trace_step(terminal_step) if trace_enabled: await trace.complete( discovery_rows, @@ -190,6 +249,9 @@ async def run( answer_text='', referenced_chunks=discovery_refs, router_used='agentic_discovery_only', + budget_snapshot=state.ledger.snapshot() if state.ledger else None, + stop_reason='no_documents_selected', + decision_trace=no_docs_trace, ) discovery_by_doc: dict[str, list[dict[str, Any]]] = {} @@ -218,21 +280,42 @@ async def run( # Record KG document selection as the first decision trace entry if state.selected_docs: - decision_trace.append({ - 'phase': 'kg_select', - 'action': 'select', - 'reason': f'{len(state.selected_docs)} document(s) selected for navigation', - 'documents': [ - { - 'document': doc.source_file_name, - 'document_id': doc.document_id, - 'confidence': doc.confidence, - 'reason': doc.reason, - 'source': doc.source, - } - for doc in state.selected_docs - ], - }) + kg_select_step = DecisionTraceStep( + step_index=0, + agent='doc_selector', + parent_step_index=None, + phase='kg_select', + document_id=None, + document=None, + scope='corpus', + observation={ + 'query': query, + 'candidate_documents': len(state.doc_id_to_name), + }, + decision={ + 'action': 'select_documents', + 'args': {}, + 'reason': f'{len(state.selected_docs)} document(s) selected for navigation', + }, + result={ + 'status': 'ok', + 'collected': [ + { + 'document': doc.source_file_name, + 'document_id': doc.document_id, + 'confidence': doc.confidence, + 'reason': doc.reason, + 'source': doc.source, + } + for doc in state.selected_docs + ], + }, + budget=state.ledger.snapshot() if state.ledger else None, + elapsed_ms=state.elapsed_ms, + ) + decision_trace.append(kg_select_step.to_dict()) + if trace_enabled: + trace.record_decision_trace_step(kg_select_step) if state.elapsed_ms >= config.latency_budget_ms: stop_reason = 'latency_budget' @@ -251,7 +334,13 @@ async def run( llm_budget=llm_budget, ) await navigation_runner.navigate_selected_documents() - decision_trace.extend(navigation_runner.decision_steps) + decision_trace.extend( + _offset_decision_trace( + navigation_runner.decision_steps, + offset=len(decision_trace), + ) + ) + context_remaining = state.ledger.remaining('context') if state.ledger else config.token_budget_total evidence_text = await _trim_evidence_to_budget( db, @@ -282,8 +371,35 @@ async def run( seen_ref_ids.add(cid) all_refs.append(ref) - - + terminal_step = DecisionTraceStep( + step_index=len(decision_trace), + agent='retrieval_agent', + parent_step_index=None, + phase='terminal', + document_id=None, + document=None, + scope='retrieve_step', + observation={ + 'router_used': router_used, + 'referenced_chunks': len(all_refs), + 'evidence_chars': len(evidence_text), + }, + decision={ + 'action': 'complete', + 'args': {}, + 'reason': stop_reason or failure_reason or 'retrieval_complete', + }, + result={ + 'status': 'error' if failure_reason else 'ok', + 'stop_reason': stop_reason, + 'failure_reason': failure_reason, + }, + budget=state.ledger.snapshot() if state.ledger else None, + elapsed_ms=state.elapsed_ms, + ) + decision_trace.append(terminal_step.to_dict()) + if trace_enabled: + trace.record_decision_trace_step(terminal_step) result = AgenticResult( evidence_text=evidence_text, @@ -312,3 +428,20 @@ async def run( ) return result + + +def _offset_decision_trace( + steps: list[dict[str, Any]], + *, + offset: int, +) -> list[dict[str, Any]]: + adjusted: list[dict[str, Any]] = [] + for step in steps: + copied = dict(step) + old_index = int(copied.get('step_index') or 0) + copied['step_index'] = old_index + offset + parent = copied.get('parent_step_index') + if parent is not None: + copied['parent_step_index'] = int(parent) + offset + adjusted.append(copied) + return adjusted diff --git a/packages/shared-python/shared/services/retrieval/agentic/prompts.py b/packages/shared-python/shared/services/retrieval/agentic/prompts.py index 1d4cbf60..3a7c5b2c 100644 --- a/packages/shared-python/shared/services/retrieval/agentic/prompts.py +++ b/packages/shared-python/shared/services/retrieval/agentic/prompts.py @@ -5,6 +5,8 @@ import re from typing import Any +from shared.services.retrieval.agentic.core.budget import project_budget_snapshot + FILE_SELECT_PROMPT = """\ You are a document routing assistant. @@ -28,150 +30,172 @@ """ -DISCOVERY_SELECT_PROMPT = """\ -You are a document navigation assistant. - -Document: "{doc_name}" - -{budget_block} -After navigating the document's section tree, the following section paths -were additionally discovered via keyword and semantic search. -They may contain relevant evidence not found through hierarchical navigation. - -=== Discovery Candidates === -{items} -=== End Discovery Candidates === - -User query: {query} -Select section paths whose content is needed to answer the query. -If none are relevant, return an EMPTY list []. - -Return ONLY a JSON object: -{{"selections": [{{"path": "...", "confidence": }}, ...]}} -Do not include any explanation. -""" - - COLLECTOR_PROMPT = """\ -You are a document navigation agent. +You are a document navigation agent running an observe-act loop. Document: "{doc_name}" (id: {doc_id}) -{budget_block} -{trace_block} -Below is the document's section tree. -Nodes marked [Leaf] have no further sub-sections. -Nodes marked [✓] are already in your collection — do not re-collect them. -Token estimates (e.g. ~1.2k) show approximate content size. +{agent_state_block} -=== Section Tree === -{items_overview} -=== End Section Tree === +{trace_block} User query: {query} -=== Behavioral Rules === +{actionable_observation} -Each step you make TWO independent decisions: +=== Rules === -1. COLLECT — Add sections to your evidence collection (optional, can be empty). - - COLLECT includes the section AND ALL its descendant content. - - If a node is [Leaf] or has ≤500 tokens, prefer COLLECT over DRILL. - - Do NOT re-collect paths marked [✓]. +Each step chooses exactly ONE main action, plus optional COLLECT side effects. -2. Navigate action — Where to go next (required, choose ONE): - - DRILL — Open one section to see its children in the next step. - Use when a section has >1000 tokens and you need to be selective. - You cannot DRILL into a path you just COLLECTed (already fully included). - - BACK — Return to parent scope to explore other branches. - - STOP — End navigation. Use when you have enough evidence or nothing relevant remains. +Action semantics: + - EXPAND observes a listed section's children in the next step. + - COLLECT adds a listed section and all descendant content to evidence. + - BACK only changes current scope; it does not collect evidence. + - SEARCH_IMAGES and SEARCH_TABLES inspect assets in the current scope. + Use only the listed SEARCH action ID. The asset inspector receives the + user's original query directly. + After a SEARCH result returns matches, use the matched assets and owner + sections to decide whether more owner context is needed; avoid repeating + the same asset search unless the current scope has changed and the prior + result is insufficient for the query. + - FINISH ends navigation for this document. -{tools_block} +COLLECT side effect: + - COLLECT includes the section AND ALL its descendant content. + - Set "outline": true to collect only structure (titles + summaries), + keeping children available for further EXPAND or COLLECT. + - If you COLLECT the same section you EXPAND as the main action, use + "outline": true so the section remains open for child exploration. + - If the advisory query intent is MACRO_SUMMARY or STRUCTURE_OVERVIEW + (document overview, chapter map, high-level summary), prefer outline + collection; outline evidence can be sufficient final evidence. + - If the advisory query intent is FACTUAL_DETAIL, NUMERIC_DETAIL, or + ASSET_LOOKUP, prefer full evidence collection ("outline": false), or + SEARCH_IMAGES/SEARCH_TABLES when visual/table evidence is central. + - If the advisory query intent is UNKNOWN, decide from the user's wording: + broad summaries can use outline, specific facts/numbers/assets need full + evidence. + - FINISH only when the collected evidence is sufficient for the user's query. + The system will not infer missing evidence for you. + - In CRITICAL budget mode, exploration is closed. Prefer the smallest + sufficient COLLECT side effects, then FINISH. + - In EXHAUSTED or overdraft budget mode, do not explore or search again. + Use current observations/tool results to FINISH, or collect only + indispensable visible evidence before FINISH. + - For [Leaf] nodes or small sections, prefer COLLECT over EXPAND. + +=== End Rules === Return ONLY a JSON object: -{{"collect": [{{"path": "...", "confidence": , "outline": false}}, ...], - "action": "DRILL", - "drill_into": "section/path", - "tools": [...], +{{"collect": [{{"id": "C1", "confidence": , "outline": false}}], + "action": "", + "action_args": {{"id": "
"}}, "reason": "..."}} -or -{{"collect": [...], "action": "BACK", "tools": [...], "reason": "..."}} -or -{{"collect": [...], "action": "STOP", "tools": [...], "reason": "..."}} - -Set "outline": true on a collect entry to collect only the section structure -(titles and summaries) without full chunk content. Use for overview/structure queries. Do not include any explanation outside the JSON. IMPORTANT: 1. All agent-generated text (e.g., "reason" and other free-text fields) MUST be written in English. 2. Document content and section paths MUST remain in their original language. +3. Use only action IDs from Actionable Observation. Never invent IDs or write raw section paths as action targets. +4. The action value MUST match the chosen ID group: E*=EXPAND, B*=BACK, S*=SEARCH, F*=FINISH. +5. When Budget mode is CRITICAL or EXHAUSTED, choose the best sufficient COLLECT side effects and then FINISH. """ +QUERY_INTENT_PROMPT = """\ +Classify the user's retrieval query for document navigation. + +Return ONLY a JSON object: {{"intent": "