diff --git a/src/cai/util.py b/src/cai/util.py index b5acca68c..2fc64688c 100644 --- a/src/cai/util.py +++ b/src/cai/util.py @@ -1244,13 +1244,39 @@ def fix_message_list(messages): # pylint: disable=R0914,R0915,R0912 # If this isn't the first message, check if the previous message is a matching assistant message if i > 0: - prev_msg = processed_messages[i - 1] + # Previous logic only checked the immediately preceding message. + # That fails when one assistant message calls multiple tools: + # + # prev_msg = processed_messages[i - 1] + # is_valid_sequence = ( + # prev_msg.get("role") == "assistant" + # and prev_msg.get("tool_calls") + # and any(tc.get("id") == tool_id for tc in prev_msg.get("tool_calls", [])) + # ) + # + # Valid multi-tool sequences look like: + # assistant(tool_calls=[a, b, c]), tool(a), tool(b), tool(c). + # So walk back over contiguous tool results and validate against + # the nearest preceding assistant tool_calls block. + sequence_assistant_idx = i - 1 + while ( + sequence_assistant_idx >= 0 + and processed_messages[sequence_assistant_idx].get("role") == "tool" + ): + sequence_assistant_idx -= 1 - # Check if the previous message is an assistant message with matching tool_call_id + sequence_assistant = ( + processed_messages[sequence_assistant_idx] + if sequence_assistant_idx >= 0 + else {} + ) is_valid_sequence = ( - prev_msg.get("role") == "assistant" - and prev_msg.get("tool_calls") - and any(tc.get("id") == tool_id for tc in prev_msg.get("tool_calls", [])) + sequence_assistant.get("role") == "assistant" + and sequence_assistant.get("tool_calls") + and any( + tc.get("id") == tool_id + for tc in sequence_assistant.get("tool_calls", []) + ) ) if not is_valid_sequence: @@ -1273,18 +1299,33 @@ def fix_message_list(messages): # pylint: disable=R0914,R0915,R0912 # Remember to save the tool message tool_msg = processed_messages.pop(i) - # Insert right after the assistant message - processed_messages.insert(assistant_idx + 1, tool_msg) + # If the assistant was after the tool message, its index + # shifts left after pop(i). + if assistant_idx > i: + assistant_idx -= 1 - # Adjust i to account for the move - if assistant_idx < i: - # We moved the message backward, so i should point to the next message - # which is now at position i (since we removed a message before it) - continue - else: - # We moved the message forward, so i should now point to the message - # that is now at position i - continue + assistant_tool_ids = { + tc.get("id") + for tc in processed_messages[assistant_idx].get("tool_calls", []) + } + insert_idx = assistant_idx + 1 + while ( + insert_idx < len(processed_messages) + and processed_messages[insert_idx].get("role") == "tool" + and processed_messages[insert_idx].get("tool_call_id") + in assistant_tool_ids + ): + insert_idx += 1 + + # Insert after the assistant's existing tool-result block + # instead of always at assistant_idx + 1, which would reorder + # multiple tool results forever. + processed_messages.insert(insert_idx, tool_msg) + + # Move past the current position to avoid reprocessing + # the same slot after mutating the list. + i += 1 + continue else: # No matching assistant message found - create one assistant_msg = { diff --git a/tests/cli/test_cli_streaming.py b/tests/cli/test_cli_streaming.py index 21e1482b2..7926bbcca 100644 --- a/tests/cli/test_cli_streaming.py +++ b/tests/cli/test_cli_streaming.py @@ -341,6 +341,53 @@ def test_fix_message_list_with_interrupted_tools(self): # No need to clean up _Converter state since it's instance-based return False + def test_fix_message_list_allows_multiple_tool_results(self): + """Test one assistant message can be followed by multiple tool results.""" + from cai.util import fix_message_list + + messages = [ + {"role": "user", "content": "Run three checks"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_a", + "type": "function", + "function": {"name": "tool_a", "arguments": "{}"}, + }, + { + "id": "call_b", + "type": "function", + "function": {"name": "tool_b", "arguments": "{}"}, + }, + { + "id": "call_c", + "type": "function", + "function": {"name": "tool_c", "arguments": "{}"}, + }, + ], + }, + {"role": "tool", "tool_call_id": "call_a", "content": "result a"}, + {"role": "tool", "tool_call_id": "call_b", "content": "result b"}, + {"role": "tool", "tool_call_id": "call_c", "content": "result c"}, + ] + + fixed_messages = fix_message_list(messages) + + assert [msg["role"] for msg in fixed_messages] == [ + "user", + "assistant", + "tool", + "tool", + "tool", + ] + assert [ + msg.get("tool_call_id") + for msg in fixed_messages + if msg.get("role") == "tool" + ] == ["call_a", "call_b", "call_c"] + def test_generic_linux_command_interrupt_simulation(self): """Test generic_linux_command behavior during interruption."""