diff --git a/effectful/handlers/llm/synthesis.py b/effectful/handlers/llm/synthesis.py index 3db32fd7..b016291a 100644 --- a/effectful/handlers/llm/synthesis.py +++ b/effectful/handlers/llm/synthesis.py @@ -1,4 +1,30 @@ -from effectful.ops.syntax import ObjectInterpretation +import collections +import collections.abc +import inspect +import linecache +import textwrap +import typing +from collections import ChainMap +from collections.abc import Callable +from typing import Any + +import pydantic +from pydantic import Field + +from effectful.handlers.llm import Template +from effectful.handlers.llm.completions import ( + InstructionHandler, + OpenAIMessageContentListBlock, +) +from effectful.handlers.llm.encoding import EncodableAs, type_to_encodable_type +from effectful.ops.semantics import NotHandled, fwd, handler +from effectful.ops.syntax import ObjectInterpretation, defop, implements + + +@defop +def get_synthesis_context() -> ChainMap[str, Any] | None: + """Get the current synthesis context for decoding synthesized code.""" + raise NotHandled class SynthesisError(Exception): @@ -9,6 +35,143 @@ def __init__(self, message, code=None): self.code = code +class SynthesizedFunction(pydantic.BaseModel): + """Structured output for function synthesis. + + Pydantic model representing synthesized code with function name and module code. + """ + + function_name: str = Field( + ..., + description="The name of the main function that satisfies the specification", + ) + module_code: str = Field( + ..., + description="Complete Python module code (no imports needed)", + ) + + +@type_to_encodable_type.register(collections.abc.Callable) +class EncodableSynthesizedFunction( + EncodableAs[Callable, SynthesizedFunction], +): + """Encodes Callable to SynthesizedFunction and vice versa.""" + + t = SynthesizedFunction + + @classmethod + def encode( + cls, vl: Callable, context: ChainMap[str, Any] | None = None + ) -> SynthesizedFunction: + """Encode a Callable to a SynthesizedFunction. + + Extracts the function name and source code. + """ + func_name = vl.__name__ + try: + source = inspect.getsource(vl) + except (OSError, TypeError): + # If we can't get source, create a minimal representation + try: + sig = inspect.signature(vl) + source = f"def {func_name}{sig}:\n pass # Source unavailable" + except (ValueError, TypeError): + source = f"def {func_name}(...):\n pass # Source unavailable" + + return SynthesizedFunction( + function_name=func_name, module_code=textwrap.dedent(source).strip() + ) + + # Counter for unique filenames + _decode_counter: typing.ClassVar[int] = 0 + + @classmethod + def decode(cls, vl: SynthesizedFunction) -> Callable: + """Decode a SynthesizedFunction to a Callable. + + Executes the module code and returns the named function. + """ + context: ChainMap[str, Any] | None = get_synthesis_context() + func_name = vl.function_name + module_code = textwrap.dedent(vl.module_code).strip() + + cls._decode_counter += 1 + filename = f"" + lines = module_code.splitlines(keepends=True) + # Ensure last line has newline for linecache + if lines and not lines[-1].endswith("\n"): + lines[-1] += "\n" + linecache.cache[filename] = ( + len(module_code), + None, + lines, + filename, + ) + + # Start with provided context or empty dict + exec_globals: dict[str, typing.Any] = {} + if context is not None: + exec_globals.update(context) + + try: + code_obj = compile(module_code, filename, "exec") + exec(code_obj, exec_globals) + except SyntaxError as exc: + raise SynthesisError( + f"Syntax error in generated code: {exc}", module_code + ) from exc + except Exception as exc: + raise SynthesisError(f"Evaluation failed: {exc!r}", module_code) from exc + + if func_name not in exec_globals: + raise SynthesisError( + f"Function '{func_name}' not found after execution. " + f"Available names: {[k for k in exec_globals.keys() if not k.startswith('_')]}", + module_code, + ) + + func = exec_globals[func_name] + # Also attach source code directly for convenience + func.__source__ = module_code + func.__synthesized__ = vl + return func + + @classmethod + def serialize(cls, vl: SynthesizedFunction) -> list[OpenAIMessageContentListBlock]: + return [{"type": "text", "text": vl.model_dump_json()}] + + class ProgramSynthesis(ObjectInterpretation): - def __init__(self, *args, **kwargs): - raise NotImplementedError + """Provides a `template` handler to instruct the LLM to generate code of the + right form and with the right type. + + """ + + @implements(Template.__apply__) + def _call(self, template, *args, **kwargs) -> None: + ret_type = template.__signature__.return_annotation + origin = typing.get_origin(ret_type) + ret_type = ret_type if origin is None else origin + + if not (issubclass(ret_type, collections.abc.Callable)): # type: ignore[arg-type] + return fwd() + + prompt_ext = textwrap.dedent(f""" + Given the specification above, generate a Python function satisfying the following specification and type signature. + + {str(ret_type)} + + + 1. Produce one block of Python code. + 2. Do not include usage examples. + 3. Return your response in tags. + 4. Do not return your response in markdown blocks. + 5. Your output function def must be the final statement in the code block. + + """).strip() + + with ( + handler(InstructionHandler(prompt_ext)), + handler({get_synthesis_context: lambda: template.__context__}), + ): + return fwd() diff --git a/tests/test_handlers_llm.py b/tests/test_handlers_llm.py index 4ad3d81c..624bea57 100644 --- a/tests/test_handlers_llm.py +++ b/tests/test_handlers_llm.py @@ -1,11 +1,16 @@ +import json from collections.abc import Callable from typing import Annotated import pytest +from litellm import Choices, Message +from litellm.types.utils import ModelResponse from effectful.handlers.llm import Template from effectful.handlers.llm.completions import ( + LiteLLMProvider, RetryLLMHandler, + completion, compute_response, format_model_input, ) @@ -44,22 +49,24 @@ def _call[**P]( return response -class SingleResponseLLMProvider[T](ObjectInterpretation): - """Simplified mock provider that returns a single response for any prompt.""" +class SingleResponseLLMProvider[T](LiteLLMProvider): + """Mock provider that reuses LiteLLMProvider and overrides completion.""" def __init__(self, response: T): - """Initialize with a single response string. - - Args: - response: The response to return for any template call - """ + """Initialize with a response value.""" + super().__init__(model_name="mock") self.response = response - @implements(Template.__apply__) - def _call[**P]( - self, template: Template[P, T], *args: P.args, **kwargs: P.kwargs - ) -> T: - return self.response + @implements(completion) + def _completion(self, *args, **kwargs) -> ModelResponse: + result = ( + self.response + if isinstance(self.response, str) + else json.dumps({"value": self.response}) + ) + message = Message(role="assistant", content=result) + choice = Choices(index=0, message=message, finish_reason="stop") + return ModelResponse(model="mock", choices=[choice]) # Test templates from the notebook examples @@ -124,18 +131,18 @@ def test_primes_decode_int(): assert isinstance(result, int) -@pytest.mark.xfail(reason="Synthesis handler not yet implemented") def test_count_char_with_program_synthesis(): """Test the count_char template with program synthesis.""" - mock_code = """ -def count_occurrences(s): - return s.count('a') -""" - mock_provider = SingleResponseLLMProvider(mock_code) + mock_provider = SingleResponseLLMProvider( + { + "function_name": "count_occurrences", + "module_code": "def count_occurrences(s):\n return s.count('a')", + } + ) with handler(mock_provider), handler(ProgramSynthesis()): count_a = count_char("a") - assert callable(count_a) + assert callable(count_a), f"count_a is not callable: {count_a}" assert count_a("banana") == 3 assert count_a("cherry") == 0