Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions docs/source/llm.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"\n",
"import dotenv\n",
"import pydantic\n",
"from pydantic import ValidationError, field_validator\n",
"from pydantic import field_validator\n",
"from pydantic_core import PydanticCustomError\n",
"\n",
"from effectful.handlers.llm import Template, Tool\n",
Expand Down Expand Up @@ -660,9 +660,7 @@
"### Retrying LLM Requests\n",
"LLM calls can sometimes fail due to transient errors or produce invalid outputs. The `RetryLLMHandler` automatically retries failed template calls:\n",
"\n",
"- `max_retries`: Maximum number of retry attempts (default: 3)\n",
"- `add_error_feedback`: When `True`, appends the error message to the prompt on retry, helping the LLM correct its output.\n",
"- `exception_cls`: RetryHandler will only attempt to try again when a specific type of `Exception` is thrown.\n"
"- `num_retries`: Maximum number of retry attempts (default: 3)\n"
]
},
{
Expand Down Expand Up @@ -730,7 +728,7 @@
" raise NotHandled\n",
"\n",
"\n",
"retry_handler = RetryLLMHandler(max_retries=5, add_error_feedback=True)\n",
"retry_handler = RetryLLMHandler(num_retries=5)\n",
"\n",
"with handler(provider), handler(retry_handler), handler({completion: log_llm}):\n",
" result = fetch_data()\n",
Expand Down Expand Up @@ -842,9 +840,7 @@
"# RetryLLMHandler with error feedback - the traceback helps LLM correct validation errors\n",
"# Note: Pydantic wraps PydanticCustomError inside ValidationError, so we catch ValidationError instead\n",
"retry_handler = RetryLLMHandler(\n",
" max_retries=3,\n",
" add_error_feedback=True,\n",
" exception_cls=ValidationError, # Catch validation errors\n",
" num_retries=3,\n",
")\n",
"\n",
"with handler(provider), handler(retry_handler), handler({completion: log_llm}):\n",
Expand Down
57 changes: 5 additions & 52 deletions effectful/handlers/llm/completions.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import contextlib
import functools
import inspect
import logging
import string
import traceback
import typing
from collections.abc import Callable, Hashable
from typing import Any
Expand All @@ -21,7 +19,7 @@

from effectful.handlers.llm import Template, Tool
from effectful.handlers.llm.encoding import type_to_encodable_type
from effectful.ops.semantics import fwd, handler
from effectful.ops.semantics import fwd
from effectful.ops.syntax import ObjectInterpretation, defop, implements
from effectful.ops.types import Operation

Expand Down Expand Up @@ -315,9 +313,10 @@ def format_model_input[**P, T](
class InstructionHandler(ObjectInterpretation):
"""Scoped handler that injects additional instructions into model input.

This handler appends instruction messages to the formatted model input.
It's designed to be used as a scoped handler within RetryLLMHandler to
provide error feedback without polluting shared state.
This handler appends instruction messages to the formatted model
input. It's designed to be used as a scoped handler to provide
error feedback without polluting shared state.

"""

def __init__(self, instruction: str):
Expand All @@ -337,52 +336,6 @@ def _inject_instruction(self, template: Template, *args, **kwargs) -> list[Any]:
]


class RetryLLMHandler(ObjectInterpretation):
"""Retries LLM requests if they fail.

If the request fails, the handler retries with optional error feedback injected
into the prompt via scoped InstructionHandler instances. This ensures nested
template calls maintain independent error tracking.

Args:
max_retries: The maximum number of retries.
add_error_feedback: Whether to add error feedback to the prompt on retry.
exception_cls: The exception class to catch and retry on.
"""

def __init__(
self,
max_retries: int = 3,
add_error_feedback: bool = False,
exception_cls: type[BaseException] = Exception,
):
self.max_retries = max_retries
self.add_error_feedback = add_error_feedback
self.exception_cls = exception_cls

@implements(Template.__apply__)
def _retry_completion(self, template: Template, *args, **kwargs) -> Any:
"""Retry template execution with error feedback injection via scoped handlers."""
failures: list[str] = []

for attempt in range(self.max_retries):
try:
# Install scoped handlers for each accumulated failure
with contextlib.ExitStack() as stack:
for failure in failures:
stack.enter_context(handler(InstructionHandler(failure)))
return fwd()
except self.exception_cls:
if attempt == self.max_retries - 1:
raise # Last attempt, re-raise the exception
if self.add_error_feedback:
tb = traceback.format_exc()
failures.append(f"\nError from previous attempt:\n```\n{tb}```")

# This should not be reached, but just in case
return fwd()


class LiteLLMProvider(ObjectInterpretation):
"""Implements templates using the LiteLLM API."""

Expand Down
101 changes: 0 additions & 101 deletions tests/test_handlers_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,6 @@
import pytest

from effectful.handlers.llm import Template
from effectful.handlers.llm.completions import (
RetryLLMHandler,
compute_response,
format_model_input,
)
from effectful.handlers.llm.synthesis import ProgramSynthesis
from effectful.handlers.llm.template import IsRecursive
from effectful.ops.semantics import NotHandled, handler
Expand Down Expand Up @@ -171,102 +166,6 @@ def _call[**P](
return self.success_response


def test_retry_handler_succeeds_after_failures():
"""Test that RetryLLMHandler retries and eventually succeeds."""
provider = FailingThenSucceedingProvider(
fail_count=2,
success_response="Success after retries!",
exception_factory=lambda: ValueError("Temporary failure"),
)
retry_handler = RetryLLMHandler(max_retries=3, exception_cls=ValueError)

with handler(provider), handler(retry_handler):
result = limerick("test")
assert result == "Success after retries!"
assert provider.call_count == 3 # 2 failures + 1 success


def test_retry_handler_exhausts_retries():
"""Test that RetryLLMHandler raises after max retries exhausted."""
provider = FailingThenSucceedingProvider(
fail_count=5, # More failures than retries
success_response="Never reached",
exception_factory=lambda: ValueError("Persistent failure"),
)
retry_handler = RetryLLMHandler(max_retries=3, exception_cls=ValueError)

with pytest.raises(ValueError, match="Persistent failure"):
with handler(provider), handler(retry_handler):
limerick("test")

assert provider.call_count == 3 # Should have tried 3 times


def test_retry_handler_only_catches_specified_exception():
"""Test that RetryLLMHandler only catches the specified exception class."""
provider = FailingThenSucceedingProvider(
fail_count=1,
success_response="Success",
exception_factory=lambda: TypeError("Wrong type"), # Different exception type
)
retry_handler = RetryLLMHandler(max_retries=3, exception_cls=ValueError)

# TypeError should not be caught, should propagate immediately
with pytest.raises(TypeError, match="Wrong type"):
with handler(provider), handler(retry_handler):
limerick("test")

assert provider.call_count == 1 # Should have only tried once


def test_retry_handler_with_error_feedback():
"""Test that RetryLLMHandler includes error feedback when enabled."""

captured_messages: list[list] = []

class MessageCapturingProvider(ObjectInterpretation):
"""Provider that captures formatted messages and fails once."""

def __init__(self):
self.call_count = 0

@implements(compute_response)
def _capture_and_respond(self, template: Template, messages: list):
"""Capture messages at compute_response level (after error injection)."""
self.call_count += 1
captured_messages.append(messages)
if self.call_count == 1:
raise ValueError("First attempt failed")
# Return a mock response - not used since we return directly
return None

@implements(Template.__apply__)
def _call(self, template: Template, *args, **kwargs):
# Call the format/compute chain but return directly
messages = format_model_input(template, *args, **kwargs)
compute_response(template, messages)
return "Success on retry"

provider = MessageCapturingProvider()
retry_handler = RetryLLMHandler(
max_retries=2, add_error_feedback=True, exception_cls=ValueError
)

with handler(provider), handler(retry_handler):
result = limerick("test")
assert result == "Success on retry"

assert len(captured_messages) == 2
# First call has original prompt only
first_msg_content = str(captured_messages[0])
assert (
"limerick" in first_msg_content.lower() or "theme" in first_msg_content.lower()
)
# Second call should include error feedback
second_msg_content = str(captured_messages[1])
assert "First attempt failed" in second_msg_content


def test_template_captures_other_templates_in_lexical_context():
"""Test that Templates defined in lexical scope are captured (orchestrator pattern)."""

Expand Down
Loading