Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docling_agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
ModelConfig,
OllamaBackend,
OutputConfig,
RAGIteration,
RAGResult,
RAGTask,
RAGTrace,
WriteTask,
create_backend,
load_task,
Expand All @@ -41,7 +44,10 @@
"ModelConfig",
"OllamaBackend",
"OutputConfig",
"RAGIteration",
"RAGResult",
"RAGTask",
"RAGTrace",
"WriteTask",
"create_backend",
"load_task",
Expand Down
39 changes: 28 additions & 11 deletions docling_agent/agent/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
AnswerAttempt,
RAGIteration,
RAGResult,
RAGTrace,
SectionSelection,
)
from docling_agent.logging import log_debug, log_info, log_warning
Expand Down Expand Up @@ -89,30 +90,46 @@ def run(
sources: list[DoclingDocument | Path] = [],
**kwargs,
) -> DoclingDocument:
trace = self.run_with_trace(task, document=document, sources=sources, **kwargs)

answer_doc = DoclingDocument(name="rag_answer")
answer_doc.add_title(text="Answer", parent=answer_doc.body)
answer_doc.add_text(label=DocItemLabel.TEXT, text=trace.final_answer, parent=answer_doc.body)
return answer_doc

def run_with_trace(
self,
task: str,
document: DoclingDocument | None = None,
sources: list[DoclingDocument | Path] = [],
**kwargs,
) -> RAGTrace:
"""Run the RAG loop and return the full reasoning trace.

Same orchestration as run(), but returns a typed RAGTrace exposing the
per-document RAGResult (selections, reasons, convergence) and the merged
final_answer. run() is a thin wrapper around this method.
"""
docs = [s for s in sources if isinstance(s, DoclingDocument)]
if not docs and document is not None:
docs = [document]
if not docs:
raise ValueError("DoclingRAGAgent requires at least one DoclingDocument.")

per_doc_answers: list[str] = []
all_iterations: list[RAGIteration] = []

per_document: list[RAGResult] = []
for doc in docs:
result = self._rag_loop(query=task, doc=doc)
per_doc_answers.append(result.answer)
all_iterations.extend(result.iterations)
per_document.append(result)
log_info(f"RAG loop finished: converged={result.converged}, iterations={len(result.iterations)}")

if len(docs) > 1:
self._rprint(Rule(f"[bold cyan]Merging answers from {len(docs)} documents[/bold cyan]"))

final_answer = self._merge_answers(query=task, answers=per_doc_answers)

answer_doc = DoclingDocument(name="rag_answer")
answer_doc.add_title(text="Answer", parent=answer_doc.body)
answer_doc.add_text(label=DocItemLabel.TEXT, text=final_answer, parent=answer_doc.body)
return answer_doc
final_answer = self._merge_answers(
query=task,
answers=[r.answer for r in per_document],
)
return RAGTrace(query=task, per_document=per_document, final_answer=final_answer)

# ------------------------------------------------------------------
# RAG loop
Expand Down
8 changes: 8 additions & 0 deletions docling_agent/agent/rag_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,11 @@ class RAGResult(BaseModel):
answer: str
iterations: list[RAGIteration]
converged: bool # True if can_answer was reached; False if max_iterations hit


class RAGTrace(BaseModel):
"""Full reasoning trace of a RAG run across one or more documents."""

