diff --git a/src/cai/util.py b/src/cai/util.py index b5acca68c..85a2c3e99 100644 --- a/src/cai/util.py +++ b/src/cai/util.py @@ -1244,9 +1244,19 @@ def fix_message_list(messages): # pylint: disable=R0914,R0915,R0912 # If this isn't the first message, check if the previous message is a matching assistant message if i > 0: - prev_msg = processed_messages[i - 1] - - # Check if the previous message is an assistant message with matching tool_call_id + # Walk backward past sibling tool messages to find the nearest + # assistant. This avoids an infinite loop when an assistant has + # multiple tool_calls and their responses arrive out of order: + # the previous message may be a sibling tool response rather + # than the parent assistant message, which is still valid. + k = i - 1 + while k >= 0 and processed_messages[k].get("role") == "tool": + k -= 1 + + prev_msg = processed_messages[k] if k >= 0 else {} + + # Check if the nearest non-tool ancestor is an assistant message + # with a matching tool_call_id is_valid_sequence = ( prev_msg.get("role") == "assistant" and prev_msg.get("tool_calls") diff --git a/tests/test_fix_message_list.py b/tests/test_fix_message_list.py new file mode 100644 index 000000000..6e9348a41 --- /dev/null +++ b/tests/test_fix_message_list.py @@ -0,0 +1,45 @@ +import signal + +import pytest + +from cai.util import fix_message_list + + +def test_fix_message_list_handles_multiple_tool_results_without_loop(): + if not hasattr(signal, "SIGALRM"): + pytest.skip("SIGALRM is required for this infinite-loop regression test") + + messages = [ + {"role": "user", "content": "Run both tools"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_one", + "type": "function", + "function": {"name": "first_tool", "arguments": "{}"}, + }, + { + "id": "call_two", + "type": "function", + "function": {"name": "second_tool", "arguments": "{}"}, + }, + ], + }, + {"role": "tool", "tool_call_id": "call_one", "content": "first result"}, + {"role": "tool", "tool_call_id": "call_two", "content": "second result"}, + ] + + def fail_on_timeout(_signum, _frame): + raise TimeoutError("fix_message_list did not terminate") + + previous_handler = signal.signal(signal.SIGALRM, fail_on_timeout) + signal.alarm(2) + try: + fixed_messages = fix_message_list(messages) + finally: + signal.alarm(0) + signal.signal(signal.SIGALRM, previous_handler) + + assert fixed_messages == messages