Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion effectful/handlers/llm/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,9 @@ def format_model_input[**P, T](
origin = typing.get_origin(ret_type)
ret_type = ret_type if origin is None else origin
ret_type_encoder = type_to_encodable_type(ret_type)
prompt_prefix = ret_type_encoder.encoding_instructions()
prompt_prefix = "\n".join(
ret_type_encoder.encoding_instructions(template.__context__)
)

if prompt_prefix:
prefix: list[ChatCompletionTextObject] = [
Expand Down
25 changes: 22 additions & 3 deletions effectful/handlers/llm/encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import io
import typing
from abc import ABC, abstractmethod
from collections.abc import Callable, Mapping
from collections.abc import Callable, Mapping, Sequence
from typing import Any

import pydantic
Expand Down Expand Up @@ -43,9 +43,9 @@ def decode(cls, vl: U, env: Mapping[str, Any] | None = None) -> T:
pass

@classmethod
def encoding_instructions(cls) -> str | None:
def encoding_instructions(cls, env: Mapping[str, Any]) -> Sequence[str]:
"""Optional instructions to be prefixed onto synthesis prompts to tune the encoding of the result."""
return None
return []

@classmethod
def serialize(cls, value: U) -> list[OpenAIMessageContentListBlock]:
Expand Down Expand Up @@ -191,6 +191,16 @@ def decode(cls, t: typing.Any, env: Mapping[str, Any] | None = None) -> T:
]
return typing.cast(T, tuple(decoded_elements))

@classmethod
def encoding_instructions(cls, env: Mapping[str, Any]) -> Sequence[str]:
return list(
{
instruction
for enc in element_encoders
for instruction in enc.encoding_instructions(env)
}
)

@classmethod
def serialize(cls, value: typing.Any) -> list[OpenAIMessageContentListBlock]:
if has_image:
Expand Down Expand Up @@ -248,6 +258,15 @@ def decode(cls, t: typing.Any, env: Mapping[str, Any] | None = None) -> T:
]
return typing.cast(T, decoded_elements)

@classmethod
def encoding_instructions(cls, env: Mapping[str, Any]) -> Sequence[str]:
return list(
{
instruction
for instruction in element_encoder.encoding_instructions(env)
}
)

@classmethod
def serialize(cls, value: typing.Any) -> list[OpenAIMessageContentListBlock]:
if has_image:
Expand Down
26 changes: 26 additions & 0 deletions tests/test_handlers_llm_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,3 +718,29 @@ class Person(pydantic.BaseModel):
assert decoded_from_model == person
assert isinstance(decoded_from_model, Person)
assert isinstance(decoded_from_model.address, Address)


@pytest.mark.parametrize(
"ty,env",
[
(int, {}),
(str, {"key": "value"}),
(tuple[int, str], {}),
(list[str], {}),
(tuple[int, str, bool], {}),
(tuple[int, int], {}),
],
)
def test_encoding_instructions(ty, env):
"""Test that encoding_instructions accepts env parameter and returns a list.

Tests various types including primitives, tuples, and lists to ensure:
- The method accepts an env parameter
- Returns a Sequence[str] (list)
- Returns empty list for default encoders
"""
encodable = type_to_encodable_type(ty)
instructions = encodable.encoding_instructions(env)
assert isinstance(instructions, list)
# Default implementation should return empty list regardless of env or type
assert instructions == []
Loading