From 32622b4be90bd4247ee602e607ff316be76738ec Mon Sep 17 00:00:00 2001 From: Pier-Jean Malandrino Date: Wed, 10 Jun 2026 09:45:45 +0200 Subject: [PATCH] feat(rag): expose reasoning trace via public run_with_trace() Add RAGTrace model (query, per_document, final_answer) and a public run_with_trace() method on DoclingRAGAgent returning the full per-iteration trace that was previously discarded inside run(). run() becomes a thin wrapper around run_with_trace(): single source of truth for the loop, no behavior change for existing callers. RAGTrace, RAGResult and RAGIteration are re-exported at package level so downstream consumers can type-hint against the trace. Closes #26 Signed-off-by: Pier-Jean Malandrino --- docling_agent/__init__.py | 6 ++ docling_agent/agent/rag.py | 39 +++++++--- docling_agent/agent/rag_models.py | 8 ++ docling_agent/agents.py | 4 + tests/test_rag_trace.py | 117 ++++++++++++++++++++++++++++++ 5 files changed, 163 insertions(+), 11 deletions(-) create mode 100644 tests/test_rag_trace.py diff --git a/docling_agent/__init__.py b/docling_agent/__init__.py index cad8b02..d4a0c44 100644 --- a/docling_agent/__init__.py +++ b/docling_agent/__init__.py @@ -16,7 +16,10 @@ ModelConfig, OllamaBackend, OutputConfig, + RAGIteration, + RAGResult, RAGTask, + RAGTrace, WriteTask, create_backend, load_task, @@ -41,7 +44,10 @@ "ModelConfig", "OllamaBackend", "OutputConfig", + "RAGIteration", + "RAGResult", "RAGTask", + "RAGTrace", "WriteTask", "create_backend", "load_task", diff --git a/docling_agent/agent/rag.py b/docling_agent/agent/rag.py index 8df95a2..d27618d 100644 --- a/docling_agent/agent/rag.py +++ b/docling_agent/agent/rag.py @@ -35,6 +35,7 @@ AnswerAttempt, RAGIteration, RAGResult, + RAGTrace, SectionSelection, ) from docling_agent.logging import log_debug, log_info, log_warning @@ -89,30 +90,46 @@ def run( sources: list[DoclingDocument | Path] = [], **kwargs, ) -> DoclingDocument: + trace = self.run_with_trace(task, document=document, sources=sources, **kwargs) + + answer_doc = DoclingDocument(name="rag_answer") + answer_doc.add_title(text="Answer", parent=answer_doc.body) + answer_doc.add_text(label=DocItemLabel.TEXT, text=trace.final_answer, parent=answer_doc.body) + return answer_doc + + def run_with_trace( + self, + task: str, + document: DoclingDocument | None = None, + sources: list[DoclingDocument | Path] = [], + **kwargs, + ) -> RAGTrace: + """Run the RAG loop and return the full reasoning trace. + + Same orchestration as run(), but returns a typed RAGTrace exposing the + per-document RAGResult (selections, reasons, convergence) and the merged + final_answer. run() is a thin wrapper around this method. + """ docs = [s for s in sources if isinstance(s, DoclingDocument)] if not docs and document is not None: docs = [document] if not docs: raise ValueError("DoclingRAGAgent requires at least one DoclingDocument.") - per_doc_answers: list[str] = [] - all_iterations: list[RAGIteration] = [] - + per_document: list[RAGResult] = [] for doc in docs: result = self._rag_loop(query=task, doc=doc) - per_doc_answers.append(result.answer) - all_iterations.extend(result.iterations) + per_document.append(result) log_info(f"RAG loop finished: converged={result.converged}, iterations={len(result.iterations)}") if len(docs) > 1: self._rprint(Rule(f"[bold cyan]Merging answers from {len(docs)} documents[/bold cyan]")) - final_answer = self._merge_answers(query=task, answers=per_doc_answers) - - answer_doc = DoclingDocument(name="rag_answer") - answer_doc.add_title(text="Answer", parent=answer_doc.body) - answer_doc.add_text(label=DocItemLabel.TEXT, text=final_answer, parent=answer_doc.body) - return answer_doc + final_answer = self._merge_answers( + query=task, + answers=[r.answer for r in per_document], + ) + return RAGTrace(query=task, per_document=per_document, final_answer=final_answer) # ------------------------------------------------------------------ # RAG loop diff --git a/docling_agent/agent/rag_models.py b/docling_agent/agent/rag_models.py index f82d57e..c884299 100644 --- a/docling_agent/agent/rag_models.py +++ b/docling_agent/agent/rag_models.py @@ -34,3 +34,11 @@ class RAGResult(BaseModel): answer: str iterations: list[RAGIteration] converged: bool # True if can_answer was reached; False if max_iterations hit + + +class RAGTrace(BaseModel): + """Full reasoning trace of a RAG run across one or more documents.""" + + query: str + per_document: list[RAGResult] # one entry per source DoclingDocument, in input order + final_answer: str # merged answer — identical to what run() wraps into a DoclingDocument diff --git a/docling_agent/agents.py b/docling_agent/agents.py index 4ae6ef7..2cd83b8 100644 --- a/docling_agent/agents.py +++ b/docling_agent/agents.py @@ -4,6 +4,7 @@ from docling_agent.agent.extractor import DoclingExtractingAgent from docling_agent.agent.orchestrator import DoclingOrchestratorAgent from docling_agent.agent.rag import DoclingRAGAgent +from docling_agent.agent.rag_models import RAGIteration, RAGResult, RAGTrace from docling_agent.agent.writer import DoclingWritingAgent from docling_agent.backends import ( BaseBackend, @@ -44,7 +45,10 @@ "ModelConfig", "OllamaBackend", "OutputConfig", + "RAGIteration", + "RAGResult", "RAGTask", + "RAGTrace", "WriteTask", "create_backend", "load_task", diff --git a/tests/test_rag_trace.py b/tests/test_rag_trace.py new file mode 100644 index 0000000..9438a8b --- /dev/null +++ b/tests/test_rag_trace.py @@ -0,0 +1,117 @@ +"""Tests for DoclingRAGAgent.run_with_trace and the run() wrapper equivalence.""" + +import pytest +from docling_core.types.doc.document import DoclingDocument + +from docling_agent import RAGIteration, RAGResult, RAGTrace +from docling_agent.agent.rag import DoclingRAGAgent + +from .test_utils import MockBackend + + +def _make_doc(name: str) -> DoclingDocument: + doc = DoclingDocument(name=name) + doc.add_title(text=f"Title of {name}", parent=doc.body) + doc.add_text(label="text", text=f"Body content of {name}.", parent=doc.body) + return doc + + +def _make_result(answer: str) -> RAGResult: + return RAGResult( + answer=answer, + iterations=[ + RAGIteration( + iteration=1, + section_ref="#/texts/0", + reason="picked first section", + section_text_length=42, + can_answer=True, + response=answer, + ) + ], + converged=True, + ) + + +@pytest.fixture +def rag_agent() -> DoclingRAGAgent: + return DoclingRAGAgent(backend=MockBackend(), tools=[]) + + +def test_run_with_trace_single_doc_returns_typed_trace(monkeypatch, rag_agent): + doc = _make_doc("doc_a") + monkeypatch.setattr( + DoclingRAGAgent, + "_rag_loop", + lambda self, *, query, doc: _make_result("answer for " + doc.name), + ) + + trace = rag_agent.run_with_trace("what is X?", sources=[doc]) + + assert isinstance(trace, RAGTrace) + assert trace.query == "what is X?" + assert len(trace.per_document) == 1 + assert trace.per_document[0].answer == "answer for doc_a" + assert trace.per_document[0].iterations # non-empty + assert trace.per_document[0].converged is True + # Single-doc path: _merge_answers returns the lone answer unchanged + assert trace.final_answer == "answer for doc_a" + + +def test_run_with_trace_multi_doc_preserves_order_and_merges(monkeypatch, rag_agent): + doc_a = _make_doc("doc_a") + doc_b = _make_doc("doc_b") + + monkeypatch.setattr( + DoclingRAGAgent, + "_rag_loop", + lambda self, *, query, doc: _make_result("answer for " + doc.name), + ) + monkeypatch.setattr( + DoclingRAGAgent, + "_merge_answers", + lambda self, *, query, answers: "MERGED:" + "|".join(answers), + ) + + trace = rag_agent.run_with_trace("q", sources=[doc_a, doc_b]) + + assert len(trace.per_document) == 2 + assert trace.per_document[0].answer == "answer for doc_a" + assert trace.per_document[1].answer == "answer for doc_b" + assert trace.final_answer == "MERGED:answer for doc_a|answer for doc_b" + + +def test_run_wraps_final_answer_from_trace(monkeypatch, rag_agent): + doc = _make_doc("doc_a") + monkeypatch.setattr( + DoclingRAGAgent, + "_rag_loop", + lambda self, *, query, doc: _make_result("the answer"), + ) + + answer_doc = rag_agent.run("q", sources=[doc]) + trace = rag_agent.run_with_trace("q", sources=[doc]) + + # run() must wrap the same final_answer that run_with_trace exposes + assert answer_doc.name == "rag_answer" + body_texts = [item.text for item, _ in answer_doc.iterate_items() if hasattr(item, "text")] + assert trace.final_answer in body_texts + + +def test_run_with_trace_accepts_legacy_document_kwarg(monkeypatch, rag_agent): + doc = _make_doc("legacy") + monkeypatch.setattr( + DoclingRAGAgent, + "_rag_loop", + lambda self, *, query, doc: _make_result("legacy answer"), + ) + + trace = rag_agent.run_with_trace("q", document=doc) + + assert len(trace.per_document) == 1 + assert trace.final_answer == "legacy answer" + + +def test_run_with_trace_raises_on_empty_sources(rag_agent): + with pytest.raises(ValueError, match="at least one DoclingDocument"): + rag_agent.run_with_trace("q")