Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 175 additions & 1 deletion effectful/handlers/llm/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,15 @@ def to_feedback_message(self, include_traceback: bool) -> Message:
type MessageResult[T] = tuple[Message, typing.Sequence[DecodedToolCall], T | None]


# prevent :func:`functools.wraps` from copying over `litellm.completion`'s type annotations,
# which contains unresolvable forward references without import aiohttp
_assn_without_type = [
attr for attr in functools.WRAPPER_ASSIGNMENTS if attr != "__annotations__"
]


@Operation.define
@functools.wraps(litellm.completion)
@functools.wraps(litellm.completion, assigned=_assn_without_type)
def completion(*args, **kwargs) -> typing.Any:
"""Low-level LLM request. Handlers may log/modify requests and delegate via fwd().

Expand All @@ -172,6 +179,9 @@ def completion(*args, **kwargs) -> typing.Any:
return litellm.completion(*args, **kwargs)


del _assn_without_type


@Operation.define
def call_assistant[T, U](
tools: collections.abc.Mapping[str, Tool],
Expand Down Expand Up @@ -490,3 +500,167 @@ def _call[**P, T](
history.clear()
history.update(history_copy)
return typing.cast(T, result)


class LoggingListener(typing.Protocol):
"""Interface for observing :class:`Tool`, :class:`Template`, and
completion call events.

All methods are no-ops by default, allowing subclasses to
subscribe to only the events they care about.
"""

def enter_tool_call[**P, Q](self, tool: Tool[P, Q]) -> None:
# can't just `pass` because that would mark the method as abstract
return None

def exit_tool_call[**P, Q](self, tool: Tool[P, Q], result: Q | None) -> None:
return None

def enter_template_call[**P, Q](self, template: Template[P, Q]) -> None:
return None

def exit_template_call[**P, Q](
self, template: Template[P, Q], result: Q | None
) -> None:
return None

def enter_completion(self) -> None:
return None

def exit_completion(self, resp: typing.Any) -> None:
return None


class LoggingHandler(ObjectInterpretation):
"""Effect handler that wraps :class:`Tool`, :class:`Template`, and completion calls
and invokes callback functions registered in :attr:`listener`.

Compose with a provider via :func:`coproduct` or nested :func:`handler`
context managers to add logging without modifying the provider logic::

listener = ThinkingElapsedListener()
obs = LoggingHandler(listener)
with handler(provider), handler(obs):
result = my_template()
"""

def __init__(self, listener: LoggingListener):
self.listener = listener

@implements(completion)
def _completion(self, *args, **kwargs) -> typing.Any:
self.listener.enter_completion()
response: typing.Any = None
try:
response = fwd(*args, **kwargs)
return response
finally:
self.listener.exit_completion(response)

@implements(Tool.__apply__)
def _call_tool[**P, T](
self, tool: Tool[P, T], *args: P.args, **kwargs: P.kwargs
) -> T:
result_opt: T | None = None
try:
self.listener.enter_tool_call(tool)
result = typing.cast(T, fwd(tool, *args, **kwargs))
result_opt = result
return result
finally:
self.listener.exit_tool_call(tool, result_opt)

@implements(Template.__apply__)
def _call_template[**P, T](
self, template: Template[P, T], *args: P.args, **kwargs: P.kwargs
) -> T:
result_opt: T | None = None
try:
self.listener.enter_template_call(template)
result = typing.cast(T, fwd(template, *args, **kwargs))
result_opt = result
return result
finally:
self.listener.exit_template_call(template, result_opt)


class EmptyCallStackException(Exception):
"""Raised when accessing the call stack is empty."""

pass


class NoTemplateException(Exception):
"""Raised when accessing the call stack does not have a :class:`Template`."""

pass


@dataclasses.dataclass(frozen=True)
class CallInfo[F: Tool[..., typing.Any]]:
func: F
info: dict[typing.Any, typing.Any]


class CallStackListener(LoggingListener):
"""Listener that maintains a call stack of active Tool and Template calls.

