diff --git a/effectful/handlers/llm/completions.py b/effectful/handlers/llm/completions.py index fc6ca47a..fd8e9eeb 100644 --- a/effectful/handlers/llm/completions.py +++ b/effectful/handlers/llm/completions.py @@ -160,8 +160,15 @@ def to_feedback_message(self, include_traceback: bool) -> Message: type MessageResult[T] = tuple[Message, typing.Sequence[DecodedToolCall], T | None] +# prevent :func:`functools.wraps` from copying over `litellm.completion`'s type annotations, +# which contains unresolvable forward references without import aiohttp +_assn_without_type = [ + attr for attr in functools.WRAPPER_ASSIGNMENTS if attr != "__annotations__" +] + + @Operation.define -@functools.wraps(litellm.completion) +@functools.wraps(litellm.completion, assigned=_assn_without_type) def completion(*args, **kwargs) -> typing.Any: """Low-level LLM request. Handlers may log/modify requests and delegate via fwd(). @@ -172,6 +179,9 @@ def completion(*args, **kwargs) -> typing.Any: return litellm.completion(*args, **kwargs) +del _assn_without_type + + @Operation.define def call_assistant[T, U]( tools: collections.abc.Mapping[str, Tool], @@ -490,3 +500,167 @@ def _call[**P, T]( history.clear() history.update(history_copy) return typing.cast(T, result) + + +class LoggingListener(typing.Protocol): + """Interface for observing :class:`Tool`, :class:`Template`, and + completion call events. + + All methods are no-ops by default, allowing subclasses to + subscribe to only the events they care about. + """ + + def enter_tool_call[**P, Q](self, tool: Tool[P, Q]) -> None: + # can't just `pass` because that would mark the method as abstract + return None + + def exit_tool_call[**P, Q](self, tool: Tool[P, Q], result: Q | None) -> None: + return None + + def enter_template_call[**P, Q](self, template: Template[P, Q]) -> None: + return None + + def exit_template_call[**P, Q]( + self, template: Template[P, Q], result: Q | None + ) -> None: + return None + + def enter_completion(self) -> None: + return None + + def exit_completion(self, resp: typing.Any) -> None: + return None + + +class LoggingHandler(ObjectInterpretation): + """Effect handler that wraps :class:`Tool`, :class:`Template`, and completion calls + and invokes callback functions registered in :attr:`listener`. + + Compose with a provider via :func:`coproduct` or nested :func:`handler` + context managers to add logging without modifying the provider logic:: + + listener = ThinkingElapsedListener() + obs = LoggingHandler(listener) + with handler(provider), handler(obs): + result = my_template() + """ + + def __init__(self, listener: LoggingListener): + self.listener = listener + + @implements(completion) + def _completion(self, *args, **kwargs) -> typing.Any: + self.listener.enter_completion() + response: typing.Any = None + try: + response = fwd(*args, **kwargs) + return response + finally: + self.listener.exit_completion(response) + + @implements(Tool.__apply__) + def _call_tool[**P, T]( + self, tool: Tool[P, T], *args: P.args, **kwargs: P.kwargs + ) -> T: + result_opt: T | None = None + try: + self.listener.enter_tool_call(tool) + result = typing.cast(T, fwd(tool, *args, **kwargs)) + result_opt = result + return result + finally: + self.listener.exit_tool_call(tool, result_opt) + + @implements(Template.__apply__) + def _call_template[**P, T]( + self, template: Template[P, T], *args: P.args, **kwargs: P.kwargs + ) -> T: + result_opt: T | None = None + try: + self.listener.enter_template_call(template) + result = typing.cast(T, fwd(template, *args, **kwargs)) + result_opt = result + return result + finally: + self.listener.exit_template_call(template, result_opt) + + +class EmptyCallStackException(Exception): + """Raised when accessing the call stack is empty.""" + + pass + + +class NoTemplateException(Exception): + """Raised when accessing the call stack does not have a :class:`Template`.""" + + pass + + +@dataclasses.dataclass(frozen=True) +class CallInfo[F: Tool[..., typing.Any]]: + func: F + info: dict[typing.Any, typing.Any] + + +class CallStackListener(LoggingListener): + """Listener that maintains a call stack of active Tool and Template calls. + + The call stack can be accessed directly through :attr:`callstack`. + The methods :meth:`current_function` and :meth:`current_template` + are provided for convenience to access the function (including + both templates and tools) or template that is currently executing + (i.e. on top of the call stack). + """ + + def __init__(self) -> None: + self.callstack: list[CallInfo[Tool[..., typing.Any]]] = [] + + @typing.override + def enter_tool_call[**P, Q](self, tool: Tool[P, Q]) -> None: + super().enter_tool_call(tool) + self.callstack.append(CallInfo(tool, {})) + + @typing.override + def exit_tool_call[**P, Q](self, tool: Tool[P, Q], result: Q | None) -> None: + assert len(self.callstack) > 0 and tool is self.callstack[-1].func + self.callstack.pop() + super().exit_tool_call(tool, result) + + @typing.override + def enter_template_call[**P, Q](self, template: Template[P, Q]) -> None: + super().enter_template_call(template) + self.callstack.append(CallInfo(template, {})) + + @typing.override + def exit_template_call[**P, Q]( + self, template: Template[P, Q], result: Q | None + ) -> None: + assert len(self.callstack) > 0 and template is self.callstack[-1].func + self.callstack.pop() + super().exit_template_call(template, result) + + def current_func_info(self) -> CallInfo[Tool[..., typing.Any]]: + """Return the innermost active :class:`Tool` or :class:`Template`. + + :raises EmptyCallStackException: if the call stack is empty. + """ + try: + return self.callstack[-1] + except IndexError: + raise EmptyCallStackException() + + def current_template_info(self) -> CallInfo[Template[..., typing.Any]]: + """Return the innermost active :class:`Template`, skipping any nested :class:`Tool`s. + + :raises NoTemplateException: if no :class:`Template` is on the stack. + """ + try: + # need to repack CallInfo to make the typechecker happy + return next( + CallInfo(ci.func, ci.info) + for ci in reversed(self.callstack) + if isinstance(ci.func, Template) + ) + except StopIteration: + raise NoTemplateException() diff --git a/obs-examples/thinking_elapsed.py b/obs-examples/thinking_elapsed.py new file mode 100644 index 00000000..73c2cd9d --- /dev/null +++ b/obs-examples/thinking_elapsed.py @@ -0,0 +1,125 @@ +import dataclasses + +from collections import defaultdict +from functools import reduce +from collections.abc import Callable +from typing import Any, Hashable, override + +from effectful.handlers.llm import Tool, Template +from effectful.ops.types import NotHandled +from effectful.handlers.llm.completions import ( + LiteLLMProvider, LoggingHandler, LoggingListener, + CallStackListener, CallInfo, +) +from effectful.ops.semantics import coproduct, handler +from time import time + + +@dataclasses.dataclass +class ThinkingRecord: + """A single thinking/reasoning extraction paired with its source template.""" + template: Template[..., Any] + reasoning_content: str | None + thinking_blocks: list[Any] | None + + +class ThinkingListener(LoggingListener): + """Extracts thinking and reasoning content from litellm completion responses.""" + + def __init__( + self, + get_template_info: Callable[[], CallInfo[Template[..., Any]]], + ) -> None: + self.thinking_records: list[ThinkingRecord] = [] + self._get_template_info = get_template_info + + @override + def exit_completion(self, resp: Any) -> None: + if resp is not None: + message = resp.choices[0].message + reasoning_content = message.get("reasoning_content") + thinking_blocks = message.get("thinking_blocks") + if reasoning_content or thinking_blocks: + self.thinking_records.append( + ThinkingRecord( + template=self._get_template_info().func, + reasoning_content=reasoning_content, + thinking_blocks=thinking_blocks, + ) + ) + + +class ElapsedListener(LoggingListener): + """Tracks the elapsed time of each :class:`Template` call.""" + + def __init__( + self, + get_func_info: Callable[[], CallInfo[Tool[..., Any]]], + ) -> None: + self.elapsed: defaultdict[Hashable, float] = defaultdict(float) + self._get_func_info = get_func_info + + @override + def enter_completion(self) -> None: + self._get_func_info().info['time'] = time() + + @override + def exit_completion(self, resp: Any) -> None: + func_info = self._get_func_info() + time_elapsed = time() - func_info.info['time'] + self.elapsed[func_info.func] += time_elapsed + + +@Template.define +def find_treasure() -> str: + """Ask Bob to find where the treasure is.""" + raise NotHandled + +@Template.define +def bob() -> str: + """Ask Alice to find where the treasure is.""" + raise NotHandled + +@Tool.define +def alice() -> str: + """Returns where the treasure is.""" + return "school" + +@Template.define +def pick_fruit() -> str: + """Return the name of a fruit.""" + raise NotHandled + + +def test_handler(): + provider = LiteLLMProvider( + model='anthropic/claude-sonnet-4-20250514', + thinking={"type": "enabled", "budget_tokens": 1024} + ) + + callstack = CallStackListener() + thinking = ThinkingListener(callstack.current_template_info) + elapsed = ElapsedListener(callstack.current_func_info) + + combined = reduce(coproduct, [ + provider, + LoggingHandler(thinking), + LoggingHandler(elapsed), + LoggingHandler(callstack), + ]) + + with handler(combined): + print(pick_fruit()) + print(find_treasure()) + + print('----------------------------------------') + for record in thinking.thinking_records: + print(record) + + print('----------------------------------------') + for func, elapsed_time in elapsed.elapsed.items(): + print(f"{func}:{elapsed_time:.2f}s") + + +if __name__ == '__main__': + test_handler() diff --git a/obs-examples/thinking_elapsed_multi.py b/obs-examples/thinking_elapsed_multi.py new file mode 100644 index 00000000..2c249867 --- /dev/null +++ b/obs-examples/thinking_elapsed_multi.py @@ -0,0 +1,115 @@ +import dataclasses + +from collections import defaultdict +from typing import Any, Hashable, override + +from effectful.handlers.llm import Tool, Template +from effectful.ops.types import NotHandled +from effectful.handlers.llm.completions import completion +from effectful.ops.semantics import fwd, coproduct, handler +from effectful.handlers.llm.completions import ( + LiteLLMProvider, LoggingHandler, CallStackListener +) +from time import time + + +@dataclasses.dataclass +class ThinkingRecord: + """A single thinking/reasoning extraction paired with its source template.""" + template: Template[...,Any] + reasoning_content: str | None + thinking_blocks: list[Any] | None + + +class ThinkingListener(CallStackListener): + """Extracts thinking and reasoning content from litellm completion responses.""" + + def __init__(self) -> None: + super().__init__() + self.thinking_records: list[ThinkingRecord] = [] + + @override + def exit_completion(self, resp: Any) -> None: + if resp is not None: + message = resp.choices[0].message + reasoning_content = message.get("reasoning_content") + thinking_blocks = message.get("thinking_blocks") + if reasoning_content or thinking_blocks: + self.thinking_records.append( + ThinkingRecord( + template=self.current_template_info().func, + reasoning_content=reasoning_content, + thinking_blocks=thinking_blocks, + ) + ) + super().exit_completion(resp) + +class ElapsedListener(CallStackListener): + """Tracks the elapsed time of each :class:`Tool` or :class:`Template` call.""" + + def __init__(self) -> None: + super().__init__() + self.elapsed:defaultdict[Hashable,float] = defaultdict(float) + + @override + def enter_completion(self): + super().enter_completion() + self.current_func_info().info['time'] = time() + + @override + def exit_completion(self, resp: Any) -> None: + time_elapsed = time() - self.current_func_info().info['time'] + self.elapsed[self.current_func_info().func] += time_elapsed + super().exit_completion(resp) + + +class ThinkingElapsedListener(ThinkingListener, ElapsedListener): + """Combines thinking extraction and elapsed time tracking.""" + def __init__(self): + super().__init__() + + +@Template.define +def find_treasure() -> str: + """Ask Bob to find where the treasure is.""" + raise NotHandled + +@Template.define +def bob() -> str: + """Ask Alice to find where the treasure is.""" + raise NotHandled + +@Tool.define +def alice() -> str: + """Returns where the treasure is.""" + return "school" + +@Template.define +def pick_fruit() -> str: + """Return the name of a fruit.""" + raise NotHandled + +def test_handler(): + provider = LiteLLMProvider( + model='anthropic/claude-sonnet-4-20250514', + thinking={"type": "enabled", "budget_tokens": 1024} + ) + + listener = ThinkingElapsedListener() + obsprovider = LoggingHandler(listener) + + + with handler(provider), handler(obsprovider): + print(pick_fruit()) + print(find_treasure()) + + print('----------------------------------------') + for thinking in listener.thinking_records: + print(thinking) + + print('----------------------------------------') + for func, time in listener.elapsed.items(): + print(f"{func}:{time:.2f}s") + +if __name__ == '__main__': + test_handler()