From 374840656767281fa2f7f9674746ee4d69f74f0e Mon Sep 17 00:00:00 2001 From: Kiran Gopinathan Date: Fri, 30 Jan 2026 20:31:31 -0500 Subject: [PATCH 01/12] implements retryhandler --- effectful/handlers/llm/completions.py | 133 +++++++++++++- tests/test_handlers_llm_provider.py | 249 ++++++++++++++++++++++++++ 2 files changed, 381 insertions(+), 1 deletion(-) diff --git a/effectful/handlers/llm/completions.py b/effectful/handlers/llm/completions.py index 5d77c881..b5c805e1 100644 --- a/effectful/handlers/llm/completions.py +++ b/effectful/handlers/llm/completions.py @@ -20,8 +20,8 @@ OpenAIMessageContentListBlock, ) -from effectful.handlers.llm import Template, Tool from effectful.handlers.llm.encoding import Encodable +from effectful.handlers.llm.template import Template, Tool from effectful.ops.syntax import ObjectInterpretation, implements from effectful.ops.types import Operation @@ -239,6 +239,137 @@ def call_system(template: Template) -> collections.abc.Sequence[Message]: return () +class RetryHandler(ObjectInterpretation): + """Retries LLM requests if tool call or result decoding fails. + + Args: + num_retries: The maximum number of retries (default: 3). + """ + + def __init__(self, num_retries: int = 3): + self.num_retries = num_retries + + @implements(call_assistant) + def _call_assistant[T, U]( + self, + messages: collections.abc.Sequence[Message], + tools: collections.abc.Mapping[str, Tool], + response_format: Encodable[T, U], + model: str, + **kwargs, + ) -> MessageResult[T]: + messages_list = list(messages) + last_error: Exception | None = None + + tool_specs = {k: _function_model(t) for k, t in tools.items()} + response_model = pydantic.create_model( + "Response", value=response_format.enc, __config__={"extra": "forbid"} + ) + + for _attempt in range(self.num_retries + 1): + response: litellm.types.utils.ModelResponse = completion( + model, + messages=messages_list, + response_format=response_model, + tools=list(tool_specs.values()), + **kwargs, + ) + choice = response.choices[0] + assert isinstance(choice, litellm.types.utils.Choices) + + message: litellm.Message = choice.message + assert message.role == "assistant" + + raw_tool_calls = message.get("tool_calls") or [] + + # Try to decode tool calls, catching any decoding errors + tool_calls: list[DecodedToolCall] = [] + decoding_errors: list[tuple[ChatCompletionMessageToolCall, Exception]] = [] + + for raw_tool_call in raw_tool_calls: + validated_tool_call = ChatCompletionMessageToolCall.model_validate( + raw_tool_call + ) + try: + decoded_tool_call = decode_tool_call(validated_tool_call, tools) + tool_calls.append(decoded_tool_call) + except Exception as e: + decoding_errors.append((validated_tool_call, e)) + + # If there were tool call decoding errors, add error feedback and retry + if decoding_errors: + # Add the malformed assistant message + messages_list.append( + typing.cast(Message, message.model_dump(mode="json")) + ) + + # Add error feedback for each failed tool call + for failed_tool_call, error in decoding_errors: + last_error = error + error_msg = ( + f"Error decoding tool call '{failed_tool_call.function.name}': {error}. " + f"Please fix the tool call arguments and try again." + ) + error_feedback: Message = typing.cast( + Message, + { + "role": "tool", + "tool_call_id": failed_tool_call.id, + "content": error_msg, + }, + ) + messages_list.append(error_feedback) + continue + + # If there are tool calls, return them without decoding result + if tool_calls: + return ( + typing.cast(Message, message.model_dump(mode="json")), + tool_calls, + None, + ) + + # No tool calls - try to decode the result + serialized_result = message.get("content") or message.get( + "reasoning_content" + ) + assert isinstance(serialized_result, str), ( + "final response from the model should be a string" + ) + + try: + raw_result = response_model.model_validate_json(serialized_result) + result = response_format.decode(raw_result.value) # type: ignore + return ( + typing.cast(Message, message.model_dump(mode="json")), + tool_calls, + result, + ) + except Exception as e: + last_error = e + # Add the assistant message and error feedback for result decoding failure + messages_list.append( + typing.cast(Message, message.model_dump(mode="json")) + ) + error_msg = ( + f"Error decoding response: {e}. " + f"Please provide a valid response and try again." + ) + result_error_feedback: Message = typing.cast( + Message, + { + "role": "user", + "content": error_msg, + }, + ) + messages_list.append(result_error_feedback) + continue + + # If all retries failed, raise the last error + assert last_error is not None + raise last_error + + class LiteLLMProvider(ObjectInterpretation): """Implements templates using the LiteLLM API.""" diff --git a/tests/test_handlers_llm_provider.py b/tests/test_handlers_llm_provider.py index 66c7af7b..6a9297dd 100644 --- a/tests/test_handlers_llm_provider.py +++ b/tests/test_handlers_llm_provider.py @@ -22,9 +22,12 @@ from effectful.handlers.llm import Template from effectful.handlers.llm.completions import ( LiteLLMProvider, + RetryHandler, + Tool, call_assistant, completion, ) +from effectful.handlers.llm.encoding import Encodable from effectful.handlers.llm.synthesis import ProgramSynthesis, SynthesisError from effectful.ops.semantics import fwd, handler from effectful.ops.syntax import ObjectInterpretation, implements @@ -367,3 +370,249 @@ def test_litellm_caching_selective(request): p1 = simple_prompt("apples") p2 = simple_prompt("apples") assert p1 != p2, "when caching is not enabled, llm outputs should be different" + + +# ============================================================================ +# RetryHandler Tests +# ============================================================================ + + +class MockCompletionHandler(ObjectInterpretation): + """Mock handler that returns pre-configured completion responses.""" + + def __init__(self, responses: list[ModelResponse]): + self.responses = responses + self.call_count = 0 + self.received_messages: list = [] + + @implements(completion) + def _completion(self, model, messages=None, **kwargs): + self.received_messages.append(list(messages) if messages else []) + response = self.responses[min(self.call_count, len(self.responses) - 1)] + self.call_count += 1 + return response + + +def make_tool_call_response( + tool_name: str, tool_args: str, tool_call_id: str = "call_1" +) -> ModelResponse: + """Create a ModelResponse with a tool call.""" + return ModelResponse( + id="test", + choices=[ + { + "index": 0, + "message": { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": tool_call_id, + "type": "function", + "function": {"name": tool_name, "arguments": tool_args}, + } + ], + }, + "finish_reason": "tool_calls", + } + ], + model="test-model", + ) + + +def make_text_response(content: str) -> ModelResponse: + """Create a ModelResponse with text content.""" + return ModelResponse( + id="test", + choices=[ + { + "index": 0, + "message": {"role": "assistant", "content": content}, + "finish_reason": "stop", + } + ], + model="test-model", + ) + + +@Tool.define +def add_numbers(a: int, b: int) -> int: + """Add two numbers together.""" + return a + b + + +class TestRetryHandler: + """Tests for RetryHandler functionality.""" + + def test_retry_handler_succeeds_on_first_attempt(self): + """Test that RetryHandler passes through when no error occurs.""" + # Response with valid tool call + responses = [make_text_response('{"value": "hello"}')] + + mock_handler = MockCompletionHandler(responses) + + with handler(RetryHandler(num_retries=3)), handler(mock_handler): + message, tool_calls, result = call_assistant( + messages=[{"role": "user", "content": "test"}], + tools={}, + response_format=Encodable.define(str), + model="test-model", + ) + + assert mock_handler.call_count == 1 + assert result == "hello" + + def test_retry_handler_retries_on_invalid_tool_call(self): + """Test that RetryHandler retries when tool call decoding fails.""" + # First response has invalid tool args, second has valid response + responses = [ + make_tool_call_response( + "add_numbers", '{"a": "not_an_int", "b": 2}' + ), # Invalid + make_text_response('{"value": "success"}'), # Valid + ] + + mock_handler = MockCompletionHandler(responses) + + with handler(RetryHandler(num_retries=3)), handler(mock_handler): + message, tool_calls, result = call_assistant( + messages=[{"role": "user", "content": "test"}], + tools={"add_numbers": add_numbers}, + response_format=Encodable.define(str), + model="test-model", + ) + + assert mock_handler.call_count == 2 + assert result == "success" + # Check that the second call included error feedback + assert len(mock_handler.received_messages[1]) > len( + mock_handler.received_messages[0] + ) + + def test_retry_handler_retries_on_unknown_tool(self): + """Test that RetryHandler retries when tool is not found.""" + # First response has unknown tool, second has valid response + responses = [ + make_tool_call_response("unknown_tool", '{"x": 1}'), # Unknown tool + make_text_response('{"value": "success"}'), # Valid + ] + + mock_handler = MockCompletionHandler(responses) + + with handler(RetryHandler(num_retries=3)), handler(mock_handler): + message, tool_calls, result = call_assistant( + messages=[{"role": "user", "content": "test"}], + tools={"add_numbers": add_numbers}, + response_format=Encodable.define(str), + model="test-model", + ) + + assert mock_handler.call_count == 2 + assert result == "success" + + def test_retry_handler_exhausts_retries(self): + """Test that RetryHandler raises after exhausting all retries.""" + # All responses have invalid tool calls + responses = [ + make_tool_call_response("add_numbers", '{"a": "bad", "b": "bad"}'), + ] + + mock_handler = MockCompletionHandler(responses) + + with pytest.raises(Exception): # Will raise the underlying decoding error + with handler(RetryHandler(num_retries=2)), handler(mock_handler): + call_assistant( + messages=[{"role": "user", "content": "test"}], + tools={"add_numbers": add_numbers}, + response_format=Encodable.define(str), + model="test-model", + ) + + # Should have attempted 3 times (1 initial + 2 retries) + assert mock_handler.call_count == 3 + + def test_retry_handler_with_zero_retries(self): + """Test RetryHandler with num_retries=0 fails immediately on error.""" + responses = [ + make_tool_call_response("add_numbers", '{"a": "bad", "b": "bad"}'), + ] + + mock_handler = MockCompletionHandler(responses) + + with pytest.raises(Exception): + with handler(RetryHandler(num_retries=0)), handler(mock_handler): + call_assistant( + messages=[{"role": "user", "content": "test"}], + tools={"add_numbers": add_numbers}, + response_format=Encodable.define(str), + model="test-model", + ) + + assert mock_handler.call_count == 1 + + def test_retry_handler_valid_tool_call_passes_through(self): + """Test that valid tool calls are decoded and returned.""" + responses = [ + make_tool_call_response("add_numbers", '{"a": 1, "b": 2}'), + ] + + mock_handler = MockCompletionHandler(responses) + + with handler(RetryHandler(num_retries=3)), handler(mock_handler): + message, tool_calls, result = call_assistant( + messages=[{"role": "user", "content": "test"}], + tools={"add_numbers": add_numbers}, + response_format=Encodable.define(str), + model="test-model", + ) + + assert mock_handler.call_count == 1 + assert len(tool_calls) == 1 + assert tool_calls[0].tool == add_numbers + assert result is None # No result when there are tool calls + + def test_retry_handler_retries_on_invalid_result(self): + """Test that RetryHandler retries when result decoding fails.""" + # First response has invalid JSON, second has valid response + responses = [ + make_text_response('{"value": "not valid for int"}'), # Invalid for int + make_text_response('{"value": 42}'), # Valid + ] + + mock_handler = MockCompletionHandler(responses) + + with handler(RetryHandler(num_retries=3)), handler(mock_handler): + message, tool_calls, result = call_assistant( + messages=[{"role": "user", "content": "test"}], + tools={}, + response_format=Encodable.define(int), + model="test-model", + ) + + assert mock_handler.call_count == 2 + assert result == 42 + # Check that the second call included error feedback + assert len(mock_handler.received_messages[1]) > len( + mock_handler.received_messages[0] + ) + + def test_retry_handler_exhausts_retries_on_result_decoding(self): + """Test that RetryHandler raises after exhausting retries on result decoding.""" + # All responses have invalid results for int type + responses = [ + make_text_response('{"value": "not an int"}'), + ] + + mock_handler = MockCompletionHandler(responses) + + with pytest.raises(Exception): # Will raise the underlying decoding error + with handler(RetryHandler(num_retries=2)), handler(mock_handler): + call_assistant( + messages=[{"role": "user", "content": "test"}], + tools={}, + response_format=Encodable.define(int), + model="test-model", + ) + + # Should have attempted 3 times (1 initial + 2 retries) + assert mock_handler.call_count == 3 From bc1fcc0c838204ea1790d6713832bca443bd416e Mon Sep 17 00:00:00 2001 From: Kiran Gopinathan Date: Fri, 30 Jan 2026 20:39:30 -0500 Subject: [PATCH 02/12] made exception catches more specific --- effectful/handlers/llm/completions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/effectful/handlers/llm/completions.py b/effectful/handlers/llm/completions.py index b5c805e1..b49ee4d0 100644 --- a/effectful/handlers/llm/completions.py +++ b/effectful/handlers/llm/completions.py @@ -293,7 +293,7 @@ def _call_assistant[T, U]( try: decoded_tool_call = decode_tool_call(validated_tool_call, tools) tool_calls.append(decoded_tool_call) - except Exception as e: + except (KeyError, pydantic.ValidationError) as e: decoding_errors.append((validated_tool_call, e)) # If there were tool call decoding errors, add error feedback and retry @@ -345,7 +345,7 @@ def _call_assistant[T, U]( tool_calls, result, ) - except Exception as e: + except pydantic.ValidationError as e: last_error = e # Add the assistant message and error feedback for result decoding failure messages_list.append( From 5234c501c41864ff5bdf7e3418058c9af4698224 Mon Sep 17 00:00:00 2001 From: Kiran Gopinathan Date: Fri, 30 Jan 2026 21:32:37 -0500 Subject: [PATCH 03/12] updated RetryLLMHandler implementation to handle tool calls and delegate implementation to fwd --- docs/source/llm.ipynb | 191 --------------- effectful/handlers/llm/completions.py | 294 ++++++++++++++--------- tests/test_handlers_llm_provider.py | 321 ++++++++++++++++++++++++-- 3 files changed, 491 insertions(+), 315 deletions(-) diff --git a/docs/source/llm.ipynb b/docs/source/llm.ipynb index f40a6fdb..be92c345 100644 --- a/docs/source/llm.ipynb +++ b/docs/source/llm.ipynb @@ -547,197 +547,6 @@ " print(\"=== Funny story ===\")\n", " print(write_story(\"a curious cat\", \"funny\"))" ] - }, - { - "cell_type": "markdown", - "id": "bd25826d", - "metadata": {}, - "source": [ - "### Retrying LLM Requests\n", - "LLM calls can sometimes fail due to transient errors or produce invalid outputs. The `RetryLLMHandler` automatically retries failed template calls:\n", - "\n", - "- `max_retries`: Maximum number of retry attempts (default: 3)\n", - "- `add_error_feedback`: When `True`, appends the error message to the prompt on retry, helping the LLM correct its output.\n", - "- `exception_cls`: RetryHandler will only attempt to try again when a specific type of `Exception` is thrown.\n" - ] - }, - { - "cell_type": "markdown", - "id": "bafc0a96", - "metadata": {}, - "source": [ - "Example usage: having an unstable service that seldomly fail." - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "4334d07a", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "> Use the unstable_service tool to fetch data.\n", - "None\n", - "> Use the unstable_service tool to fetch data.\n", - "> None\n", - "> {'status': 'failure', 'exception': 'Service unavailable! Attempt 1/3. Please retry.'}\n", - "None\n", - "> Use the unstable_service tool to fetch data.\n", - "> None\n", - "> {'status': 'failure', 'exception': 'Service unavailable! Attempt 1/3. Please retry.'}\n", - "> None\n", - "> {'status': 'failure', 'exception': 'Service unavailable! Attempt 2/3. Please retry.'}\n", - "None\n", - "> Use the unstable_service tool to fetch data.\n", - "> None\n", - "> {'status': 'failure', 'exception': 'Service unavailable! Attempt 1/3. Please retry.'}\n", - "> None\n", - "> {'status': 'failure', 'exception': 'Service unavailable! Attempt 2/3. Please retry.'}\n", - "> None\n", - "> { 'status': 'ok', 'data': [1, 2, 3] }\n", - "The data fetched from the unstable service is: [1, 2, 3].\n", - "Result: The data fetched from the unstable service is: [1, 2, 3]. Retries: 3\n" - ] - } - ], - "source": [ - "call_count = 0\n", - "REQUIRED_RETRIES = 3\n", - "\n", - "\n", - "@Tool.define\n", - "def unstable_service() -> str:\n", - " \"\"\"Fetch data from an unstable external service. May require retries.\"\"\"\n", - " global call_count\n", - " call_count += 1\n", - " if call_count < REQUIRED_RETRIES:\n", - " raise ConnectionError(\n", - " f\"Service unavailable! Attempt {call_count}/{REQUIRED_RETRIES}. Please retry.\"\n", - " )\n", - " return \"{ 'status': 'ok', 'data': [1, 2, 3] }\"\n", - "\n", - "\n", - "@Template.define # unstable_service auto-captured from lexical scope\n", - "def fetch_data() -> str:\n", - " \"\"\"Use the unstable_service tool to fetch data.\"\"\"\n", - " raise NotHandled\n", - "\n", - "\n", - "with handler(provider), handler({call_assistant: log_llm}):\n", - " result = fetch_data()\n", - " print(f\"Result: {result}\", \"Retries:\", call_count)" - ] - }, - { - "cell_type": "markdown", - "id": "4ac00e01", - "metadata": {}, - "source": [ - "### Retrying with Validation Errors\n", - "As noted above, the `RetryHandler` can also be used to retry on runtime/validation error:" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "39b2b225", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "> Give a rating for Die Hard. The explanation MUST include the numeric score. Do not use any tools.\n", - "{\"value\":{\"score\":9,\"explanation\":\"Die Hard is widely regarded as one of the greatest action films ever made. It features a gripping narrative, iconic performances (especially by Bruce Willis as John McClane), and groundbreaking action sequences. The movie's blend of intense action, humor, and suspense has set a high standard for the genre. For these reasons, Die Hard receives a score of 9 out of 10.\"}}\n", - "> Give a rating for Die Hard. The explanation MUST include the numeric score. Do not use any tools.\n", - "> \n", - "Error from previous attempt:\n", - "```\n", - "Traceback (most recent call last):\n", - " File \"/Users/nguyendat/Marc/effectful/effectful/handlers/llm/providers.py\", line 374, in _retry_completion\n", - " return fwd()\n", - " ^^^^^\n", - " File \"/Users/nguyendat/Marc/effectful/effectful/ops/types.py\", line 488, in __call__\n", - " return self_handler(*args, **kwargs)\n", - " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", - " File \"/Users/nguyendat/.local/share/uv/python/cpython-3.12.11-macos-aarch64-none/lib/python3.12/contextlib.py\", line 81, in inner\n", - " return func(*args, **kwds)\n", - " ^^^^^^^^^^^^^^^^^^^\n", - " File \"/Users/nguyendat/Marc/effectful/effectful/internals/runtime.py\", line 45, in _cont_wrapper\n", - " return fn(*a, **k)\n", - " ^^^^^^^^^^^\n", - " File \"/Users/nguyendat/Marc/effectful/effectful/internals/runtime.py\", line 56, in _cont_wrapper\n", - " return fn(*a, **k)\n", - " ^^^^^^^^^^^\n", - " File \"/Users/nguyendat/Marc/effectful/effectful/internals/runtime.py\", line 70, in bound_body\n", - " return body(*a, **k)\n", - " ^^^^^^^^^^^^^\n", - " File \"/Users/nguyendat/Marc/effectful/effectful/internals/runtime.py\", line 56, in _cont_wrapper\n", - " return fn(*a, **k)\n", - " ^^^^^^^^^^^\n", - " File \"/Users/nguyendat/Marc/effectful/effectful/handlers/llm/providers.py\", line 406, in _call\n", - " return decode_response(template, resp)\n", - " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", - " File \"/Users/nguyendat/Marc/effectful/effectful/handlers/llm/providers.py\", line 279, in decode_response\n", - " result = Result.model_validate_json(result_str)\n", - " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", - " File \"/Users/nguyendat/Marc/effectful/.venv/lib/python3.12/site-packages/pydantic/main.py\", line 766, in model_validate_json\n", - " return cls.__pydantic_validator__.validate_json(\n", - " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", - "pydantic_core._pydantic_core.ValidationError: 1 validation error for Result\n", - "value.score\n", - " score must be 1–5, got 9 [type=invalid_score, input_value=9, input_type=int]\n", - "```\n", - "{\"value\":{\"score\":5,\"explanation\":\"Die Hard is a classic action film that is widely regarded as one of the best in its genre, earning a score of 5. The movie's combination of thrilling action sequences, memorable performances, particularly by Bruce Willis and Alan Rickman, and its clever script have cemented its status as a must-watch. It's a perfect blend of tension, humor, and drama, making it a favorite among audiences and critics alike.\"}}\n", - "Score: 5/5\n", - "Explanation: Die Hard is a classic action film that is widely regarded as one of the best in its genre, earning a score of 5. The movie's combination of thrilling action sequences, memorable performances, particularly by Bruce Willis and Alan Rickman, and its clever script have cemented its status as a must-watch. It's a perfect blend of tension, humor, and drama, making it a favorite among audiences and critics alike.\n" - ] - } - ], - "source": [ - "@pydantic.dataclasses.dataclass\n", - "class Rating:\n", - " score: int\n", - " explanation: str\n", - "\n", - " @field_validator(\"score\")\n", - " @classmethod\n", - " def check_score(cls, v):\n", - " if v < 1 or v > 5:\n", - " raise PydanticCustomError(\n", - " \"invalid_score\",\n", - " \"score must be 1–5, got {v}\",\n", - " {\"v\": v},\n", - " )\n", - " return v\n", - "\n", - " @field_validator(\"explanation\")\n", - " @classmethod\n", - " def check_explanation_contains_score(cls, v, info):\n", - " score = info.data.get(\"score\", None)\n", - " if score is not None and str(score) not in v:\n", - " raise PydanticCustomError(\n", - " \"invalid_explanation\",\n", - " \"explanation must mention the score {score}, got '{explanation}'\",\n", - " {\"score\": score, \"explanation\": v},\n", - " )\n", - " return v\n", - "\n", - "\n", - "@Template.define\n", - "def give_rating_for_movie(movie_name: str) -> Rating:\n", - " \"\"\"Give a rating for {movie_name}. The explanation MUST include the numeric score. Do not use any tools.\"\"\"\n", - " raise NotHandled\n", - "\n", - "\n", - "with handler(provider), handler({call_assistant: log_llm}):\n", - " rating = give_rating_for_movie(\"Die Hard\")\n", - " print(f\"Score: {rating.score}/5\")\n", - " print(f\"Explanation: {rating.explanation}\")" - ] } ], "metadata": { diff --git a/effectful/handlers/llm/completions.py b/effectful/handlers/llm/completions.py index b49ee4d0..ac1894f7 100644 --- a/effectful/handlers/llm/completions.py +++ b/effectful/handlers/llm/completions.py @@ -22,9 +22,68 @@ from effectful.handlers.llm.encoding import Encodable from effectful.handlers.llm.template import Template, Tool +from effectful.ops.semantics import fwd from effectful.ops.syntax import ObjectInterpretation, implements from effectful.ops.types import Operation + +class ToolCallDecodingError(Exception): + """Error raised when decoding a tool call fails. + + Attributes: + tool_name: Name of the tool that failed to decode. + tool_call_id: ID of the tool call that failed. + original_error: The underlying exception that caused the failure. + raw_message: The raw assistant message containing the failed tool call. + """ + + def __init__( + self, + tool_name: str, + tool_call_id: str, + original_error: Exception, + raw_message: typing.Any = None, + ): + self.tool_name = tool_name + self.tool_call_id = tool_call_id + self.original_error = original_error + self.raw_message = raw_message + super().__init__(f"Error decoding tool call '{tool_name}': {original_error}") + + +class ResultDecodingError(Exception): + """Error raised when decoding the LLM response result fails. + + Attributes: + original_error: The underlying exception that caused the failure. + raw_message: The raw assistant message containing the failed result. + """ + + def __init__( + self, + original_error: Exception, + raw_message: typing.Any = None, + ): + self.original_error = original_error + self.raw_message = raw_message + super().__init__(f"Error decoding response: {original_error}") + + +class ToolExecutionError(Exception): + """Error raised when a tool execution fails at runtime.""" + + def __init__( + self, + tool_name: str, + tool_call_id: str, + original_error: Exception, + ): + self.tool_name = tool_name + self.tool_call_id = tool_call_id + self.original_error = original_error + super().__init__(f"Error executing tool '{tool_name}': {original_error}") + + Message = ( OpenAIChatCompletionAssistantMessage | ChatCompletionToolMessage @@ -77,26 +136,49 @@ def _function_model(tool: Tool) -> ChatCompletionToolParam: def decode_tool_call( tool_call: ChatCompletionMessageToolCall, tools: collections.abc.Mapping[str, Tool], + raw_message: Message | None = None, ) -> DecodedToolCall: - """Decode a tool call from the LLM response into a DecodedToolCall.""" - assert tool_call.function.name is not None - tool = tools[tool_call.function.name] - json_str = tool_call.function.arguments + """Decode a tool call from the LLM response into a DecodedToolCall. + + Args: + tool_call: The tool call to decode. + tools: Mapping of tool names to Tool objects. + raw_message: Optional raw assistant message for error context. + Raises: + ToolCallDecodingError: If the tool call cannot be decoded. + """ + tool_name = tool_call.function.name + assert tool_name is not None + + try: + tool = tools[tool_name] + except KeyError as e: + raise ToolCallDecodingError( + tool_name, tool_call.id, e, raw_message=raw_message + ) from e + + json_str = tool_call.function.arguments sig = inspect.signature(tool) - # build dict of raw encodable types U - raw_args = _param_model(tool).model_validate_json(json_str) + try: + # build dict of raw encodable types U + raw_args = _param_model(tool).model_validate_json(json_str) + + # use encoders to decode Us to python types T + bound_sig: inspect.BoundArguments = sig.bind( + **{ + param_name: Encodable.define( + sig.parameters[param_name].annotation, {} + ).decode(getattr(raw_args, param_name)) + for param_name in raw_args.model_fields_set + } + ) + except (pydantic.ValidationError, TypeError, ValueError) as e: + raise ToolCallDecodingError( + tool_name, tool_call.id, e, raw_message=raw_message + ) from e - # use encoders to decode Us to python types T - bound_sig: inspect.BoundArguments = sig.bind( - **{ - param_name: Encodable.define( - sig.parameters[param_name].annotation, {} - ).decode(getattr(raw_args, param_name)) - for param_name in raw_args.model_fields_set - } - ) return DecodedToolCall(tool, bound_sig, tool_call.id) @@ -125,6 +207,11 @@ def call_assistant[T, U]( This effect is emitted for model request/response rounds so handlers can observe/log requests. + Raises: + ToolCallDecodingError: If a tool call cannot be decoded. The error + includes the raw assistant message for retry handling. + ResultDecodingError: If the result cannot be decoded. The error + includes the raw assistant message for retry handling. """ tool_specs = {k: _function_model(t) for k, t in tools.items()} response_model = pydantic.create_model( @@ -144,11 +231,15 @@ def call_assistant[T, U]( message: litellm.Message = choice.message assert message.role == "assistant" + raw_message = typing.cast(Message, message.model_dump(mode="json")) + tool_calls: list[DecodedToolCall] = [] raw_tool_calls = message.get("tool_calls") or [] - for tool_call in raw_tool_calls: - tool_call = ChatCompletionMessageToolCall.model_validate(tool_call) - decoded_tool_call = decode_tool_call(tool_call, tools) + for raw_tool_call in raw_tool_calls: + validated_tool_call = ChatCompletionMessageToolCall.model_validate( + raw_tool_call + ) + decoded_tool_call = decode_tool_call(validated_tool_call, tools, raw_message) tool_calls.append(decoded_tool_call) result = None @@ -158,10 +249,13 @@ def call_assistant[T, U]( assert isinstance(serialized_result, str), ( "final response from the model should be a string" ) - raw_result = response_model.model_validate_json(serialized_result) - result = response_format.decode(raw_result.value) # type: ignore + try: + raw_result = response_model.model_validate_json(serialized_result) + result = response_format.decode(raw_result.value) # type: ignore + except pydantic.ValidationError as e: + raise ResultDecodingError(e, raw_message=raw_message) from e - return (typing.cast(Message, message.model_dump(mode="json")), tool_calls, result) + return (raw_message, tool_calls, result) @Operation.define @@ -239,9 +333,17 @@ def call_system(template: Template) -> collections.abc.Sequence[Message]: return () -class RetryHandler(ObjectInterpretation): +class RetryLLMHandler(ObjectInterpretation): """Retries LLM requests if tool call or result decoding fails. + This handler intercepts `call_assistant` and catches `ToolCallDecodingError` + and `ResultDecodingError`. When these errors occur, it appends error feedback + to the messages and retries the request. Malformed messages from retry attempts + are pruned from the final result. + + For runtime tool execution failures (handled via `call_tool`), errors are + captured and returned as tool response messages. + Args: num_retries: The maximum number of retries (default: 3). """ @@ -261,100 +363,52 @@ def _call_assistant[T, U]( messages_list = list(messages) last_error: Exception | None = None - tool_specs = {k: _function_model(t) for k, t in tools.items()} - response_model = pydantic.create_model( - "Response", value=response_format.enc, __config__={"extra": "forbid"} - ) - for _attempt in range(self.num_retries + 1): - response: litellm.types.utils.ModelResponse = completion( - model, - messages=messages_list, - response_format=response_model, - tools=list(tool_specs.values()), - **kwargs, - ) - choice = response.choices[0] - assert isinstance(choice, litellm.types.utils.Choices) - - message: litellm.Message = choice.message - assert message.role == "assistant" - - raw_tool_calls = message.get("tool_calls") or [] + try: + message, tool_calls, result = fwd( + messages_list, tools, response_format, model, **kwargs + ) - # Try to decode tool calls, catching any decoding errors - tool_calls: list[DecodedToolCall] = [] - decoding_errors: list[tuple[ChatCompletionMessageToolCall, Exception]] = [] + # Success! The returned message is the final successful response. + # Malformed messages from retries are only in messages_list, + # not in the returned result. + return (message, tool_calls, result) - for raw_tool_call in raw_tool_calls: - validated_tool_call = ChatCompletionMessageToolCall.model_validate( - raw_tool_call + except ToolCallDecodingError as e: + last_error = e + # The error includes the raw message from the failed attempt + assert e.raw_message is not None, ( + "ToolCallDecodingError should include raw_message" ) - try: - decoded_tool_call = decode_tool_call(validated_tool_call, tools) - tool_calls.append(decoded_tool_call) - except (KeyError, pydantic.ValidationError) as e: - decoding_errors.append((validated_tool_call, e)) - - # If there were tool call decoding errors, add error feedback and retry - if decoding_errors: + # Add the malformed assistant message - messages_list.append( - typing.cast(Message, message.model_dump(mode="json")) - ) + messages_list.append(e.raw_message) - # Add error feedback for each failed tool call - for failed_tool_call, error in decoding_errors: - last_error = error - error_msg = ( - f"Error decoding tool call '{failed_tool_call.function.name}': {error}. " - f"Please fix the tool call arguments and try again." - ) - error_feedback: Message = typing.cast( - Message, - { - "role": "tool", - "tool_call_id": failed_tool_call.id, - "content": error_msg, - }, - ) - messages_list.append(error_feedback) + # Add error feedback as a tool response + error_msg = f"{e}. Please fix the tool call arguments and try again." + error_feedback: Message = typing.cast( + Message, + { + "role": "tool", + "tool_call_id": e.tool_call_id, + "content": error_msg, + }, + ) + messages_list.append(error_feedback) continue - # If there are tool calls, return them without decoding result - if tool_calls: - return ( - typing.cast(Message, message.model_dump(mode="json")), - tool_calls, - None, + except ResultDecodingError as e: + last_error = e + # The error includes the raw message from the failed attempt + assert e.raw_message is not None, ( + "ResultDecodingError should include raw_message" ) - # No tool calls - try to decode the result - serialized_result = message.get("content") or message.get( - "reasoning_content" - ) - assert isinstance(serialized_result, str), ( - "final response from the model should be a string" - ) + # Add the malformed assistant message + messages_list.append(e.raw_message) - try: - raw_result = response_model.model_validate_json(serialized_result) - result = response_format.decode(raw_result.value) # type: ignore - return ( - typing.cast(Message, message.model_dump(mode="json")), - tool_calls, - result, - ) - except pydantic.ValidationError as e: - last_error = e - # Add the assistant message and error feedback for result decoding failure - messages_list.append( - typing.cast(Message, message.model_dump(mode="json")) - ) - error_msg = ( - f"Error decoding response: {e}. " - f"Please provide a valid response and try again." - ) + # Add error feedback as a user message + error_msg = f"{e}. Please provide a valid response and try again." result_error_feedback: Message = typing.cast( Message, { @@ -369,6 +423,36 @@ def _call_assistant[T, U]( assert last_error is not None raise last_error + @implements(completion) + def _completion(self, *args, **kwargs) -> typing.Any: + """Inject num_retries for litellm's built-in network error handling.""" + return fwd(*args, num_retries=self.num_retries, **kwargs) + + @implements(call_tool) + def _call_tool(self, tool_call: DecodedToolCall) -> Message: + """Handle tool execution with runtime error capture. + + Runtime errors from tool execution are captured and returned as + error messages to the LLM. + """ + try: + return fwd(tool_call) + except Exception as e: + # Wrap runtime errors and return as a tool message + error = ToolExecutionError( + tool_call.tool.__name__, + tool_call.id, + e, + ) + return typing.cast( + Message, + { + "role": "tool", + "tool_call_id": tool_call.id, + "content": f"Tool execution failed: {error}", + }, + ) + class LiteLLMProvider(ObjectInterpretation): """Implements templates using the LiteLLM API.""" diff --git a/tests/test_handlers_llm_provider.py b/tests/test_handlers_llm_provider.py index 6a9297dd..f94cea89 100644 --- a/tests/test_handlers_llm_provider.py +++ b/tests/test_handlers_llm_provider.py @@ -22,9 +22,13 @@ from effectful.handlers.llm import Template from effectful.handlers.llm.completions import ( LiteLLMProvider, - RetryHandler, + ResultDecodingError, + RetryLLMHandler, Tool, + ToolCallDecodingError, + ToolExecutionError, call_assistant, + call_tool, completion, ) from effectful.handlers.llm.encoding import Encodable @@ -373,7 +377,7 @@ def test_litellm_caching_selective(request): # ============================================================================ -# RetryHandler Tests +# RetryLLMHandler Tests # ============================================================================ @@ -441,17 +445,17 @@ def add_numbers(a: int, b: int) -> int: return a + b -class TestRetryHandler: - """Tests for RetryHandler functionality.""" +class TestRetryLLMHandler: + """Tests for RetryLLMHandler functionality.""" def test_retry_handler_succeeds_on_first_attempt(self): - """Test that RetryHandler passes through when no error occurs.""" + """Test that RetryLLMHandler passes through when no error occurs.""" # Response with valid tool call responses = [make_text_response('{"value": "hello"}')] mock_handler = MockCompletionHandler(responses) - with handler(RetryHandler(num_retries=3)), handler(mock_handler): + with handler(RetryLLMHandler(num_retries=3)), handler(mock_handler): message, tool_calls, result = call_assistant( messages=[{"role": "user", "content": "test"}], tools={}, @@ -463,7 +467,7 @@ def test_retry_handler_succeeds_on_first_attempt(self): assert result == "hello" def test_retry_handler_retries_on_invalid_tool_call(self): - """Test that RetryHandler retries when tool call decoding fails.""" + """Test that RetryLLMHandler retries when tool call decoding fails.""" # First response has invalid tool args, second has valid response responses = [ make_tool_call_response( @@ -474,7 +478,7 @@ def test_retry_handler_retries_on_invalid_tool_call(self): mock_handler = MockCompletionHandler(responses) - with handler(RetryHandler(num_retries=3)), handler(mock_handler): + with handler(RetryLLMHandler(num_retries=3)), handler(mock_handler): message, tool_calls, result = call_assistant( messages=[{"role": "user", "content": "test"}], tools={"add_numbers": add_numbers}, @@ -490,7 +494,7 @@ def test_retry_handler_retries_on_invalid_tool_call(self): ) def test_retry_handler_retries_on_unknown_tool(self): - """Test that RetryHandler retries when tool is not found.""" + """Test that RetryLLMHandler retries when tool is not found.""" # First response has unknown tool, second has valid response responses = [ make_tool_call_response("unknown_tool", '{"x": 1}'), # Unknown tool @@ -499,7 +503,7 @@ def test_retry_handler_retries_on_unknown_tool(self): mock_handler = MockCompletionHandler(responses) - with handler(RetryHandler(num_retries=3)), handler(mock_handler): + with handler(RetryLLMHandler(num_retries=3)), handler(mock_handler): message, tool_calls, result = call_assistant( messages=[{"role": "user", "content": "test"}], tools={"add_numbers": add_numbers}, @@ -511,7 +515,7 @@ def test_retry_handler_retries_on_unknown_tool(self): assert result == "success" def test_retry_handler_exhausts_retries(self): - """Test that RetryHandler raises after exhausting all retries.""" + """Test that RetryLLMHandler raises after exhausting all retries.""" # All responses have invalid tool calls responses = [ make_tool_call_response("add_numbers", '{"a": "bad", "b": "bad"}'), @@ -520,7 +524,7 @@ def test_retry_handler_exhausts_retries(self): mock_handler = MockCompletionHandler(responses) with pytest.raises(Exception): # Will raise the underlying decoding error - with handler(RetryHandler(num_retries=2)), handler(mock_handler): + with handler(RetryLLMHandler(num_retries=2)), handler(mock_handler): call_assistant( messages=[{"role": "user", "content": "test"}], tools={"add_numbers": add_numbers}, @@ -532,7 +536,7 @@ def test_retry_handler_exhausts_retries(self): assert mock_handler.call_count == 3 def test_retry_handler_with_zero_retries(self): - """Test RetryHandler with num_retries=0 fails immediately on error.""" + """Test RetryLLMHandler with num_retries=0 fails immediately on error.""" responses = [ make_tool_call_response("add_numbers", '{"a": "bad", "b": "bad"}'), ] @@ -540,7 +544,7 @@ def test_retry_handler_with_zero_retries(self): mock_handler = MockCompletionHandler(responses) with pytest.raises(Exception): - with handler(RetryHandler(num_retries=0)), handler(mock_handler): + with handler(RetryLLMHandler(num_retries=0)), handler(mock_handler): call_assistant( messages=[{"role": "user", "content": "test"}], tools={"add_numbers": add_numbers}, @@ -558,7 +562,7 @@ def test_retry_handler_valid_tool_call_passes_through(self): mock_handler = MockCompletionHandler(responses) - with handler(RetryHandler(num_retries=3)), handler(mock_handler): + with handler(RetryLLMHandler(num_retries=3)), handler(mock_handler): message, tool_calls, result = call_assistant( messages=[{"role": "user", "content": "test"}], tools={"add_numbers": add_numbers}, @@ -572,7 +576,7 @@ def test_retry_handler_valid_tool_call_passes_through(self): assert result is None # No result when there are tool calls def test_retry_handler_retries_on_invalid_result(self): - """Test that RetryHandler retries when result decoding fails.""" + """Test that RetryLLMHandler retries when result decoding fails.""" # First response has invalid JSON, second has valid response responses = [ make_text_response('{"value": "not valid for int"}'), # Invalid for int @@ -581,7 +585,7 @@ def test_retry_handler_retries_on_invalid_result(self): mock_handler = MockCompletionHandler(responses) - with handler(RetryHandler(num_retries=3)), handler(mock_handler): + with handler(RetryLLMHandler(num_retries=3)), handler(mock_handler): message, tool_calls, result = call_assistant( messages=[{"role": "user", "content": "test"}], tools={}, @@ -597,7 +601,7 @@ def test_retry_handler_retries_on_invalid_result(self): ) def test_retry_handler_exhausts_retries_on_result_decoding(self): - """Test that RetryHandler raises after exhausting retries on result decoding.""" + """Test that RetryLLMHandler raises after exhausting retries on result decoding.""" # All responses have invalid results for int type responses = [ make_text_response('{"value": "not an int"}'), @@ -606,7 +610,7 @@ def test_retry_handler_exhausts_retries_on_result_decoding(self): mock_handler = MockCompletionHandler(responses) with pytest.raises(Exception): # Will raise the underlying decoding error - with handler(RetryHandler(num_retries=2)), handler(mock_handler): + with handler(RetryLLMHandler(num_retries=2)), handler(mock_handler): call_assistant( messages=[{"role": "user", "content": "test"}], tools={}, @@ -616,3 +620,282 @@ def test_retry_handler_exhausts_retries_on_result_decoding(self): # Should have attempted 3 times (1 initial + 2 retries) assert mock_handler.call_count == 3 + + def test_retry_handler_raises_tool_call_decoding_error(self): + """Test that RetryLLMHandler raises ToolCallDecodingError with correct attributes.""" + responses = [ + make_tool_call_response("add_numbers", '{"a": "bad", "b": "bad"}'), + ] + + mock_handler = MockCompletionHandler(responses) + + with pytest.raises(ToolCallDecodingError) as exc_info: + with handler(RetryLLMHandler(num_retries=0)), handler(mock_handler): + call_assistant( + messages=[{"role": "user", "content": "test"}], + tools={"add_numbers": add_numbers}, + response_format=Encodable.define(str), + model="test-model", + ) + + error = exc_info.value + assert error.tool_name == "add_numbers" + assert error.tool_call_id == "call_1" + assert error.raw_message is not None + assert "add_numbers" in str(error) + + def test_retry_handler_raises_result_decoding_error(self): + """Test that RetryLLMHandler raises ResultDecodingError with correct attributes.""" + responses = [ + make_text_response('{"value": "not an int"}'), + ] + + mock_handler = MockCompletionHandler(responses) + + with pytest.raises(ResultDecodingError) as exc_info: + with handler(RetryLLMHandler(num_retries=0)), handler(mock_handler): + call_assistant( + messages=[{"role": "user", "content": "test"}], + tools={}, + response_format=Encodable.define(int), + model="test-model", + ) + + error = exc_info.value + assert error.raw_message is not None + assert error.original_error is not None + + def test_retry_handler_error_feedback_contains_tool_name(self): + """Test that error feedback messages contain the tool name.""" + responses = [ + make_tool_call_response("add_numbers", '{"a": "bad", "b": 2}'), + make_text_response('{"value": "success"}'), + ] + + mock_handler = MockCompletionHandler(responses) + + with handler(RetryLLMHandler(num_retries=3)), handler(mock_handler): + call_assistant( + messages=[{"role": "user", "content": "test"}], + tools={"add_numbers": add_numbers}, + response_format=Encodable.define(str), + model="test-model", + ) + + # Check that the error feedback in the second call mentions the tool name + second_call_messages = mock_handler.received_messages[1] + tool_feedback = [m for m in second_call_messages if m.get("role") == "tool"] + assert len(tool_feedback) == 1 + assert "add_numbers" in tool_feedback[0]["content"] + + def test_retry_handler_unknown_tool_error_contains_tool_name(self): + """Test that unknown tool errors contain the tool name in the feedback.""" + responses = [ + make_tool_call_response("nonexistent_tool", '{"x": 1}'), + make_text_response('{"value": "success"}'), + ] + + mock_handler = MockCompletionHandler(responses) + + with handler(RetryLLMHandler(num_retries=3)), handler(mock_handler): + call_assistant( + messages=[{"role": "user", "content": "test"}], + tools={"add_numbers": add_numbers}, + response_format=Encodable.define(str), + model="test-model", + ) + + # Check that the error feedback mentions the unknown tool + second_call_messages = mock_handler.received_messages[1] + tool_feedback = [m for m in second_call_messages if m.get("role") == "tool"] + assert len(tool_feedback) == 1 + assert "nonexistent_tool" in tool_feedback[0]["content"] + + +# ============================================================================ +# Tool Execution Error Tests +# ============================================================================ + + +@Tool.define +def failing_tool(x: int) -> int: + """A tool that always raises an exception.""" + raise ValueError(f"Tool failed with input {x}") + + +@Tool.define +def divide_tool(a: int, b: int) -> int: + """Divide a by b.""" + return a // b + + +class TestToolExecutionErrorHandling: + """Tests for runtime tool execution error handling.""" + + def test_tool_execution_error_attributes(self): + """Test ToolExecutionError has correct attributes.""" + original = ValueError("something went wrong") + error = ToolExecutionError("my_tool", "call_123", original) + + assert error.tool_name == "my_tool" + assert error.tool_call_id == "call_123" + assert error.original_error is original + assert "my_tool" in str(error) + assert "something went wrong" in str(error) + + def test_retry_handler_catches_tool_runtime_error(self): + """Test that RetryLLMHandler catches tool runtime errors and returns error message.""" + import inspect + + from effectful.handlers.llm.completions import DecodedToolCall + + # Create a decoded tool call for failing_tool + sig = inspect.signature(failing_tool) + bound_args = sig.bind(x=42) + tool_call = DecodedToolCall(failing_tool, bound_args, "call_1") + + with handler(RetryLLMHandler(num_retries=3)): + result = call_tool(tool_call) + + # The result should be an error message, not an exception + assert result["role"] == "tool" + assert result["tool_call_id"] == "call_1" + assert "Tool execution failed" in result["content"] + assert "failing_tool" in result["content"] + assert "42" in result["content"] + + def test_retry_handler_catches_division_by_zero(self): + """Test that RetryLLMHandler catches division by zero errors.""" + import inspect + + from effectful.handlers.llm.completions import DecodedToolCall + + sig = inspect.signature(divide_tool) + bound_args = sig.bind(a=10, b=0) + tool_call = DecodedToolCall(divide_tool, bound_args, "call_div") + + with handler(RetryLLMHandler(num_retries=3)): + result = call_tool(tool_call) + + assert result["role"] == "tool" + assert result["tool_call_id"] == "call_div" + assert "Tool execution failed" in result["content"] + assert "divide_tool" in result["content"] + + def test_successful_tool_execution_returns_result(self): + """Test that successful tool executions return normal results.""" + import inspect + + from effectful.handlers.llm.completions import DecodedToolCall + + sig = inspect.signature(add_numbers) + bound_args = sig.bind(a=3, b=4) + tool_call = DecodedToolCall(add_numbers, bound_args, "call_add") + + with handler(RetryLLMHandler(num_retries=3)): + result = call_tool(tool_call) + + assert result["role"] == "tool" + assert result["tool_call_id"] == "call_add" + # The result should be the serialized return value, not an error + assert "Tool execution failed" not in result["content"] + + def test_tool_execution_error_not_pruned_from_messages(self): + """Test that tool execution errors are NOT pruned (they're legitimate failures).""" + # This test verifies the docstring claim that tool execution errors + # should be kept in the message history, unlike decoding errors + + # First call: valid tool call that will fail at runtime + # Second call: successful text response + responses = [ + make_tool_call_response("failing_tool", '{"x": 42}'), + make_text_response('{"value": "handled the error"}'), + ] + + mock_handler = MockCompletionHandler(responses) + + # We need a custom provider that actually calls call_tool + class TestProvider(ObjectInterpretation): + @implements(call_assistant) + def _call_assistant( + self, messages, tools, response_format, model, **kwargs + ): + return fwd(messages, tools, response_format, model, **kwargs) + + with ( + handler(RetryLLMHandler(num_retries=3)), + handler(TestProvider()), + handler(mock_handler), + ): + message, tool_calls, result = call_assistant( + messages=[{"role": "user", "content": "test"}], + tools={"failing_tool": failing_tool}, + response_format=Encodable.define(str), + model="test-model", + ) + + # First call should succeed (tool call is valid) + assert mock_handler.call_count == 1 + assert len(tool_calls) == 1 + + +# ============================================================================ +# Error Class Tests +# ============================================================================ + + +class TestErrorClasses: + """Tests for the error class definitions.""" + + def test_tool_call_decoding_error_string_representation(self): + """Test ToolCallDecodingError string includes relevant info.""" + original = ValueError("invalid value") + error = ToolCallDecodingError( + "my_function", "call_abc", original, raw_message={"role": "assistant"} + ) + + error_str = str(error) + assert "my_function" in error_str + assert "invalid value" in error_str + + def test_result_decoding_error_string_representation(self): + """Test ResultDecodingError string includes relevant info.""" + original = ValueError("parse error") + error = ResultDecodingError(original, raw_message={"role": "assistant"}) + + error_str = str(error) + assert "parse error" in error_str + assert "decoding response" in error_str.lower() + + def test_tool_execution_error_string_representation(self): + """Test ToolExecutionError string includes relevant info.""" + original = RuntimeError("runtime failure") + error = ToolExecutionError("compute", "call_xyz", original) + + error_str = str(error) + assert "compute" in error_str + assert "runtime failure" in error_str + + def test_error_classes_preserve_original_error(self): + """Test that all error classes preserve the original exception.""" + original = TypeError("type mismatch") + + tool_decode_err = ToolCallDecodingError("fn", "id", original) + assert tool_decode_err.original_error is original + + result_decode_err = ResultDecodingError(original) + assert result_decode_err.original_error is original + + tool_exec_err = ToolExecutionError("fn", "id", original) + assert tool_exec_err.original_error is original + + def test_tool_call_decoding_error_raw_message_optional(self): + """Test that raw_message can be None initially (set later in call_assistant).""" + error = ToolCallDecodingError("fn", "id", ValueError("test")) + assert error.raw_message is None + + # Can be set after creation + error_with_msg = ToolCallDecodingError( + "fn", "id", ValueError("test"), raw_message={"role": "assistant"} + ) + assert error_with_msg.raw_message is not None From e802db1781cc69080761dea6f7ca9721df851b45 Mon Sep 17 00:00:00 2001 From: Kiran Gopinathan Date: Fri, 30 Jan 2026 21:36:24 -0500 Subject: [PATCH 04/12] restored documentation on retry using litellm --- docs/source/llm.ipynb | 191 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 191 insertions(+) diff --git a/docs/source/llm.ipynb b/docs/source/llm.ipynb index be92c345..f40a6fdb 100644 --- a/docs/source/llm.ipynb +++ b/docs/source/llm.ipynb @@ -547,6 +547,197 @@ " print(\"=== Funny story ===\")\n", " print(write_story(\"a curious cat\", \"funny\"))" ] + }, + { + "cell_type": "markdown", + "id": "bd25826d", + "metadata": {}, + "source": [ + "### Retrying LLM Requests\n", + "LLM calls can sometimes fail due to transient errors or produce invalid outputs. The `RetryLLMHandler` automatically retries failed template calls:\n", + "\n", + "- `max_retries`: Maximum number of retry attempts (default: 3)\n", + "- `add_error_feedback`: When `True`, appends the error message to the prompt on retry, helping the LLM correct its output.\n", + "- `exception_cls`: RetryHandler will only attempt to try again when a specific type of `Exception` is thrown.\n" + ] + }, + { + "cell_type": "markdown", + "id": "bafc0a96", + "metadata": {}, + "source": [ + "Example usage: having an unstable service that seldomly fail." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "4334d07a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "> Use the unstable_service tool to fetch data.\n", + "None\n", + "> Use the unstable_service tool to fetch data.\n", + "> None\n", + "> {'status': 'failure', 'exception': 'Service unavailable! Attempt 1/3. Please retry.'}\n", + "None\n", + "> Use the unstable_service tool to fetch data.\n", + "> None\n", + "> {'status': 'failure', 'exception': 'Service unavailable! Attempt 1/3. Please retry.'}\n", + "> None\n", + "> {'status': 'failure', 'exception': 'Service unavailable! Attempt 2/3. Please retry.'}\n", + "None\n", + "> Use the unstable_service tool to fetch data.\n", + "> None\n", + "> {'status': 'failure', 'exception': 'Service unavailable! Attempt 1/3. Please retry.'}\n", + "> None\n", + "> {'status': 'failure', 'exception': 'Service unavailable! Attempt 2/3. Please retry.'}\n", + "> None\n", + "> { 'status': 'ok', 'data': [1, 2, 3] }\n", + "The data fetched from the unstable service is: [1, 2, 3].\n", + "Result: The data fetched from the unstable service is: [1, 2, 3]. Retries: 3\n" + ] + } + ], + "source": [ + "call_count = 0\n", + "REQUIRED_RETRIES = 3\n", + "\n", + "\n", + "@Tool.define\n", + "def unstable_service() -> str:\n", + " \"\"\"Fetch data from an unstable external service. May require retries.\"\"\"\n", + " global call_count\n", + " call_count += 1\n", + " if call_count < REQUIRED_RETRIES:\n", + " raise ConnectionError(\n", + " f\"Service unavailable! Attempt {call_count}/{REQUIRED_RETRIES}. Please retry.\"\n", + " )\n", + " return \"{ 'status': 'ok', 'data': [1, 2, 3] }\"\n", + "\n", + "\n", + "@Template.define # unstable_service auto-captured from lexical scope\n", + "def fetch_data() -> str:\n", + " \"\"\"Use the unstable_service tool to fetch data.\"\"\"\n", + " raise NotHandled\n", + "\n", + "\n", + "with handler(provider), handler({call_assistant: log_llm}):\n", + " result = fetch_data()\n", + " print(f\"Result: {result}\", \"Retries:\", call_count)" + ] + }, + { + "cell_type": "markdown", + "id": "4ac00e01", + "metadata": {}, + "source": [ + "### Retrying with Validation Errors\n", + "As noted above, the `RetryHandler` can also be used to retry on runtime/validation error:" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "39b2b225", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "> Give a rating for Die Hard. The explanation MUST include the numeric score. Do not use any tools.\n", + "{\"value\":{\"score\":9,\"explanation\":\"Die Hard is widely regarded as one of the greatest action films ever made. It features a gripping narrative, iconic performances (especially by Bruce Willis as John McClane), and groundbreaking action sequences. The movie's blend of intense action, humor, and suspense has set a high standard for the genre. For these reasons, Die Hard receives a score of 9 out of 10.\"}}\n", + "> Give a rating for Die Hard. The explanation MUST include the numeric score. Do not use any tools.\n", + "> \n", + "Error from previous attempt:\n", + "```\n", + "Traceback (most recent call last):\n", + " File \"/Users/nguyendat/Marc/effectful/effectful/handlers/llm/providers.py\", line 374, in _retry_completion\n", + " return fwd()\n", + " ^^^^^\n", + " File \"/Users/nguyendat/Marc/effectful/effectful/ops/types.py\", line 488, in __call__\n", + " return self_handler(*args, **kwargs)\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + " File \"/Users/nguyendat/.local/share/uv/python/cpython-3.12.11-macos-aarch64-none/lib/python3.12/contextlib.py\", line 81, in inner\n", + " return func(*args, **kwds)\n", + " ^^^^^^^^^^^^^^^^^^^\n", + " File \"/Users/nguyendat/Marc/effectful/effectful/internals/runtime.py\", line 45, in _cont_wrapper\n", + " return fn(*a, **k)\n", + " ^^^^^^^^^^^\n", + " File \"/Users/nguyendat/Marc/effectful/effectful/internals/runtime.py\", line 56, in _cont_wrapper\n", + " return fn(*a, **k)\n", + " ^^^^^^^^^^^\n", + " File \"/Users/nguyendat/Marc/effectful/effectful/internals/runtime.py\", line 70, in bound_body\n", + " return body(*a, **k)\n", + " ^^^^^^^^^^^^^\n", + " File \"/Users/nguyendat/Marc/effectful/effectful/internals/runtime.py\", line 56, in _cont_wrapper\n", + " return fn(*a, **k)\n", + " ^^^^^^^^^^^\n", + " File \"/Users/nguyendat/Marc/effectful/effectful/handlers/llm/providers.py\", line 406, in _call\n", + " return decode_response(template, resp)\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + " File \"/Users/nguyendat/Marc/effectful/effectful/handlers/llm/providers.py\", line 279, in decode_response\n", + " result = Result.model_validate_json(result_str)\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + " File \"/Users/nguyendat/Marc/effectful/.venv/lib/python3.12/site-packages/pydantic/main.py\", line 766, in model_validate_json\n", + " return cls.__pydantic_validator__.validate_json(\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + "pydantic_core._pydantic_core.ValidationError: 1 validation error for Result\n", + "value.score\n", + " score must be 1–5, got 9 [type=invalid_score, input_value=9, input_type=int]\n", + "```\n", + "{\"value\":{\"score\":5,\"explanation\":\"Die Hard is a classic action film that is widely regarded as one of the best in its genre, earning a score of 5. The movie's combination of thrilling action sequences, memorable performances, particularly by Bruce Willis and Alan Rickman, and its clever script have cemented its status as a must-watch. It's a perfect blend of tension, humor, and drama, making it a favorite among audiences and critics alike.\"}}\n", + "Score: 5/5\n", + "Explanation: Die Hard is a classic action film that is widely regarded as one of the best in its genre, earning a score of 5. The movie's combination of thrilling action sequences, memorable performances, particularly by Bruce Willis and Alan Rickman, and its clever script have cemented its status as a must-watch. It's a perfect blend of tension, humor, and drama, making it a favorite among audiences and critics alike.\n" + ] + } + ], + "source": [ + "@pydantic.dataclasses.dataclass\n", + "class Rating:\n", + " score: int\n", + " explanation: str\n", + "\n", + " @field_validator(\"score\")\n", + " @classmethod\n", + " def check_score(cls, v):\n", + " if v < 1 or v > 5:\n", + " raise PydanticCustomError(\n", + " \"invalid_score\",\n", + " \"score must be 1–5, got {v}\",\n", + " {\"v\": v},\n", + " )\n", + " return v\n", + "\n", + " @field_validator(\"explanation\")\n", + " @classmethod\n", + " def check_explanation_contains_score(cls, v, info):\n", + " score = info.data.get(\"score\", None)\n", + " if score is not None and str(score) not in v:\n", + " raise PydanticCustomError(\n", + " \"invalid_explanation\",\n", + " \"explanation must mention the score {score}, got '{explanation}'\",\n", + " {\"score\": score, \"explanation\": v},\n", + " )\n", + " return v\n", + "\n", + "\n", + "@Template.define\n", + "def give_rating_for_movie(movie_name: str) -> Rating:\n", + " \"\"\"Give a rating for {movie_name}. The explanation MUST include the numeric score. Do not use any tools.\"\"\"\n", + " raise NotHandled\n", + "\n", + "\n", + "with handler(provider), handler({call_assistant: log_llm}):\n", + " rating = give_rating_for_movie(\"Die Hard\")\n", + " print(f\"Score: {rating.score}/5\")\n", + " print(f\"Explanation: {rating.explanation}\")" + ] } ], "metadata": { From e7f872123c4bce0612aa94ee1051887a38572803 Mon Sep 17 00:00:00 2001 From: Kiran Gopinathan Date: Fri, 30 Jan 2026 21:42:33 -0500 Subject: [PATCH 05/12] switched exception classes to cleaner dataclasses --- effectful/handlers/llm/completions.py | 64 +++++++++------------------ 1 file changed, 22 insertions(+), 42 deletions(-) diff --git a/effectful/handlers/llm/completions.py b/effectful/handlers/llm/completions.py index ac1894f7..0bdc106c 100644 --- a/effectful/handlers/llm/completions.py +++ b/effectful/handlers/llm/completions.py @@ -1,5 +1,6 @@ import collections import collections.abc +import dataclasses import functools import inspect import string @@ -27,61 +28,40 @@ from effectful.ops.types import Operation +@dataclasses.dataclass class ToolCallDecodingError(Exception): - """Error raised when decoding a tool call fails. + """Error raised when decoding a tool call fails.""" - Attributes: - tool_name: Name of the tool that failed to decode. - tool_call_id: ID of the tool call that failed. - original_error: The underlying exception that caused the failure. - raw_message: The raw assistant message containing the failed tool call. - """ + tool_name: str + tool_call_id: str + original_error: Exception + raw_message: typing.Any = None - def __init__( - self, - tool_name: str, - tool_call_id: str, - original_error: Exception, - raw_message: typing.Any = None, - ): - self.tool_name = tool_name - self.tool_call_id = tool_call_id - self.original_error = original_error - self.raw_message = raw_message - super().__init__(f"Error decoding tool call '{tool_name}': {original_error}") + def __str__(self) -> str: + return f"Error decoding tool call '{self.tool_name}': {self.original_error}" +@dataclasses.dataclass class ResultDecodingError(Exception): - """Error raised when decoding the LLM response result fails. + """Error raised when decoding the LLM response result fails.""" - Attributes: - original_error: The underlying exception that caused the failure. - raw_message: The raw assistant message containing the failed result. - """ + original_error: Exception + raw_message: typing.Any = None - def __init__( - self, - original_error: Exception, - raw_message: typing.Any = None, - ): - self.original_error = original_error - self.raw_message = raw_message - super().__init__(f"Error decoding response: {original_error}") + def __str__(self) -> str: + return f"Error decoding response: {self.original_error}" +@dataclasses.dataclass class ToolExecutionError(Exception): """Error raised when a tool execution fails at runtime.""" - def __init__( - self, - tool_name: str, - tool_call_id: str, - original_error: Exception, - ): - self.tool_name = tool_name - self.tool_call_id = tool_call_id - self.original_error = original_error - super().__init__(f"Error executing tool '{tool_name}': {original_error}") + tool_name: str + tool_call_id: str + original_error: Exception + + def __str__(self) -> str: + return f"Error executing tool '{self.tool_name}': {self.original_error}" Message = ( From 49a35f51505c505ca997634665b23162fd55db08 Mon Sep 17 00:00:00 2001 From: Kiran Gopinathan Date: Fri, 30 Jan 2026 21:58:18 -0500 Subject: [PATCH 06/12] dropped redundant ToolExecutionError --- effectful/handlers/llm/completions.py | 20 +--------------- tests/test_handlers_llm_provider.py | 34 +-------------------------- 2 files changed, 2 insertions(+), 52 deletions(-) diff --git a/effectful/handlers/llm/completions.py b/effectful/handlers/llm/completions.py index 0bdc106c..e85f5455 100644 --- a/effectful/handlers/llm/completions.py +++ b/effectful/handlers/llm/completions.py @@ -52,18 +52,6 @@ def __str__(self) -> str: return f"Error decoding response: {self.original_error}" -@dataclasses.dataclass -class ToolExecutionError(Exception): - """Error raised when a tool execution fails at runtime.""" - - tool_name: str - tool_call_id: str - original_error: Exception - - def __str__(self) -> str: - return f"Error executing tool '{self.tool_name}': {self.original_error}" - - Message = ( OpenAIChatCompletionAssistantMessage | ChatCompletionToolMessage @@ -418,18 +406,12 @@ def _call_tool(self, tool_call: DecodedToolCall) -> Message: try: return fwd(tool_call) except Exception as e: - # Wrap runtime errors and return as a tool message - error = ToolExecutionError( - tool_call.tool.__name__, - tool_call.id, - e, - ) return typing.cast( Message, { "role": "tool", "tool_call_id": tool_call.id, - "content": f"Tool execution failed: {error}", + "content": f"Tool execution failed: Error executing tool '{tool_call.tool.__name__}': {e}", }, ) diff --git a/tests/test_handlers_llm_provider.py b/tests/test_handlers_llm_provider.py index ee0380af..70a40947 100644 --- a/tests/test_handlers_llm_provider.py +++ b/tests/test_handlers_llm_provider.py @@ -22,12 +22,12 @@ from effectful.handlers.llm import Template from effectful.handlers.llm.completions import ( + DecodedToolCall, LiteLLMProvider, ResultDecodingError, RetryLLMHandler, Tool, ToolCallDecodingError, - ToolExecutionError, call_assistant, call_tool, completion, @@ -709,22 +709,8 @@ def divide_tool(a: int, b: int) -> int: class TestToolExecutionErrorHandling: """Tests for runtime tool execution error handling.""" - def test_tool_execution_error_attributes(self): - """Test ToolExecutionError has correct attributes.""" - original = ValueError("something went wrong") - error = ToolExecutionError("my_tool", "call_123", original) - - assert error.tool_name == "my_tool" - assert error.tool_call_id == "call_123" - assert error.original_error is original - assert "my_tool" in str(error) - assert "something went wrong" in str(error) - def test_retry_handler_catches_tool_runtime_error(self): """Test that RetryLLMHandler catches tool runtime errors and returns error message.""" - import inspect - - from effectful.handlers.llm.completions import DecodedToolCall # Create a decoded tool call for failing_tool sig = inspect.signature(failing_tool) @@ -743,9 +729,6 @@ def test_retry_handler_catches_tool_runtime_error(self): def test_retry_handler_catches_division_by_zero(self): """Test that RetryLLMHandler catches division by zero errors.""" - import inspect - - from effectful.handlers.llm.completions import DecodedToolCall sig = inspect.signature(divide_tool) bound_args = sig.bind(a=10, b=0) @@ -761,9 +744,6 @@ def test_retry_handler_catches_division_by_zero(self): def test_successful_tool_execution_returns_result(self): """Test that successful tool executions return normal results.""" - import inspect - - from effectful.handlers.llm.completions import DecodedToolCall sig = inspect.signature(add_numbers) bound_args = sig.bind(a=3, b=4) @@ -844,15 +824,6 @@ def test_result_decoding_error_string_representation(self): assert "parse error" in error_str assert "decoding response" in error_str.lower() - def test_tool_execution_error_string_representation(self): - """Test ToolExecutionError string includes relevant info.""" - original = RuntimeError("runtime failure") - error = ToolExecutionError("compute", "call_xyz", original) - - error_str = str(error) - assert "compute" in error_str - assert "runtime failure" in error_str - def test_error_classes_preserve_original_error(self): """Test that all error classes preserve the original exception.""" original = TypeError("type mismatch") @@ -863,9 +834,6 @@ def test_error_classes_preserve_original_error(self): result_decode_err = ResultDecodingError(original) assert result_decode_err.original_error is original - tool_exec_err = ToolExecutionError("fn", "id", original) - assert tool_exec_err.original_error is original - def test_tool_call_decoding_error_raw_message_optional(self): """Test that raw_message can be None initially (set later in call_assistant).""" error = ToolCallDecodingError("fn", "id", ValueError("test")) From 63239b687dd4f3846fde435284eeac1270929435 Mon Sep 17 00:00:00 2001 From: Kiran Gopinathan Date: Fri, 30 Jan 2026 22:07:05 -0500 Subject: [PATCH 07/12] add parameter to include traceback in calls --- effectful/handlers/llm/completions.py | 24 +++++++++---- tests/test_handlers_llm_provider.py | 50 +++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 6 deletions(-) diff --git a/effectful/handlers/llm/completions.py b/effectful/handlers/llm/completions.py index e85f5455..f47fc4f9 100644 --- a/effectful/handlers/llm/completions.py +++ b/effectful/handlers/llm/completions.py @@ -5,6 +5,7 @@ import inspect import string import textwrap +import traceback import typing import litellm @@ -314,10 +315,20 @@ class RetryLLMHandler(ObjectInterpretation): Args: num_retries: The maximum number of retries (default: 3). + include_traceback: If True, include full traceback in error feedback + for better debugging context (default: False). """ - def __init__(self, num_retries: int = 3): + def __init__(self, num_retries: int = 3, include_traceback: bool = False): self.num_retries = num_retries + self.include_traceback = include_traceback + + def _format_error(self, error: Exception, base_msg: str) -> str: + """Format an error message, optionally including traceback.""" + if self.include_traceback: + tb = traceback.format_exc() + return f"{base_msg}\n\nTraceback:\n```\n{tb}```" + return base_msg @implements(call_assistant) def _call_assistant[T, U]( @@ -353,13 +364,13 @@ def _call_assistant[T, U]( messages_list.append(e.raw_message) # Add error feedback as a tool response - error_msg = f"{e}. Please fix the tool call arguments and try again." + base_msg = f"{e}. Please fix the tool call arguments and try again." error_feedback: Message = typing.cast( Message, { "role": "tool", "tool_call_id": e.tool_call_id, - "content": error_msg, + "content": self._format_error(e, base_msg), }, ) messages_list.append(error_feedback) @@ -376,12 +387,12 @@ def _call_assistant[T, U]( messages_list.append(e.raw_message) # Add error feedback as a user message - error_msg = f"{e}. Please provide a valid response and try again." + base_msg = f"{e}. Please provide a valid response and try again." result_error_feedback: Message = typing.cast( Message, { "role": "user", - "content": error_msg, + "content": self._format_error(e, base_msg), }, ) messages_list.append(result_error_feedback) @@ -406,12 +417,13 @@ def _call_tool(self, tool_call: DecodedToolCall) -> Message: try: return fwd(tool_call) except Exception as e: + base_msg = f"Tool execution failed: Error executing tool '{tool_call.tool.__name__}': {e}" return typing.cast( Message, { "role": "tool", "tool_call_id": tool_call.id, - "content": f"Tool execution failed: Error executing tool '{tool_call.tool.__name__}': {e}", + "content": self._format_error(e, base_msg), }, ) diff --git a/tests/test_handlers_llm_provider.py b/tests/test_handlers_llm_provider.py index 70a40947..c9a63f8d 100644 --- a/tests/test_handlers_llm_provider.py +++ b/tests/test_handlers_llm_provider.py @@ -688,6 +688,56 @@ def test_retry_handler_unknown_tool_error_contains_tool_name(self): assert len(tool_feedback) == 1 assert "nonexistent_tool" in tool_feedback[0]["content"] + def test_retry_handler_include_traceback_in_error_feedback(self): + """Test that include_traceback=True adds traceback to error messages.""" + responses = [ + make_tool_call_response("add_numbers", '{"a": "bad", "b": 2}'), + make_text_response('{"value": "success"}'), + ] + + mock_handler = MockCompletionHandler(responses) + + with ( + handler(RetryLLMHandler(num_retries=3, include_traceback=True)), + handler(mock_handler), + ): + call_assistant( + messages=[{"role": "user", "content": "test"}], + tools={"add_numbers": add_numbers}, + response_format=Encodable.define(str), + model="test-model", + ) + + # Check that the error feedback includes traceback + second_call_messages = mock_handler.received_messages[1] + tool_feedback = [m for m in second_call_messages if m.get("role") == "tool"] + assert len(tool_feedback) == 1 + assert "Traceback:" in tool_feedback[0]["content"] + assert "```" in tool_feedback[0]["content"] + + def test_retry_handler_no_traceback_by_default(self): + """Test that include_traceback=False (default) doesn't add traceback.""" + responses = [ + make_tool_call_response("add_numbers", '{"a": "bad", "b": 2}'), + make_text_response('{"value": "success"}'), + ] + + mock_handler = MockCompletionHandler(responses) + + with handler(RetryLLMHandler(num_retries=3)), handler(mock_handler): + call_assistant( + messages=[{"role": "user", "content": "test"}], + tools={"add_numbers": add_numbers}, + response_format=Encodable.define(str), + model="test-model", + ) + + # Check that the error feedback does NOT include traceback + second_call_messages = mock_handler.received_messages[1] + tool_feedback = [m for m in second_call_messages if m.get("role") == "tool"] + assert len(tool_feedback) == 1 + assert "Traceback:" not in tool_feedback[0]["content"] + # ============================================================================ # Tool Execution Error Tests From c528b5096004d8558e94eedc879af96aa98971aa Mon Sep 17 00:00:00 2001 From: Kiran Gopinathan Date: Fri, 30 Jan 2026 22:14:58 -0500 Subject: [PATCH 08/12] moved formatting to exception classes --- effectful/handlers/llm/completions.py | 63 +++++++++++++-------------- 1 file changed, 31 insertions(+), 32 deletions(-) diff --git a/effectful/handlers/llm/completions.py b/effectful/handlers/llm/completions.py index f47fc4f9..9c855c9b 100644 --- a/effectful/handlers/llm/completions.py +++ b/effectful/handlers/llm/completions.py @@ -28,6 +28,16 @@ from effectful.ops.syntax import ObjectInterpretation, implements from effectful.ops.types import Operation +Message = ( + OpenAIChatCompletionAssistantMessage + | ChatCompletionToolMessage + | ChatCompletionFunctionMessage + | OpenAIChatCompletionSystemMessage + | OpenAIChatCompletionUserMessage +) + +type ToolCallID = str + @dataclasses.dataclass class ToolCallDecodingError(Exception): @@ -36,10 +46,10 @@ class ToolCallDecodingError(Exception): tool_name: str tool_call_id: str original_error: Exception - raw_message: typing.Any = None + raw_message: Message def __str__(self) -> str: - return f"Error decoding tool call '{self.tool_name}': {self.original_error}" + return f"Error decoding tool call '{self.tool_name}': {self.original_error}. Please provide a valid response and try again." @dataclasses.dataclass @@ -47,21 +57,21 @@ class ResultDecodingError(Exception): """Error raised when decoding the LLM response result fails.""" original_error: Exception - raw_message: typing.Any = None + raw_message: Message def __str__(self) -> str: - return f"Error decoding response: {self.original_error}" + return f"Error decoding response: {self.original_error}. Please provide a valid response and try again." -Message = ( - OpenAIChatCompletionAssistantMessage - | ChatCompletionToolMessage - | ChatCompletionFunctionMessage - | OpenAIChatCompletionSystemMessage - | OpenAIChatCompletionUserMessage -) +@dataclasses.dataclass +class ToolCallExecutionError(Exception): + tool_name: str + e: Exception -type ToolCallID = str + def __str__(self) -> str: + return ( + f"Tool execution failed: Error executing tool '{self.tool_name}': {self.e}" + ) class DecodedToolCall[T](typing.NamedTuple): @@ -105,7 +115,7 @@ def _function_model(tool: Tool) -> ChatCompletionToolParam: def decode_tool_call( tool_call: ChatCompletionMessageToolCall, tools: collections.abc.Mapping[str, Tool], - raw_message: Message | None = None, + raw_message: Message, ) -> DecodedToolCall: """Decode a tool call from the LLM response into a DecodedToolCall. @@ -323,12 +333,12 @@ def __init__(self, num_retries: int = 3, include_traceback: bool = False): self.num_retries = num_retries self.include_traceback = include_traceback - def _format_error(self, error: Exception, base_msg: str) -> str: + def _format_error(self, error: Exception) -> str: """Format an error message, optionally including traceback.""" if self.include_traceback: tb = traceback.format_exc() - return f"{base_msg}\n\nTraceback:\n```\n{tb}```" - return base_msg + return f"{error}\n\nTraceback:\n```\n{tb}```" + return f"{error}" @implements(call_assistant) def _call_assistant[T, U]( @@ -355,22 +365,17 @@ def _call_assistant[T, U]( except ToolCallDecodingError as e: last_error = e - # The error includes the raw message from the failed attempt - assert e.raw_message is not None, ( - "ToolCallDecodingError should include raw_message" - ) # Add the malformed assistant message messages_list.append(e.raw_message) # Add error feedback as a tool response - base_msg = f"{e}. Please fix the tool call arguments and try again." error_feedback: Message = typing.cast( Message, { "role": "tool", "tool_call_id": e.tool_call_id, - "content": self._format_error(e, base_msg), + "content": self._format_error(e), }, ) messages_list.append(error_feedback) @@ -378,21 +383,14 @@ def _call_assistant[T, U]( except ResultDecodingError as e: last_error = e - # The error includes the raw message from the failed attempt - assert e.raw_message is not None, ( - "ResultDecodingError should include raw_message" - ) - # Add the malformed assistant message messages_list.append(e.raw_message) - # Add error feedback as a user message - base_msg = f"{e}. Please provide a valid response and try again." result_error_feedback: Message = typing.cast( Message, { "role": "user", - "content": self._format_error(e, base_msg), + "content": self._format_error(e), }, ) messages_list.append(result_error_feedback) @@ -417,13 +415,14 @@ def _call_tool(self, tool_call: DecodedToolCall) -> Message: try: return fwd(tool_call) except Exception as e: - base_msg = f"Tool execution failed: Error executing tool '{tool_call.tool.__name__}': {e}" return typing.cast( Message, { "role": "tool", "tool_call_id": tool_call.id, - "content": self._format_error(e, base_msg), + "content": self._format_error( + ToolCallExecutionError(f"{tool_call.tool.__name__}", e) + ), }, ) From e4bdae956ff05726544e90f591688d6f79fc5c50 Mon Sep 17 00:00:00 2001 From: Kiran Gopinathan Date: Fri, 30 Jan 2026 22:19:38 -0500 Subject: [PATCH 09/12] fixed failing tests --- tests/test_handlers_llm_provider.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/tests/test_handlers_llm_provider.py b/tests/test_handlers_llm_provider.py index c9a63f8d..f557811d 100644 --- a/tests/test_handlers_llm_provider.py +++ b/tests/test_handlers_llm_provider.py @@ -877,23 +877,19 @@ def test_result_decoding_error_string_representation(self): def test_error_classes_preserve_original_error(self): """Test that all error classes preserve the original exception.""" original = TypeError("type mismatch") + mock_message = {"role": "assistant", "content": "test"} - tool_decode_err = ToolCallDecodingError("fn", "id", original) + tool_decode_err = ToolCallDecodingError("fn", "id", original, mock_message) assert tool_decode_err.original_error is original - result_decode_err = ResultDecodingError(original) + result_decode_err = ResultDecodingError(original, mock_message) assert result_decode_err.original_error is original - def test_tool_call_decoding_error_raw_message_optional(self): - """Test that raw_message can be None initially (set later in call_assistant).""" - error = ToolCallDecodingError("fn", "id", ValueError("test")) - assert error.raw_message is None - - # Can be set after creation - error_with_msg = ToolCallDecodingError( - "fn", "id", ValueError("test"), raw_message={"role": "assistant"} - ) - assert error_with_msg.raw_message is not None + def test_tool_call_decoding_error_includes_raw_message(self): + """Test that ToolCallDecodingError includes the raw message.""" + mock_message = {"role": "assistant", "content": "test"} + error = ToolCallDecodingError("fn", "id", ValueError("test"), mock_message) + assert error.raw_message == mock_message # ============================================================================ From 423cfcd2b49c19a9b49e239bca9982555e7609b8 Mon Sep 17 00:00:00 2001 From: Kiran Gopinathan Date: Fri, 30 Jan 2026 22:21:25 -0500 Subject: [PATCH 10/12] raised inside except to preserve backtrace --- effectful/handlers/llm/completions.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/effectful/handlers/llm/completions.py b/effectful/handlers/llm/completions.py index 9c855c9b..ff362040 100644 --- a/effectful/handlers/llm/completions.py +++ b/effectful/handlers/llm/completions.py @@ -350,9 +350,9 @@ def _call_assistant[T, U]( **kwargs, ) -> MessageResult[T]: messages_list = list(messages) - last_error: Exception | None = None + last_attempt = self.num_retries - for _attempt in range(self.num_retries + 1): + for attempt in range(self.num_retries + 1): try: message, tool_calls, result = fwd( messages_list, tools, response_format, model, **kwargs @@ -364,7 +364,9 @@ def _call_assistant[T, U]( return (message, tool_calls, result) except ToolCallDecodingError as e: - last_error = e + # On last attempt, re-raise to preserve full traceback + if attempt == last_attempt: + raise # Add the malformed assistant message messages_list.append(e.raw_message) @@ -379,10 +381,12 @@ def _call_assistant[T, U]( }, ) messages_list.append(error_feedback) - continue except ResultDecodingError as e: - last_error = e + # On last attempt, re-raise to preserve full traceback + if attempt == last_attempt: + raise + # Add the malformed assistant message messages_list.append(e.raw_message) # Add error feedback as a user message @@ -394,11 +398,9 @@ def _call_assistant[T, U]( }, ) messages_list.append(result_error_feedback) - continue - # If all retries failed, raise the last error - assert last_error is not None - raise last_error + # Should never reach here - either we return on success or raise on final failure + raise AssertionError("Unreachable: retry loop exited without return or raise") @implements(completion) def _completion(self, *args, **kwargs) -> typing.Any: From c836f94c991b8aa3ce6c20f201d0d2712544da3b Mon Sep 17 00:00:00 2001 From: Kiran Gopinathan Date: Fri, 30 Jan 2026 22:31:56 -0500 Subject: [PATCH 11/12] unified exception message formatting --- effectful/handlers/llm/completions.py | 95 ++++++++++++++------------- 1 file changed, 49 insertions(+), 46 deletions(-) diff --git a/effectful/handlers/llm/completions.py b/effectful/handlers/llm/completions.py index ff362040..72768a69 100644 --- a/effectful/handlers/llm/completions.py +++ b/effectful/handlers/llm/completions.py @@ -51,6 +51,20 @@ class ToolCallDecodingError(Exception): def __str__(self) -> str: return f"Error decoding tool call '{self.tool_name}': {self.original_error}. Please provide a valid response and try again." + def to_feedback_message(self, include_traceback: bool) -> Message: + error_message = f"{self}" + if include_traceback: + tb = traceback.format_exc() + error_message = f"{error_message}\n\nTraceback:\n```\n{tb}```" + return typing.cast( + Message, + { + "role": "tool", + "tool_call_id": self.tool_call_id, + "content": error_message, + }, + ) + @dataclasses.dataclass class ResultDecodingError(Exception): @@ -62,15 +76,43 @@ class ResultDecodingError(Exception): def __str__(self) -> str: return f"Error decoding response: {self.original_error}. Please provide a valid response and try again." + def to_feedback_message(self, include_traceback: bool) -> Message: + error_message = f"{self}" + if include_traceback: + tb = traceback.format_exc() + error_message = f"{error_message}\n\nTraceback:\n```\n{tb}```" + return typing.cast( + Message, + { + "role": "user", + "content": error_message, + }, + ) + @dataclasses.dataclass class ToolCallExecutionError(Exception): + """Error raised when a tool execution fails at runtime.""" + tool_name: str - e: Exception + tool_call_id: str + original_error: Exception def __str__(self) -> str: - return ( - f"Tool execution failed: Error executing tool '{self.tool_name}': {self.e}" + return f"Tool execution failed: Error executing tool '{self.tool_name}': {self.original_error}" + + def to_feedback_message(self, include_traceback: bool) -> Message: + error_message = f"{self}" + if include_traceback: + tb = traceback.format_exc() + error_message = f"{error_message}\n\nTraceback:\n```\n{tb}```" + return typing.cast( + Message, + { + "role": "tool", + "tool_call_id": self.tool_call_id, + "content": error_message, + }, ) @@ -333,13 +375,6 @@ def __init__(self, num_retries: int = 3, include_traceback: bool = False): self.num_retries = num_retries self.include_traceback = include_traceback - def _format_error(self, error: Exception) -> str: - """Format an error message, optionally including traceback.""" - if self.include_traceback: - tb = traceback.format_exc() - return f"{error}\n\nTraceback:\n```\n{tb}```" - return f"{error}" - @implements(call_assistant) def _call_assistant[T, U]( self, @@ -363,7 +398,7 @@ def _call_assistant[T, U]( # not in the returned result. return (message, tool_calls, result) - except ToolCallDecodingError as e: + except (ToolCallDecodingError, ResultDecodingError) as e: # On last attempt, re-raise to preserve full traceback if attempt == last_attempt: raise @@ -372,33 +407,9 @@ def _call_assistant[T, U]( messages_list.append(e.raw_message) # Add error feedback as a tool response - error_feedback: Message = typing.cast( - Message, - { - "role": "tool", - "tool_call_id": e.tool_call_id, - "content": self._format_error(e), - }, - ) + error_feedback: Message = e.to_feedback_message(self.include_traceback) messages_list.append(error_feedback) - except ResultDecodingError as e: - # On last attempt, re-raise to preserve full traceback - if attempt == last_attempt: - raise - - # Add the malformed assistant message - messages_list.append(e.raw_message) - # Add error feedback as a user message - result_error_feedback: Message = typing.cast( - Message, - { - "role": "user", - "content": self._format_error(e), - }, - ) - messages_list.append(result_error_feedback) - # Should never reach here - either we return on success or raise on final failure raise AssertionError("Unreachable: retry loop exited without return or raise") @@ -417,16 +428,8 @@ def _call_tool(self, tool_call: DecodedToolCall) -> Message: try: return fwd(tool_call) except Exception as e: - return typing.cast( - Message, - { - "role": "tool", - "tool_call_id": tool_call.id, - "content": self._format_error( - ToolCallExecutionError(f"{tool_call.tool.__name__}", e) - ), - }, - ) + error = ToolCallExecutionError(tool_call.tool.__name__, tool_call.id, e) + return error.to_feedback_message(self.include_traceback) class LiteLLMProvider(ObjectInterpretation): From 45d4bc7ae0167b6c0c71c6183f6d59df274e2b15 Mon Sep 17 00:00:00 2001 From: Kiran Gopinathan Date: Fri, 30 Jan 2026 22:34:49 -0500 Subject: [PATCH 12/12] made retryllmhandler parametric on errors it catches --- effectful/handlers/llm/completions.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/effectful/handlers/llm/completions.py b/effectful/handlers/llm/completions.py index 72768a69..daf88b4b 100644 --- a/effectful/handlers/llm/completions.py +++ b/effectful/handlers/llm/completions.py @@ -96,7 +96,7 @@ class ToolCallExecutionError(Exception): tool_name: str tool_call_id: str - original_error: Exception + original_error: BaseException def __str__(self) -> str: return f"Tool execution failed: Error executing tool '{self.tool_name}': {self.original_error}" @@ -369,11 +369,21 @@ class RetryLLMHandler(ObjectInterpretation): num_retries: The maximum number of retries (default: 3). include_traceback: If True, include full traceback in error feedback for better debugging context (default: False). + catch_tool_errors: Exception type(s) to catch during tool execution. + Can be a single exception class or a tuple of exception classes. + Defaults to Exception (catches all exceptions). """ - def __init__(self, num_retries: int = 3, include_traceback: bool = False): + def __init__( + self, + num_retries: int = 3, + include_traceback: bool = False, + catch_tool_errors: type[BaseException] + | tuple[type[BaseException], ...] = Exception, + ): self.num_retries = num_retries self.include_traceback = include_traceback + self.catch_tool_errors = catch_tool_errors @implements(call_assistant) def _call_assistant[T, U]( @@ -423,11 +433,12 @@ def _call_tool(self, tool_call: DecodedToolCall) -> Message: """Handle tool execution with runtime error capture. Runtime errors from tool execution are captured and returned as - error messages to the LLM. + error messages to the LLM. Only exceptions matching `catch_tool_errors` + are caught; others propagate up. """ try: return fwd(tool_call) - except Exception as e: + except self.catch_tool_errors as e: error = ToolCallExecutionError(tool_call.tool.__name__, tool_call.id, e) return error.to_feedback_message(self.include_traceback)