The call stack can be accessed directly through :attr:`callstack`.
The methods :meth:`current_function` and :meth:`current_template`
are provided for convenience to access the function (including
both templates and tools) or template that is currently executing
(i.e. on top of the call stack).
"""

def __init__(self) -> None:
self.callstack: list[CallInfo[Tool[..., typing.Any]]] = []

@typing.override
def enter_tool_call[**P, Q](self, tool: Tool[P, Q]) -> None:
super().enter_tool_call(tool)
self.callstack.append(CallInfo(tool, {}))

@typing.override
def exit_tool_call[**P, Q](self, tool: Tool[P, Q], result: Q | None) -> None:
assert len(self.callstack) > 0 and tool is self.callstack[-1].func
self.callstack.pop()
super().exit_tool_call(tool, result)

@typing.override
def enter_template_call[**P, Q](self, template: Template[P, Q]) -> None:
super().enter_template_call(template)
self.callstack.append(CallInfo(template, {}))

@typing.override
def exit_template_call[**P, Q](
self, template: Template[P, Q], result: Q | None
) -> None:
assert len(self.callstack) > 0 and template is self.callstack[-1].func
self.callstack.pop()
super().exit_template_call(template, result)

def current_func_info(self) -> CallInfo[Tool[..., typing.Any]]:
"""Return the innermost active :class:`Tool` or :class:`Template`.

:raises EmptyCallStackException: if the call stack is empty.
"""
try:
return self.callstack[-1]
except IndexError:
raise EmptyCallStackException()

def current_template_info(self) -> CallInfo[Template[..., typing.Any]]:
"""Return the innermost active :class:`Template`, skipping any nested :class:`Tool`s.

:raises NoTemplateException: if no :class:`Template` is on the stack.
"""
try:
# need to repack CallInfo to make the typechecker happy
return next(
CallInfo(ci.func, ci.info)
for ci in reversed(self.callstack)
if isinstance(ci.func, Template)
)
except StopIteration:
raise NoTemplateException()
125 changes: 125 additions & 0 deletions obs-examples/thinking_elapsed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import dataclasses

from collections import defaultdict
from functools import reduce
from collections.abc import Callable
from typing import Any, Hashable, override

from effectful.handlers.llm import Tool, Template
from effectful.ops.types import NotHandled
from effectful.handlers.llm.completions import (
LiteLLMProvider, LoggingHandler, LoggingListener,
CallStackListener, CallInfo,
)
from effectful.ops.semantics import coproduct, handler
from time import time


@dataclasses.dataclass
class ThinkingRecord:
"""A single thinking/reasoning extraction paired with its source template."""
template: Template[..., Any]
reasoning_content: str | None
thinking_blocks: list[Any] | None


class ThinkingListener(LoggingListener):
"""Extracts thinking and reasoning content from litellm completion responses."""

def __init__(
self,
get_template_info: Callable[[], CallInfo[Template[..., Any]]],
) -> None:
self.thinking_records: list[ThinkingRecord] = []
self._get_template_info = get_template_info

@override
def exit_completion(self, resp: Any) -> None:
if resp is not None:
message = resp.choices[0].message
reasoning_content = message.get("reasoning_content")
thinking_blocks = message.get("thinking_blocks")
if reasoning_content or thinking_blocks:
self.thinking_records.append(
ThinkingRecord(
template=self._get_template_info().func,
reasoning_content=reasoning_content,
thinking_blocks=thinking_blocks,
)
)


class ElapsedListener(LoggingListener):
"""Tracks the elapsed time of each :class:`Template` call."""

def __init__(
self,
get_func_info: Callable[[], CallInfo[Tool[..., Any]]],
) -> None:
self.elapsed: defaultdict[Hashable, float] = defaultdict(float)
self._get_func_info = get_func_info

@override
def enter_completion(self) -> None:
self._get_func_info().info['time'] = time()

@override
def exit_completion(self, resp: Any) -> None:
func_info = self._get_func_info()
time_elapsed = time() - func_info.info['time']
self.elapsed[func_info.func] += time_elapsed


@Template.define
def find_treasure() -> str:
"""Ask Bob to find where the treasure is."""
raise NotHandled

@Template.define
def bob() -> str:
"""Ask Alice to find where the treasure is."""
raise NotHandled

@Tool.define
def alice() -> str:
"""Returns where the treasure is."""
return "school"

@Template.define
def pick_fruit() -> str:
"""Return the name of a fruit."""
raise NotHandled


def test_handler():
provider = LiteLLMProvider(
model='anthropic/claude-sonnet-4-20250514',
thinking={"type": "enabled", "budget_tokens": 1024}
)

callstack = CallStackListener()
thinking = ThinkingListener(callstack.current_template_info)
elapsed = ElapsedListener(callstack.current_func_info)

combined = reduce(coproduct, [
provider,
LoggingHandler(thinking),
LoggingHandler(elapsed),
LoggingHandler(callstack),
])

with handler(combined):
print(pick_fruit())
print(find_treasure())

print('----------------------------------------')
for record in thinking.thinking_records:
print(record)

print('----------------------------------------')
for func, elapsed_time in elapsed.elapsed.items():
print(f"{func}:{elapsed_time:.2f}s")


if __name__ == '__main__':
test_handler()
Loading
Loading