diff --git a/docs/source/llm.ipynb b/docs/source/llm.ipynb index a6b778c0..c1370547 100644 --- a/docs/source/llm.ipynb +++ b/docs/source/llm.ipynb @@ -14,7 +14,7 @@ "\n", "import dotenv\n", "import pydantic\n", - "from pydantic import ValidationError, field_validator\n", + "from pydantic import field_validator\n", "from pydantic_core import PydanticCustomError\n", "\n", "from effectful.handlers.llm import Template, Tool\n", @@ -660,9 +660,7 @@ "### Retrying LLM Requests\n", "LLM calls can sometimes fail due to transient errors or produce invalid outputs. The `RetryLLMHandler` automatically retries failed template calls:\n", "\n", - "- `max_retries`: Maximum number of retry attempts (default: 3)\n", - "- `add_error_feedback`: When `True`, appends the error message to the prompt on retry, helping the LLM correct its output.\n", - "- `exception_cls`: RetryHandler will only attempt to try again when a specific type of `Exception` is thrown.\n" + "- `num_retries`: Maximum number of retry attempts (default: 3)\n" ] }, { @@ -730,7 +728,7 @@ " raise NotHandled\n", "\n", "\n", - "retry_handler = RetryLLMHandler(max_retries=5, add_error_feedback=True)\n", + "retry_handler = RetryLLMHandler(num_retries=5)\n", "\n", "with handler(provider), handler(retry_handler), handler({completion: log_llm}):\n", " result = fetch_data()\n", @@ -842,9 +840,7 @@ "# RetryLLMHandler with error feedback - the traceback helps LLM correct validation errors\n", "# Note: Pydantic wraps PydanticCustomError inside ValidationError, so we catch ValidationError instead\n", "retry_handler = RetryLLMHandler(\n", - " max_retries=3,\n", - " add_error_feedback=True,\n", - " exception_cls=ValidationError, # Catch validation errors\n", + " num_retries=3,\n", ")\n", "\n", "with handler(provider), handler(retry_handler), handler({completion: log_llm}):\n", diff --git a/effectful/handlers/llm/completions.py b/effectful/handlers/llm/completions.py index 7f00f323..e9c1dd09 100644 --- a/effectful/handlers/llm/completions.py +++ b/effectful/handlers/llm/completions.py @@ -1,9 +1,7 @@ -import contextlib import functools import inspect import logging import string -import traceback import typing from collections.abc import Callable, Hashable from typing import Any @@ -21,7 +19,7 @@ from effectful.handlers.llm import Template, Tool from effectful.handlers.llm.encoding import type_to_encodable_type -from effectful.ops.semantics import fwd, handler +from effectful.ops.semantics import fwd from effectful.ops.syntax import ObjectInterpretation, defop, implements from effectful.ops.types import Operation @@ -315,9 +313,10 @@ def format_model_input[**P, T]( class InstructionHandler(ObjectInterpretation): """Scoped handler that injects additional instructions into model input. - This handler appends instruction messages to the formatted model input. - It's designed to be used as a scoped handler within RetryLLMHandler to - provide error feedback without polluting shared state. + This handler appends instruction messages to the formatted model + input. It's designed to be used as a scoped handler to provide + error feedback without polluting shared state. + """ def __init__(self, instruction: str): @@ -337,52 +336,6 @@ def _inject_instruction(self, template: Template, *args, **kwargs) -> list[Any]: ] -class RetryLLMHandler(ObjectInterpretation): - """Retries LLM requests if they fail. - - If the request fails, the handler retries with optional error feedback injected - into the prompt via scoped InstructionHandler instances. This ensures nested - template calls maintain independent error tracking. - - Args: - max_retries: The maximum number of retries. - add_error_feedback: Whether to add error feedback to the prompt on retry. - exception_cls: The exception class to catch and retry on. - """ - - def __init__( - self, - max_retries: int = 3, - add_error_feedback: bool = False, - exception_cls: type[BaseException] = Exception, - ): - self.max_retries = max_retries - self.add_error_feedback = add_error_feedback - self.exception_cls = exception_cls - - @implements(Template.__apply__) - def _retry_completion(self, template: Template, *args, **kwargs) -> Any: - """Retry template execution with error feedback injection via scoped handlers.""" - failures: list[str] = [] - - for attempt in range(self.max_retries): - try: - # Install scoped handlers for each accumulated failure - with contextlib.ExitStack() as stack: - for failure in failures: - stack.enter_context(handler(InstructionHandler(failure))) - return fwd() - except self.exception_cls: - if attempt == self.max_retries - 1: - raise # Last attempt, re-raise the exception - if self.add_error_feedback: - tb = traceback.format_exc() - failures.append(f"\nError from previous attempt:\n```\n{tb}```") - - # This should not be reached, but just in case - return fwd() - - class LiteLLMProvider(ObjectInterpretation): """Implements templates using the LiteLLM API.""" diff --git a/tests/test_handlers_llm.py b/tests/test_handlers_llm.py index 4ad3d81c..2c98a650 100644 --- a/tests/test_handlers_llm.py +++ b/tests/test_handlers_llm.py @@ -4,11 +4,6 @@ import pytest from effectful.handlers.llm import Template -from effectful.handlers.llm.completions import ( - RetryLLMHandler, - compute_response, - format_model_input, -) from effectful.handlers.llm.synthesis import ProgramSynthesis from effectful.handlers.llm.template import IsRecursive from effectful.ops.semantics import NotHandled, handler @@ -171,102 +166,6 @@ def _call[**P]( return self.success_response -def test_retry_handler_succeeds_after_failures(): - """Test that RetryLLMHandler retries and eventually succeeds.""" - provider = FailingThenSucceedingProvider( - fail_count=2, - success_response="Success after retries!", - exception_factory=lambda: ValueError("Temporary failure"), - ) - retry_handler = RetryLLMHandler(max_retries=3, exception_cls=ValueError) - - with handler(provider), handler(retry_handler): - result = limerick("test") - assert result == "Success after retries!" - assert provider.call_count == 3 # 2 failures + 1 success - - -def test_retry_handler_exhausts_retries(): - """Test that RetryLLMHandler raises after max retries exhausted.""" - provider = FailingThenSucceedingProvider( - fail_count=5, # More failures than retries - success_response="Never reached", - exception_factory=lambda: ValueError("Persistent failure"), - ) - retry_handler = RetryLLMHandler(max_retries=3, exception_cls=ValueError) - - with pytest.raises(ValueError, match="Persistent failure"): - with handler(provider), handler(retry_handler): - limerick("test") - - assert provider.call_count == 3 # Should have tried 3 times - - -def test_retry_handler_only_catches_specified_exception(): - """Test that RetryLLMHandler only catches the specified exception class.""" - provider = FailingThenSucceedingProvider( - fail_count=1, - success_response="Success", - exception_factory=lambda: TypeError("Wrong type"), # Different exception type - ) - retry_handler = RetryLLMHandler(max_retries=3, exception_cls=ValueError) - - # TypeError should not be caught, should propagate immediately - with pytest.raises(TypeError, match="Wrong type"): - with handler(provider), handler(retry_handler): - limerick("test") - - assert provider.call_count == 1 # Should have only tried once - - -def test_retry_handler_with_error_feedback(): - """Test that RetryLLMHandler includes error feedback when enabled.""" - - captured_messages: list[list] = [] - - class MessageCapturingProvider(ObjectInterpretation): - """Provider that captures formatted messages and fails once.""" - - def __init__(self): - self.call_count = 0 - - @implements(compute_response) - def _capture_and_respond(self, template: Template, messages: list): - """Capture messages at compute_response level (after error injection).""" - self.call_count += 1 - captured_messages.append(messages) - if self.call_count == 1: - raise ValueError("First attempt failed") - # Return a mock response - not used since we return directly - return None - - @implements(Template.__apply__) - def _call(self, template: Template, *args, **kwargs): - # Call the format/compute chain but return directly - messages = format_model_input(template, *args, **kwargs) - compute_response(template, messages) - return "Success on retry" - - provider = MessageCapturingProvider() - retry_handler = RetryLLMHandler( - max_retries=2, add_error_feedback=True, exception_cls=ValueError - ) - - with handler(provider), handler(retry_handler): - result = limerick("test") - assert result == "Success on retry" - - assert len(captured_messages) == 2 - # First call has original prompt only - first_msg_content = str(captured_messages[0]) - assert ( - "limerick" in first_msg_content.lower() or "theme" in first_msg_content.lower() - ) - # Second call should include error feedback - second_msg_content = str(captured_messages[1]) - assert "First attempt failed" in second_msg_content - - def test_template_captures_other_templates_in_lexical_context(): """Test that Templates defined in lexical scope are captured (orchestrator pattern)."""