diff --git a/src/uipath_langchain/agent/react/llm_node.py b/src/uipath_langchain/agent/react/llm_node.py index 26ae0268..5103c4b6 100644 --- a/src/uipath_langchain/agent/react/llm_node.py +++ b/src/uipath_langchain/agent/react/llm_node.py @@ -30,11 +30,17 @@ def _filter_control_flow_tool_calls( tool_calls: list[ToolCall], ) -> list[ToolCall]: - """Remove control flow tools when multiple tool calls exist.""" + """Remove control flow tools only when regular tool calls exist alongside them.""" if len(tool_calls) <= 1: return tool_calls - return [tc for tc in tool_calls if tc.get("name") not in FLOW_CONTROL_TOOLS] + non_control_flow_tool_calls = [ + tc for tc in tool_calls if tc.get("name") not in FLOW_CONTROL_TOOLS + ] + if not non_control_flow_tool_calls: + return tool_calls + + return non_control_flow_tool_calls StateT = TypeVar("StateT", bound=AgentGraphState) diff --git a/src/uipath_langchain/agent/react/terminate_node.py b/src/uipath_langchain/agent/react/terminate_node.py index c2aa1d46..fd860c14 100644 --- a/src/uipath_langchain/agent/react/terminate_node.py +++ b/src/uipath_langchain/agent/react/terminate_node.py @@ -107,7 +107,26 @@ def terminate_node(state: AgentGraphState): category=UiPathErrorCategory.SYSTEM, ) - for tool_call in last_message.tool_calls: + control_flow_tool_calls = [ + tool_call + for tool_call in last_message.tool_calls + if tool_call["name"] in {END_EXECUTION_TOOL.name, RAISE_ERROR_TOOL.name} + ] + + if len(control_flow_tool_calls) > 1: + tool_names = ", ".join( + tool_call["name"] for tool_call in control_flow_tool_calls + ) + raise AgentRuntimeError( + code=AgentRuntimeErrorCode.ROUTING_ERROR, + title="Multiple control flow tool calls found in terminate node.", + detail="The terminate node received more than one control flow tool call " + f"in a single AIMessage: {tool_names}. The LLM must return exactly one " + f"of {END_EXECUTION_TOOL.name} or {RAISE_ERROR_TOOL.name}.", + category=UiPathErrorCategory.SYSTEM, + ) + + for tool_call in control_flow_tool_calls: tool_name = tool_call["name"] if tool_name == END_EXECUTION_TOOL.name: diff --git a/tests/agent/react/test_llm_node.py b/tests/agent/react/test_llm_node.py index e7517582..3977a701 100644 --- a/tests/agent/react/test_llm_node.py +++ b/tests/agent/react/test_llm_node.py @@ -263,3 +263,43 @@ async def test_multiple_flow_control_calls_all_filtered(self): ] assert len(tool_call_blocks) == 1 assert tool_call_blocks[0]["name"] == "regular_tool" + + @pytest.mark.asyncio + async def test_multiple_flow_control_calls_only_are_preserved(self): + """Multiple control-flow calls without regular tools should remain intact.""" + mock_response = AIMessage( + content_blocks=[ + create_tool_call( + name=END_EXECUTION_TOOL.name, + args={"result": "done"}, + id="call_1", + ), + create_tool_call( + name=RAISE_ERROR_TOOL.name, + args={"message": "conflict"}, + id="call_2", + ), + ], + tool_calls=[ + { + "name": END_EXECUTION_TOOL.name, + "args": {"result": "done"}, + "id": "call_1", + }, + { + "name": RAISE_ERROR_TOOL.name, + "args": {"message": "conflict"}, + "id": "call_2", + }, + ], + ) + self.mock_model.ainvoke = AsyncMock(return_value=mock_response) + + llm_node = create_llm_node(self.mock_model, [self.regular_tool]) + + result = await llm_node(self.test_state) + + response_message = result["messages"][0] + assert len(response_message.tool_calls) == 2 + assert response_message.tool_calls[0]["name"] == END_EXECUTION_TOOL.name + assert response_message.tool_calls[1]["name"] == RAISE_ERROR_TOOL.name diff --git a/tests/agent/react/test_terminate_node.py b/tests/agent/react/test_terminate_node.py index 088f5d1b..7b197cb2 100644 --- a/tests/agent/react/test_terminate_node.py +++ b/tests/agent/react/test_terminate_node.py @@ -261,6 +261,29 @@ def state_with_no_control_flow_tool(self): ) return MockAgentGraphState(messages=[ai_message]) + @pytest.fixture + def state_with_conflicting_control_flow_tools(self): + """Fixture for state with conflicting control-flow tool calls.""" + ai_message = AIMessage( + content="", + tool_calls=[ + { + "name": END_EXECUTION_TOOL.name, + "args": {"success": True, "message": "done"}, + "id": "call_1", + }, + { + "name": RAISE_ERROR_TOOL.name, + "args": { + "message": "Something went wrong", + "details": "Additional info", + }, + "id": "call_2", + }, + ], + ) + return MockAgentGraphState(messages=[ai_message]) + def test_non_conversational_handles_end_execution( self, terminate_node, state_with_end_execution ): @@ -304,6 +327,18 @@ def test_non_conversational_raises_on_no_control_flow_tool( AgentRuntimeErrorCode.ROUTING_ERROR ) + def test_non_conversational_raises_on_conflicting_control_flow_tools( + self, terminate_node, state_with_conflicting_control_flow_tools + ): + """Non-conversational mode should reject conflicting control-flow tool calls.""" + with pytest.raises(AgentRuntimeError) as exc_info: + terminate_node(state_with_conflicting_control_flow_tools) + + assert exc_info.value.error_info.code == AgentRuntimeError.full_code( + AgentRuntimeErrorCode.ROUTING_ERROR + ) + assert "Multiple control flow tool calls" in exc_info.value.error_info.title + class TestTerminateNodeWithResponseSchema: """Test cases for terminate node with custom response schema."""