Skip to content

Commit 6ca21dc

Browse files
feat: tool call serialization, deprecated param shims, and OpenResponses protocol
1 parent 1e9b2be commit 6ca21dc

File tree

17 files changed

+311
-140
lines changed

17 files changed

+311
-140
lines changed

src/celeste/modalities/text/client.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
"""Text modality client."""
22

3-
from typing import Any, Unpack
3+
import warnings
4+
from typing import Any, ClassVar, Unpack
45

56
from asgiref.sync import async_to_sync
67

78
from celeste.client import ModalityClient
89
from celeste.core import InputType, Modality
10+
from celeste.tools import CodeExecution, WebSearch, XSearch
911
from celeste.types import AudioContent, ImageContent, Message, TextContent, VideoContent
1012

1113
from .io import TextChunk, TextFinishReason, TextInput, TextOutput, TextUsage
@@ -25,11 +27,45 @@ class TextClient(
2527
_usage_class = TextUsage
2628
_finish_reason_class = TextFinishReason
2729

30+
# Deprecated param → Tool class mapping.
31+
# TODO(deprecation): Remove on 2026-06-07.
32+
_DEPRECATED_TOOL_PARAMS: ClassVar[dict[str, type]] = {
33+
"web_search": WebSearch,
34+
"x_search": XSearch,
35+
"code_execution": CodeExecution,
36+
}
37+
2838
@classmethod
2939
def _output_class(cls) -> type[TextOutput]:
3040
"""Return the Output class for text modality."""
3141
return TextOutput
3242

43+
def _build_request(
44+
self,
45+
inputs: TextInput,
46+
extra_body: dict[str, Any] | None = None,
47+
streaming: bool = False,
48+
**parameters: Unpack[TextParameters],
49+
) -> dict[str, Any]:
50+
"""Build request, migrating deprecated boolean tool params first.
51+
52+
TODO(deprecation): Remove this override on 2026-06-07.
53+
"""
54+
for old_param, tool_cls in self._DEPRECATED_TOOL_PARAMS.items():
55+
value = parameters.pop(old_param, None) # type: ignore[misc]
56+
if value:
57+
warnings.warn(
58+
f"'{old_param}=True' is deprecated, "
59+
f"use tools=[{tool_cls.__name__}()] instead. "
60+
"Will be removed on 2026-06-07.",
61+
DeprecationWarning,
62+
stacklevel=4,
63+
)
64+
parameters.setdefault("tools", []).append(tool_cls())
65+
return super()._build_request(
66+
inputs, extra_body=extra_body, streaming=streaming, **parameters
67+
)
68+
3369
def _check_media_support(
3470
self,
3571
image: ImageContent | None,

src/celeste/modalities/text/parameters.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@ class TextParameter(StrEnum):
2727
TOOLS = "tools"
2828
VERBOSITY = "verbosity"
2929

30+
# Deprecated: use tools=[WebSearch()], tools=[XSearch()], tools=[CodeExecution()] instead.
31+
# TODO(deprecation): Remove on 2026-06-07.
32+
WEB_SEARCH = "web_search"
33+
X_SEARCH = "x_search"
34+
CODE_EXECUTION = "code_execution"
35+
3036
# Media input declarations (for optional_input_types)
3137
IMAGE = "image"
3238
VIDEO = "video"
@@ -48,6 +54,12 @@ class TextParameters(Parameters):
4854
tools: list[ToolDefinition]
4955
verbosity: str
5056

57+
# Deprecated: use tools=[WebSearch()], tools=[XSearch()], tools=[CodeExecution()] instead.
58+
# TODO(deprecation): Remove on 2026-06-07.
59+
web_search: bool
60+
x_search: bool
61+
code_execution: bool
62+
5163

5264
__all__ = [
5365
"TextParameter",
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Text modality protocol implementations."""
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""OpenResponses protocol for text modality."""
2+
3+
from .client import OpenResponsesTextClient, OpenResponsesTextStream
4+
5+
__all__ = ["OpenResponsesTextClient", "OpenResponsesTextStream"]
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
"""OpenResponses text client."""
2+
3+
from typing import Any, Unpack
4+
5+
from celeste.parameters import ParameterMapper
6+
from celeste.protocols.openresponses.client import (
7+
OpenResponsesClient as OpenResponsesMixin,
8+
)
9+
from celeste.protocols.openresponses.streaming import (
10+
OpenResponsesStream as _OpenResponsesStream,
11+
)
12+
from celeste.protocols.openresponses.tools import (
13+
parse_content,
14+
parse_tool_calls,
15+
serialize_messages,
16+
)
17+
from celeste.tools import ToolCall
18+
from celeste.types import ImageContent, Message, TextContent, VideoContent
19+
from celeste.utils import build_image_data_url
20+
21+
from ...client import TextClient
22+
from ...io import (
23+
TextChunk,
24+
TextInput,
25+
TextOutput,
26+
)
27+
from ...parameters import TextParameters
28+
from ...streaming import TextStream
29+
from .parameters import OPENRESPONSES_PARAMETER_MAPPERS
30+
31+
32+
class OpenResponsesTextStream(_OpenResponsesStream, TextStream):
33+
"""OpenResponses streaming for text modality."""
34+
35+
def __init__(self, *args: Any, **kwargs: Any) -> None:
36+
super().__init__(*args, **kwargs)
37+
self._response_data: dict[str, Any] | None = None
38+
39+
def _parse_chunk(self, event_data: dict[str, Any]) -> TextChunk | None:
40+
"""Parse one SSE event into a typed chunk (captures response.completed)."""
41+
event_type = event_data.get("type")
42+
if event_type == "response.completed":
43+
response = event_data.get("response")
44+
if isinstance(response, dict):
45+
self._response_data = response
46+
return super()._parse_chunk(event_data)
47+
48+
def _aggregate_event_data(self, chunks: list[TextChunk]) -> list[dict[str, Any]]:
49+
"""Prepend response_data, then delegate to base."""
50+
events: list[dict[str, Any]] = []
51+
if self._response_data is not None:
52+
events.append(self._response_data)
53+
events.extend(super()._aggregate_event_data(chunks))
54+
return events
55+
56+
def _aggregate_tool_calls(
57+
self, chunks: list[TextChunk], raw_events: list[dict[str, Any]]
58+
) -> list[ToolCall]:
59+
"""Extract tool calls from response.completed data."""
60+
if self._response_data is None:
61+
return []
62+
return parse_tool_calls(self._response_data)
63+
64+
65+
class OpenResponsesTextClient(OpenResponsesMixin, TextClient):
66+
"""OpenResponses text client using Responses API."""
67+
68+
@classmethod
69+
def parameter_mappers(cls) -> list[ParameterMapper[TextContent]]:
70+
return OPENRESPONSES_PARAMETER_MAPPERS
71+
72+
async def generate(
73+
self,
74+
prompt: str | None = None,
75+
*,
76+
messages: list[Message] | None = None,
77+
base_url: str | None = None,
78+
extra_body: dict[str, Any] | None = None,
79+
**parameters: Unpack[TextParameters],
80+
) -> TextOutput:
81+
"""Generate text from prompt."""
82+
inputs = TextInput(prompt=prompt, messages=messages)
83+
return await self._predict(
84+
inputs, base_url=base_url, extra_body=extra_body, **parameters
85+
)
86+
87+
async def analyze(
88+
self,
89+
prompt: str | None = None,
90+
*,
91+
messages: list[Message] | None = None,
92+
image: ImageContent | None = None,
93+
video: VideoContent | None = None,
94+
base_url: str | None = None,
95+
extra_body: dict[str, Any] | None = None,
96+
**parameters: Unpack[TextParameters],
97+
) -> TextOutput:
98+
"""Analyze image(s) or video(s) with prompt or messages."""
99+
inputs = TextInput(prompt=prompt, messages=messages, image=image, video=video)
100+
return await self._predict(
101+
inputs, base_url=base_url, extra_body=extra_body, **parameters
102+
)
103+
104+
def _init_request(self, inputs: TextInput) -> dict[str, Any]:
105+
"""Initialize request with input content."""
106+
if inputs.messages is not None:
107+
return {"input": serialize_messages(inputs.messages)}
108+
109+
content: list[dict[str, Any]] = []
110+
if inputs.image is not None:
111+
images = inputs.image if isinstance(inputs.image, list) else [inputs.image]
112+
for img in images:
113+
content.append(
114+
{"type": "input_image", "image_url": build_image_data_url(img)}
115+
)
116+
117+
content.append({"type": "input_text", "text": inputs.prompt or ""})
118+
return {"input": [{"role": "user", "content": content}]}
119+
120+
def _parse_content(
121+
self,
122+
response_data: dict[str, Any],
123+
) -> TextContent:
124+
"""Parse text content from response."""
125+
output = super()._parse_content(response_data)
126+
return parse_content(output)
127+
128+
def _parse_tool_calls(self, response_data: dict[str, Any]) -> list[ToolCall]:
129+
"""Parse tool calls from OpenResponses response."""
130+
return parse_tool_calls(response_data)
131+
132+
def _stream_class(self) -> type[TextStream]:
133+
"""Return the Stream class for this provider."""
134+
return OpenResponsesTextStream
135+
136+
137+
__all__ = ["OpenResponsesTextClient", "OpenResponsesTextStream"]
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
"""OpenResponses parameter mappers for text."""
2+
3+
from celeste.parameters import ParameterMapper
4+
from celeste.protocols.openresponses.parameters import (
5+
MaxOutputTokensMapper as _MaxOutputTokensMapper,
6+
)
7+
from celeste.protocols.openresponses.parameters import (
8+
TemperatureMapper as _TemperatureMapper,
9+
)
10+
from celeste.protocols.openresponses.parameters import (
11+
TextFormatMapper as _TextFormatMapper,
12+
)
13+
from celeste.protocols.openresponses.parameters import (
14+
ToolsMapper as _ToolsMapper,
15+
)
16+
from celeste.types import TextContent
17+
18+
from ...parameters import TextParameter
19+
20+
21+
class TemperatureMapper(_TemperatureMapper):
22+
"""Map temperature to Responses temperature parameter."""
23+
24+
name = TextParameter.TEMPERATURE
25+
26+
27+
class MaxTokensMapper(_MaxOutputTokensMapper):
28+
"""Map max_tokens to Responses max_output_tokens parameter."""
29+
30+
name = TextParameter.MAX_TOKENS
31+
32+
33+
class OutputSchemaMapper(_TextFormatMapper):
34+
"""Map output_schema to Responses text.format parameter."""
35+
36+
name = TextParameter.OUTPUT_SCHEMA
37+
38+
39+
class ToolsMapper(_ToolsMapper):
40+
"""Map tools to Responses tools parameter."""
41+
42+
name = TextParameter.TOOLS
43+
44+
45+
OPENRESPONSES_PARAMETER_MAPPERS: list[ParameterMapper[TextContent]] = [
46+
TemperatureMapper(),
47+
MaxTokensMapper(),
48+
OutputSchemaMapper(),
49+
ToolsMapper(),
50+
]
51+
52+
__all__ = ["OPENRESPONSES_PARAMETER_MAPPERS"]

src/celeste/modalities/text/providers/cohere/client.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,11 @@ def _init_request(self, inputs: TextInput) -> dict[str, Any]:
5959
"""Initialize request from Cohere v2 Chat API messages array format."""
6060
# If messages provided, use them directly (messages take precedence)
6161
if inputs.messages is not None:
62-
return {"messages": [message.model_dump() for message in inputs.messages]}
62+
return {
63+
"messages": [
64+
message.model_dump(exclude_none=True) for message in inputs.messages
65+
]
66+
}
6367

6468
# Fall back to prompt-based input
6569
if inputs.image is None:

src/celeste/modalities/text/providers/google/client.py

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,15 @@ def _aggregate_tool_calls(
3737
for candidate in event.get("candidates", []):
3838
for part in candidate.get("content", {}).get("parts", []):
3939
if "functionCall" in part:
40+
kwargs: dict[str, Any] = {}
41+
if "thoughtSignature" in part:
42+
kwargs["thoughtSignature"] = part["thoughtSignature"]
4043
tool_calls.append(
4144
ToolCall(
4245
id=str(uuid4()),
4346
name=part["functionCall"]["name"],
4447
arguments=part["functionCall"].get("args", {}),
48+
**kwargs,
4549
)
4650
)
4751
return tool_calls
@@ -139,14 +143,16 @@ def content_to_parts(content: Any) -> list[dict[str, Any]]:
139143
msg_parts = content_to_parts(msg.content)
140144
if msg.tool_calls:
141145
for tc in msg.tool_calls:
142-
msg_parts.append(
143-
{
144-
"functionCall": {
145-
"name": tc.name,
146-
"args": tc.arguments,
147-
}
146+
part: dict[str, Any] = {
147+
"functionCall": {
148+
"name": tc.name,
149+
"args": tc.arguments,
148150
}
149-
)
151+
}
152+
thought_sig = getattr(tc, "thoughtSignature", None)
153+
if thought_sig:
154+
part["thoughtSignature"] = thought_sig
155+
msg_parts.append(part)
150156
contents.append({"role": role, "parts": msg_parts})
151157

152158
result: dict[str, Any] = {"contents": contents}
@@ -219,24 +225,35 @@ def _parse_content(
219225
"""Parse content from response."""
220226
candidates = super()._parse_content(response_data)
221227
parts = candidates[0].get("content", {}).get("parts", [])
222-
text = parts[0].get("text") if parts else ""
223-
return text or ""
228+
for p in parts:
229+
if p.get("thought"):
230+
continue
231+
text = p.get("text")
232+
if text is not None:
233+
return text
234+
return ""
224235

225236
def _parse_tool_calls(self, response_data: dict[str, Any]) -> list[ToolCall]:
226237
"""Parse tool calls from Google response."""
227238
candidates = response_data.get("candidates", [])
228239
if not candidates:
229240
return []
230241
parts = candidates[0].get("content", {}).get("parts", [])
231-
return [
232-
ToolCall(
233-
id=str(uuid4()),
234-
name=p["functionCall"]["name"],
235-
arguments=p["functionCall"].get("args", {}),
236-
)
237-
for p in parts
238-
if "functionCall" in p
239-
]
242+
tool_calls: list[ToolCall] = []
243+
for p in parts:
244+
if "functionCall" in p:
245+
kwargs: dict[str, Any] = {}
246+
if "thoughtSignature" in p:
247+
kwargs["thoughtSignature"] = p["thoughtSignature"]
248+
tool_calls.append(
249+
ToolCall(
250+
id=str(uuid4()),
251+
name=p["functionCall"]["name"],
252+
arguments=p["functionCall"].get("args", {}),
253+
**kwargs,
254+
)
255+
)
256+
return tool_calls
240257

241258
def _stream_class(self) -> type[TextStream]:
242259
"""Return the Stream class for this provider."""

src/celeste/modalities/text/providers/google/models.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,7 @@
141141
Parameter.TEMPERATURE: Range(min=0.0, max=2.0),
142142
Parameter.MAX_TOKENS: Range(min=1, max=65536),
143143
TextParameter.THINKING_LEVEL: Choice(options=["low", "high"]),
144-
TextParameter.WEB_SEARCH: Bool(),
145-
TextParameter.CODE_EXECUTION: Bool(),
144+
TextParameter.TOOLS: ToolSupport(tools=[WebSearch, CodeExecution]),
146145
TextParameter.OUTPUT_SCHEMA: Schema(),
147146
# Media input support
148147
TextParameter.IMAGE: ImagesConstraint(),

0 commit comments

Comments
 (0)