Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions agent_cli/rag/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ def _convert_messages(

for m in history_msgs:
if m.role == "system":
pyd_messages.append(ModelRequest(parts=[SystemPromptPart(content=m.content)]))
if m.content: # Skip empty system messages (rejected by some providers)
pyd_messages.append(ModelRequest(parts=[SystemPromptPart(content=m.content)]))
elif m.role == "user":
pyd_messages.append(ModelRequest(parts=[UserPromptPart(content=m.content)]))
elif m.role == "assistant":
Expand Down Expand Up @@ -228,13 +229,11 @@ def read_full_document(file_path: str) -> str:
# - If CLI flag `enable_rag_tools` is False, tools are disabled globally.
# - If CLI flag is True, check request.rag_enable_tools (default True).
tools_allowed = enable_rag_tools and (request.rag_enable_tools is not False)

system_prompt: str | tuple[()] = () # No system prompt by default
if retrieval and retrieval.context:
truncated = truncate_context(retrieval.context)
template = RAG_PROMPT_WITH_TOOLS if tools_allowed else RAG_PROMPT_NO_TOOLS
system_prompt = template.format(context=truncated)
else:
system_prompt = ""

# 4. Setup Agent
from pydantic_ai import Agent # noqa: PLC0415
Expand Down
39 changes: 36 additions & 3 deletions tests/rag/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest

from agent_cli.rag import engine
from agent_cli.rag.engine import _is_path_safe, truncate_context
from agent_cli.rag.engine import _convert_messages, _is_path_safe, truncate_context
from agent_cli.rag.models import ChatRequest, Message


Expand Down Expand Up @@ -82,6 +82,35 @@ def test_is_path_safe_symlink_escape(tmp_path: Path) -> None:
pass


def test_convert_messages_skips_empty_system() -> None:
"""Test that empty system messages are filtered out."""
messages = [
Message(role="system", content=""),
Message(role="user", content="Hello"),
Message(role="user", content="Question"),
]

history, user_prompt = _convert_messages(messages)

# Empty system message should be skipped
assert user_prompt == "Question"
# Only the first user message should be in history (system was skipped)
assert len(history) == 1


def test_convert_messages_keeps_nonempty_system() -> None:
"""Test that non-empty system messages are preserved."""
messages = [
Message(role="system", content="You are helpful."),
Message(role="user", content="Question"),
]

history, user_prompt = _convert_messages(messages)

assert user_prompt == "Question"
assert len(history) == 1


def test_retrieve_context_direct() -> None:
"""Test direct usage of _retrieve_context without async/HTTP."""
mock_collection = MagicMock()
Expand Down Expand Up @@ -132,6 +161,7 @@ async def test_process_chat_request_no_rag(tmp_path: Path) -> None:
with (
patch("pydantic_ai.Agent.run", new_callable=AsyncMock) as mock_run,
patch("agent_cli.rag.engine.search_context") as mock_search,
patch("pydantic_ai.Agent.__init__", return_value=None) as mock_agent_init,
):
mock_run.return_value = mock_run_result
# Mock retrieval to return empty
Expand All @@ -151,11 +181,14 @@ async def test_process_chat_request_no_rag(tmp_path: Path) -> None:
)

assert resp["choices"][0]["message"]["content"] == "Response"
# Should check if search was called
mock_search.assert_called_once()
# Verify Agent.run was called
mock_run.assert_called_once()

# Verify system_prompt is an empty tuple, not an empty string
# (empty strings cause errors on providers like Vertex AI)
init_kwargs = mock_agent_init.call_args
assert init_kwargs.kwargs["system_prompt"] == ()


@pytest.mark.asyncio
async def test_process_chat_request_with_rag(tmp_path: Path) -> None:
Expand Down
Loading