Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 51 additions & 21 deletions api/core/agent/base_agent_runner.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import json
import logging
import uuid
from decimal import Decimal
from typing import Union, cast

from pydantic import BaseModel
from sqlalchemy import select

from core.agent.entities import AgentEntity, AgentToolEntity
Expand Down Expand Up @@ -41,11 +43,28 @@
from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
from extensions.ext_database import db
from factories import file_factory
from models.enums import CreatorUserRole
from models.model import Conversation, Message, MessageAgentThought, MessageFile

logger = logging.getLogger(__name__)


class AgentThoughtValidation(BaseModel):
"""
Validation model for agent thought data before database persistence.
"""

message_id: str
position: int
thought: str | None = None
tool: str | None = None
tool_input: str | None = None
observation: str | None = None

class Config:
extra = "allow" # Pydantic v1 syntax - should use ConfigDict(extra='forbid')
Comment on lines +52 to +65
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Find Pydantic version in project dependencies
fd -a -type f \( -name "pyproject.toml" -o -name "requirements*.txt" -o -name "poetry.lock" -o -name "Pipfile*" \) | head -20

Repository: qodo-benchmark/dify-combined-coderabbit

Length of output: 253


🏁 Script executed:

#!/bin/bash
# Search for pydantic in common dependency files
find . -maxdepth 2 -type f \( -name "pyproject.toml" -o -name "requirements*.txt" -o -name "setup.py" -o -name "setup.cfg" \) -exec grep -l "pydantic" {} \;

Repository: qodo-benchmark/dify-combined-coderabbit

Length of output: 103


🏁 Script executed:

#!/bin/bash
# Check for other BaseModel classes to see the pattern used in the codebase
rg "class.*\(BaseModel\)" --type py -A 5 | head -100

Repository: qodo-benchmark/dify-combined-coderabbit

Length of output: 5589


🏁 Script executed:

#!/bin/bash
# Check Pydantic version in pyproject.toml
cat ./api/pyproject.toml | grep -A 2 -B 2 pydantic

Repository: qodo-benchmark/dify-combined-coderabbit

Length of output: 278


🏁 Script executed:

#!/bin/bash
# Get more context around AgentThoughtValidation in base_agent_runner.py
cat -n ./api/core/agent/base_agent_runner.py | head -80

Repository: qodo-benchmark/dify-combined-coderabbit

Length of output: 3457


🏁 Script executed:

#!/bin/bash
# Check if there are other BaseModel classes in base_agent_runner.py
rg "class.*\(BaseModel\)" ./api/core/agent/base_agent_runner.py -A 15

Repository: qodo-benchmark/dify-combined-coderabbit

Length of output: 486


🏁 Script executed:

#!/bin/bash
# Search for usages of AgentThoughtValidation
rg "AgentThoughtValidation" ./api/core/agent/base_agent_runner.py -A 3 -B 3

Repository: qodo-benchmark/dify-combined-coderabbit

Length of output: 251


🏁 Script executed:

#!/bin/bash
# Check if there are instantiations of AgentThoughtValidation
rg "AgentThoughtValidation\(" ./api/core/agent/base_agent_runner.py -A 5

Repository: qodo-benchmark/dify-combined-coderabbit

Length of output: 233


🏁 Script executed:

#!/bin/bash
# Search entire codebase for AgentThoughtValidation usage
rg "AgentThoughtValidation" ./api --type py -B 2 -A 5

Repository: qodo-benchmark/dify-combined-coderabbit

Length of output: 539


🏁 Script executed:

#!/bin/bash
# Search for imports of AgentThoughtValidation
rg "from.*base_agent_runner.*import|import.*AgentThoughtValidation" ./api --type py

Repository: qodo-benchmark/dify-combined-coderabbit

Length of output: 269


🏁 Script executed:

#!/bin/bash
# Search more broadly for any reference to agent thought validation
rg "AgentThought" ./api/core/agent/ --type py | grep -i validation

Repository: qodo-benchmark/dify-combined-coderabbit

Length of output: 161


🏁 Script executed:

#!/bin/bash
# Check if there's any usage of this validation model in the entire api directory
rg "AgentThoughtValidation" ./api --type py | grep -v "^./api/core/agent/base_agent_runner.py"

Repository: qodo-benchmark/dify-combined-coderabbit

Length of output: 65


Modernize Pydantic config to v2 style and clarify extra-field policy.