query: str
per_document: list[RAGResult] # one entry per source DoclingDocument, in input order
final_answer: str # merged answer — identical to what run() wraps into a DoclingDocument
4 changes: 4 additions & 0 deletions docling_agent/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from docling_agent.agent.extractor import DoclingExtractingAgent
from docling_agent.agent.orchestrator import DoclingOrchestratorAgent
from docling_agent.agent.rag import DoclingRAGAgent
from docling_agent.agent.rag_models import RAGIteration, RAGResult, RAGTrace
from docling_agent.agent.writer import DoclingWritingAgent
from docling_agent.backends import (
BaseBackend,
Expand Down Expand Up @@ -44,7 +45,10 @@
"ModelConfig",
"OllamaBackend",
"OutputConfig",
"RAGIteration",
"RAGResult",
"RAGTask",
"RAGTrace",
"WriteTask",
"create_backend",
"load_task",
Expand Down
117 changes: 117 additions & 0 deletions tests/test_rag_trace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""Tests for DoclingRAGAgent.run_with_trace and the run() wrapper equivalence."""

import pytest
from docling_core.types.doc.document import DoclingDocument

from docling_agent import RAGIteration, RAGResult, RAGTrace
from docling_agent.agent.rag import DoclingRAGAgent

from .test_utils import MockBackend


def _make_doc(name: str) -> DoclingDocument:
doc = DoclingDocument(name=name)
doc.add_title(text=f"Title of {name}", parent=doc.body)
doc.add_text(label="text", text=f"Body content of {name}.", parent=doc.body)
return doc


def _make_result(answer: str) -> RAGResult:
return RAGResult(
answer=answer,
iterations=[
RAGIteration(
iteration=1,
section_ref="#/texts/0",
reason="picked first section",
section_text_length=42,
can_answer=True,
response=answer,
)
],
converged=True,
)


@pytest.fixture
def rag_agent() -> DoclingRAGAgent:
return DoclingRAGAgent(backend=MockBackend(), tools=[])


def test_run_with_trace_single_doc_returns_typed_trace(monkeypatch, rag_agent):
doc = _make_doc("doc_a")
monkeypatch.setattr(
DoclingRAGAgent,
"_rag_loop",
lambda self, *, query, doc: _make_result("answer for " + doc.name),
)

trace = rag_agent.run_with_trace("what is X?", sources=[doc])

assert isinstance(trace, RAGTrace)
assert trace.query == "what is X?"
assert len(trace.per_document) == 1
assert trace.per_document[0].answer == "answer for doc_a"
assert trace.per_document[0].iterations # non-empty
assert trace.per_document[0].converged is True
# Single-doc path: _merge_answers returns the lone answer unchanged
assert trace.final_answer == "answer for doc_a"


def test_run_with_trace_multi_doc_preserves_order_and_merges(monkeypatch, rag_agent):
doc_a = _make_doc("doc_a")
doc_b = _make_doc("doc_b")

monkeypatch.setattr(
DoclingRAGAgent,
"_rag_loop",
lambda self, *, query, doc: _make_result("answer for " + doc.name),
)
monkeypatch.setattr(
DoclingRAGAgent,
"_merge_answers",
lambda self, *, query, answers: "MERGED:" + "|".join(answers),
)

trace = rag_agent.run_with_trace("q", sources=[doc_a, doc_b])

assert len(trace.per_document) == 2
assert trace.per_document[0].answer == "answer for doc_a"
assert trace.per_document[1].answer == "answer for doc_b"
assert trace.final_answer == "MERGED:answer for doc_a|answer for doc_b"


def test_run_wraps_final_answer_from_trace(monkeypatch, rag_agent):
doc = _make_doc("doc_a")
monkeypatch.setattr(
DoclingRAGAgent,
"_rag_loop",
lambda self, *, query, doc: _make_result("the answer"),
)

answer_doc = rag_agent.run("q", sources=[doc])
trace = rag_agent.run_with_trace("q", sources=[doc])

# run() must wrap the same final_answer that run_with_trace exposes
assert answer_doc.name == "rag_answer"
body_texts = [item.text for item, _ in answer_doc.iterate_items() if hasattr(item, "text")]
assert trace.final_answer in body_texts


def test_run_with_trace_accepts_legacy_document_kwarg(monkeypatch, rag_agent):
doc = _make_doc("legacy")
monkeypatch.setattr(
DoclingRAGAgent,
"_rag_loop",
lambda self, *, query, doc: _make_result("legacy answer"),
)

trace = rag_agent.run_with_trace("q", document=doc)

assert len(trace.per_document) == 1
assert trace.final_answer == "legacy answer"


def test_run_with_trace_raises_on_empty_sources(rag_agent):
with pytest.raises(ValueError, match="at least one DoclingDocument"):
rag_agent.run_with_trace("q")
Loading