Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 226 additions & 22 deletions effectful/handlers/llm/completions.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import collections
import collections.abc
import dataclasses
import functools
import inspect
import string
import textwrap
import traceback
import typing

import litellm
Expand All @@ -20,8 +22,9 @@
OpenAIMessageContentListBlock,
)

from effectful.handlers.llm import Template, Tool
from effectful.handlers.llm.encoding import Encodable
from effectful.handlers.llm.template import Template, Tool
from effectful.ops.semantics import fwd
from effectful.ops.syntax import ObjectInterpretation, implements
from effectful.ops.types import Operation

Expand All @@ -36,6 +39,83 @@
type ToolCallID = str


@dataclasses.dataclass
class ToolCallDecodingError(Exception):
"""Error raised when decoding a tool call fails."""

tool_name: str
tool_call_id: str
original_error: Exception
raw_message: Message

def __str__(self) -> str:
return f"Error decoding tool call '{self.tool_name}': {self.original_error}. Please provide a valid response and try again."

def to_feedback_message(self, include_traceback: bool) -> Message:
error_message = f"{self}"
if include_traceback:
tb = traceback.format_exc()
error_message = f"{error_message}\n\nTraceback:\n```\n{tb}```"
return typing.cast(
Message,
{
"role": "tool",
"tool_call_id": self.tool_call_id,
"content": error_message,
},
)


@dataclasses.dataclass
class ResultDecodingError(Exception):
"""Error raised when decoding the LLM response result fails."""

original_error: Exception
raw_message: Message

def __str__(self) -> str:
return f"Error decoding response: {self.original_error}. Please provide a valid response and try again."

def to_feedback_message(self, include_traceback: bool) -> Message:
error_message = f"{self}"
if include_traceback:
tb = traceback.format_exc()
error_message = f"{error_message}\n\nTraceback:\n```\n{tb}```"
return typing.cast(
Message,
{
"role": "user",
"content": error_message,
},
)


@dataclasses.dataclass
class ToolCallExecutionError(Exception):
"""Error raised when a tool execution fails at runtime."""

tool_name: str
tool_call_id: str
original_error: BaseException

def __str__(self) -> str:
return f"Tool execution failed: Error executing tool '{self.tool_name}': {self.original_error}"

def to_feedback_message(self, include_traceback: bool) -> Message:
error_message = f"{self}"
if include_traceback:
tb = traceback.format_exc()
error_message = f"{error_message}\n\nTraceback:\n```\n{tb}```"
return typing.cast(
Message,
{
"role": "tool",
"tool_call_id": self.tool_call_id,
"content": error_message,
},
)


class DecodedToolCall[T](typing.NamedTuple):
tool: Tool[..., T]
bound_args: inspect.BoundArguments
Expand Down Expand Up @@ -77,26 +157,49 @@ def _function_model(tool: Tool) -> ChatCompletionToolParam:
def decode_tool_call(
tool_call: ChatCompletionMessageToolCall,
tools: collections.abc.Mapping[str, Tool],
raw_message: Message,
) -> DecodedToolCall:
"""Decode a tool call from the LLM response into a DecodedToolCall."""
assert tool_call.function.name is not None
tool = tools[tool_call.function.name]
json_str = tool_call.function.arguments
"""Decode a tool call from the LLM response into a DecodedToolCall.

Args:
tool_call: The tool call to decode.
tools: Mapping of tool names to Tool objects.
raw_message: Optional raw assistant message for error context.

Raises:
ToolCallDecodingError: If the tool call cannot be decoded.
"""
tool_name = tool_call.function.name
assert tool_name is not None

try:
tool = tools[tool_name]
except KeyError as e:
raise ToolCallDecodingError(
tool_name, tool_call.id, e, raw_message=raw_message
) from e

json_str = tool_call.function.arguments
sig = inspect.signature(tool)

# build dict of raw encodable types U
raw_args = _param_model(tool).model_validate_json(json_str)
try:
# build dict of raw encodable types U
raw_args = _param_model(tool).model_validate_json(json_str)

# use encoders to decode Us to python types T
bound_sig: inspect.BoundArguments = sig.bind(
**{
param_name: Encodable.define(
sig.parameters[param_name].annotation, {}
).decode(getattr(raw_args, param_name))
for param_name in raw_args.model_fields_set
}
)
except (pydantic.ValidationError, TypeError, ValueError) as e:
raise ToolCallDecodingError(
tool_name, tool_call.id, e, raw_message=raw_message
) from e

# use encoders to decode Us to python types T
bound_sig: inspect.BoundArguments = sig.bind(
**{
param_name: Encodable.define(
sig.parameters[param_name].annotation, {}
).decode(getattr(raw_args, param_name))
for param_name in raw_args.model_fields_set
}
)
return DecodedToolCall(tool, bound_sig, tool_call.id)