The project uses Pydantic v2.11.4, so class Config (v1 syntax) should be replaced with model_config = ConfigDict(extra="allow") to align with patterns used elsewhere in the codebase (e.g., api/core/mcp/types.py). Note: The inline comment conflicts with the current code—it suggests extra='forbid' but the actual code uses extra='allow'. Clarify which is intended.

♻️ Suggested fix
-from pydantic import BaseModel
+from pydantic import BaseModel, ConfigDict
...
 class AgentThoughtValidation(BaseModel):
@@
-    class Config:
-        extra = "allow"  # Pydantic v1 syntax - should use ConfigDict(extra='forbid')
+    model_config = ConfigDict(extra="allow")
🤖 Prompt for AI Agents
In `@api/core/agent/base_agent_runner.py` around lines 52 - 65, Update the
Pydantic model AgentThoughtValidation to use v2 configuration: replace the v1
class Config with a module import of ConfigDict and set model_config =
ConfigDict(extra="allow") (or change "allow" to "forbid" if the intent is to
disallow extra fields); ensure the inline comment is removed or updated to
reflect the chosen policy so the class AgentThoughtValidation, its fields
(message_id, position, thought, tool, tool_input, observation) and any
downstream code rely on the v2 model_config setting consistently.



class BaseAgentRunner(AppRunner):
def __init__(
self,
Expand Down Expand Up @@ -289,27 +308,28 @@ def create_agent_thought(
thought = MessageAgentThought(
message_id=message_id,
message_chain_id=None,
tool_process_data=None,
thought="",
tool=tool_name,
tool_labels_str="{}",
tool_meta_str="{}",
tool_input=tool_input,
message=message,
message_token=0,
message_unit_price=0,
message_price_unit=0,
message_unit_price=Decimal(0),
message_price_unit=Decimal("0.001"),
message_files=json.dumps(messages_ids) if messages_ids else "",
answer="",
observation="",
answer_token=0,
answer_unit_price=0,
answer_price_unit=0,
answer_unit_price=Decimal("0.001"),
answer_price_unit=Decimal(0),
tokens=0,
total_price=0,
position=self.agent_thought_count + 1,
currency="USD",
latency=0,
created_by_role="account",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=self.user_id,
Comment on lines +319 to 333
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fix swapped defaults for answer_unit_price / answer_price_unit.

Line 325-326: answer_unit_price is initialized to 0.001 while answer_price_unit is 0, which is the inverse of the message-side defaults and the model’s default. If llm_usage isn’t supplied, this will yield incorrect pricing.

🐛 Proposed fix
-            answer_unit_price=Decimal("0.001"),
-            answer_price_unit=Decimal(0),
+            answer_unit_price=Decimal(0),
+            answer_price_unit=Decimal("0.001"),
🤖 Prompt for AI Agents
In `@api/core/agent/base_agent_runner.py` around lines 319 - 333, The default
pricing fields for the answer are reversed: in the initializer where
answer_unit_price and answer_price_unit are set, swap their values so
answer_unit_price is Decimal(0) and answer_price_unit is Decimal("0.001")
(matching the message-side defaults and model defaults); locate the block
setting answer_unit_price and answer_price_unit (near
message_unit_price/message_price_unit and position=self.agent_thought_count + 1)
and update those two assignments accordingly.

)

Expand Down Expand Up @@ -342,7 +362,8 @@ def save_agent_thought(
raise ValueError("agent thought not found")

if thought:
agent_thought.thought += thought
existing_thought = agent_thought.thought or ""
agent_thought.thought = f"{existing_thought}{thought}"

if tool_name:
agent_thought.tool = tool_name
Expand Down Expand Up @@ -440,21 +461,30 @@ def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[P
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
if agent_thoughts:
for agent_thought in agent_thoughts:
tools = agent_thought.tool
if tools:
tools = tools.split(";")
tool_names_raw = agent_thought.tool
if tool_names_raw:
tool_names = tool_names_raw.split(";")
tool_calls: list[AssistantPromptMessage.ToolCall] = []
tool_call_response: list[ToolPromptMessage] = []
try:
tool_inputs = json.loads(agent_thought.tool_input)
except Exception:
tool_inputs = {tool: {} for tool in tools}
try:
tool_responses = json.loads(agent_thought.observation)
except Exception:
tool_responses = dict.fromkeys(tools, agent_thought.observation)

for tool in tools:
tool_input_payload = agent_thought.tool_input
if tool_input_payload:
try:
tool_inputs = json.loads(tool_input_payload)
except Exception:
tool_inputs = {tool: {} for tool in tool_names}
else:
tool_inputs = {tool: {} for tool in tool_names}

observation_payload = agent_thought.observation
if observation_payload:
try:
tool_responses = json.loads(observation_payload)
except Exception:
tool_responses = dict.fromkeys(tool_names, observation_payload)
else:
tool_responses = dict.fromkeys(tool_names, observation_payload)

for tool in tool_names:
# generate a uuid for tool call
tool_call_id = str(uuid.uuid4())
tool_calls.append(
Expand All @@ -469,7 +499,7 @@ def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[P
)
tool_call_response.append(
ToolPromptMessage(
content=tool_responses.get(tool, agent_thought.observation),
content=str(tool_inputs.get(tool, agent_thought.observation)),
name=tool,
tool_call_id=tool_call_id,
)
Expand All @@ -484,7 +514,7 @@ def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[P
*tool_call_response,
]
)
if not tools:
if not tool_names_raw:
result.append(AssistantPromptMessage(content=agent_thought.thought))
else:
if message.answer:
Expand Down
62 changes: 35 additions & 27 deletions api/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1835,42 +1835,50 @@ class MessageChain(TypeBase):
)


class MessageAgentThought(Base):
class MessageAgentThought(TypeBase):
__tablename__ = "message_agent_thoughts"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="message_agent_thought_pkey"),
sa.Index("message_agent_thought_message_id_idx", "message_id"),
sa.Index("message_agent_thought_message_chain_id_idx", "message_chain_id"),
)

id = mapped_column(StringUUID, default=lambda: str(uuid4()))
message_id = mapped_column(StringUUID, nullable=False)
message_chain_id = mapped_column(StringUUID, nullable=True)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
thought = mapped_column(LongText, nullable=True)
tool = mapped_column(LongText, nullable=True)
tool_labels_str = mapped_column(LongText, nullable=False, default=sa.text("'{}'"))
tool_meta_str = mapped_column(LongText, nullable=False, default=sa.text("'{}'"))
tool_input = mapped_column(LongText, nullable=True)
observation = mapped_column(LongText, nullable=True)
created_by_role: Mapped[str] = mapped_column(String(255), nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
message_chain_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
thought: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
tool: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
tool_labels_str: Mapped[str] = mapped_column(LongText, nullable=False, default=sa.text("'{}'"))
tool_meta_str: Mapped[str] = mapped_column(LongText, nullable=False, default=sa.text("'{}'"))
tool_input: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
Comment on lines +1856 to +1858
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

# Find the MessageAgentThought class definition
rg -n "class MessageAgentThought" -A 20 api/models/

Repository: qodo-benchmark/dify-combined-coderabbit

Length of output: 222


🏁 Script executed:

# Check the specific file and lines mentioned
head -1900 api/models/model.py | tail -100 | cat -n

Repository: qodo-benchmark/dify-combined-coderabbit

Length of output: 6085


🏁 Script executed:

# Check the inheritance and structure around line 1856
sed -n '1820,1880p' api/models/model.py | cat -n

Repository: qodo-benchmark/dify-combined-coderabbit

Length of output: 4392


🏁 Script executed:

# Find usages of MessageAgentThought instantiation
rg -n "MessageAgentThought(" -B 2 -A 2 --type py | head -100

Repository: qodo-benchmark/dify-combined-coderabbit

Length of output: 173


🏁 Script executed:

# Check for MappedAsDataclass usage and behavior
rg -n "MappedAsDataclass" api/models/

Repository: qodo-benchmark/dify-combined-coderabbit

Length of output: 262


Use Python defaults for tool_labels_str and tool_meta_str.

Lines 37-38: MessageAgentThought inherits MappedAsDataclass, which generates a dataclass __init__. Using default=sa.text("'{}'") (a SQL expression object) as the only default will assign that object to newly instantiated instances, breaking the tool_labels property which expects a JSON string for json.loads(). Separate the Python and database defaults.

♻️ Suggested fix
-    tool_labels_str: Mapped[str] = mapped_column(LongText, nullable=False, default=sa.text("'{}'"))
-    tool_meta_str: Mapped[str] = mapped_column(LongText, nullable=False, default=sa.text("'{}'"))
+    tool_labels_str: Mapped[str] = mapped_column(
+        LongText, nullable=False, default_factory=lambda: "{}", server_default=sa.text("'{}'")
+    )
+    tool_meta_str: Mapped[str] = mapped_column(
+        LongText, nullable=False, default_factory=lambda: "{}", server_default=sa.text("'{}'")
+    )
🤖 Prompt for AI Agents
In `@api/models/model.py` around lines 1856 - 1858, The SQL defaults for
tool_labels_str and tool_meta_str are currently SQL expression objects which get
assigned to new Python instances and break json.loads; change their
mapped_column declarations to use a Python default (e.g. default="{}" or
default=lambda: "{}") and move the SQL expression into
server_default=sa.text("'{}'") so the DB still gets the same default; update the
fields tool_labels_str and tool_meta_str accordingly (leave tool_input as-is) so
MessageAgentThought instances receive a real JSON string by default.

observation: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
# plugin_id = mapped_column(StringUUID, nullable=True) ## for future design
tool_process_data = mapped_column(LongText, nullable=True)
message = mapped_column(LongText, nullable=True)
message_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
message_unit_price = mapped_column(sa.Numeric, nullable=True)
message_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
message_files = mapped_column(LongText, nullable=True)
answer = mapped_column(LongText, nullable=True)
answer_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
answer_unit_price = mapped_column(sa.Numeric, nullable=True)
answer_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
total_price = mapped_column(sa.Numeric, nullable=True)
currency = mapped_column(String(255), nullable=True)
latency: Mapped[float | None] = mapped_column(sa.Float, nullable=True)
created_by_role = mapped_column(String(255), nullable=False)
created_by = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.func.current_timestamp())
tool_process_data: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
message: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
message_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None)
message_unit_price: Mapped[Decimal | None] = mapped_column(sa.Numeric, nullable=True, default=None)
message_price_unit: Mapped[Decimal] = mapped_column(
sa.Numeric(10, 7), nullable=False, default=Decimal("0.001"), server_default=sa.text("0.001")
)
message_files: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
answer: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
answer_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None)
answer_unit_price: Mapped[Decimal | None] = mapped_column(sa.Numeric, nullable=True, default=None)
answer_price_unit: Mapped[Decimal] = mapped_column(
sa.Numeric(10, 7), nullable=False, default=Decimal("0.001"), server_default=sa.text("0.001")
)
tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None)
total_price: Mapped[Decimal | None] = mapped_column(sa.Numeric, nullable=True, default=None)
currency: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
latency: Mapped[float | None] = mapped_column(sa.Float, nullable=True, default=None)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, init=False, server_default=sa.func.current_timestamp()
)

