diff --git a/effectful/handlers/llm/langfuse.py b/effectful/handlers/llm/langfuse.py new file mode 100644 index 00000000..5b687b45 --- /dev/null +++ b/effectful/handlers/llm/langfuse.py @@ -0,0 +1,104 @@ +import functools +import typing + +import litellm +from langfuse import get_client, observe + +from effectful.handlers.llm import Template, Tool +from effectful.handlers.llm.completions import ( + call_assistant, + call_system, + call_user, + completion, +) +from effectful.ops.semantics import fwd +from effectful.ops.syntax import ObjectInterpretation, implements + + +def _extract_generation_meta(result) -> dict[str, typing.Any]: + # helper function for populating usage metadata so they render + # more nicely on langfuse + usage = result.usage + if usage is None: + return {} + meta: dict[str, typing.Any] = {"model": getattr(result, "model", None)} + usage_details: dict[str, int] = {} + for k in ("prompt_tokens", "completion_tokens", "total_tokens"): + v = getattr(usage, k, None) + if v is not None: + usage_details[k] = v + if usage_details: + meta["usage_details"] = usage_details + try: + cost = litellm.completion_cost(completion_response=result) + meta["cost_details"] = {"total": cost} + except Exception: + pass + meta["metadata"] = {"response_id": getattr(result, "id", None)} + return meta + + +def _make_instrumented(op, as_type): + @observe(as_type=as_type) + @functools.wraps(op) + def wrapper(*args, **kwargs): + return fwd(op, *args, **kwargs) + + return wrapper + + +class LangfuseProvider(ObjectInterpretation): + """Traces Tool, Template, and completion calls with Langfuse. + + Compose with a provider via :func:`~effectful.ops.semantics.handler` + to add tracing:: + + with handler(provider), handler(LangfuseProvider()): + print(limerick(theme)) + """ + + def __init__(self): + self.langfuse = get_client() + # cache each template instead of repeatedly instrumenting it + self._get_instrumented = functools.cache(_make_instrumented) + + @implements(completion) + @observe(as_type="generation") + def completion(self, *args, **kwargs): + messages = kwargs.get("messages") + result = fwd(*args, **kwargs) + meta = _extract_generation_meta(result) + + # populate messages as part of the langfuse metadata so we get + # the nice rendering of "Assistant", "User", and "System" + # messages + if messages is not None: + meta["input"] = messages + choice = result.choices[0] if result.choices else None + if choice is not None: + meta["output"] = choice.message.model_dump(mode="json", exclude_none=True) + self.langfuse.update_current_generation(**meta) + return result + + @implements(call_user) + @observe() + def call_user(self, template, env): + return fwd(template, env) + + @implements(call_system) + @observe() + def call_system(self, template): + return fwd(template) + + @implements(call_assistant) + @observe() + def call_assistant(self, tools, response_format, model, **kwargs): + return fwd(tools, response_format, model, **kwargs) + + @implements(Tool.__apply__) + def call_tool(self, tool, *args, **kwargs): + return self._get_instrumented(tool, "tool")(*args, **kwargs) + + @implements(Template.__apply__) + def call_template(self, template, *args, **kwargs): + return self._get_instrumented(template, "generation")(*args, **kwargs) diff --git a/effectful/handlers/llm/weave.py b/effectful/handlers/llm/weave.py new file mode 100644 index 00000000..fe9e7dca --- /dev/null +++ b/effectful/handlers/llm/weave.py @@ -0,0 +1,60 @@ +import functools + +import weave + +from effectful.handlers.llm import Template, Tool +from effectful.handlers.llm.completions import ( + call_assistant, + call_system, + call_user, +) +from effectful.ops.semantics import fwd +from effectful.ops.syntax import ObjectInterpretation, implements + + +def _make_instrumented(op): + @weave.op() + @functools.wraps(op) + def wrapper(*args, **kwargs): + return fwd(op, *args, **kwargs) + + return wrapper + + +class WeaveProvider(ObjectInterpretation): + """Traces Tool, Template, and message-level calls with Weights & Biases Weave. + + Compose with a provider via :func:`~effectful.ops.semantics.handler` + to add tracing:: + + weave.init("my-project") + with handler(provider), handler(WeaveProvider()): + print(limerick(theme)) + """ + + def __init__(self): + # cache each template instead of repeatedly instrumenting it + self._get_instrumented = functools.cache(_make_instrumented) + + @implements(call_user) + @weave.op() + def call_user(self, template, env): + return fwd(template, env) + + @implements(call_system) + @weave.op() + def call_system(self, template): + return fwd(template) + + @implements(call_assistant) + @weave.op() + def call_assistant(self, tools, response_format, model, **kwargs): + return fwd(tools, response_format, model, **kwargs) + + @implements(Tool.__apply__) + def call_tool(self, tool, *args, **kwargs): + return self._get_instrumented(tool)(*args, **kwargs) + + @implements(Template.__apply__) + def call_template(self, template, *args, **kwargs): + return self._get_instrumented(template)(*args, **kwargs) diff --git a/pyproject.toml b/pyproject.toml index d6219f62..cb5adff7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,8 @@ llm = [ "typing_extensions", "restrictedpython>=8.1" ] +llm-weave = ["effectful[llm]", "weave"] +llm-langfuse = ["effectful[llm]", "langfuse"] prettyprinter = ["prettyprinter"] docs = [ "effectful[torch,pyro,jax,numpyro,llm,prettyprinter]", @@ -74,7 +76,7 @@ test = [ ] [dependency-groups] -dev = ["effectful[torch,pyro,jax,numpyro,llm,docs,test]"] +dev = ["effectful[torch,pyro,jax,numpyro,llm,llm-weave,llm-langfuse,docs,test]"] [tool.ruff] target-version = "py312"