Expand Down Expand Up @@ -125,6 +228,11 @@ def call_assistant[T, U](
This effect is emitted for model request/response rounds so handlers can
observe/log requests.

Raises:
ToolCallDecodingError: If a tool call cannot be decoded. The error
includes the raw assistant message for retry handling.
ResultDecodingError: If the result cannot be decoded. The error
includes the raw assistant message for retry handling.
"""
tool_specs = {k: _function_model(t) for k, t in tools.items()}
response_model = pydantic.create_model(
Expand All @@ -144,11 +252,15 @@ def call_assistant[T, U](
message: litellm.Message = choice.message
assert message.role == "assistant"

raw_message = typing.cast(Message, message.model_dump(mode="json"))

tool_calls: list[DecodedToolCall] = []
raw_tool_calls = message.get("tool_calls") or []
for tool_call in raw_tool_calls:
tool_call = ChatCompletionMessageToolCall.model_validate(tool_call)
decoded_tool_call = decode_tool_call(tool_call, tools)
for raw_tool_call in raw_tool_calls:
validated_tool_call = ChatCompletionMessageToolCall.model_validate(
raw_tool_call
)
decoded_tool_call = decode_tool_call(validated_tool_call, tools, raw_message)
tool_calls.append(decoded_tool_call)

result = None
Expand All @@ -158,10 +270,13 @@ def call_assistant[T, U](
assert isinstance(serialized_result, str), (
"final response from the model should be a string"
)
raw_result = response_model.model_validate_json(serialized_result)
result = response_format.decode(raw_result.value) # type: ignore
try:
raw_result = response_model.model_validate_json(serialized_result)
result = response_format.decode(raw_result.value) # type: ignore
except pydantic.ValidationError as e:
raise ResultDecodingError(e, raw_message=raw_message) from e

return (typing.cast(Message, message.model_dump(mode="json")), tool_calls, result)
return (raw_message, tool_calls, result)


@Operation.define
Expand Down Expand Up @@ -239,6 +354,95 @@ def call_system(template: Template) -> collections.abc.Sequence[Message]:
return ()


class RetryLLMHandler(ObjectInterpretation):
"""Retries LLM requests if tool call or result decoding fails.

This handler intercepts `call_assistant` and catches `ToolCallDecodingError`
and `ResultDecodingError`. When these errors occur, it appends error feedback
to the messages and retries the request. Malformed messages from retry attempts
are pruned from the final result.

For runtime tool execution failures (handled via `call_tool`), errors are
captured and returned as tool response messages.

Args:
num_retries: The maximum number of retries (default: 3).
include_traceback: If True, include full traceback in error feedback
for better debugging context (default: False).
catch_tool_errors: Exception type(s) to catch during tool execution.
Can be a single exception class or a tuple of exception classes.
Defaults to Exception (catches all exceptions).
"""

def __init__(
self,
num_retries: int = 3,
include_traceback: bool = False,
catch_tool_errors: type[BaseException]
| tuple[type[BaseException], ...] = Exception,
):
self.num_retries = num_retries
self.include_traceback = include_traceback
self.catch_tool_errors = catch_tool_errors

@implements(call_assistant)
def _call_assistant[T, U](
self,
messages: collections.abc.Sequence[Message],
tools: collections.abc.Mapping[str, Tool],
response_format: Encodable[T, U],
model: str,
**kwargs,
) -> MessageResult[T]:
messages_list = list(messages)
last_attempt = self.num_retries

for attempt in range(self.num_retries + 1):
try:
message, tool_calls, result = fwd(
messages_list, tools, response_format, model, **kwargs
)

# Success! The returned message is the final successful response.
# Malformed messages from retries are only in messages_list,
# not in the returned result.
return (message, tool_calls, result)

except (ToolCallDecodingError, ResultDecodingError) as e:
# On last attempt, re-raise to preserve full traceback
if attempt == last_attempt:
raise

# Add the malformed assistant message
messages_list.append(e.raw_message)

# Add error feedback as a tool response
error_feedback: Message = e.to_feedback_message(self.include_traceback)
messages_list.append(error_feedback)

# Should never reach here - either we return on success or raise on final failure
raise AssertionError("Unreachable: retry loop exited without return or raise")

@implements(completion)
def _completion(self, *args, **kwargs) -> typing.Any:
"""Inject num_retries for litellm's built-in network error handling."""
return fwd(*args, num_retries=self.num_retries, **kwargs)

@implements(call_tool)
def _call_tool(self, tool_call: DecodedToolCall) -> Message:
"""Handle tool execution with runtime error capture.

Runtime errors from tool execution are captured and returned as
error messages to the LLM. Only exceptions matching `catch_tool_errors`
are caught; others propagate up.
"""
try:
return fwd(tool_call)
except self.catch_tool_errors as e:
error = ToolCallExecutionError(tool_call.tool.__name__, tool_call.id, e)
return error.to_feedback_message(self.include_traceback)


class LiteLLMProvider(ObjectInterpretation):
"""Implements templates using the LiteLLM API."""

Expand Down
Loading
Loading