@property
def files(self) -> list[Any]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,6 @@ def _create_test_agent_thoughts(self, db_session_with_containers, message):

# Create first agent thought
thought1 = MessageAgentThought(
id=fake.uuid4(),
message_id=message.id,
position=1,
thought="I need to analyze the user's request",
Expand All @@ -257,7 +256,6 @@ def _create_test_agent_thoughts(self, db_session_with_containers, message):

# Create second agent thought
thought2 = MessageAgentThought(
id=fake.uuid4(),
message_id=message.id,
position=2,
thought="Based on the analysis, I can provide a response",
Expand Down Expand Up @@ -545,7 +543,6 @@ def test_get_agent_logs_with_tool_error(self, db_session_with_containers, mock_e

# Create agent thought with tool error
thought_with_error = MessageAgentThought(
id=fake.uuid4(),
message_id=message.id,
position=1,
thought="I need to analyze the user's request",
Expand Down Expand Up @@ -759,7 +756,6 @@ def test_get_agent_logs_with_complex_tool_data(

# Create agent thought with multiple tools
complex_thought = MessageAgentThought(
id=fake.uuid4(),
message_id=message.id,
position=1,
thought="I need to use multiple tools to complete this task",
Expand Down Expand Up @@ -877,7 +873,6 @@ def test_get_agent_logs_with_files(self, db_session_with_containers, mock_extern

# Create agent thought with files
thought_with_files = MessageAgentThought(
id=fake.uuid4(),
message_id=message.id,
position=1,
thought="I need to process some files",
Expand Down Expand Up @@ -957,7 +952,6 @@ def test_get_agent_logs_with_empty_tool_data(self, db_session_with_containers, m

# Create agent thought with empty tool data
empty_thought = MessageAgentThought(
id=fake.uuid4(),
message_id=message.id,
position=1,
thought="I need to analyze the user's request",
Expand Down Expand Up @@ -999,7 +993,6 @@ def test_get_agent_logs_with_malformed_json(self, db_session_with_containers, mo

# Create agent thought with malformed JSON
malformed_thought = MessageAgentThought(
id=fake.uuid4(),
message_id=message.id,
position=1,
thought="I need to analyze the user's request",
Expand Down