Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/uipath_langchain/agent/react/llm_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,17 @@
def _filter_control_flow_tool_calls(
tool_calls: list[ToolCall],
) -> list[ToolCall]:
"""Remove control flow tools when multiple tool calls exist."""
"""Remove control flow tools only when regular tool calls exist alongside them."""
if len(tool_calls) <= 1:
return tool_calls

return [tc for tc in tool_calls if tc.get("name") not in FLOW_CONTROL_TOOLS]
non_control_flow_tool_calls = [
tc for tc in tool_calls if tc.get("name") not in FLOW_CONTROL_TOOLS
]
if not non_control_flow_tool_calls:
return tool_calls

return non_control_flow_tool_calls


StateT = TypeVar("StateT", bound=AgentGraphState)
Expand Down
21 changes: 20 additions & 1 deletion src/uipath_langchain/agent/react/terminate_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,26 @@ def terminate_node(state: AgentGraphState):
category=UiPathErrorCategory.SYSTEM,
)

for tool_call in last_message.tool_calls:
control_flow_tool_calls = [
tool_call
for tool_call in last_message.tool_calls
if tool_call["name"] in {END_EXECUTION_TOOL.name, RAISE_ERROR_TOOL.name}
]

if len(control_flow_tool_calls) > 1:
tool_names = ", ".join(
tool_call["name"] for tool_call in control_flow_tool_calls
)
raise AgentRuntimeError(
code=AgentRuntimeErrorCode.ROUTING_ERROR,
title="Multiple control flow tool calls found in terminate node.",
detail="The terminate node received more than one control flow tool call "
f"in a single AIMessage: {tool_names}. The LLM must return exactly one "
f"of {END_EXECUTION_TOOL.name} or {RAISE_ERROR_TOOL.name}.",
category=UiPathErrorCategory.SYSTEM,
)

for tool_call in control_flow_tool_calls:
tool_name = tool_call["name"]

if tool_name == END_EXECUTION_TOOL.name:
Expand Down
40 changes: 40 additions & 0 deletions tests/agent/react/test_llm_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,3 +263,43 @@ async def test_multiple_flow_control_calls_all_filtered(self):
]
assert len(tool_call_blocks) == 1
assert tool_call_blocks[0]["name"] == "regular_tool"

@pytest.mark.asyncio
async def test_multiple_flow_control_calls_only_are_preserved(self):
"""Multiple control-flow calls without regular tools should remain intact."""
mock_response = AIMessage(
content_blocks=[
create_tool_call(
name=END_EXECUTION_TOOL.name,
args={"result": "done"},
id="call_1",
),
create_tool_call(
name=RAISE_ERROR_TOOL.name,
args={"message": "conflict"},
id="call_2",
),
],
tool_calls=[
{
"name": END_EXECUTION_TOOL.name,
"args": {"result": "done"},
"id": "call_1",
},
{
"name": RAISE_ERROR_TOOL.name,
"args": {"message": "conflict"},
"id": "call_2",
},
],
)
self.mock_model.ainvoke = AsyncMock(return_value=mock_response)

llm_node = create_llm_node(self.mock_model, [self.regular_tool])

result = await llm_node(self.test_state)

response_message = result["messages"][0]
assert len(response_message.tool_calls) == 2
assert response_message.tool_calls[0]["name"] == END_EXECUTION_TOOL.name
assert response_message.tool_calls[1]["name"] == RAISE_ERROR_TOOL.name
35 changes: 35 additions & 0 deletions tests/agent/react/test_terminate_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,29 @@ def state_with_no_control_flow_tool(self):
)
return MockAgentGraphState(messages=[ai_message])

@pytest.fixture
def state_with_conflicting_control_flow_tools(self):
"""Fixture for state with conflicting control-flow tool calls."""
ai_message = AIMessage(
content="",
tool_calls=[
{
"name": END_EXECUTION_TOOL.name,
"args": {"success": True, "message": "done"},
"id": "call_1",
},
{
"name": RAISE_ERROR_TOOL.name,
"args": {
"message": "Something went wrong",
"details": "Additional info",
},
"id": "call_2",
},
],
)
return MockAgentGraphState(messages=[ai_message])

def test_non_conversational_handles_end_execution(
self, terminate_node, state_with_end_execution
):
Expand Down Expand Up @@ -304,6 +327,18 @@ def test_non_conversational_raises_on_no_control_flow_tool(
AgentRuntimeErrorCode.ROUTING_ERROR
)

def test_non_conversational_raises_on_conflicting_control_flow_tools(
self, terminate_node, state_with_conflicting_control_flow_tools
):
"""Non-conversational mode should reject conflicting control-flow tool calls."""
with pytest.raises(AgentRuntimeError) as exc_info:
terminate_node(state_with_conflicting_control_flow_tools)

assert exc_info.value.error_info.code == AgentRuntimeError.full_code(
AgentRuntimeErrorCode.ROUTING_ERROR
)
assert "Multiple control flow tool calls" in exc_info.value.error_info.title


class TestTerminateNodeWithResponseSchema:
"""Test cases for terminate node with custom response schema."""
Expand Down