diff --git a/effectful/handlers/llm/completions.py b/effectful/handlers/llm/completions.py index 5d77c881..daf88b4b 100644 --- a/effectful/handlers/llm/completions.py +++ b/effectful/handlers/llm/completions.py @@ -1,9 +1,11 @@ import collections import collections.abc +import dataclasses import functools import inspect import string import textwrap +import traceback import typing import litellm @@ -20,8 +22,9 @@ OpenAIMessageContentListBlock, ) -from effectful.handlers.llm import Template, Tool from effectful.handlers.llm.encoding import Encodable +from effectful.handlers.llm.template import Template, Tool +from effectful.ops.semantics import fwd from effectful.ops.syntax import ObjectInterpretation, implements from effectful.ops.types import Operation @@ -36,6 +39,83 @@ type ToolCallID = str +@dataclasses.dataclass +class ToolCallDecodingError(Exception): + """Error raised when decoding a tool call fails.""" + + tool_name: str + tool_call_id: str + original_error: Exception + raw_message: Message + + def __str__(self) -> str: + return f"Error decoding tool call '{self.tool_name}': {self.original_error}. Please provide a valid response and try again." + + def to_feedback_message(self, include_traceback: bool) -> Message: + error_message = f"{self}" + if include_traceback: + tb = traceback.format_exc() + error_message = f"{error_message}\n\nTraceback:\n```\n{tb}```" + return typing.cast( + Message, + { + "role": "tool", + "tool_call_id": self.tool_call_id, + "content": error_message, + }, + ) + + +@dataclasses.dataclass +class ResultDecodingError(Exception): + """Error raised when decoding the LLM response result fails.""" + + original_error: Exception + raw_message: Message + + def __str__(self) -> str: + return f"Error decoding response: {self.original_error}. Please provide a valid response and try again." + + def to_feedback_message(self, include_traceback: bool) -> Message: + error_message = f"{self}" + if include_traceback: + tb = traceback.format_exc() + error_message = f"{error_message}\n\nTraceback:\n```\n{tb}```" + return typing.cast( + Message, + { + "role": "user", + "content": error_message, + }, + ) + + +@dataclasses.dataclass +class ToolCallExecutionError(Exception): + """Error raised when a tool execution fails at runtime.""" + + tool_name: str + tool_call_id: str + original_error: BaseException + + def __str__(self) -> str: + return f"Tool execution failed: Error executing tool '{self.tool_name}': {self.original_error}" + + def to_feedback_message(self, include_traceback: bool) -> Message: + error_message = f"{self}" + if include_traceback: + tb = traceback.format_exc() + error_message = f"{error_message}\n\nTraceback:\n```\n{tb}```" + return typing.cast( + Message, + { + "role": "tool", + "tool_call_id": self.tool_call_id, + "content": error_message, + }, + ) + + class DecodedToolCall[T](typing.NamedTuple): tool: Tool[..., T] bound_args: inspect.BoundArguments @@ -77,26 +157,49 @@ def _function_model(tool: Tool) -> ChatCompletionToolParam: def decode_tool_call( tool_call: ChatCompletionMessageToolCall, tools: collections.abc.Mapping[str, Tool], + raw_message: Message, ) -> DecodedToolCall: - """Decode a tool call from the LLM response into a DecodedToolCall.""" - assert tool_call.function.name is not None - tool = tools[tool_call.function.name] - json_str = tool_call.function.arguments + """Decode a tool call from the LLM response into a DecodedToolCall. + + Args: + tool_call: The tool call to decode. + tools: Mapping of tool names to Tool objects. + raw_message: Optional raw assistant message for error context. + Raises: + ToolCallDecodingError: If the tool call cannot be decoded. + """ + tool_name = tool_call.function.name + assert tool_name is not None + + try: + tool = tools[tool_name] + except KeyError as e: + raise ToolCallDecodingError( + tool_name, tool_call.id, e, raw_message=raw_message + ) from e + + json_str = tool_call.function.arguments sig = inspect.signature(tool) - # build dict of raw encodable types U - raw_args = _param_model(tool).model_validate_json(json_str) + try: + # build dict of raw encodable types U + raw_args = _param_model(tool).model_validate_json(json_str) + + # use encoders to decode Us to python types T + bound_sig: inspect.BoundArguments = sig.bind( + **{ + param_name: Encodable.define( + sig.parameters[param_name].annotation, {} + ).decode(getattr(raw_args, param_name)) + for param_name in raw_args.model_fields_set + } + ) + except (pydantic.ValidationError, TypeError, ValueError) as e: + raise ToolCallDecodingError( + tool_name, tool_call.id, e, raw_message=raw_message + ) from e - # use encoders to decode Us to python types T - bound_sig: inspect.BoundArguments = sig.bind( - **{ - param_name: Encodable.define( - sig.parameters[param_name].annotation, {} - ).decode(getattr(raw_args, param_name)) - for param_name in raw_args.model_fields_set - } - ) return DecodedToolCall(tool, bound_sig, tool_call.id) @@ -125,6 +228,11 @@ def call_assistant[T, U]( This effect is emitted for model request/response rounds so handlers can observe/log requests. + Raises: + ToolCallDecodingError: If a tool call cannot be decoded. The error + includes the raw assistant message for retry handling. + ResultDecodingError: If the result cannot be decoded. The error + includes the raw assistant message for retry handling. """ tool_specs = {k: _function_model(t) for k, t in tools.items()} response_model = pydantic.create_model( @@ -144,11 +252,15 @@ def call_assistant[T, U]( message: litellm.Message = choice.message assert message.role == "assistant" + raw_message = typing.cast(Message, message.model_dump(mode="json")) + tool_calls: list[DecodedToolCall] = [] raw_tool_calls = message.get("tool_calls") or [] - for tool_call in raw_tool_calls: - tool_call = ChatCompletionMessageToolCall.model_validate(tool_call) - decoded_tool_call = decode_tool_call(tool_call, tools) + for raw_tool_call in raw_tool_calls: + validated_tool_call = ChatCompletionMessageToolCall.model_validate( + raw_tool_call + ) + decoded_tool_call = decode_tool_call(validated_tool_call, tools, raw_message) tool_calls.append(decoded_tool_call) result = None @@ -158,10 +270,13 @@ def call_assistant[T, U]( assert isinstance(serialized_result, str), ( "final response from the model should be a string" ) - raw_result = response_model.model_validate_json(serialized_result) - result = response_format.decode(raw_result.value) # type: ignore + try: + raw_result = response_model.model_validate_json(serialized_result) + result = response_format.decode(raw_result.value) # type: ignore + except pydantic.ValidationError as e: + raise ResultDecodingError(e, raw_message=raw_message) from e - return (typing.cast(Message, message.model_dump(mode="json")), tool_calls, result) + return (raw_message, tool_calls, result) @Operation.define @@ -239,6 +354,95 @@ def call_system(template: Template) -> collections.abc.Sequence[Message]: return () +class RetryLLMHandler(ObjectInterpretation): + """Retries LLM requests if tool call or result decoding fails. + + This handler intercepts `call_assistant` and catches `ToolCallDecodingError` + and `ResultDecodingError`. When these errors occur, it appends error feedback + to the messages and retries the request. Malformed messages from retry attempts + are pruned from the final result. + + For runtime tool execution failures (handled via `call_tool`), errors are + captured and returned as tool response messages. + + Args: + num_retries: The maximum number of retries (default: 3). + include_traceback: If True, include full traceback in error feedback + for better debugging context (default: False). + catch_tool_errors: Exception type(s) to catch during tool execution. + Can be a single exception class or a tuple of exception classes. + Defaults to Exception (catches all exceptions). + """ + + def __init__( + self, + num_retries: int = 3, + include_traceback: bool = False, + catch_tool_errors: type[BaseException] + | tuple[type[BaseException], ...] = Exception, + ): + self.num_retries = num_retries + self.include_traceback = include_traceback + self.catch_tool_errors = catch_tool_errors + + @implements(call_assistant) + def _call_assistant[T, U]( + self, + messages: collections.abc.Sequence[Message], + tools: collections.abc.Mapping[str, Tool], + response_format: Encodable[T, U], + model: str, + **kwargs, + ) -> MessageResult[T]: + messages_list = list(messages) + last_attempt = self.num_retries + + for attempt in range(self.num_retries + 1): + try: + message, tool_calls, result = fwd( + messages_list, tools, response_format, model, **kwargs + ) + + # Success! The returned message is the final successful response. + # Malformed messages from retries are only in messages_list, + # not in the returned result. + return (message, tool_calls, result) + + except (ToolCallDecodingError, ResultDecodingError) as e: + # On last attempt, re-raise to preserve full traceback + if attempt == last_attempt: + raise + + # Add the malformed assistant message + messages_list.append(e.raw_message) + + # Add error feedback as a tool response + error_feedback: Message = e.to_feedback_message(self.include_traceback) + messages_list.append(error_feedback) + + # Should never reach here - either we return on success or raise on final failure + raise AssertionError("Unreachable: retry loop exited without return or raise") + + @implements(completion) + def _completion(self, *args, **kwargs) -> typing.Any: + """Inject num_retries for litellm's built-in network error handling.""" + return fwd(*args, num_retries=self.num_retries, **kwargs) + + @implements(call_tool) + def _call_tool(self, tool_call: DecodedToolCall) -> Message: + """Handle tool execution with runtime error capture. + + Runtime errors from tool execution are captured and returned as + error messages to the LLM. Only exceptions matching `catch_tool_errors` + are caught; others propagate up. + """ + try: + return fwd(tool_call) + except self.catch_tool_errors as e: + error = ToolCallExecutionError(tool_call.tool.__name__, tool_call.id, e) + return error.to_feedback_message(self.include_traceback) + + class LiteLLMProvider(ObjectInterpretation): """Implements templates using the LiteLLM API.""" diff --git a/tests/test_handlers_llm_provider.py b/tests/test_handlers_llm_provider.py index 0a766231..f557811d 100644 --- a/tests/test_handlers_llm_provider.py +++ b/tests/test_handlers_llm_provider.py @@ -22,8 +22,14 @@ from effectful.handlers.llm import Template from effectful.handlers.llm.completions import ( + DecodedToolCall, LiteLLMProvider, + ResultDecodingError, + RetryLLMHandler, + Tool, + ToolCallDecodingError, call_assistant, + call_tool, completion, ) from effectful.handlers.llm.encoding import Encodable, SynthesizedFunction @@ -50,9 +56,8 @@ REBUILD_FIXTURES = os.getenv("REBUILD_FIXTURES") == "true" -# ============================================================================ - +# ============================================================================ # Test Fixtures and Mock Data # ============================================================================ def retry_on_error(error: type[Exception], n: int): @@ -348,6 +353,545 @@ def test_litellm_caching_selective(request): assert p1 != p2, "when caching is not enabled, llm outputs should be different" +# ============================================================================ +# RetryLLMHandler Tests +# ============================================================================ + + +class MockCompletionHandler(ObjectInterpretation): + """Mock handler that returns pre-configured completion responses.""" + + def __init__(self, responses: list[ModelResponse]): + self.responses = responses + self.call_count = 0 + self.received_messages: list = [] + + @implements(completion) + def _completion(self, model, messages=None, **kwargs): + self.received_messages.append(list(messages) if messages else []) + response = self.responses[min(self.call_count, len(self.responses) - 1)] + self.call_count += 1 + return response + + +def make_tool_call_response( + tool_name: str, tool_args: str, tool_call_id: str = "call_1" +) -> ModelResponse: + """Create a ModelResponse with a tool call.""" + return ModelResponse( + id="test", + choices=[ + { + "index": 0, + "message": { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": tool_call_id, + "type": "function", + "function": {"name": tool_name, "arguments": tool_args}, + } + ], + }, + "finish_reason": "tool_calls", + } + ], + model="test-model", + ) + + +def make_text_response(content: str) -> ModelResponse: + """Create a ModelResponse with text content.""" + return ModelResponse( + id="test", + choices=[ + { + "index": 0, + "message": {"role": "assistant", "content": content}, + "finish_reason": "stop", + } + ], + model="test-model", + ) + + +@Tool.define +def add_numbers(a: int, b: int) -> int: + """Add two numbers together.""" + return a + b + + +class TestRetryLLMHandler: + """Tests for RetryLLMHandler functionality.""" + + def test_retry_handler_succeeds_on_first_attempt(self): + """Test that RetryLLMHandler passes through when no error occurs.""" + # Response with valid tool call + responses = [make_text_response('{"value": "hello"}')] + + mock_handler = MockCompletionHandler(responses) + + with handler(RetryLLMHandler(num_retries=3)), handler(mock_handler): + message, tool_calls, result = call_assistant( + messages=[{"role": "user", "content": "test"}], + tools={}, + response_format=Encodable.define(str), + model="test-model", + ) + + assert mock_handler.call_count == 1 + assert result == "hello" + + def test_retry_handler_retries_on_invalid_tool_call(self): + """Test that RetryLLMHandler retries when tool call decoding fails.""" + # First response has invalid tool args, second has valid response + responses = [ + make_tool_call_response( + "add_numbers", '{"a": "not_an_int", "b": 2}' + ), # Invalid + make_text_response('{"value": "success"}'), # Valid + ] + + mock_handler = MockCompletionHandler(responses) + + with handler(RetryLLMHandler(num_retries=3)), handler(mock_handler): + message, tool_calls, result = call_assistant( + messages=[{"role": "user", "content": "test"}], + tools={"add_numbers": add_numbers}, + response_format=Encodable.define(str), + model="test-model", + ) + + assert mock_handler.call_count == 2 + assert result == "success" + # Check that the second call included error feedback + assert len(mock_handler.received_messages[1]) > len( + mock_handler.received_messages[0] + ) + + def test_retry_handler_retries_on_unknown_tool(self): + """Test that RetryLLMHandler retries when tool is not found.""" + # First response has unknown tool, second has valid response + responses = [ + make_tool_call_response("unknown_tool", '{"x": 1}'), # Unknown tool + make_text_response('{"value": "success"}'), # Valid + ] + + mock_handler = MockCompletionHandler(responses) + + with handler(RetryLLMHandler(num_retries=3)), handler(mock_handler): + message, tool_calls, result = call_assistant( + messages=[{"role": "user", "content": "test"}], + tools={"add_numbers": add_numbers}, + response_format=Encodable.define(str), + model="test-model", + ) + + assert mock_handler.call_count == 2 + assert result == "success" + + def test_retry_handler_exhausts_retries(self): + """Test that RetryLLMHandler raises after exhausting all retries.""" + # All responses have invalid tool calls + responses = [ + make_tool_call_response("add_numbers", '{"a": "bad", "b": "bad"}'), + ] + + mock_handler = MockCompletionHandler(responses) + + with pytest.raises(Exception): # Will raise the underlying decoding error + with handler(RetryLLMHandler(num_retries=2)), handler(mock_handler): + call_assistant( + messages=[{"role": "user", "content": "test"}], + tools={"add_numbers": add_numbers}, + response_format=Encodable.define(str), + model="test-model", + ) + + # Should have attempted 3 times (1 initial + 2 retries) + assert mock_handler.call_count == 3 + + def test_retry_handler_with_zero_retries(self): + """Test RetryLLMHandler with num_retries=0 fails immediately on error.""" + responses = [ + make_tool_call_response("add_numbers", '{"a": "bad", "b": "bad"}'), + ] + + mock_handler = MockCompletionHandler(responses) + + with pytest.raises(Exception): + with handler(RetryLLMHandler(num_retries=0)), handler(mock_handler): + call_assistant( + messages=[{"role": "user", "content": "test"}], + tools={"add_numbers": add_numbers}, + response_format=Encodable.define(str), + model="test-model", + ) + + assert mock_handler.call_count == 1 + + def test_retry_handler_valid_tool_call_passes_through(self): + """Test that valid tool calls are decoded and returned.""" + responses = [ + make_tool_call_response("add_numbers", '{"a": 1, "b": 2}'), + ] + + mock_handler = MockCompletionHandler(responses) + + with handler(RetryLLMHandler(num_retries=3)), handler(mock_handler): + message, tool_calls, result = call_assistant( + messages=[{"role": "user", "content": "test"}], + tools={"add_numbers": add_numbers}, + response_format=Encodable.define(str), + model="test-model", + ) + + assert mock_handler.call_count == 1 + assert len(tool_calls) == 1 + assert tool_calls[0].tool == add_numbers + assert result is None # No result when there are tool calls + + def test_retry_handler_retries_on_invalid_result(self): + """Test that RetryLLMHandler retries when result decoding fails.""" + # First response has invalid JSON, second has valid response + responses = [ + make_text_response('{"value": "not valid for int"}'), # Invalid for int + make_text_response('{"value": 42}'), # Valid + ] + + mock_handler = MockCompletionHandler(responses) + + with handler(RetryLLMHandler(num_retries=3)), handler(mock_handler): + message, tool_calls, result = call_assistant( + messages=[{"role": "user", "content": "test"}], + tools={}, + response_format=Encodable.define(int), + model="test-model", + ) + + assert mock_handler.call_count == 2 + assert result == 42 + # Check that the second call included error feedback + assert len(mock_handler.received_messages[1]) > len( + mock_handler.received_messages[0] + ) + + def test_retry_handler_exhausts_retries_on_result_decoding(self): + """Test that RetryLLMHandler raises after exhausting retries on result decoding.""" + # All responses have invalid results for int type + responses = [ + make_text_response('{"value": "not an int"}'), + ] + + mock_handler = MockCompletionHandler(responses) + + with pytest.raises(Exception): # Will raise the underlying decoding error + with handler(RetryLLMHandler(num_retries=2)), handler(mock_handler): + call_assistant( + messages=[{"role": "user", "content": "test"}], + tools={}, + response_format=Encodable.define(int), + model="test-model", + ) + + # Should have attempted 3 times (1 initial + 2 retries) + assert mock_handler.call_count == 3 + + def test_retry_handler_raises_tool_call_decoding_error(self): + """Test that RetryLLMHandler raises ToolCallDecodingError with correct attributes.""" + responses = [ + make_tool_call_response("add_numbers", '{"a": "bad", "b": "bad"}'), + ] + + mock_handler = MockCompletionHandler(responses) + + with pytest.raises(ToolCallDecodingError) as exc_info: + with handler(RetryLLMHandler(num_retries=0)), handler(mock_handler): + call_assistant( + messages=[{"role": "user", "content": "test"}], + tools={"add_numbers": add_numbers}, + response_format=Encodable.define(str), + model="test-model", + ) + + error = exc_info.value + assert error.tool_name == "add_numbers" + assert error.tool_call_id == "call_1" + assert error.raw_message is not None + assert "add_numbers" in str(error) + + def test_retry_handler_raises_result_decoding_error(self): + """Test that RetryLLMHandler raises ResultDecodingError with correct attributes.""" + responses = [ + make_text_response('{"value": "not an int"}'), + ] + + mock_handler = MockCompletionHandler(responses) + + with pytest.raises(ResultDecodingError) as exc_info: + with handler(RetryLLMHandler(num_retries=0)), handler(mock_handler): + call_assistant( + messages=[{"role": "user", "content": "test"}], + tools={}, + response_format=Encodable.define(int), + model="test-model", + ) + + error = exc_info.value + assert error.raw_message is not None + assert error.original_error is not None + + def test_retry_handler_error_feedback_contains_tool_name(self): + """Test that error feedback messages contain the tool name.""" + responses = [ + make_tool_call_response("add_numbers", '{"a": "bad", "b": 2}'), + make_text_response('{"value": "success"}'), + ] + + mock_handler = MockCompletionHandler(responses) + + with handler(RetryLLMHandler(num_retries=3)), handler(mock_handler): + call_assistant( + messages=[{"role": "user", "content": "test"}], + tools={"add_numbers": add_numbers}, + response_format=Encodable.define(str), + model="test-model", + ) + + # Check that the error feedback in the second call mentions the tool name + second_call_messages = mock_handler.received_messages[1] + tool_feedback = [m for m in second_call_messages if m.get("role") == "tool"] + assert len(tool_feedback) == 1 + assert "add_numbers" in tool_feedback[0]["content"] + + def test_retry_handler_unknown_tool_error_contains_tool_name(self): + """Test that unknown tool errors contain the tool name in the feedback.""" + responses = [ + make_tool_call_response("nonexistent_tool", '{"x": 1}'), + make_text_response('{"value": "success"}'), + ] + + mock_handler = MockCompletionHandler(responses) + + with handler(RetryLLMHandler(num_retries=3)), handler(mock_handler): + call_assistant( + messages=[{"role": "user", "content": "test"}], + tools={"add_numbers": add_numbers}, + response_format=Encodable.define(str), + model="test-model", + ) + + # Check that the error feedback mentions the unknown tool + second_call_messages = mock_handler.received_messages[1] + tool_feedback = [m for m in second_call_messages if m.get("role") == "tool"] + assert len(tool_feedback) == 1 + assert "nonexistent_tool" in tool_feedback[0]["content"] + + def test_retry_handler_include_traceback_in_error_feedback(self): + """Test that include_traceback=True adds traceback to error messages.""" + responses = [ + make_tool_call_response("add_numbers", '{"a": "bad", "b": 2}'), + make_text_response('{"value": "success"}'), + ] + + mock_handler = MockCompletionHandler(responses) + + with ( + handler(RetryLLMHandler(num_retries=3, include_traceback=True)), + handler(mock_handler), + ): + call_assistant( + messages=[{"role": "user", "content": "test"}], + tools={"add_numbers": add_numbers}, + response_format=Encodable.define(str), + model="test-model", + ) + + # Check that the error feedback includes traceback + second_call_messages = mock_handler.received_messages[1] + tool_feedback = [m for m in second_call_messages if m.get("role") == "tool"] + assert len(tool_feedback) == 1 + assert "Traceback:" in tool_feedback[0]["content"] + assert "```" in tool_feedback[0]["content"] + + def test_retry_handler_no_traceback_by_default(self): + """Test that include_traceback=False (default) doesn't add traceback.""" + responses = [ + make_tool_call_response("add_numbers", '{"a": "bad", "b": 2}'), + make_text_response('{"value": "success"}'), + ] + + mock_handler = MockCompletionHandler(responses) + + with handler(RetryLLMHandler(num_retries=3)), handler(mock_handler): + call_assistant( + messages=[{"role": "user", "content": "test"}], + tools={"add_numbers": add_numbers}, + response_format=Encodable.define(str), + model="test-model", + ) + + # Check that the error feedback does NOT include traceback + second_call_messages = mock_handler.received_messages[1] + tool_feedback = [m for m in second_call_messages if m.get("role") == "tool"] + assert len(tool_feedback) == 1 + assert "Traceback:" not in tool_feedback[0]["content"] + + +# ============================================================================ +# Tool Execution Error Tests +# ============================================================================ + + +@Tool.define +def failing_tool(x: int) -> int: + """A tool that always raises an exception.""" + raise ValueError(f"Tool failed with input {x}") + + +@Tool.define +def divide_tool(a: int, b: int) -> int: + """Divide a by b.""" + return a // b + + +class TestToolExecutionErrorHandling: + """Tests for runtime tool execution error handling.""" + + def test_retry_handler_catches_tool_runtime_error(self): + """Test that RetryLLMHandler catches tool runtime errors and returns error message.""" + + # Create a decoded tool call for failing_tool + sig = inspect.signature(failing_tool) + bound_args = sig.bind(x=42) + tool_call = DecodedToolCall(failing_tool, bound_args, "call_1") + + with handler(RetryLLMHandler(num_retries=3)): + result = call_tool(tool_call) + + # The result should be an error message, not an exception + assert result["role"] == "tool" + assert result["tool_call_id"] == "call_1" + assert "Tool execution failed" in result["content"] + assert "failing_tool" in result["content"] + assert "42" in result["content"] + + def test_retry_handler_catches_division_by_zero(self): + """Test that RetryLLMHandler catches division by zero errors.""" + + sig = inspect.signature(divide_tool) + bound_args = sig.bind(a=10, b=0) + tool_call = DecodedToolCall(divide_tool, bound_args, "call_div") + + with handler(RetryLLMHandler(num_retries=3)): + result = call_tool(tool_call) + + assert result["role"] == "tool" + assert result["tool_call_id"] == "call_div" + assert "Tool execution failed" in result["content"] + assert "divide_tool" in result["content"] + + def test_successful_tool_execution_returns_result(self): + """Test that successful tool executions return normal results.""" + + sig = inspect.signature(add_numbers) + bound_args = sig.bind(a=3, b=4) + tool_call = DecodedToolCall(add_numbers, bound_args, "call_add") + + with handler(RetryLLMHandler(num_retries=3)): + result = call_tool(tool_call) + + assert result["role"] == "tool" + assert result["tool_call_id"] == "call_add" + # The result should be the serialized return value, not an error + assert "Tool execution failed" not in result["content"] + + def test_tool_execution_error_not_pruned_from_messages(self): + """Test that tool execution errors are NOT pruned (they're legitimate failures).""" + # This test verifies the docstring claim that tool execution errors + # should be kept in the message history, unlike decoding errors + + # First call: valid tool call that will fail at runtime + # Second call: successful text response + responses = [ + make_tool_call_response("failing_tool", '{"x": 42}'), + make_text_response('{"value": "handled the error"}'), + ] + + mock_handler = MockCompletionHandler(responses) + + # We need a custom provider that actually calls call_tool + class TestProvider(ObjectInterpretation): + @implements(call_assistant) + def _call_assistant( + self, messages, tools, response_format, model, **kwargs + ): + return fwd(messages, tools, response_format, model, **kwargs) + + with ( + handler(RetryLLMHandler(num_retries=3)), + handler(TestProvider()), + handler(mock_handler), + ): + message, tool_calls, result = call_assistant( + messages=[{"role": "user", "content": "test"}], + tools={"failing_tool": failing_tool}, + response_format=Encodable.define(str), + model="test-model", + ) + + # First call should succeed (tool call is valid) + assert mock_handler.call_count == 1 + assert len(tool_calls) == 1 + + +# ============================================================================ +# Error Class Tests +# ============================================================================ + + +class TestErrorClasses: + """Tests for the error class definitions.""" + + def test_tool_call_decoding_error_string_representation(self): + """Test ToolCallDecodingError string includes relevant info.""" + original = ValueError("invalid value") + error = ToolCallDecodingError( + "my_function", "call_abc", original, raw_message={"role": "assistant"} + ) + + error_str = str(error) + assert "my_function" in error_str + assert "invalid value" in error_str + + def test_result_decoding_error_string_representation(self): + """Test ResultDecodingError string includes relevant info.""" + original = ValueError("parse error") + error = ResultDecodingError(original, raw_message={"role": "assistant"}) + + error_str = str(error) + assert "parse error" in error_str + assert "decoding response" in error_str.lower() + + def test_error_classes_preserve_original_error(self): + """Test that all error classes preserve the original exception.""" + original = TypeError("type mismatch") + mock_message = {"role": "assistant", "content": "test"} + + tool_decode_err = ToolCallDecodingError("fn", "id", original, mock_message) + assert tool_decode_err.original_error is original + + result_decode_err = ResultDecodingError(original, mock_message) + assert result_decode_err.original_error is original + + def test_tool_call_decoding_error_includes_raw_message(self): + """Test that ToolCallDecodingError includes the raw message.""" + mock_message = {"role": "assistant", "content": "test"} + error = ToolCallDecodingError("fn", "id", ValueError("test"), mock_message) + assert error.raw_message == mock_message + + # ============================================================================ # Callable Synthesis Tests # ============================================================================