diff --git a/src/celeste/__init__.py b/src/celeste/__init__.py index eb90db0..20c30d6 100644 --- a/src/celeste/__init__.py +++ b/src/celeste/__init__.py @@ -52,6 +52,7 @@ StrictJsonSchemaGenerator, StrictRefResolvingJsonSchemaGenerator, ) +from celeste.tools import CodeExecution, Tool, ToolCall, ToolResult, WebSearch, XSearch from celeste.types import Content, JsonValue, Message, Role from celeste.websocket import WebSocketClient, WebSocketConnection, close_all_ws_clients @@ -246,6 +247,7 @@ def create_client( "Authentication", "Capability", "ClientNotFoundError", + "CodeExecution", "ConstraintViolationError", "Content", "Error", @@ -271,6 +273,9 @@ def create_client( "StreamingNotSupportedError", "StrictJsonSchemaGenerator", "StrictRefResolvingJsonSchemaGenerator", + "Tool", + "ToolCall", + "ToolResult", "UnsupportedCapabilityError", "UnsupportedParameterError", "UnsupportedParameterWarning", @@ -278,8 +283,10 @@ def create_client( "Usage", "UsageField", "ValidationError", + "WebSearch", "WebSocketClient", "WebSocketConnection", + "XSearch", "audio", "close_all_http_clients", "close_all_ws_clients", diff --git a/src/celeste/client.py b/src/celeste/client.py index da30dd8..284167a 100644 --- a/src/celeste/client.py +++ b/src/celeste/client.py @@ -19,6 +19,7 @@ from celeste.models import Model from celeste.parameters import ParameterMapper, Parameters from celeste.streaming import Stream, enrich_stream_errors +from celeste.tools import ToolCall from celeste.types import RawUsage @@ -206,13 +207,19 @@ async def _predict( ) content = self._parse_content(response_data) content = self._transform_output(content, **parameters) + tool_calls = self._parse_tool_calls(response_data) return self._output_class()( content=content, usage=self._get_usage(response_data), finish_reason=self._get_finish_reason(response_data), metadata=self._build_metadata(response_data), + tool_calls=tool_calls, ) + def _parse_tool_calls(self, response_data: dict[str, Any]) -> list[ToolCall]: + """Parse tool calls from response. Override in providers that support tools.""" + return [] + def _stream( self, inputs: In, diff --git a/src/celeste/constraints.py b/src/celeste/constraints.py index d35d3b0..31befbb 100644 --- a/src/celeste/constraints.py +++ b/src/celeste/constraints.py @@ -5,11 +5,12 @@ from abc import ABC, abstractmethod from typing import Any, ClassVar, get_args, get_origin -from pydantic import BaseModel, Field, computed_field +from pydantic import BaseModel, Field, computed_field, field_serializer from celeste.artifacts import AudioArtifact, ImageArtifact, VideoArtifact from celeste.exceptions import ConstraintViolationError from celeste.mime_types import AudioMimeType, ImageMimeType, MimeType, VideoMimeType +from celeste.tools import Tool class Constraint(BaseModel, ABC): @@ -367,6 +368,26 @@ class AudiosConstraint(_MediaListConstraint[AudioMimeType]): _media_label = "audio" +class ToolSupport(Constraint): + """Tool support constraint - validates Tool instances are supported by the model.""" + + tools: list[type[Tool]] + + @field_serializer("tools") + @classmethod + def _serialize_tools(cls, v: list[type[Tool]]) -> list[str]: + return [t.__name__ for t in v] + + def __call__(self, value: list) -> list: + """Validate tools list against supported tools.""" + for item in value: + if isinstance(item, Tool) and type(item) not in self.tools: + supported = [t.__name__ for t in self.tools] + msg = f"Tool '{type(item).__name__}' not supported. Supported: {supported}" + raise ConstraintViolationError(msg) + return value + + __all__ = [ "AudioConstraint", "AudiosConstraint", @@ -382,6 +403,7 @@ class AudiosConstraint(_MediaListConstraint[AudioMimeType]): "Range", "Schema", "Str", + "ToolSupport", "VideoConstraint", "VideosConstraint", ] diff --git a/src/celeste/io.py b/src/celeste/io.py index 5757e37..a497280 100644 --- a/src/celeste/io.py +++ b/src/celeste/io.py @@ -9,6 +9,7 @@ from celeste.artifacts import AudioArtifact, ImageArtifact, VideoArtifact from celeste.constraints import Constraint from celeste.core import InputType +from celeste.tools import ToolCall class Input(BaseModel): @@ -38,6 +39,7 @@ class Output[Content](BaseModel): usage: Usage = Field(default_factory=Usage) finish_reason: FinishReason | None = None metadata: dict[str, Any] = Field(default_factory=dict) + tool_calls: list[ToolCall] = Field(default_factory=list) class Chunk[Content](BaseModel): diff --git a/src/celeste/modalities/text/client.py b/src/celeste/modalities/text/client.py index f46ee61..9b174c9 100644 --- a/src/celeste/modalities/text/client.py +++ b/src/celeste/modalities/text/client.py @@ -1,11 +1,13 @@ """Text modality client.""" -from typing import Any, Unpack +import warnings +from typing import Any, ClassVar, Unpack from asgiref.sync import async_to_sync from celeste.client import ModalityClient from celeste.core import InputType, Modality +from celeste.tools import CodeExecution, WebSearch, XSearch from celeste.types import AudioContent, ImageContent, Message, TextContent, VideoContent from .io import TextChunk, TextFinishReason, TextInput, TextOutput, TextUsage @@ -25,11 +27,45 @@ class TextClient( _usage_class = TextUsage _finish_reason_class = TextFinishReason + # Deprecated param → Tool class mapping. + # TODO(deprecation): Remove on 2026-06-07. + _DEPRECATED_TOOL_PARAMS: ClassVar[dict[str, type]] = { + "web_search": WebSearch, + "x_search": XSearch, + "code_execution": CodeExecution, + } + @classmethod def _output_class(cls) -> type[TextOutput]: """Return the Output class for text modality.""" return TextOutput + def _build_request( + self, + inputs: TextInput, + extra_body: dict[str, Any] | None = None, + streaming: bool = False, + **parameters: Unpack[TextParameters], + ) -> dict[str, Any]: + """Build request, migrating deprecated boolean tool params first. + + TODO(deprecation): Remove this override on 2026-06-07. + """ + for old_param, tool_cls in self._DEPRECATED_TOOL_PARAMS.items(): + value = parameters.pop(old_param, None) # type: ignore[misc] + if value: + warnings.warn( + f"'{old_param}=True' is deprecated, " + f"use tools=[{tool_cls.__name__}()] instead. " + "Will be removed on 2026-06-07.", + DeprecationWarning, + stacklevel=4, + ) + parameters.setdefault("tools", []).append(tool_cls()) + return super()._build_request( + inputs, extra_body=extra_body, streaming=streaming, **parameters + ) + def _check_media_support( self, image: ImageContent | None, diff --git a/src/celeste/modalities/text/io.py b/src/celeste/modalities/text/io.py index ebf6826..9492cc2 100644 --- a/src/celeste/modalities/text/io.py +++ b/src/celeste/modalities/text/io.py @@ -10,14 +10,22 @@ from pydantic import Field from celeste.io import Chunk, FinishReason, Input, Output, Usage -from celeste.types import AudioContent, ImageContent, Message, TextContent, VideoContent +from celeste.tools import ToolResult +from celeste.types import ( + AudioContent, + ImageContent, + Message, + Role, + TextContent, + VideoContent, +) class TextInput(Input): """Input for text operations.""" prompt: str | None = None - messages: list[Message] | None = None + messages: list[ToolResult | Message] | None = None text: str | list[str] | None = None image: ImageContent | None = None video: VideoContent | None = None @@ -46,6 +54,15 @@ class TextOutput(Output[TextContent]): usage: TextUsage = Field(default_factory=TextUsage) finish_reason: TextFinishReason | None = None + @property + def message(self) -> Message: + """The assistant message for multi-turn conversations.""" + return Message( + role=Role.ASSISTANT, + content=self.content, + tool_calls=self.tool_calls if self.tool_calls else None, + ) + class TextChunk(Chunk[str]): """Chunk for text streaming.""" diff --git a/src/celeste/modalities/text/parameters.py b/src/celeste/modalities/text/parameters.py index 98f6a4b..4f1b29d 100644 --- a/src/celeste/modalities/text/parameters.py +++ b/src/celeste/modalities/text/parameters.py @@ -9,6 +9,7 @@ from pydantic import BaseModel from celeste.parameters import Parameters +from celeste.tools import ToolDefinition class TextParameter(StrEnum): @@ -23,8 +24,12 @@ class TextParameter(StrEnum): THINKING_BUDGET = "thinking_budget" THINKING_LEVEL = "thinking_level" OUTPUT_SCHEMA = "output_schema" - WEB_SEARCH = "web_search" + TOOLS = "tools" VERBOSITY = "verbosity" + + # Deprecated: use tools=[WebSearch()], tools=[XSearch()], tools=[CodeExecution()] instead. + # TODO(deprecation): Remove on 2026-06-07. + WEB_SEARCH = "web_search" X_SEARCH = "x_search" CODE_EXECUTION = "code_execution" @@ -46,8 +51,12 @@ class TextParameters(Parameters): thinking_budget: int | str thinking_level: str output_schema: type[BaseModel] - web_search: bool + tools: list[ToolDefinition] verbosity: str + + # Deprecated: use tools=[WebSearch()], tools=[XSearch()], tools=[CodeExecution()] instead. + # TODO(deprecation): Remove on 2026-06-07. + web_search: bool x_search: bool code_execution: bool diff --git a/src/celeste/modalities/text/protocols/__init__.py b/src/celeste/modalities/text/protocols/__init__.py new file mode 100644 index 0000000..b6ae043 --- /dev/null +++ b/src/celeste/modalities/text/protocols/__init__.py @@ -0,0 +1 @@ +"""Text modality protocol implementations.""" diff --git a/src/celeste/modalities/text/protocols/chatcompletions/__init__.py b/src/celeste/modalities/text/protocols/chatcompletions/__init__.py new file mode 100644 index 0000000..3dc9983 --- /dev/null +++ b/src/celeste/modalities/text/protocols/chatcompletions/__init__.py @@ -0,0 +1,5 @@ +"""Chat Completions protocol for text modality.""" + +from .client import ChatCompletionsTextClient, ChatCompletionsTextStream + +__all__ = ["ChatCompletionsTextClient", "ChatCompletionsTextStream"] diff --git a/src/celeste/modalities/text/protocols/chatcompletions/client.py b/src/celeste/modalities/text/protocols/chatcompletions/client.py new file mode 100644 index 0000000..3cf6e10 --- /dev/null +++ b/src/celeste/modalities/text/protocols/chatcompletions/client.py @@ -0,0 +1,101 @@ +"""Chat Completions text client.""" + +from typing import Any, Unpack + +from celeste.parameters import ParameterMapper +from celeste.protocols.chatcompletions.client import ( + ChatCompletionsClient as ChatCompletionsMixin, +) +from celeste.protocols.chatcompletions.streaming import ( + ChatCompletionsStream as _ChatCompletionsStream, +) +from celeste.protocols.chatcompletions.tools import ( + parse_tool_calls, + serialize_messages, +) +from celeste.tools import ToolCall +from celeste.types import ImageContent, Message, TextContent, VideoContent +from celeste.utils import build_image_data_url + +from ...client import TextClient +from ...io import ( + TextInput, + TextOutput, +) +from ...parameters import TextParameters +from ...streaming import TextStream +from .parameters import CHATCOMPLETIONS_PARAMETER_MAPPERS + + +class ChatCompletionsTextStream(_ChatCompletionsStream, TextStream): + """Chat Completions streaming for text modality.""" + + +class ChatCompletionsTextClient(ChatCompletionsMixin, TextClient): + """Chat Completions text client.""" + + @classmethod + def parameter_mappers(cls) -> list[ParameterMapper[TextContent]]: + return CHATCOMPLETIONS_PARAMETER_MAPPERS + + async def generate( + self, + prompt: str | None = None, + *, + messages: list[Message] | None = None, + **parameters: Unpack[TextParameters], + ) -> TextOutput: + """Generate text from prompt.""" + inputs = TextInput(prompt=prompt, messages=messages) + return await self._predict(inputs, **parameters) + + async def analyze( + self, + prompt: str | None = None, + *, + messages: list[Message] | None = None, + image: ImageContent | None = None, + video: VideoContent | None = None, + **parameters: Unpack[TextParameters], + ) -> TextOutput: + """Analyze image(s) or video(s) with prompt or messages.""" + inputs = TextInput(prompt=prompt, messages=messages, image=image, video=video) + return await self._predict(inputs, **parameters) + + def _init_request(self, inputs: TextInput) -> dict[str, Any]: + """Initialize request with Chat Completions message format.""" + if inputs.messages is not None: + return {"messages": serialize_messages(inputs.messages)} + + if inputs.image is None: + content: str | list[dict[str, Any]] = inputs.prompt or "" + else: + images = inputs.image if isinstance(inputs.image, list) else [inputs.image] + content = [ + {"type": "image_url", "image_url": {"url": build_image_data_url(img)}} + for img in images + ] + content.append({"type": "text", "text": inputs.prompt or ""}) + + return {"messages": [{"role": "user", "content": content}]} + + def _parse_content( + self, + response_data: dict[str, Any], + ) -> TextContent: + """Parse text content from response.""" + choices = super()._parse_content(response_data) + message = choices[0].get("message", {}) + content = message.get("content") or "" + return content + + def _parse_tool_calls(self, response_data: dict[str, Any]) -> list[ToolCall]: + """Parse tool calls from Chat Completions response.""" + return parse_tool_calls(response_data) + + def _stream_class(self) -> type[TextStream]: + """Return the Stream class for this provider.""" + return ChatCompletionsTextStream + + +__all__ = ["ChatCompletionsTextClient", "ChatCompletionsTextStream"] diff --git a/src/celeste/modalities/text/protocols/chatcompletions/parameters.py b/src/celeste/modalities/text/protocols/chatcompletions/parameters.py new file mode 100644 index 0000000..6d16b4c --- /dev/null +++ b/src/celeste/modalities/text/protocols/chatcompletions/parameters.py @@ -0,0 +1,52 @@ +"""Chat Completions parameter mappers for text.""" + +from celeste.parameters import ParameterMapper +from celeste.protocols.chatcompletions.parameters import ( + MaxTokensMapper as _MaxTokensMapper, +) +from celeste.protocols.chatcompletions.parameters import ( + ResponseFormatMapper as _ResponseFormatMapper, +) +from celeste.protocols.chatcompletions.parameters import ( + TemperatureMapper as _TemperatureMapper, +) +from celeste.protocols.chatcompletions.parameters import ( + ToolsMapper as _ToolsMapper, +) +from celeste.types import TextContent + +from ...parameters import TextParameter + + +class TemperatureMapper(_TemperatureMapper): + """Map temperature to Chat Completions temperature parameter.""" + + name = TextParameter.TEMPERATURE + + +class MaxTokensMapper(_MaxTokensMapper): + """Map max_tokens to Chat Completions max_tokens parameter.""" + + name = TextParameter.MAX_TOKENS + + +class OutputSchemaMapper(_ResponseFormatMapper): + """Map output_schema to Chat Completions response_format parameter.""" + + name = TextParameter.OUTPUT_SCHEMA + + +class ToolsMapper(_ToolsMapper): + """Map tools to Chat Completions tools parameter.""" + + name = TextParameter.TOOLS + + +CHATCOMPLETIONS_PARAMETER_MAPPERS: list[ParameterMapper[TextContent]] = [ + TemperatureMapper(), + MaxTokensMapper(), + OutputSchemaMapper(), + ToolsMapper(), +] + +__all__ = ["CHATCOMPLETIONS_PARAMETER_MAPPERS"] diff --git a/src/celeste/modalities/text/protocols/openresponses/__init__.py b/src/celeste/modalities/text/protocols/openresponses/__init__.py new file mode 100644 index 0000000..da4c8aa --- /dev/null +++ b/src/celeste/modalities/text/protocols/openresponses/__init__.py @@ -0,0 +1,5 @@ +"""OpenResponses protocol for text modality.""" + +from .client import OpenResponsesTextClient, OpenResponsesTextStream + +__all__ = ["OpenResponsesTextClient", "OpenResponsesTextStream"] diff --git a/src/celeste/modalities/text/protocols/openresponses/client.py b/src/celeste/modalities/text/protocols/openresponses/client.py new file mode 100644 index 0000000..e8d4331 --- /dev/null +++ b/src/celeste/modalities/text/protocols/openresponses/client.py @@ -0,0 +1,137 @@ +"""OpenResponses text client.""" + +from typing import Any, Unpack + +from celeste.parameters import ParameterMapper +from celeste.protocols.openresponses.client import ( + OpenResponsesClient as OpenResponsesMixin, +) +from celeste.protocols.openresponses.streaming import ( + OpenResponsesStream as _OpenResponsesStream, +) +from celeste.protocols.openresponses.tools import ( + parse_content, + parse_tool_calls, + serialize_messages, +) +from celeste.tools import ToolCall +from celeste.types import ImageContent, Message, TextContent, VideoContent +from celeste.utils import build_image_data_url + +from ...client import TextClient +from ...io import ( + TextChunk, + TextInput, + TextOutput, +) +from ...parameters import TextParameters +from ...streaming import TextStream +from .parameters import OPENRESPONSES_PARAMETER_MAPPERS + + +class OpenResponsesTextStream(_OpenResponsesStream, TextStream): + """OpenResponses streaming for text modality.""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._response_data: dict[str, Any] | None = None + + def _parse_chunk(self, event_data: dict[str, Any]) -> TextChunk | None: + """Parse one SSE event into a typed chunk (captures response.completed).""" + event_type = event_data.get("type") + if event_type == "response.completed": + response = event_data.get("response") + if isinstance(response, dict): + self._response_data = response + return super()._parse_chunk(event_data) + + def _aggregate_event_data(self, chunks: list[TextChunk]) -> list[dict[str, Any]]: + """Prepend response_data, then delegate to base.""" + events: list[dict[str, Any]] = [] + if self._response_data is not None: + events.append(self._response_data) + events.extend(super()._aggregate_event_data(chunks)) + return events + + def _aggregate_tool_calls( + self, chunks: list[TextChunk], raw_events: list[dict[str, Any]] + ) -> list[ToolCall]: + """Extract tool calls from response.completed data.""" + if self._response_data is None: + return [] + return parse_tool_calls(self._response_data) + + +class OpenResponsesTextClient(OpenResponsesMixin, TextClient): + """OpenResponses text client using Responses API.""" + + @classmethod + def parameter_mappers(cls) -> list[ParameterMapper[TextContent]]: + return OPENRESPONSES_PARAMETER_MAPPERS + + async def generate( + self, + prompt: str | None = None, + *, + messages: list[Message] | None = None, + base_url: str | None = None, + extra_body: dict[str, Any] | None = None, + **parameters: Unpack[TextParameters], + ) -> TextOutput: + """Generate text from prompt.""" + inputs = TextInput(prompt=prompt, messages=messages) + return await self._predict( + inputs, base_url=base_url, extra_body=extra_body, **parameters + ) + + async def analyze( + self, + prompt: str | None = None, + *, + messages: list[Message] | None = None, + image: ImageContent | None = None, + video: VideoContent | None = None, + base_url: str | None = None, + extra_body: dict[str, Any] | None = None, + **parameters: Unpack[TextParameters], + ) -> TextOutput: + """Analyze image(s) or video(s) with prompt or messages.""" + inputs = TextInput(prompt=prompt, messages=messages, image=image, video=video) + return await self._predict( + inputs, base_url=base_url, extra_body=extra_body, **parameters + ) + + def _init_request(self, inputs: TextInput) -> dict[str, Any]: + """Initialize request with input content.""" + if inputs.messages is not None: + return {"input": serialize_messages(inputs.messages)} + + content: list[dict[str, Any]] = [] + if inputs.image is not None: + images = inputs.image if isinstance(inputs.image, list) else [inputs.image] + for img in images: + content.append( + {"type": "input_image", "image_url": build_image_data_url(img)} + ) + + content.append({"type": "input_text", "text": inputs.prompt or ""}) + return {"input": [{"role": "user", "content": content}]} + + def _parse_content( + self, + response_data: dict[str, Any], + ) -> TextContent: + """Parse text content from response.""" + output = super()._parse_content(response_data) + return parse_content(output) + + def _parse_tool_calls(self, response_data: dict[str, Any]) -> list[ToolCall]: + """Parse tool calls from OpenResponses response.""" + return parse_tool_calls(response_data) + + def _stream_class(self) -> type[TextStream]: + """Return the Stream class for this provider.""" + return OpenResponsesTextStream + + +__all__ = ["OpenResponsesTextClient", "OpenResponsesTextStream"] diff --git a/src/celeste/modalities/text/protocols/openresponses/parameters.py b/src/celeste/modalities/text/protocols/openresponses/parameters.py new file mode 100644 index 0000000..b9d2f7f --- /dev/null +++ b/src/celeste/modalities/text/protocols/openresponses/parameters.py @@ -0,0 +1,52 @@ +"""OpenResponses parameter mappers for text.""" + +from celeste.parameters import ParameterMapper +from celeste.protocols.openresponses.parameters import ( + MaxOutputTokensMapper as _MaxOutputTokensMapper, +) +from celeste.protocols.openresponses.parameters import ( + TemperatureMapper as _TemperatureMapper, +) +from celeste.protocols.openresponses.parameters import ( + TextFormatMapper as _TextFormatMapper, +) +from celeste.protocols.openresponses.parameters import ( + ToolsMapper as _ToolsMapper, +) +from celeste.types import TextContent + +from ...parameters import TextParameter + + +class TemperatureMapper(_TemperatureMapper): + """Map temperature to Responses temperature parameter.""" + + name = TextParameter.TEMPERATURE + + +class MaxTokensMapper(_MaxOutputTokensMapper): + """Map max_tokens to Responses max_output_tokens parameter.""" + + name = TextParameter.MAX_TOKENS + + +class OutputSchemaMapper(_TextFormatMapper): + """Map output_schema to Responses text.format parameter.""" + + name = TextParameter.OUTPUT_SCHEMA + + +class ToolsMapper(_ToolsMapper): + """Map tools to Responses tools parameter.""" + + name = TextParameter.TOOLS + + +OPENRESPONSES_PARAMETER_MAPPERS: list[ParameterMapper[TextContent]] = [ + TemperatureMapper(), + MaxTokensMapper(), + OutputSchemaMapper(), + ToolsMapper(), +] + +__all__ = ["OPENRESPONSES_PARAMETER_MAPPERS"] diff --git a/src/celeste/modalities/text/providers/anthropic/client.py b/src/celeste/modalities/text/providers/anthropic/client.py index 8d0ae8c..6e4453a 100644 --- a/src/celeste/modalities/text/providers/anthropic/client.py +++ b/src/celeste/modalities/text/providers/anthropic/client.py @@ -1,6 +1,7 @@ """Anthropic text client (modality).""" import base64 +import contextlib from typing import Any, Unpack from celeste.artifacts import ImageArtifact @@ -10,6 +11,7 @@ from celeste.providers.anthropic.messages.streaming import ( AnthropicMessagesStream as _AnthropicMessagesStream, ) +from celeste.tools import ToolCall, ToolResult from celeste.types import ImageContent, Message, TextContent, VideoContent from celeste.utils import detect_mime_type @@ -30,15 +32,31 @@ class AnthropicTextStream(_AnthropicMessagesStream, TextStream): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self._message_start: dict[str, Any] | None = None + self._tool_calls: dict[int, dict[str, Any]] = {} def _parse_chunk(self, event_data: dict[str, Any]) -> TextChunk | None: - """Parse one SSE event into a typed chunk (captures message_start).""" + """Parse one SSE event into a typed chunk (captures message_start and tool_use).""" event_type = event_data.get("type") if event_type == "message_start": message = event_data.get("message") if isinstance(message, dict): self._message_start = message return None + if event_type == "content_block_start": + block = event_data.get("content_block", {}) + if block.get("type") == "tool_use": + idx = event_data.get("index", len(self._tool_calls)) + self._tool_calls[idx] = { + "id": block.get("id", ""), + "name": block.get("name", ""), + "input_json": "", + } + elif event_type == "content_block_delta": + delta = event_data.get("delta", {}) + if delta.get("type") == "input_json_delta": + idx = event_data.get("index", -1) + if idx in self._tool_calls: + self._tool_calls[idx]["input_json"] += delta.get("partial_json", "") return super()._parse_chunk(event_data) def _aggregate_event_data(self, chunks: list[TextChunk]) -> list[dict[str, Any]]: @@ -49,6 +67,21 @@ def _aggregate_event_data(self, chunks: list[TextChunk]) -> list[dict[str, Any]] events.extend(super()._aggregate_event_data(chunks)) return events + def _aggregate_tool_calls( + self, chunks: list[TextChunk], raw_events: list[dict[str, Any]] + ) -> list[ToolCall]: + """Reconstruct tool calls from accumulated content_block events.""" + import json as _json + + result: list[ToolCall] = [] + for tc in self._tool_calls.values(): + arguments = {} + if tc["input_json"]: + with contextlib.suppress(ValueError, TypeError): + arguments = _json.loads(tc["input_json"]) + result.append(ToolCall(id=tc["id"], name=tc["name"], arguments=arguments)) + return result + class AnthropicTextClient(AnthropicMessagesClient, TextClient): """Anthropic text client.""" @@ -86,10 +119,12 @@ def _init_request(self, inputs: TextInput) -> dict[str, Any]: if inputs.messages is not None: system_blocks: list[dict[str, Any]] = [] messages: list[dict[str, Any]] = [] + pending_tool_results: list[dict[str, Any]] = [] for message in inputs.messages: role = message.role content = message.content + if role in {"system", "developer"}: if isinstance(content, list): for block in content: @@ -105,7 +140,41 @@ def _init_request(self, inputs: TextInput) -> dict[str, Any]: system_blocks.append({"type": "text", "text": str(content)}) continue - messages.append({"role": role, "content": content}) + if isinstance(message, ToolResult): + pending_tool_results.append( + { + "type": "tool_result", + "tool_use_id": message.tool_call_id, + "content": str(content), + } + ) + continue + + # Flush pending tool results as a single user message + if pending_tool_results: + messages.append({"role": "user", "content": pending_tool_results}) + pending_tool_results = [] + + if role == "assistant" and message.tool_calls: + content_blocks: list[dict[str, Any]] = [] + if content: + content_blocks.append({"type": "text", "text": str(content)}) + for tc in message.tool_calls: + content_blocks.append( + { + "type": "tool_use", + "id": tc.id, + "name": tc.name, + "input": tc.arguments, + } + ) + messages.append({"role": "assistant", "content": content_blocks}) + else: + messages.append({"role": role, "content": content}) + + # Flush remaining tool results + if pending_tool_results: + messages.append({"role": "user", "content": pending_tool_results}) request: dict[str, Any] = {"messages": messages} if system_blocks: @@ -165,6 +234,16 @@ def _parse_content( return text_content + def _parse_tool_calls(self, response_data: dict[str, Any]) -> list[ToolCall]: + """Parse tool calls from Anthropic response.""" + return [ + ToolCall( + id=block["id"], name=block["name"], arguments=block.get("input", {}) + ) + for block in response_data.get("content", []) + if block.get("type") == "tool_use" + ] + def _stream_class(self) -> type[TextStream]: """Return the Stream class for this provider.""" return AnthropicTextStream diff --git a/src/celeste/modalities/text/providers/anthropic/models.py b/src/celeste/modalities/text/providers/anthropic/models.py index dc126d5..67c4ac5 100644 --- a/src/celeste/modalities/text/providers/anthropic/models.py +++ b/src/celeste/modalities/text/providers/anthropic/models.py @@ -1,8 +1,9 @@ """Anthropic models for text modality.""" -from celeste.constraints import Bool, ImagesConstraint, Range, Schema +from celeste.constraints import ImagesConstraint, Range, Schema, ToolSupport from celeste.core import Modality, Operation, Parameter, Provider from celeste.models import Model +from celeste.tools import WebSearch from ...parameters import TextParameter @@ -17,7 +18,7 @@ Parameter.MAX_TOKENS: Range(min=1, max=64000), TextParameter.THINKING_BUDGET: Range(min=-1, max=64000), TextParameter.OUTPUT_SCHEMA: Schema(), - TextParameter.WEB_SEARCH: Bool(), + TextParameter.TOOLS: ToolSupport(tools=[WebSearch]), TextParameter.IMAGE: ImagesConstraint(), }, ), @@ -31,7 +32,7 @@ Parameter.MAX_TOKENS: Range(min=1, max=64000), TextParameter.THINKING_BUDGET: Range(min=-1, max=32000), TextParameter.OUTPUT_SCHEMA: Schema(), - TextParameter.WEB_SEARCH: Bool(), + TextParameter.TOOLS: ToolSupport(tools=[WebSearch]), TextParameter.IMAGE: ImagesConstraint(), }, ), @@ -45,7 +46,7 @@ Parameter.MAX_TOKENS: Range(min=1, max=32000), TextParameter.THINKING_BUDGET: Range(min=-1, max=32000), TextParameter.OUTPUT_SCHEMA: Schema(), - TextParameter.WEB_SEARCH: Bool(), + TextParameter.TOOLS: ToolSupport(tools=[WebSearch]), TextParameter.IMAGE: ImagesConstraint(), }, ), @@ -59,7 +60,7 @@ Parameter.MAX_TOKENS: Range(min=1, max=64000), TextParameter.THINKING_BUDGET: Range(min=-1, max=32000), TextParameter.OUTPUT_SCHEMA: Schema(), - TextParameter.WEB_SEARCH: Bool(), + TextParameter.TOOLS: ToolSupport(tools=[WebSearch]), TextParameter.IMAGE: ImagesConstraint(), }, ), @@ -73,7 +74,7 @@ Parameter.MAX_TOKENS: Range(min=1, max=64000), TextParameter.THINKING_BUDGET: Range(min=-1, max=32000), TextParameter.OUTPUT_SCHEMA: Schema(), - TextParameter.WEB_SEARCH: Bool(), + TextParameter.TOOLS: ToolSupport(tools=[WebSearch]), TextParameter.IMAGE: ImagesConstraint(), }, ), @@ -87,7 +88,7 @@ Parameter.MAX_TOKENS: Range(min=1, max=64000), TextParameter.THINKING_BUDGET: Range(min=-1, max=64000), TextParameter.OUTPUT_SCHEMA: Schema(), - TextParameter.WEB_SEARCH: Bool(), + TextParameter.TOOLS: ToolSupport(tools=[WebSearch]), TextParameter.IMAGE: ImagesConstraint(), }, ), @@ -100,7 +101,7 @@ parameter_constraints={ Parameter.MAX_TOKENS: Range(min=1, max=64000), TextParameter.THINKING_BUDGET: Range(min=-1, max=64000), - TextParameter.WEB_SEARCH: Bool(), + TextParameter.TOOLS: ToolSupport(tools=[WebSearch]), TextParameter.IMAGE: ImagesConstraint(), }, ), @@ -113,7 +114,7 @@ parameter_constraints={ Parameter.MAX_TOKENS: Range(min=1, max=32000), TextParameter.THINKING_BUDGET: Range(min=-1, max=32000), - TextParameter.WEB_SEARCH: Bool(), + TextParameter.TOOLS: ToolSupport(tools=[WebSearch]), TextParameter.IMAGE: ImagesConstraint(), }, ), diff --git a/src/celeste/modalities/text/providers/anthropic/parameters.py b/src/celeste/modalities/text/providers/anthropic/parameters.py index 0470a3f..0ba40bf 100644 --- a/src/celeste/modalities/text/providers/anthropic/parameters.py +++ b/src/celeste/modalities/text/providers/anthropic/parameters.py @@ -17,7 +17,7 @@ ThinkingMapper as _ThinkingMapper, ) from celeste.providers.anthropic.messages.parameters import ( - WebSearchMapper as _WebSearchMapper, + ToolsMapper as _ToolsMapper, ) from celeste.types import TextContent @@ -67,10 +67,10 @@ class OutputSchemaMapper(_OutputFormatMapper): name = TextParameter.OUTPUT_SCHEMA -class WebSearchMapper(_WebSearchMapper): - """Map web_search to Anthropic's tools parameter.""" +class ToolsMapper(_ToolsMapper): + """Map tools to Anthropic's tools parameter.""" - name = TextParameter.WEB_SEARCH + name = TextParameter.TOOLS ANTHROPIC_PARAMETER_MAPPERS: list[ParameterMapper[TextContent]] = [ @@ -78,7 +78,7 @@ class WebSearchMapper(_WebSearchMapper): MaxTokensMapper(), ThinkingBudgetMapper(), OutputSchemaMapper(), - WebSearchMapper(), + ToolsMapper(), ] __all__ = ["ANTHROPIC_PARAMETER_MAPPERS"] diff --git a/src/celeste/modalities/text/providers/cohere/client.py b/src/celeste/modalities/text/providers/cohere/client.py index 31621f8..5fba4fa 100644 --- a/src/celeste/modalities/text/providers/cohere/client.py +++ b/src/celeste/modalities/text/providers/cohere/client.py @@ -59,7 +59,11 @@ def _init_request(self, inputs: TextInput) -> dict[str, Any]: """Initialize request from Cohere v2 Chat API messages array format.""" # If messages provided, use them directly (messages take precedence) if inputs.messages is not None: - return {"messages": [message.model_dump() for message in inputs.messages]} + return { + "messages": [ + message.model_dump(exclude_none=True) for message in inputs.messages + ] + } # Fall back to prompt-based input if inputs.image is None: diff --git a/src/celeste/modalities/text/providers/deepseek/client.py b/src/celeste/modalities/text/providers/deepseek/client.py index 100e481..523825a 100644 --- a/src/celeste/modalities/text/providers/deepseek/client.py +++ b/src/celeste/modalities/text/providers/deepseek/client.py @@ -1,72 +1,33 @@ """DeepSeek text client (modality).""" -from typing import Any, Unpack - from celeste.parameters import ParameterMapper from celeste.providers.deepseek.chat.client import DeepSeekChatClient from celeste.providers.deepseek.chat.streaming import ( DeepSeekChatStream as _DeepSeekChatStream, ) -from celeste.types import Message, TextContent +from celeste.types import TextContent -from ...client import TextClient -from ...io import ( - TextInput, - TextOutput, +from ...protocols.chatcompletions.client import ( + ChatCompletionsTextClient, +) +from ...protocols.chatcompletions.client import ( + ChatCompletionsTextStream as _ChatCompletionsTextStream, ) -from ...parameters import TextParameters from ...streaming import TextStream from .parameters import DEEPSEEK_PARAMETER_MAPPERS -class DeepSeekTextStream(_DeepSeekChatStream, TextStream): +class DeepSeekTextStream(_DeepSeekChatStream, _ChatCompletionsTextStream): """DeepSeek streaming for text modality.""" -class DeepSeekTextClient(DeepSeekChatClient, TextClient): +class DeepSeekTextClient(DeepSeekChatClient, ChatCompletionsTextClient): """DeepSeek text client.""" @classmethod def parameter_mappers(cls) -> list[ParameterMapper[TextContent]]: return DEEPSEEK_PARAMETER_MAPPERS - async def generate( - self, - prompt: str | None = None, - *, - messages: list[Message] | None = None, - **parameters: Unpack[TextParameters], - ) -> TextOutput: - """Generate text from prompt.""" - inputs = TextInput(prompt=prompt, messages=messages) - return await self._predict(inputs, **parameters) - - def _init_request(self, inputs: TextInput) -> dict[str, Any]: - """Initialize request from DeepSeek messages array format.""" - # If messages provided, use them directly (messages take precedence) - if inputs.messages is not None: - return {"messages": [message.model_dump() for message in inputs.messages]} - - # Fall back to prompt-based input - messages = [ - { - "role": "user", - "content": inputs.prompt or "", - } - ] - - return {"messages": messages} - - def _parse_content( - self, - response_data: dict[str, Any], - ) -> TextContent: - """Parse content from response.""" - choices = super()._parse_content(response_data) - message = choices[0].get("message", {}) - content = message.get("content") or "" - return content - def _stream_class(self) -> type[TextStream]: """Return the Stream class for this provider.""" return DeepSeekTextStream diff --git a/src/celeste/modalities/text/providers/deepseek/models.py b/src/celeste/modalities/text/providers/deepseek/models.py index e7dae88..73f28ef 100644 --- a/src/celeste/modalities/text/providers/deepseek/models.py +++ b/src/celeste/modalities/text/providers/deepseek/models.py @@ -1,6 +1,6 @@ """DeepSeek models for text modality.""" -from celeste.constraints import Range, Schema +from celeste.constraints import Range, Schema, ToolSupport from celeste.core import Modality, Operation, Parameter, Provider from celeste.models import Model @@ -17,6 +17,7 @@ Parameter.TEMPERATURE: Range(min=0.0, max=2.0, step=0.01), Parameter.MAX_TOKENS: Range(min=1, max=8192, step=1), TextParameter.OUTPUT_SCHEMA: Schema(), + TextParameter.TOOLS: ToolSupport(tools=[]), }, ), Model( @@ -29,6 +30,7 @@ Parameter.TEMPERATURE: Range(min=0.0, max=2.0, step=0.01), Parameter.MAX_TOKENS: Range(min=1, max=65536, step=1), TextParameter.OUTPUT_SCHEMA: Schema(), + TextParameter.TOOLS: ToolSupport(tools=[]), }, ), ] diff --git a/src/celeste/modalities/text/providers/deepseek/parameters.py b/src/celeste/modalities/text/providers/deepseek/parameters.py index f5e6510..dbde213 100644 --- a/src/celeste/modalities/text/providers/deepseek/parameters.py +++ b/src/celeste/modalities/text/providers/deepseek/parameters.py @@ -10,6 +10,9 @@ from celeste.protocols.chatcompletions.parameters import ( TemperatureMapper as _TemperatureMapper, ) +from celeste.protocols.chatcompletions.parameters import ( + ToolsMapper as _ToolsMapper, +) from celeste.types import TextContent from ...parameters import TextParameter @@ -33,10 +36,17 @@ class OutputSchemaMapper(_ResponseFormatMapper): name = TextParameter.OUTPUT_SCHEMA +class ToolsMapper(_ToolsMapper): + """Map tools to DeepSeek's tools parameter (user-defined only).""" + + name = TextParameter.TOOLS + + DEEPSEEK_PARAMETER_MAPPERS: list[ParameterMapper[TextContent]] = [ TemperatureMapper(), MaxTokensMapper(), OutputSchemaMapper(), + ToolsMapper(), ] __all__ = ["DEEPSEEK_PARAMETER_MAPPERS"] diff --git a/src/celeste/modalities/text/providers/google/client.py b/src/celeste/modalities/text/providers/google/client.py index d026745..555fe3d 100644 --- a/src/celeste/modalities/text/providers/google/client.py +++ b/src/celeste/modalities/text/providers/google/client.py @@ -2,6 +2,7 @@ import base64 from typing import Any, Unpack +from uuid import uuid4 from celeste.artifacts import AudioArtifact, ImageArtifact, VideoArtifact from celeste.parameters import ParameterMapper @@ -10,6 +11,7 @@ from celeste.providers.google.generate_content.streaming import ( GoogleGenerateContentStream as _GoogleGenerateContentStream, ) +from celeste.tools import ToolCall, ToolResult from celeste.types import AudioContent, ImageContent, Message, TextContent, VideoContent from celeste.utils import detect_mime_type @@ -26,6 +28,28 @@ class GoogleTextStream(_GoogleGenerateContentStream, TextStream): """Google streaming for text modality.""" + def _aggregate_tool_calls( + self, chunks: list, raw_events: list[dict[str, Any]] + ) -> list[ToolCall]: + """Extract tool calls from Google streaming events.""" + tool_calls: list[ToolCall] = [] + for event in raw_events: + for candidate in event.get("candidates", []): + for part in candidate.get("content", {}).get("parts", []): + if "functionCall" in part: + kwargs: dict[str, Any] = {} + if "thoughtSignature" in part: + kwargs["thoughtSignature"] = part["thoughtSignature"] + tool_calls.append( + ToolCall( + id=str(uuid4()), + name=part["functionCall"]["name"], + arguments=part["functionCall"].get("args", {}), + **kwargs, + ) + ) + return tool_calls + class GoogleTextClient(GoogleGenerateContentClient, TextClient): """Google text client.""" @@ -100,11 +124,36 @@ def content_to_parts(content: Any) -> list[dict[str, Any]]: for msg in inputs.messages: if msg.role in ("system", "developer"): system_parts.extend(content_to_parts(msg.content)) - else: - role = "model" if msg.role == "assistant" else msg.role + elif isinstance(msg, ToolResult): contents.append( - {"role": role, "parts": content_to_parts(msg.content)} + { + "role": "user", + "parts": [ + { + "functionResponse": { + "name": msg.name, + "response": {"result": msg.content}, + } + } + ], + } ) + else: + role = "model" if msg.role == "assistant" else msg.role + msg_parts = content_to_parts(msg.content) + if msg.tool_calls: + for tc in msg.tool_calls: + part: dict[str, Any] = { + "functionCall": { + "name": tc.name, + "args": tc.arguments, + } + } + thought_sig = getattr(tc, "thoughtSignature", None) + if thought_sig: + part["thoughtSignature"] = thought_sig + msg_parts.append(part) + contents.append({"role": role, "parts": msg_parts}) result: dict[str, Any] = {"contents": contents} if system_parts: @@ -176,8 +225,35 @@ def _parse_content( """Parse content from response.""" candidates = super()._parse_content(response_data) parts = candidates[0].get("content", {}).get("parts", []) - text = parts[0].get("text") if parts else "" - return text or "" + for p in parts: + if p.get("thought"): + continue + text = p.get("text") + if text is not None: + return text + return "" + + def _parse_tool_calls(self, response_data: dict[str, Any]) -> list[ToolCall]: + """Parse tool calls from Google response.""" + candidates = response_data.get("candidates", []) + if not candidates: + return [] + parts = candidates[0].get("content", {}).get("parts", []) + tool_calls: list[ToolCall] = [] + for p in parts: + if "functionCall" in p: + kwargs: dict[str, Any] = {} + if "thoughtSignature" in p: + kwargs["thoughtSignature"] = p["thoughtSignature"] + tool_calls.append( + ToolCall( + id=str(uuid4()), + name=p["functionCall"]["name"], + arguments=p["functionCall"].get("args", {}), + **kwargs, + ) + ) + return tool_calls def _stream_class(self) -> type[TextStream]: """Return the Stream class for this provider.""" diff --git a/src/celeste/modalities/text/providers/google/models.py b/src/celeste/modalities/text/providers/google/models.py index 2e77a66..2e46bfd 100644 --- a/src/celeste/modalities/text/providers/google/models.py +++ b/src/celeste/modalities/text/providers/google/models.py @@ -2,15 +2,16 @@ from celeste.constraints import ( AudioConstraint, - Bool, Choice, ImagesConstraint, Range, Schema, + ToolSupport, VideosConstraint, ) from celeste.core import Modality, Operation, Parameter, Provider from celeste.models import Model +from celeste.tools import CodeExecution, WebSearch from ...parameters import TextParameter @@ -26,8 +27,8 @@ Parameter.MAX_TOKENS: Range(min=1, max=65536), # Flash: allows -1 (dynamic), 0 (disable), or >= 0 TextParameter.THINKING_BUDGET: Range(min=-1, max=24576), + TextParameter.TOOLS: ToolSupport(tools=[WebSearch, CodeExecution]), TextParameter.OUTPUT_SCHEMA: Schema(), - TextParameter.WEB_SEARCH: Bool(), # Media input support TextParameter.IMAGE: ImagesConstraint(), TextParameter.VIDEO: VideosConstraint(), @@ -47,8 +48,8 @@ TextParameter.THINKING_BUDGET: Range( min=512, max=24576, special_values=[-1, 0] ), + TextParameter.TOOLS: ToolSupport(tools=[WebSearch, CodeExecution]), TextParameter.OUTPUT_SCHEMA: Schema(), - TextParameter.WEB_SEARCH: Bool(), # Media input support TextParameter.IMAGE: ImagesConstraint(), TextParameter.VIDEO: VideosConstraint(), @@ -68,8 +69,8 @@ TextParameter.THINKING_BUDGET: Range( min=128, max=32768, special_values=[-1] ), + TextParameter.TOOLS: ToolSupport(tools=[WebSearch, CodeExecution]), TextParameter.OUTPUT_SCHEMA: Schema(), - TextParameter.WEB_SEARCH: Bool(), # Media input support TextParameter.IMAGE: ImagesConstraint(), TextParameter.VIDEO: VideosConstraint(), @@ -86,8 +87,8 @@ Parameter.TEMPERATURE: Range(min=0.0, max=2.0), Parameter.MAX_TOKENS: Range(min=1, max=65536), TextParameter.THINKING_LEVEL: Choice(options=["low", "high"]), + TextParameter.TOOLS: ToolSupport(tools=[WebSearch, CodeExecution]), TextParameter.OUTPUT_SCHEMA: Schema(), - TextParameter.WEB_SEARCH: Bool(), # Media input support TextParameter.IMAGE: ImagesConstraint(), TextParameter.VIDEO: VideosConstraint(), @@ -104,8 +105,8 @@ Parameter.TEMPERATURE: Range(min=0.0, max=2.0), Parameter.MAX_TOKENS: Range(min=1, max=65536), TextParameter.THINKING_LEVEL: Choice(options=["low", "high"]), + TextParameter.TOOLS: ToolSupport(tools=[WebSearch, CodeExecution]), TextParameter.OUTPUT_SCHEMA: Schema(), - TextParameter.WEB_SEARCH: Bool(), # Media input support TextParameter.IMAGE: ImagesConstraint(), TextParameter.VIDEO: VideosConstraint(), @@ -122,8 +123,8 @@ Parameter.TEMPERATURE: Range(min=0.0, max=2.0), Parameter.MAX_TOKENS: Range(min=1, max=65536), TextParameter.THINKING_LEVEL: Choice(options=["low", "medium", "high"]), + TextParameter.TOOLS: ToolSupport(tools=[WebSearch, CodeExecution]), TextParameter.OUTPUT_SCHEMA: Schema(), - TextParameter.WEB_SEARCH: Bool(), # Media input support TextParameter.IMAGE: ImagesConstraint(), TextParameter.VIDEO: VideosConstraint(), @@ -140,8 +141,7 @@ Parameter.TEMPERATURE: Range(min=0.0, max=2.0), Parameter.MAX_TOKENS: Range(min=1, max=65536), TextParameter.THINKING_LEVEL: Choice(options=["low", "high"]), - TextParameter.WEB_SEARCH: Bool(), - TextParameter.CODE_EXECUTION: Bool(), + TextParameter.TOOLS: ToolSupport(tools=[WebSearch, CodeExecution]), TextParameter.OUTPUT_SCHEMA: Schema(), # Media input support TextParameter.IMAGE: ImagesConstraint(), diff --git a/src/celeste/modalities/text/providers/google/parameters.py b/src/celeste/modalities/text/providers/google/parameters.py index d9a0702..df8dea6 100644 --- a/src/celeste/modalities/text/providers/google/parameters.py +++ b/src/celeste/modalities/text/providers/google/parameters.py @@ -17,7 +17,7 @@ ThinkingLevelMapper as _ThinkingLevelMapper, ) from celeste.providers.google.generate_content.parameters import ( - WebSearchMapper as _WebSearchMapper, + ToolsMapper as _ToolsMapper, ) from celeste.types import TextContent @@ -54,10 +54,10 @@ class OutputSchemaMapper(_ResponseJsonSchemaMapper): name = TextParameter.OUTPUT_SCHEMA -class WebSearchMapper(_WebSearchMapper[TextContent]): - """Map web_search to Google's tools parameter.""" +class ToolsMapper(_ToolsMapper[TextContent]): + """Map tools to Google's tools parameter.""" - name = TextParameter.WEB_SEARCH + name = TextParameter.TOOLS GOOGLE_PARAMETER_MAPPERS: list[ParameterMapper[TextContent]] = [ @@ -66,7 +66,7 @@ class WebSearchMapper(_WebSearchMapper[TextContent]): ThinkingBudgetMapper(), ThinkingLevelMapper(), OutputSchemaMapper(), - WebSearchMapper(), + ToolsMapper(), ] __all__ = ["GOOGLE_PARAMETER_MAPPERS"] diff --git a/src/celeste/modalities/text/providers/groq/client.py b/src/celeste/modalities/text/providers/groq/client.py index aa180d7..e1ea12d 100644 --- a/src/celeste/modalities/text/providers/groq/client.py +++ b/src/celeste/modalities/text/providers/groq/client.py @@ -1,90 +1,31 @@ """Groq text client (modality).""" -from typing import Any, Unpack - from celeste.parameters import ParameterMapper from celeste.providers.groq.chat.client import GroqChatClient from celeste.providers.groq.chat.streaming import GroqChatStream as _GroqChatStream -from celeste.types import ImageContent, Message, TextContent, VideoContent -from celeste.utils import build_image_data_url +from celeste.types import TextContent -from ...client import TextClient -from ...io import ( - TextInput, - TextOutput, +from ...protocols.chatcompletions.client import ( + ChatCompletionsTextClient, +) +from ...protocols.chatcompletions.client import ( + ChatCompletionsTextStream as _ChatCompletionsTextStream, ) -from ...parameters import TextParameters from ...streaming import TextStream from .parameters import GROQ_PARAMETER_MAPPERS -class GroqTextStream(_GroqChatStream, TextStream): +class GroqTextStream(_GroqChatStream, _ChatCompletionsTextStream): """Groq streaming for text modality.""" -class GroqTextClient(GroqChatClient, TextClient): +class GroqTextClient(GroqChatClient, ChatCompletionsTextClient): """Groq text client.""" @classmethod def parameter_mappers(cls) -> list[ParameterMapper[TextContent]]: return GROQ_PARAMETER_MAPPERS - async def generate( - self, - prompt: str | None = None, - *, - messages: list[Message] | None = None, - **parameters: Unpack[TextParameters], - ) -> TextOutput: - """Generate text from prompt.""" - inputs = TextInput(prompt=prompt, messages=messages) - return await self._predict(inputs, **parameters) - - async def analyze( - self, - prompt: str | None = None, - *, - messages: list[Message] | None = None, - image: ImageContent | None = None, - video: VideoContent | None = None, - **parameters: Unpack[TextParameters], - ) -> TextOutput: - """Analyze image(s) or video(s) with prompt or messages.""" - inputs = TextInput(prompt=prompt, messages=messages, image=image, video=video) - return await self._predict(inputs, **parameters) - - def _init_request(self, inputs: TextInput) -> dict[str, Any]: - """Initialize request from Groq messages array format.""" - # If messages provided, use them directly (messages take precedence) - if inputs.messages is not None: - return {"messages": [message.model_dump() for message in inputs.messages]} - - # Fall back to prompt-based input - if inputs.image is None: - content: str | list[dict[str, Any]] = inputs.prompt or "" - else: - images = inputs.image if isinstance(inputs.image, list) else [inputs.image] - content = [ - { - "type": "image_url", - "image_url": {"url": build_image_data_url(img)}, - } - for img in images - ] - content.append({"type": "text", "text": inputs.prompt or ""}) - - return {"messages": [{"role": "user", "content": content}]} - - def _parse_content( - self, - response_data: dict[str, Any], - ) -> TextContent: - """Parse content from response.""" - choices = super()._parse_content(response_data) - message = choices[0].get("message", {}) - content = message.get("content") or "" - return content - def _stream_class(self) -> type[TextStream]: """Return the Stream class for this provider.""" return GroqTextStream diff --git a/src/celeste/modalities/text/providers/groq/models.py b/src/celeste/modalities/text/providers/groq/models.py index 5b1c42f..9e23b5d 100644 --- a/src/celeste/modalities/text/providers/groq/models.py +++ b/src/celeste/modalities/text/providers/groq/models.py @@ -1,8 +1,9 @@ """Groq models for text modality.""" -from celeste.constraints import ImagesConstraint, Range, Schema +from celeste.constraints import ImagesConstraint, Range, Schema, ToolSupport from celeste.core import Modality, Operation, Parameter, Provider from celeste.models import Model +from celeste.tools import CodeExecution, WebSearch from ...parameters import TextParameter @@ -17,6 +18,7 @@ Parameter.TEMPERATURE: Range(min=0.0, max=2.0, step=0.01), Parameter.MAX_TOKENS: Range(min=1, max=32768, step=1), TextParameter.OUTPUT_SCHEMA: Schema(), + TextParameter.TOOLS: ToolSupport(tools=[]), }, ), Model( @@ -29,6 +31,7 @@ Parameter.TEMPERATURE: Range(min=0.0, max=2.0, step=0.01), Parameter.MAX_TOKENS: Range(min=1, max=131072, step=1), TextParameter.OUTPUT_SCHEMA: Schema(), + TextParameter.TOOLS: ToolSupport(tools=[]), }, ), Model( @@ -41,6 +44,7 @@ Parameter.TEMPERATURE: Range(min=0.0, max=2.0, step=0.01), Parameter.MAX_TOKENS: Range(min=1, max=40960, step=1), TextParameter.OUTPUT_SCHEMA: Schema(), + TextParameter.TOOLS: ToolSupport(tools=[]), }, ), Model( @@ -53,6 +57,7 @@ Parameter.TEMPERATURE: Range(min=0.0, max=2.0, step=0.01), Parameter.MAX_TOKENS: Range(min=1, max=16384, step=1), TextParameter.OUTPUT_SCHEMA: Schema(), + TextParameter.TOOLS: ToolSupport(tools=[]), }, ), Model( @@ -65,6 +70,7 @@ Parameter.TEMPERATURE: Range(min=0.0, max=2.0, step=0.01), Parameter.MAX_TOKENS: Range(min=1, max=16384, step=1), TextParameter.OUTPUT_SCHEMA: Schema(), + TextParameter.TOOLS: ToolSupport(tools=[]), }, ), Model( @@ -78,6 +84,7 @@ Parameter.MAX_TOKENS: Range(min=1, max=8192, step=1), TextParameter.OUTPUT_SCHEMA: Schema(), TextParameter.IMAGE: ImagesConstraint(), + TextParameter.TOOLS: ToolSupport(tools=[]), }, ), Model( @@ -91,6 +98,7 @@ Parameter.MAX_TOKENS: Range(min=1, max=8192, step=1), TextParameter.OUTPUT_SCHEMA: Schema(), TextParameter.IMAGE: ImagesConstraint(), + TextParameter.TOOLS: ToolSupport(tools=[]), }, ), Model( @@ -103,6 +111,7 @@ Parameter.TEMPERATURE: Range(min=0.0, max=2.0, step=0.01), Parameter.MAX_TOKENS: Range(min=1, max=65536, step=1), TextParameter.OUTPUT_SCHEMA: Schema(), + TextParameter.TOOLS: ToolSupport(tools=[WebSearch, CodeExecution]), }, ), Model( @@ -115,6 +124,7 @@ Parameter.TEMPERATURE: Range(min=0.0, max=2.0, step=0.01), Parameter.MAX_TOKENS: Range(min=1, max=65536, step=1), TextParameter.OUTPUT_SCHEMA: Schema(), + TextParameter.TOOLS: ToolSupport(tools=[WebSearch, CodeExecution]), }, ), Model( @@ -127,6 +137,7 @@ Parameter.TEMPERATURE: Range(min=0.0, max=2.0, step=0.01), Parameter.MAX_TOKENS: Range(min=1, max=65536, step=1), TextParameter.OUTPUT_SCHEMA: Schema(), + TextParameter.TOOLS: ToolSupport(tools=[WebSearch, CodeExecution]), }, ), Model( @@ -139,6 +150,7 @@ Parameter.TEMPERATURE: Range(min=0.0, max=2.0, step=0.01), Parameter.MAX_TOKENS: Range(min=1, max=8192, step=1), TextParameter.OUTPUT_SCHEMA: Schema(), + TextParameter.TOOLS: ToolSupport(tools=[]), }, ), Model( @@ -151,6 +163,7 @@ Parameter.TEMPERATURE: Range(min=0.0, max=2.0, step=0.01), Parameter.MAX_TOKENS: Range(min=1, max=8192, step=1), TextParameter.OUTPUT_SCHEMA: Schema(), + TextParameter.TOOLS: ToolSupport(tools=[]), }, ), Model( @@ -163,6 +176,7 @@ Parameter.TEMPERATURE: Range(min=0.0, max=2.0, step=0.01), Parameter.MAX_TOKENS: Range(min=1, max=4096, step=1), TextParameter.OUTPUT_SCHEMA: Schema(), + TextParameter.TOOLS: ToolSupport(tools=[]), }, ), ] diff --git a/src/celeste/modalities/text/providers/groq/parameters.py b/src/celeste/modalities/text/providers/groq/parameters.py index d1ce11c..658bb21 100644 --- a/src/celeste/modalities/text/providers/groq/parameters.py +++ b/src/celeste/modalities/text/providers/groq/parameters.py @@ -7,9 +7,13 @@ from celeste.protocols.chatcompletions.parameters import ( TemperatureMapper as _TemperatureMapper, ) +from celeste.protocols.chatcompletions.parameters import ( + ToolsMapper as _ToolsMapper, +) from celeste.providers.groq.chat.parameters import ( ResponseFormatMapper as _ResponseFormatMapper, ) +from celeste.providers.groq.chat.tools import TOOL_MAPPERS as GROQ_TOOL_MAPPERS from celeste.types import TextContent from ...parameters import TextParameter @@ -33,10 +37,18 @@ class OutputSchemaMapper(_ResponseFormatMapper): name = TextParameter.OUTPUT_SCHEMA +class ToolsMapper(_ToolsMapper): + """Map tools to Groq's tools parameter.""" + + name = TextParameter.TOOLS + _tool_mappers = GROQ_TOOL_MAPPERS + + GROQ_PARAMETER_MAPPERS: list[ParameterMapper[TextContent]] = [ TemperatureMapper(), MaxTokensMapper(), OutputSchemaMapper(), + ToolsMapper(), ] __all__ = ["GROQ_PARAMETER_MAPPERS"] diff --git a/src/celeste/modalities/text/providers/huggingface/client.py b/src/celeste/modalities/text/providers/huggingface/client.py index 4d292f5..b71fa3b 100644 --- a/src/celeste/modalities/text/providers/huggingface/client.py +++ b/src/celeste/modalities/text/providers/huggingface/client.py @@ -1,92 +1,33 @@ """HuggingFace text client (modality).""" -from typing import Any, Unpack - from celeste.parameters import ParameterMapper from celeste.providers.huggingface.chat.client import HuggingFaceChatClient from celeste.providers.huggingface.chat.streaming import ( HuggingFaceChatStream as _HuggingFaceChatStream, ) -from celeste.types import ImageContent, Message, TextContent, VideoContent -from celeste.utils import build_image_data_url +from celeste.types import TextContent -from ...client import TextClient -from ...io import ( - TextInput, - TextOutput, +from ...protocols.chatcompletions.client import ( + ChatCompletionsTextClient, +) +from ...protocols.chatcompletions.client import ( + ChatCompletionsTextStream as _ChatCompletionsTextStream, ) -from ...parameters import TextParameters from ...streaming import TextStream from .parameters import HUGGINGFACE_PARAMETER_MAPPERS -class HuggingFaceTextStream(_HuggingFaceChatStream, TextStream): +class HuggingFaceTextStream(_HuggingFaceChatStream, _ChatCompletionsTextStream): """HuggingFace streaming for text modality.""" -class HuggingFaceTextClient(HuggingFaceChatClient, TextClient): +class HuggingFaceTextClient(HuggingFaceChatClient, ChatCompletionsTextClient): """HuggingFace text client.""" @classmethod def parameter_mappers(cls) -> list[ParameterMapper[TextContent]]: return HUGGINGFACE_PARAMETER_MAPPERS - async def generate( - self, - prompt: str | None = None, - *, - messages: list[Message] | None = None, - **parameters: Unpack[TextParameters], - ) -> TextOutput: - """Generate text from prompt.""" - inputs = TextInput(prompt=prompt, messages=messages) - return await self._predict(inputs, **parameters) - - async def analyze( - self, - prompt: str | None = None, - *, - messages: list[Message] | None = None, - image: ImageContent | None = None, - video: VideoContent | None = None, - **parameters: Unpack[TextParameters], - ) -> TextOutput: - """Analyze image(s) or video(s) with prompt or messages.""" - inputs = TextInput(prompt=prompt, messages=messages, image=image, video=video) - return await self._predict(inputs, **parameters) - - def _init_request(self, inputs: TextInput) -> dict[str, Any]: - """Initialize request from HuggingFace messages array format.""" - # If messages provided, use them directly (messages take precedence) - if inputs.messages is not None: - return {"messages": [message.model_dump() for message in inputs.messages]} - - # Fall back to prompt-based input - if inputs.image is None: - content: str | list[dict[str, Any]] = inputs.prompt or "" - else: - images = inputs.image if isinstance(inputs.image, list) else [inputs.image] - content = [ - { - "type": "image_url", - "image_url": {"url": build_image_data_url(img)}, - } - for img in images - ] - content.append({"type": "text", "text": inputs.prompt or ""}) - - return {"messages": [{"role": "user", "content": content}]} - - def _parse_content( - self, - response_data: dict[str, Any], - ) -> TextContent: - """Parse content from response.""" - choices = super()._parse_content(response_data) - message = choices[0].get("message", {}) - content = message.get("content") or "" - return content - def _stream_class(self) -> type[TextStream]: """Return the Stream class for this provider.""" return HuggingFaceTextStream diff --git a/src/celeste/modalities/text/providers/huggingface/models.py b/src/celeste/modalities/text/providers/huggingface/models.py index a1d92e7..d1d9189 100644 --- a/src/celeste/modalities/text/providers/huggingface/models.py +++ b/src/celeste/modalities/text/providers/huggingface/models.py @@ -1,6 +1,6 @@ """HuggingFace models for text modality.""" -from celeste.constraints import ImagesConstraint, Range, Schema +from celeste.constraints import ImagesConstraint, Range, Schema, ToolSupport from celeste.core import Modality, Operation, Parameter, Provider from celeste.models import Model @@ -17,6 +17,7 @@ Parameter.TEMPERATURE: Range(min=0.0, max=2.0, step=0.01), Parameter.MAX_TOKENS: Range(min=1, max=32768, step=1), TextParameter.OUTPUT_SCHEMA: Schema(), + TextParameter.TOOLS: ToolSupport(tools=[]), }, ), Model( @@ -30,6 +31,7 @@ Parameter.MAX_TOKENS: Range(min=1, max=32768, step=1), TextParameter.OUTPUT_SCHEMA: Schema(), TextParameter.IMAGE: ImagesConstraint(), + TextParameter.TOOLS: ToolSupport(tools=[]), }, ), ] diff --git a/src/celeste/modalities/text/providers/huggingface/parameters.py b/src/celeste/modalities/text/providers/huggingface/parameters.py index c03adb4..040a49d 100644 --- a/src/celeste/modalities/text/providers/huggingface/parameters.py +++ b/src/celeste/modalities/text/providers/huggingface/parameters.py @@ -7,6 +7,9 @@ from celeste.protocols.chatcompletions.parameters import ( TemperatureMapper as _TemperatureMapper, ) +from celeste.protocols.chatcompletions.parameters import ( + ToolsMapper as _ToolsMapper, +) from celeste.providers.huggingface.chat.parameters import ( ResponseFormatMapper as _ResponseFormatMapper, ) @@ -33,10 +36,17 @@ class OutputSchemaMapper(_ResponseFormatMapper): name = TextParameter.OUTPUT_SCHEMA +class ToolsMapper(_ToolsMapper): + """Map tools to HuggingFace's tools parameter (user-defined only).""" + + name = TextParameter.TOOLS + + HUGGINGFACE_PARAMETER_MAPPERS: list[ParameterMapper[TextContent]] = [ TemperatureMapper(), MaxTokensMapper(), OutputSchemaMapper(), + ToolsMapper(), ] __all__ = ["HUGGINGFACE_PARAMETER_MAPPERS"] diff --git a/src/celeste/modalities/text/providers/mistral/client.py b/src/celeste/modalities/text/providers/mistral/client.py index 896cdf5..c0e338e 100644 --- a/src/celeste/modalities/text/providers/mistral/client.py +++ b/src/celeste/modalities/text/providers/mistral/client.py @@ -1,88 +1,41 @@ """Mistral text client (modality).""" -from typing import Any, Unpack +from typing import Any from celeste.parameters import ParameterMapper from celeste.providers.mistral.chat.client import MistralChatClient from celeste.providers.mistral.chat.streaming import ( MistralChatStream as _MistralChatStream, ) -from celeste.types import ImageContent, Message, TextContent, VideoContent -from celeste.utils import build_image_data_url +from celeste.types import TextContent -from ...client import TextClient -from ...io import ( - TextInput, - TextOutput, +from ...protocols.chatcompletions.client import ( + ChatCompletionsTextClient, +) +from ...protocols.chatcompletions.client import ( + ChatCompletionsTextStream as _ChatCompletionsTextStream, ) -from ...parameters import TextParameters from ...streaming import TextStream from .parameters import MISTRAL_PARAMETER_MAPPERS -class MistralTextStream(_MistralChatStream, TextStream): +class MistralTextStream(_MistralChatStream, _ChatCompletionsTextStream): """Mistral streaming for text modality.""" -class MistralTextClient(MistralChatClient, TextClient): +class MistralTextClient(MistralChatClient, ChatCompletionsTextClient): """Mistral text client.""" @classmethod def parameter_mappers(cls) -> list[ParameterMapper[TextContent]]: return MISTRAL_PARAMETER_MAPPERS - async def generate( - self, - prompt: str | None = None, - *, - messages: list[Message] | None = None, - **parameters: Unpack[TextParameters], - ) -> TextOutput: - """Generate text from prompt.""" - inputs = TextInput(prompt=prompt, messages=messages) - return await self._predict(inputs, **parameters) - - async def analyze( - self, - prompt: str | None = None, - *, - messages: list[Message] | None = None, - image: ImageContent | None = None, - video: VideoContent | None = None, - **parameters: Unpack[TextParameters], - ) -> TextOutput: - """Analyze image(s) or video(s) with prompt or messages.""" - inputs = TextInput(prompt=prompt, messages=messages, image=image, video=video) - return await self._predict(inputs, **parameters) - - def _init_request(self, inputs: TextInput) -> dict[str, Any]: - """Initialize request from Mistral messages array format.""" - # If messages provided, use them directly (messages take precedence) - if inputs.messages is not None: - return {"messages": [message.model_dump() for message in inputs.messages]} - - # Fall back to prompt-based input - if inputs.image is None: - content: str | list[dict[str, Any]] = inputs.prompt or "" - else: - images = inputs.image if isinstance(inputs.image, list) else [inputs.image] - content = [ - {"type": "image_url", "image_url": {"url": build_image_data_url(img)}} - for img in images - ] - content.append({"type": "text", "text": inputs.prompt or ""}) - - return {"messages": [{"role": "user", "content": content}]} - def _parse_content( self, response_data: dict[str, Any], ) -> TextContent: - """Parse content from response.""" - choices = super()._parse_content(response_data) - first_choice = choices[0] - message = first_choice.get("message", {}) - content = message.get("content") or "" + """Parse content from response, handling thinking model list content.""" + content = super()._parse_content(response_data) # Handle magistral thinking models that return list content if isinstance(content, list): diff --git a/src/celeste/modalities/text/providers/mistral/models.py b/src/celeste/modalities/text/providers/mistral/models.py index 114b5ec..e7857de 100644 --- a/src/celeste/modalities/text/providers/mistral/models.py +++ b/src/celeste/modalities/text/providers/mistral/models.py @@ -1,6 +1,6 @@ """Mistral models for text modality.""" -from celeste.constraints import ImagesConstraint, Range, Schema +from celeste.constraints import ImagesConstraint, Range, Schema, ToolSupport from celeste.core import Modality, Operation, Parameter, Provider from celeste.models import Model @@ -17,6 +17,7 @@ Parameter.TEMPERATURE: Range(min=0.0, max=2.0, step=0.01), Parameter.MAX_TOKENS: Range(min=1, max=32768, step=1), TextParameter.OUTPUT_SCHEMA: Schema(), + TextParameter.TOOLS: ToolSupport(tools=[]), }, ), Model( @@ -30,6 +31,7 @@ Parameter.MAX_TOKENS: Range(min=1, max=32768, step=1), TextParameter.OUTPUT_SCHEMA: Schema(), TextParameter.IMAGE: ImagesConstraint(), + TextParameter.TOOLS: ToolSupport(tools=[]), }, ), Model( @@ -43,6 +45,7 @@ Parameter.MAX_TOKENS: Range(min=1, max=32768, step=1), TextParameter.OUTPUT_SCHEMA: Schema(), TextParameter.IMAGE: ImagesConstraint(), + TextParameter.TOOLS: ToolSupport(tools=[]), }, ), Model( @@ -55,6 +58,7 @@ Parameter.TEMPERATURE: Range(min=0.0, max=2.0, step=0.01), Parameter.MAX_TOKENS: Range(min=1, max=32768, step=1), TextParameter.OUTPUT_SCHEMA: Schema(), + TextParameter.TOOLS: ToolSupport(tools=[]), }, ), Model( @@ -67,6 +71,7 @@ Parameter.TEMPERATURE: Range(min=0.0, max=2.0, step=0.01), Parameter.MAX_TOKENS: Range(min=1, max=32768, step=1), TextParameter.OUTPUT_SCHEMA: Schema(), + TextParameter.TOOLS: ToolSupport(tools=[]), }, ), Model( @@ -79,6 +84,7 @@ Parameter.TEMPERATURE: Range(min=0.0, max=2.0, step=0.01), Parameter.MAX_TOKENS: Range(min=1, max=32768, step=1), TextParameter.OUTPUT_SCHEMA: Schema(), + TextParameter.TOOLS: ToolSupport(tools=[]), }, ), Model( @@ -91,6 +97,7 @@ Parameter.TEMPERATURE: Range(min=0.0, max=2.0, step=0.01), Parameter.MAX_TOKENS: Range(min=1, max=32768, step=1), TextParameter.OUTPUT_SCHEMA: Schema(), + TextParameter.TOOLS: ToolSupport(tools=[]), }, ), Model( @@ -103,6 +110,7 @@ Parameter.TEMPERATURE: Range(min=0.0, max=2.0, step=0.01), Parameter.MAX_TOKENS: Range(min=1, max=32768, step=1), TextParameter.OUTPUT_SCHEMA: Schema(), + TextParameter.TOOLS: ToolSupport(tools=[]), }, ), Model( @@ -115,6 +123,7 @@ Parameter.TEMPERATURE: Range(min=0.0, max=2.0, step=0.01), Parameter.MAX_TOKENS: Range(min=1, max=32768, step=1), TextParameter.OUTPUT_SCHEMA: Schema(), + TextParameter.TOOLS: ToolSupport(tools=[]), }, ), Model( @@ -128,6 +137,7 @@ Parameter.MAX_TOKENS: Range(min=1, max=32768, step=1), TextParameter.OUTPUT_SCHEMA: Schema(), TextParameter.IMAGE: ImagesConstraint(), + TextParameter.TOOLS: ToolSupport(tools=[]), }, ), Model( @@ -141,6 +151,7 @@ Parameter.MAX_TOKENS: Range(min=1, max=32768, step=1), TextParameter.OUTPUT_SCHEMA: Schema(), TextParameter.IMAGE: ImagesConstraint(), + TextParameter.TOOLS: ToolSupport(tools=[]), }, ), Model( @@ -154,6 +165,7 @@ Parameter.MAX_TOKENS: Range(min=1, max=32768, step=1), TextParameter.OUTPUT_SCHEMA: Schema(), TextParameter.IMAGE: ImagesConstraint(), + TextParameter.TOOLS: ToolSupport(tools=[]), }, ), Model( @@ -166,6 +178,7 @@ Parameter.TEMPERATURE: Range(min=0.0, max=2.0, step=0.01), Parameter.MAX_TOKENS: Range(min=1, max=32768, step=1), TextParameter.OUTPUT_SCHEMA: Schema(), + TextParameter.TOOLS: ToolSupport(tools=[]), }, ), Model( @@ -178,6 +191,7 @@ Parameter.TEMPERATURE: Range(min=0.0, max=2.0, step=0.01), Parameter.MAX_TOKENS: Range(min=1, max=32768, step=1), TextParameter.OUTPUT_SCHEMA: Schema(), + TextParameter.TOOLS: ToolSupport(tools=[]), }, ), Model( @@ -190,6 +204,7 @@ Parameter.TEMPERATURE: Range(min=0.0, max=2.0, step=0.01), Parameter.MAX_TOKENS: Range(min=1, max=32768, step=1), TextParameter.OUTPUT_SCHEMA: Schema(), + TextParameter.TOOLS: ToolSupport(tools=[]), }, ), Model( @@ -202,6 +217,7 @@ Parameter.TEMPERATURE: Range(min=0.0, max=2.0, step=0.01), Parameter.MAX_TOKENS: Range(min=1, max=32768, step=1), TextParameter.OUTPUT_SCHEMA: Schema(), + TextParameter.TOOLS: ToolSupport(tools=[]), }, ), Model( @@ -216,6 +232,7 @@ TextParameter.THINKING_BUDGET: Range(min=-1, max=32768, step=1), TextParameter.OUTPUT_SCHEMA: Schema(), TextParameter.IMAGE: ImagesConstraint(), + TextParameter.TOOLS: ToolSupport(tools=[]), }, ), Model( @@ -230,6 +247,7 @@ TextParameter.THINKING_BUDGET: Range(min=-1, max=32768, step=1), TextParameter.OUTPUT_SCHEMA: Schema(), TextParameter.IMAGE: ImagesConstraint(), + TextParameter.TOOLS: ToolSupport(tools=[]), }, ), ] diff --git a/src/celeste/modalities/text/providers/mistral/parameters.py b/src/celeste/modalities/text/providers/mistral/parameters.py index fc96729..993ddcf 100644 --- a/src/celeste/modalities/text/providers/mistral/parameters.py +++ b/src/celeste/modalities/text/providers/mistral/parameters.py @@ -10,6 +10,9 @@ from celeste.protocols.chatcompletions.parameters import ( TemperatureMapper as _TemperatureMapper, ) +from celeste.protocols.chatcompletions.parameters import ( + ToolsMapper as _ToolsMapper, +) from celeste.providers.mistral.chat.parameters import ( ResponseFormatMapper as _ResponseFormatMapper, ) @@ -63,11 +66,18 @@ class OutputSchemaMapper(_ResponseFormatMapper): name = TextParameter.OUTPUT_SCHEMA +class ToolsMapper(_ToolsMapper): + """Map tools to Mistral's tools parameter.""" + + name = TextParameter.TOOLS + + MISTRAL_PARAMETER_MAPPERS: list[ParameterMapper[TextContent]] = [ TemperatureMapper(), MaxTokensMapper(), ThinkingBudgetMapper(), OutputSchemaMapper(), + ToolsMapper(), ] __all__ = ["MISTRAL_PARAMETER_MAPPERS"] diff --git a/src/celeste/modalities/text/providers/moonshot/client.py b/src/celeste/modalities/text/providers/moonshot/client.py index e0fa905..6d2ecdb 100644 --- a/src/celeste/modalities/text/providers/moonshot/client.py +++ b/src/celeste/modalities/text/providers/moonshot/client.py @@ -1,89 +1,33 @@ """Moonshot text client (modality).""" -from typing import Any, Unpack - from celeste.parameters import ParameterMapper from celeste.providers.moonshot.chat.client import MoonshotChatClient from celeste.providers.moonshot.chat.streaming import ( MoonshotChatStream as _MoonshotChatStream, ) -from celeste.types import ImageContent, Message, TextContent, VideoContent -from celeste.utils import build_image_data_url +from celeste.types import TextContent -from ...client import TextClient -from ...io import ( - TextInput, - TextOutput, +from ...protocols.chatcompletions.client import ( + ChatCompletionsTextClient, +) +from ...protocols.chatcompletions.client import ( + ChatCompletionsTextStream as _ChatCompletionsTextStream, ) -from ...parameters import TextParameters from ...streaming import TextStream from .parameters import MOONSHOT_PARAMETER_MAPPERS -class MoonshotTextStream(_MoonshotChatStream, TextStream): +class MoonshotTextStream(_MoonshotChatStream, _ChatCompletionsTextStream): """Moonshot streaming for text modality.""" -class MoonshotTextClient(MoonshotChatClient, TextClient): +class MoonshotTextClient(MoonshotChatClient, ChatCompletionsTextClient): """Moonshot text client.""" @classmethod def parameter_mappers(cls) -> list[ParameterMapper[TextContent]]: return MOONSHOT_PARAMETER_MAPPERS - async def generate( - self, - prompt: str | None = None, - *, - messages: list[Message] | None = None, - **parameters: Unpack[TextParameters], - ) -> TextOutput: - """Generate text from prompt.""" - inputs = TextInput(prompt=prompt, messages=messages) - return await self._predict(inputs, **parameters) - - async def analyze( - self, - prompt: str | None = None, - *, - messages: list[Message] | None = None, - image: ImageContent | None = None, - video: VideoContent | None = None, - **parameters: Unpack[TextParameters], - ) -> TextOutput: - """Analyze image(s) or video(s) with prompt or messages.""" - inputs = TextInput(prompt=prompt, messages=messages, image=image, video=video) - return await self._predict(inputs, **parameters) - - def _init_request(self, inputs: TextInput) -> dict[str, Any]: - """Initialize request from Moonshot messages array format.""" - # If messages provided, use them directly (messages take precedence) - if inputs.messages is not None: - return {"messages": [message.model_dump() for message in inputs.messages]} - - # Fall back to prompt-based input - if inputs.image is None: - content: str | list[dict[str, Any]] = inputs.prompt or "" - else: - images = inputs.image if isinstance(inputs.image, list) else [inputs.image] - content = [ - {"type": "image_url", "image_url": {"url": build_image_data_url(img)}} - for img in images - ] - content.append({"type": "text", "text": inputs.prompt or ""}) - - return {"messages": [{"role": "user", "content": content}]} - - def _parse_content( - self, - response_data: dict[str, Any], - ) -> TextContent: - """Parse content from response.""" - choices = super()._parse_content(response_data) - message = choices[0].get("message", {}) - content = message.get("content") or "" - return content - def _stream_class(self) -> type[TextStream]: """Return the Stream class for this provider.""" return MoonshotTextStream diff --git a/src/celeste/modalities/text/providers/moonshot/models.py b/src/celeste/modalities/text/providers/moonshot/models.py index fa1d8c2..0c3bf5f 100644 --- a/src/celeste/modalities/text/providers/moonshot/models.py +++ b/src/celeste/modalities/text/providers/moonshot/models.py @@ -1,12 +1,27 @@ """Moonshot models for text modality.""" -from celeste.constraints import ImagesConstraint, Range, Schema +from celeste.constraints import ImagesConstraint, Range, Schema, ToolSupport from celeste.core import Modality, Operation, Parameter, Provider from celeste.models import Model +from celeste.tools import WebSearch from ...parameters import TextParameter MODELS: list[Model] = [ + Model( + id="kimi-k2.5", + provider=Provider.MOONSHOT, + display_name="Kimi K2.5", + operations={Modality.TEXT: {Operation.GENERATE, Operation.ANALYZE}}, + streaming=True, + parameter_constraints={ + Parameter.TEMPERATURE: Range(min=0.0, max=2.0, step=0.01), + Parameter.MAX_TOKENS: Range(min=1, max=65535, step=1), + TextParameter.OUTPUT_SCHEMA: Schema(), + TextParameter.IMAGE: ImagesConstraint(), + TextParameter.TOOLS: ToolSupport(tools=[WebSearch]), + }, + ), Model( id="moonshot-v1-8k-vision-preview", provider=Provider.MOONSHOT, @@ -18,6 +33,7 @@ Parameter.MAX_TOKENS: Range(min=1, max=8192, step=1), TextParameter.OUTPUT_SCHEMA: Schema(), TextParameter.IMAGE: ImagesConstraint(), + TextParameter.TOOLS: ToolSupport(tools=[WebSearch]), }, ), Model( @@ -30,6 +46,7 @@ Parameter.TEMPERATURE: Range(min=0.0, max=1.0, step=0.01), Parameter.MAX_TOKENS: Range(min=1, max=32768, step=1), TextParameter.OUTPUT_SCHEMA: Schema(), + TextParameter.TOOLS: ToolSupport(tools=[WebSearch]), }, ), Model( @@ -42,6 +59,7 @@ Parameter.TEMPERATURE: Range(min=0.0, max=1.0, step=0.01), Parameter.MAX_TOKENS: Range(min=1, max=32768, step=1), TextParameter.OUTPUT_SCHEMA: Schema(), + TextParameter.TOOLS: ToolSupport(tools=[WebSearch]), }, ), Model( @@ -54,6 +72,7 @@ Parameter.TEMPERATURE: Range(min=0.0, max=1.0, step=0.01), Parameter.MAX_TOKENS: Range(min=1, max=32768, step=1), TextParameter.OUTPUT_SCHEMA: Schema(), + TextParameter.TOOLS: ToolSupport(tools=[WebSearch]), }, ), Model( @@ -66,6 +85,7 @@ Parameter.TEMPERATURE: Range(min=0.0, max=1.0, step=0.01), Parameter.MAX_TOKENS: Range(min=1, max=32768, step=1), TextParameter.OUTPUT_SCHEMA: Schema(), + TextParameter.TOOLS: ToolSupport(tools=[WebSearch]), }, ), Model( @@ -78,6 +98,7 @@ Parameter.TEMPERATURE: Range(min=0.0, max=1.0, step=0.01), Parameter.MAX_TOKENS: Range(min=1, max=32768, step=1), TextParameter.OUTPUT_SCHEMA: Schema(), + TextParameter.TOOLS: ToolSupport(tools=[WebSearch]), }, ), ] diff --git a/src/celeste/modalities/text/providers/moonshot/parameters.py b/src/celeste/modalities/text/providers/moonshot/parameters.py index 480cf17..d7f11e2 100644 --- a/src/celeste/modalities/text/providers/moonshot/parameters.py +++ b/src/celeste/modalities/text/providers/moonshot/parameters.py @@ -10,6 +10,10 @@ from celeste.protocols.chatcompletions.parameters import ( TemperatureMapper as _TemperatureMapper, ) +from celeste.protocols.chatcompletions.parameters import ( + ToolsMapper as _ToolsMapper, +) +from celeste.providers.moonshot.chat.tools import TOOL_MAPPERS as MOONSHOT_TOOL_MAPPERS from celeste.types import TextContent from ...parameters import TextParameter @@ -33,10 +37,18 @@ class OutputSchemaMapper(_ResponseFormatMapper): name = TextParameter.OUTPUT_SCHEMA +class ToolsMapper(_ToolsMapper): + """Map tools to Moonshot's tools parameter.""" + + name = TextParameter.TOOLS + _tool_mappers = MOONSHOT_TOOL_MAPPERS + + MOONSHOT_PARAMETER_MAPPERS: list[ParameterMapper[TextContent]] = [ TemperatureMapper(), MaxTokensMapper(), OutputSchemaMapper(), + ToolsMapper(), ] __all__ = ["MOONSHOT_PARAMETER_MAPPERS"] diff --git a/src/celeste/modalities/text/providers/ollama/client.py b/src/celeste/modalities/text/providers/ollama/client.py index f87d931..2b65800 100644 --- a/src/celeste/modalities/text/providers/ollama/client.py +++ b/src/celeste/modalities/text/providers/ollama/client.py @@ -1,76 +1,18 @@ """Ollama text client (OpenResponses protocol).""" -from typing import Any, Unpack +from typing import ClassVar from celeste.modalities.text.providers.openresponses.client import ( OpenResponsesTextClient, OpenResponsesTextStream, ) from celeste.providers.ollama.responses.config import DEFAULT_BASE_URL -from celeste.types import ImageContent, Message, VideoContent - -from ...io import TextInput, TextOutput -from ...parameters import TextParameters -from ...streaming import TextStream class OllamaTextClient(OpenResponsesTextClient): """Ollama - OpenResponses with default localhost:11434.""" - async def generate( - self, - prompt: str | None = None, - *, - messages: list[Message] | None = None, - base_url: str | None = None, - extra_body: dict[str, Any] | None = None, - **parameters: Unpack[TextParameters], - ) -> TextOutput: - return await super().generate( - prompt, - messages=messages, - base_url=base_url or DEFAULT_BASE_URL, - extra_body=extra_body, - **parameters, - ) - - async def analyze( - self, - prompt: str | None = None, - *, - messages: list[Message] | None = None, - image: ImageContent | None = None, - video: VideoContent | None = None, - base_url: str | None = None, - extra_body: dict[str, Any] | None = None, - **parameters: Unpack[TextParameters], - ) -> TextOutput: - return await super().analyze( - prompt, - messages=messages, - image=image, - video=video, - base_url=base_url or DEFAULT_BASE_URL, - extra_body=extra_body, - **parameters, - ) - - def _stream( - self, - inputs: TextInput, - stream_class: type[TextStream], - *, - base_url: str | None = None, - extra_body: dict[str, Any] | None = None, - **parameters: Unpack[TextParameters], - ) -> TextStream: - return super()._stream( - inputs, - stream_class, - base_url=base_url or DEFAULT_BASE_URL, - extra_body=extra_body, - **parameters, - ) + _default_base_url: ClassVar[str] = DEFAULT_BASE_URL OllamaTextStream = OpenResponsesTextStream diff --git a/src/celeste/modalities/text/providers/openai/client.py b/src/celeste/modalities/text/providers/openai/client.py index 14de021..f565df3 100644 --- a/src/celeste/modalities/text/providers/openai/client.py +++ b/src/celeste/modalities/text/providers/openai/client.py @@ -1,14 +1,17 @@ """OpenAI text client.""" +import json from typing import Any, Unpack from celeste.parameters import ParameterMapper +from celeste.protocols.openresponses.tools import serialize_messages from celeste.providers.openai.responses.client import ( OpenAIResponsesClient as OpenAIResponsesMixin, ) from celeste.providers.openai.responses.streaming import ( OpenAIResponsesStream as _OpenAIResponsesStream, ) +from celeste.tools import ToolCall from celeste.types import ImageContent, Message, TextContent, VideoContent from celeste.utils import build_image_data_url @@ -61,7 +64,7 @@ async def analyze( def _init_request(self, inputs: TextInput) -> dict[str, Any]: """Initialize request with input content.""" if inputs.messages is not None: - return {"input": [message.model_dump() for message in inputs.messages]} + return {"input": serialize_messages(inputs.messages)} content: list[dict[str, Any]] = [] @@ -92,6 +95,20 @@ def _parse_content( return "" + def _parse_tool_calls(self, response_data: dict[str, Any]) -> list[ToolCall]: + """Parse tool calls from OpenAI response.""" + return [ + ToolCall( + id=item.get("call_id", item.get("id", "")), + name=item["name"], + arguments=json.loads(item["arguments"]) + if isinstance(item.get("arguments"), str) + else item.get("arguments", {}), + ) + for item in response_data.get("output", []) + if item.get("type") == "function_call" + ] + def _stream_class(self) -> type[TextStream]: """Return the Stream class for this provider.""" return OpenAITextStream diff --git a/src/celeste/modalities/text/providers/openai/models.py b/src/celeste/modalities/text/providers/openai/models.py index 6268cf9..3e53845 100644 --- a/src/celeste/modalities/text/providers/openai/models.py +++ b/src/celeste/modalities/text/providers/openai/models.py @@ -1,14 +1,15 @@ """OpenAI models for text modality.""" from celeste.constraints import ( - Bool, Choice, ImagesConstraint, Range, Schema, + ToolSupport, ) from celeste.core import Modality, Operation, Parameter, Provider from celeste.models import Model +from celeste.tools import WebSearch from ...parameters import TextParameter @@ -22,8 +23,8 @@ parameter_constraints={ Parameter.TEMPERATURE: Range(min=0.0, max=2.0), Parameter.MAX_TOKENS: Range(min=1, max=16384), + TextParameter.TOOLS: ToolSupport(tools=[WebSearch]), TextParameter.OUTPUT_SCHEMA: Schema(), - TextParameter.WEB_SEARCH: Bool(), TextParameter.IMAGE: ImagesConstraint(), }, ), @@ -36,7 +37,7 @@ parameter_constraints={ Parameter.TEMPERATURE: Range(min=0.0, max=2.0), Parameter.MAX_TOKENS: Range(min=1, max=16384), - TextParameter.WEB_SEARCH: Bool(), + TextParameter.TOOLS: ToolSupport(tools=[WebSearch]), TextParameter.IMAGE: ImagesConstraint(), }, ), @@ -86,8 +87,8 @@ TextParameter.THINKING_BUDGET: Choice( options=["minimal", "low", "medium", "high"] ), + TextParameter.TOOLS: ToolSupport(tools=[WebSearch]), TextParameter.OUTPUT_SCHEMA: Schema(), - TextParameter.WEB_SEARCH: Bool(), TextParameter.IMAGE: ImagesConstraint(), }, ), @@ -102,8 +103,8 @@ TextParameter.THINKING_BUDGET: Choice( options=["minimal", "low", "medium", "high", "xhigh"] ), + TextParameter.TOOLS: ToolSupport(tools=[WebSearch]), TextParameter.OUTPUT_SCHEMA: Schema(), - TextParameter.WEB_SEARCH: Bool(), TextParameter.IMAGE: ImagesConstraint(), }, ), @@ -119,8 +120,8 @@ options=["minimal", "low", "medium", "high", "xhigh"] ), TextParameter.VERBOSITY: Choice(options=["low", "medium", "high"]), + TextParameter.TOOLS: ToolSupport(tools=[WebSearch]), TextParameter.OUTPUT_SCHEMA: Schema(), - TextParameter.WEB_SEARCH: Bool(), TextParameter.IMAGE: ImagesConstraint(), }, ), @@ -135,6 +136,7 @@ TextParameter.THINKING_BUDGET: Choice( options=["low", "medium", "high", "xhigh"] ), + TextParameter.TOOLS: ToolSupport(tools=[WebSearch]), TextParameter.OUTPUT_SCHEMA: Schema(), TextParameter.IMAGE: ImagesConstraint(), }, @@ -148,7 +150,7 @@ parameter_constraints={ Parameter.TEMPERATURE: Range(min=0.0, max=2.0), Parameter.MAX_TOKENS: Range(min=1, max=128000), - TextParameter.WEB_SEARCH: Bool(), + TextParameter.TOOLS: ToolSupport(tools=[WebSearch]), TextParameter.IMAGE: ImagesConstraint(), }, ), @@ -164,8 +166,8 @@ options=["minimal", "low", "medium", "high"] ), TextParameter.VERBOSITY: Choice(options=["low", "medium", "high"]), + TextParameter.TOOLS: ToolSupport(tools=[WebSearch]), TextParameter.OUTPUT_SCHEMA: Schema(), - TextParameter.WEB_SEARCH: Bool(), TextParameter.IMAGE: ImagesConstraint(), }, ), @@ -181,8 +183,8 @@ options=["minimal", "low", "medium", "high"] ), TextParameter.VERBOSITY: Choice(options=["low", "medium", "high"]), + TextParameter.TOOLS: ToolSupport(tools=[WebSearch]), TextParameter.OUTPUT_SCHEMA: Schema(), - TextParameter.WEB_SEARCH: Bool(), TextParameter.IMAGE: ImagesConstraint(), }, ), @@ -197,8 +199,8 @@ TextParameter.THINKING_BUDGET: Choice( options=["minimal", "low", "medium", "high"] ), + TextParameter.TOOLS: ToolSupport(tools=[WebSearch]), TextParameter.OUTPUT_SCHEMA: Schema(), - TextParameter.WEB_SEARCH: Bool(), TextParameter.IMAGE: ImagesConstraint(), }, ), @@ -213,8 +215,8 @@ TextParameter.THINKING_BUDGET: Choice( options=["minimal", "low", "medium", "high"] ), + TextParameter.TOOLS: ToolSupport(tools=[WebSearch]), TextParameter.OUTPUT_SCHEMA: Schema(), - TextParameter.WEB_SEARCH: Bool(), TextParameter.IMAGE: ImagesConstraint(), }, ), @@ -230,7 +232,7 @@ options=["minimal", "low", "medium", "high", "xhigh"] ), TextParameter.VERBOSITY: Choice(options=["low", "medium", "high"]), - TextParameter.WEB_SEARCH: Bool(), + TextParameter.TOOLS: ToolSupport(tools=[WebSearch]), TextParameter.OUTPUT_SCHEMA: Schema(), TextParameter.IMAGE: ImagesConstraint(), }, @@ -247,37 +249,7 @@ options=["minimal", "low", "medium", "high", "xhigh"] ), TextParameter.VERBOSITY: Choice(options=["low", "medium", "high"]), - TextParameter.WEB_SEARCH: Bool(), - TextParameter.OUTPUT_SCHEMA: Schema(), - TextParameter.IMAGE: ImagesConstraint(), - }, - ), - Model( - id="gpt-5.4-mini", - provider=Provider.OPENAI, - display_name="GPT-5.4 Mini", - operations={Modality.TEXT: {Operation.GENERATE, Operation.ANALYZE}}, - streaming=True, - parameter_constraints={ - Parameter.MAX_TOKENS: Range(min=1, max=128000), - TextParameter.THINKING_BUDGET: Choice( - options=["minimal", "low", "medium", "high", "xhigh"] - ), - TextParameter.OUTPUT_SCHEMA: Schema(), - TextParameter.IMAGE: ImagesConstraint(), - }, - ), - Model( - id="gpt-5.4-nano", - provider=Provider.OPENAI, - display_name="GPT-5.4 Nano", - operations={Modality.TEXT: {Operation.GENERATE, Operation.ANALYZE}}, - streaming=True, - parameter_constraints={ - Parameter.MAX_TOKENS: Range(min=1, max=128000), - TextParameter.THINKING_BUDGET: Choice( - options=["minimal", "low", "medium", "high", "xhigh"] - ), + TextParameter.TOOLS: ToolSupport(tools=[WebSearch]), TextParameter.OUTPUT_SCHEMA: Schema(), TextParameter.IMAGE: ImagesConstraint(), }, @@ -291,8 +263,8 @@ parameter_constraints={ Parameter.TEMPERATURE: Range(min=0.0, max=2.0), Parameter.MAX_TOKENS: Range(min=1, max=32768), + TextParameter.TOOLS: ToolSupport(tools=[WebSearch]), TextParameter.OUTPUT_SCHEMA: Schema(), - TextParameter.WEB_SEARCH: Bool(), TextParameter.IMAGE: ImagesConstraint(), }, ), diff --git a/src/celeste/modalities/text/providers/openai/parameters.py b/src/celeste/modalities/text/providers/openai/parameters.py index 96b6381..b745d4a 100644 --- a/src/celeste/modalities/text/providers/openai/parameters.py +++ b/src/celeste/modalities/text/providers/openai/parameters.py @@ -14,10 +14,10 @@ TextFormatMapper as _TextFormatMapper, ) from celeste.protocols.openresponses.parameters import ( - VerbosityMapper as _VerbosityMapper, + ToolsMapper as _ToolsMapper, ) from celeste.protocols.openresponses.parameters import ( - WebSearchMapper as _WebSearchMapper, + VerbosityMapper as _VerbosityMapper, ) from celeste.types import TextContent @@ -42,10 +42,10 @@ class OutputSchemaMapper(_TextFormatMapper): name = TextParameter.OUTPUT_SCHEMA -class WebSearchMapper(_WebSearchMapper): - """Map web_search to OpenAI's tools parameter.""" +class ToolsMapper(_ToolsMapper): + """Map tools to OpenAI's tools parameter.""" - name = TextParameter.WEB_SEARCH + name = TextParameter.TOOLS class VerbosityMapper(_VerbosityMapper): @@ -64,7 +64,7 @@ class ThinkingBudgetMapper(_ReasoningEffortMapper): TemperatureMapper(), MaxTokensMapper(), OutputSchemaMapper(), - WebSearchMapper(), + ToolsMapper(), VerbosityMapper(), ThinkingBudgetMapper(), ] diff --git a/src/celeste/modalities/text/providers/openresponses/client.py b/src/celeste/modalities/text/providers/openresponses/client.py index 59ca587..ac8c229 100644 --- a/src/celeste/modalities/text/providers/openresponses/client.py +++ b/src/celeste/modalities/text/providers/openresponses/client.py @@ -1,5 +1,6 @@ """OpenResponses text client.""" +import json from typing import Any, Unpack from celeste.parameters import ParameterMapper @@ -9,6 +10,8 @@ from celeste.protocols.openresponses.streaming import ( OpenResponsesStream as _OpenResponsesStream, ) +from celeste.protocols.openresponses.tools import serialize_messages +from celeste.tools import ToolCall from celeste.types import ImageContent, Message, TextContent, VideoContent from celeste.utils import build_image_data_url @@ -47,6 +50,24 @@ def _aggregate_event_data(self, chunks: list[TextChunk]) -> list[dict[str, Any]] events.extend(super()._aggregate_event_data(chunks)) return events + def _aggregate_tool_calls( + self, chunks: list[TextChunk], raw_events: list[dict[str, Any]] + ) -> list[ToolCall]: + """Extract tool calls from response.completed data.""" + if self._response_data is None: + return [] + return [ + ToolCall( + id=item.get("call_id", item.get("id", "")), + name=item["name"], + arguments=json.loads(item["arguments"]) + if isinstance(item.get("arguments"), str) + else item.get("arguments", {}), + ) + for item in self._response_data.get("output", []) + if item.get("type") == "function_call" + ] + class OpenResponsesTextClient(OpenResponsesMixin, TextClient): """OpenResponses text client using Responses API.""" @@ -90,7 +111,7 @@ async def analyze( def _init_request(self, inputs: TextInput) -> dict[str, Any]: """Initialize request with input content.""" if inputs.messages is not None: - return {"input": [message.model_dump() for message in inputs.messages]} + return {"input": serialize_messages(inputs.messages)} content: list[dict[str, Any]] = [] if inputs.image is not None: @@ -118,6 +139,20 @@ def _parse_content( return "" + def _parse_tool_calls(self, response_data: dict[str, Any]) -> list[ToolCall]: + """Parse tool calls from OpenResponses response.""" + return [ + ToolCall( + id=item.get("call_id", item.get("id", "")), + name=item["name"], + arguments=json.loads(item["arguments"]) + if isinstance(item.get("arguments"), str) + else item.get("arguments", {}), + ) + for item in response_data.get("output", []) + if item.get("type") == "function_call" + ] + def _stream_class(self) -> type[TextStream]: """Return the Stream class for this provider.""" return OpenResponsesTextStream diff --git a/src/celeste/modalities/text/providers/openresponses/parameters.py b/src/celeste/modalities/text/providers/openresponses/parameters.py index 378237b..b9d2f7f 100644 --- a/src/celeste/modalities/text/providers/openresponses/parameters.py +++ b/src/celeste/modalities/text/providers/openresponses/parameters.py @@ -10,6 +10,9 @@ from celeste.protocols.openresponses.parameters import ( TextFormatMapper as _TextFormatMapper, ) +from celeste.protocols.openresponses.parameters import ( + ToolsMapper as _ToolsMapper, +) from celeste.types import TextContent from ...parameters import TextParameter @@ -33,10 +36,17 @@ class OutputSchemaMapper(_TextFormatMapper): name = TextParameter.OUTPUT_SCHEMA +class ToolsMapper(_ToolsMapper): + """Map tools to Responses tools parameter.""" + + name = TextParameter.TOOLS + + OPENRESPONSES_PARAMETER_MAPPERS: list[ParameterMapper[TextContent]] = [ TemperatureMapper(), MaxTokensMapper(), OutputSchemaMapper(), + ToolsMapper(), ] __all__ = ["OPENRESPONSES_PARAMETER_MAPPERS"] diff --git a/src/celeste/modalities/text/providers/xai/client.py b/src/celeste/modalities/text/providers/xai/client.py index e67c40d..c015ab4 100644 --- a/src/celeste/modalities/text/providers/xai/client.py +++ b/src/celeste/modalities/text/providers/xai/client.py @@ -1,12 +1,15 @@ """xAI text client (modality).""" +import json from typing import Any, Unpack from celeste.parameters import ParameterMapper +from celeste.protocols.openresponses.tools import serialize_messages from celeste.providers.xai.responses.client import XAIResponsesClient from celeste.providers.xai.responses.streaming import ( XAIResponsesStream as _XAIResponsesStream, ) +from celeste.tools import ToolCall from celeste.types import ImageContent, Message, TextContent, VideoContent from celeste.utils import build_image_data_url @@ -59,7 +62,7 @@ async def analyze( def _init_request(self, inputs: TextInput) -> dict[str, Any]: """Initialize request from XAI Responses API format.""" if inputs.messages is not None: - return {"input": [message.model_dump() for message in inputs.messages]} + return {"input": serialize_messages(inputs.messages)} if inputs.image is None: return {"input": inputs.prompt or ""} @@ -90,6 +93,20 @@ def _parse_content( return "" + def _parse_tool_calls(self, response_data: dict[str, Any]) -> list[ToolCall]: + """Parse tool calls from xAI response.""" + return [ + ToolCall( + id=item.get("call_id", item.get("id", "")), + name=item["name"], + arguments=json.loads(item["arguments"]) + if isinstance(item.get("arguments"), str) + else item.get("arguments", {}), + ) + for item in response_data.get("output", []) + if item.get("type") == "function_call" + ] + def _stream_class(self) -> type[TextStream]: """Return the Stream class for this provider.""" return XAITextStream diff --git a/src/celeste/modalities/text/providers/xai/models.py b/src/celeste/modalities/text/providers/xai/models.py index e189f41..76f124e 100644 --- a/src/celeste/modalities/text/providers/xai/models.py +++ b/src/celeste/modalities/text/providers/xai/models.py @@ -1,14 +1,15 @@ """xAI models for text modality.""" from celeste.constraints import ( - Bool, Choice, ImagesConstraint, Range, Schema, + ToolSupport, ) from celeste.core import Modality, Operation, Parameter, Provider from celeste.models import Model +from celeste.tools import CodeExecution, WebSearch, XSearch from ...parameters import TextParameter @@ -22,10 +23,8 @@ parameter_constraints={ Parameter.TEMPERATURE: Range(min=0.0, max=2.0), Parameter.MAX_TOKENS: Range(min=1, max=30000), + TextParameter.TOOLS: ToolSupport(tools=[WebSearch, XSearch, CodeExecution]), TextParameter.OUTPUT_SCHEMA: Schema(), - TextParameter.WEB_SEARCH: Bool(), - TextParameter.X_SEARCH: Bool(), - TextParameter.CODE_EXECUTION: Bool(), }, ), Model( @@ -37,10 +36,8 @@ parameter_constraints={ Parameter.TEMPERATURE: Range(min=0.0, max=2.0), Parameter.MAX_TOKENS: Range(min=1, max=30000), + TextParameter.TOOLS: ToolSupport(tools=[WebSearch, XSearch, CodeExecution]), TextParameter.OUTPUT_SCHEMA: Schema(), - TextParameter.WEB_SEARCH: Bool(), - TextParameter.X_SEARCH: Bool(), - TextParameter.CODE_EXECUTION: Bool(), }, ), Model( @@ -52,10 +49,8 @@ parameter_constraints={ Parameter.TEMPERATURE: Range(min=0.0, max=2.0), Parameter.MAX_TOKENS: Range(min=1, max=30000), + TextParameter.TOOLS: ToolSupport(tools=[WebSearch, XSearch, CodeExecution]), TextParameter.OUTPUT_SCHEMA: Schema(), - TextParameter.WEB_SEARCH: Bool(), - TextParameter.X_SEARCH: Bool(), - TextParameter.CODE_EXECUTION: Bool(), }, ), Model( @@ -67,10 +62,8 @@ parameter_constraints={ Parameter.TEMPERATURE: Range(min=0.0, max=2.0), Parameter.MAX_TOKENS: Range(min=1, max=30000), + TextParameter.TOOLS: ToolSupport(tools=[WebSearch, XSearch, CodeExecution]), TextParameter.OUTPUT_SCHEMA: Schema(), - TextParameter.WEB_SEARCH: Bool(), - TextParameter.X_SEARCH: Bool(), - TextParameter.CODE_EXECUTION: Bool(), }, ), Model( @@ -82,10 +75,8 @@ parameter_constraints={ Parameter.TEMPERATURE: Range(min=0.0, max=2.0), Parameter.MAX_TOKENS: Range(min=1, max=64000), + TextParameter.TOOLS: ToolSupport(tools=[WebSearch, XSearch, CodeExecution]), TextParameter.OUTPUT_SCHEMA: Schema(), - TextParameter.WEB_SEARCH: Bool(), - TextParameter.X_SEARCH: Bool(), - TextParameter.CODE_EXECUTION: Bool(), }, ), Model( @@ -97,9 +88,7 @@ parameter_constraints={ Parameter.TEMPERATURE: Range(min=0.0, max=2.0), Parameter.MAX_TOKENS: Range(min=1, max=30000), - TextParameter.WEB_SEARCH: Bool(), - TextParameter.X_SEARCH: Bool(), - TextParameter.CODE_EXECUTION: Bool(), + TextParameter.TOOLS: ToolSupport(tools=[WebSearch, XSearch, CodeExecution]), TextParameter.OUTPUT_SCHEMA: Schema(), }, ), @@ -112,9 +101,7 @@ parameter_constraints={ Parameter.TEMPERATURE: Range(min=0.0, max=2.0), Parameter.MAX_TOKENS: Range(min=1, max=30000), - TextParameter.WEB_SEARCH: Bool(), - TextParameter.X_SEARCH: Bool(), - TextParameter.CODE_EXECUTION: Bool(), + TextParameter.TOOLS: ToolSupport(tools=[WebSearch, XSearch, CodeExecution]), TextParameter.OUTPUT_SCHEMA: Schema(), }, ), @@ -128,10 +115,8 @@ Parameter.TEMPERATURE: Range(min=0.0, max=2.0), Parameter.MAX_TOKENS: Range(min=1, max=16000), TextParameter.THINKING_BUDGET: Choice(options=["low", "high"]), + TextParameter.TOOLS: ToolSupport(tools=[]), TextParameter.OUTPUT_SCHEMA: Schema(), - TextParameter.WEB_SEARCH: Bool(), - TextParameter.X_SEARCH: Bool(), - TextParameter.CODE_EXECUTION: Bool(), }, ), Model( diff --git a/src/celeste/modalities/text/providers/xai/parameters.py b/src/celeste/modalities/text/providers/xai/parameters.py index 064d79f..9b6c6dc 100644 --- a/src/celeste/modalities/text/providers/xai/parameters.py +++ b/src/celeste/modalities/text/providers/xai/parameters.py @@ -1,9 +1,6 @@ """xAI parameter mappers for text.""" from celeste.parameters import ParameterMapper -from celeste.protocols.openresponses.parameters import ( - CodeExecutionMapper as _CodeExecutionMapper, -) from celeste.protocols.openresponses.parameters import ( MaxOutputTokensMapper as _MaxOutputTokensMapper, ) @@ -17,10 +14,7 @@ TextFormatMapper as _TextFormatMapper, ) from celeste.protocols.openresponses.parameters import ( - WebSearchMapper as _WebSearchMapper, -) -from celeste.protocols.openresponses.parameters import ( - XSearchMapper as _XSearchMapper, + ToolsMapper as _ToolsMapper, ) from celeste.types import TextContent @@ -51,22 +45,10 @@ class OutputSchemaMapper(_TextFormatMapper): name = TextParameter.OUTPUT_SCHEMA -class WebSearchMapper(_WebSearchMapper): - """Map web_search to xAI's tools parameter.""" - - name = TextParameter.WEB_SEARCH - - -class XSearchMapper(_XSearchMapper): - """Map x_search to xAI's tools parameter.""" - - name = TextParameter.X_SEARCH - - -class CodeExecutionMapper(_CodeExecutionMapper): - """Map code_execution to xAI's tools parameter.""" +class ToolsMapper(_ToolsMapper): + """Map tools to xAI's tools parameter.""" - name = TextParameter.CODE_EXECUTION + name = TextParameter.TOOLS XAI_PARAMETER_MAPPERS: list[ParameterMapper[TextContent]] = [ @@ -74,9 +56,7 @@ class CodeExecutionMapper(_CodeExecutionMapper): MaxTokensMapper(), ThinkingBudgetMapper(), OutputSchemaMapper(), - WebSearchMapper(), - XSearchMapper(), - CodeExecutionMapper(), + ToolsMapper(), ] __all__ = ["XAI_PARAMETER_MAPPERS"] diff --git a/src/celeste/protocols/chatcompletions/parameters.py b/src/celeste/protocols/chatcompletions/parameters.py index 3917311..44e15be 100644 --- a/src/celeste/protocols/chatcompletions/parameters.py +++ b/src/celeste/protocols/chatcompletions/parameters.py @@ -1,14 +1,18 @@ """Chat Completions protocol parameter mappers.""" import json -from typing import Any, get_origin +from typing import Any, ClassVar, get_origin from pydantic import BaseModel, TypeAdapter from celeste.models import Model from celeste.parameters import FieldMapper, ParameterMapper +from celeste.structured_outputs import StrictJsonSchemaGenerator +from celeste.tools import Tool, ToolMapper from celeste.types import TextContent +from .tools import TOOL_MAPPERS + class TemperatureMapper(FieldMapper[TextContent]): """Map temperature to Chat Completions temperature field.""" @@ -68,4 +72,66 @@ def parse_output(self, content: TextContent, value: object | None) -> TextConten return TypeAdapter(value).validate_python(parsed) -__all__ = ["MaxTokensMapper", "ResponseFormatMapper", "TemperatureMapper"] +class ToolsMapper(ParameterMapper[TextContent]): + """Map tools list to Chat Completions tools array. + + Subclasses override _tool_mappers with provider-specific built-in tool mappers. + """ + + _tool_mappers: ClassVar[list[ToolMapper]] = TOOL_MAPPERS + + def map( + self, + request: dict[str, Any], + value: object, + model: Model, + ) -> dict[str, Any]: + """Transform tools into provider request.""" + validated_value = self._validate_value(value, model) + if not validated_value: + return request + + dispatch = {m.tool_type: m for m in self._tool_mappers} + tools = request.setdefault("tools", []) + + for item in validated_value: + if isinstance(item, Tool): + mapper = dispatch.get(type(item)) + if mapper is None: + msg = f"Tool '{type(item).__name__}' is not supported by this provider" + raise ValueError(msg) + tools.append(mapper.map_tool(item)) + elif isinstance(item, dict) and "name" in item: + tools.append(self._map_user_tool(item)) + elif isinstance(item, dict): + tools.append(item) + + return request + + @staticmethod + def _map_user_tool(tool: dict[str, Any]) -> dict[str, Any]: + """Map a user-defined tool dict to Chat Completions function wire format.""" + params = tool.get("parameters", {}) + if isinstance(params, type) and issubclass(params, BaseModel): + schema = TypeAdapter(params).json_schema( + schema_generator=StrictJsonSchemaGenerator, + mode="serialization", + ) + else: + schema = params + + function: dict[str, Any] = {"name": tool["name"]} + if "description" in tool: + function["description"] = tool["description"] + if schema: + function["parameters"] = schema + + return {"type": "function", "function": function} + + +__all__ = [ + "MaxTokensMapper", + "ResponseFormatMapper", + "TemperatureMapper", + "ToolsMapper", +] diff --git a/src/celeste/protocols/chatcompletions/streaming.py b/src/celeste/protocols/chatcompletions/streaming.py index 95656b7..fe6740a 100644 --- a/src/celeste/protocols/chatcompletions/streaming.py +++ b/src/celeste/protocols/chatcompletions/streaming.py @@ -1,8 +1,11 @@ """Chat Completions protocol SSE parsing for streaming.""" +import contextlib +import json from typing import Any from celeste.io import FinishReason +from celeste.tools import ToolCall from .client import ChatCompletionsClient @@ -14,6 +17,8 @@ class ChatCompletionsStream: - _parse_chunk_content(event_data) - Extract content from SSE event - _parse_chunk_usage(event_data) - Extract and normalize usage from SSE event - _parse_chunk_finish_reason(event_data) - Extract finish reason from SSE event + - _parse_chunk(event_data) - Capture tool_call deltas before delegating + - _aggregate_tool_calls(chunks, raw_events) - Reconstruct tool calls from deltas - _build_stream_metadata(raw_events) - Filter content-only events Provider streams inherit this and override methods for provider-specific behavior @@ -22,6 +27,10 @@ class ChatCompletionsStream: Modality streams call super() methods which resolve to this via MRO. """ + def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: ANN401 + super().__init__(*args, **kwargs) + self._tool_call_deltas: dict[int, dict[str, Any]] = {} + def _parse_chunk_content(self, event_data: dict[str, Any]) -> str | None: """Extract content from SSE event.""" object_type = event_data.get("object") @@ -74,6 +83,44 @@ def _parse_chunk_finish_reason( return None + def _parse_chunk(self, event_data: dict[str, Any]) -> Any: # noqa: ANN401 + """Capture tool_call deltas before delegating to base _parse_chunk.""" + choices = event_data.get("choices", []) + if choices and isinstance(choices[0], dict): + delta = choices[0].get("delta", {}) + if isinstance(delta, dict): + for tc_delta in delta.get("tool_calls", []): + idx = tc_delta.get("index", 0) + if idx not in self._tool_call_deltas: + self._tool_call_deltas[idx] = { + "id": tc_delta.get("id", ""), + "name": tc_delta.get("function", {}).get("name", ""), + "arguments": "", + } + else: + if tc_delta.get("id"): + self._tool_call_deltas[idx]["id"] = tc_delta["id"] + fn = tc_delta.get("function", {}) + if fn.get("name"): + self._tool_call_deltas[idx]["name"] = fn["name"] + # Accumulate argument fragments + fn = tc_delta.get("function", {}) + self._tool_call_deltas[idx]["arguments"] += fn.get("arguments", "") + return super()._parse_chunk(event_data) # type: ignore[misc] + + def _aggregate_tool_calls( + self, chunks: list, raw_events: list[dict[str, Any]] + ) -> list[ToolCall]: + """Reconstruct tool calls from accumulated Chat Completions deltas.""" + result: list[ToolCall] = [] + for tc in self._tool_call_deltas.values(): + arguments: dict[str, Any] = {} + if tc["arguments"]: + with contextlib.suppress(json.JSONDecodeError, ValueError, TypeError): + arguments = json.loads(tc["arguments"]) + result.append(ToolCall(id=tc["id"], name=tc["name"], arguments=arguments)) + return result + def _build_stream_metadata( self, raw_events: list[dict[str, Any]] ) -> dict[str, Any]: diff --git a/src/celeste/protocols/chatcompletions/tools.py b/src/celeste/protocols/chatcompletions/tools.py new file mode 100644 index 0000000..6d81b89 --- /dev/null +++ b/src/celeste/protocols/chatcompletions/tools.py @@ -0,0 +1,84 @@ +"""Chat Completions protocol tool mappers and shared parsing helpers.""" + +import json +from typing import Any + +from celeste.tools import ToolCall, ToolMapper, ToolResult +from celeste.types import Message + +TOOL_MAPPERS: list[ToolMapper] = [] + + +def parse_tool_calls( + response_data: dict[str, Any], +) -> list[ToolCall]: + """Parse tool calls from Chat Completions API response.""" + choices = response_data.get("choices", []) + if not choices: + return [] + + message = choices[0].get("message", {}) + raw_tool_calls = message.get("tool_calls") + if not raw_tool_calls: + return [] + + tool_calls: list[ToolCall] = [] + for tc in raw_tool_calls: + raw_args = tc.get("function", {}).get("arguments") + if isinstance(raw_args, str): + try: + arguments = json.loads(raw_args) + except (json.JSONDecodeError, ValueError): + arguments = {} + else: + arguments = raw_args if isinstance(raw_args, dict) else {} + tool_calls.append( + ToolCall( + id=tc.get("id", ""), + name=tc.get("function", {}).get("name", ""), + arguments=arguments, + ) + ) + return tool_calls + + +def serialize_messages( + messages: list[Message], +) -> list[dict[str, Any]]: + """Serialize messages to Chat Completions API format.""" + items: list[dict[str, Any]] = [] + for msg in messages: + if isinstance(msg, ToolResult): + items.append( + { + "role": "tool", + "tool_call_id": msg.tool_call_id, + "content": str(msg.content), + } + ) + elif msg.role == "assistant" and msg.tool_calls: + msg_dict = msg.model_dump(exclude_none=True) + msg_dict["tool_calls"] = [ + { + "id": tc.id, + "type": "function", + "function": { + "name": tc.name, + "arguments": json.dumps(tc.arguments), + }, + } + for tc in msg.tool_calls + ] + items.append(msg_dict) + else: + msg_dict = msg.model_dump(exclude_none=True) + msg_dict.pop("tool_calls", None) + items.append(msg_dict) + return items + + +__all__ = [ + "TOOL_MAPPERS", + "parse_tool_calls", + "serialize_messages", +] diff --git a/src/celeste/protocols/openresponses/parameters.py b/src/celeste/protocols/openresponses/parameters.py index 4f7a6d3..5ccfb34 100644 --- a/src/celeste/protocols/openresponses/parameters.py +++ b/src/celeste/protocols/openresponses/parameters.py @@ -8,8 +8,11 @@ from celeste.models import Model from celeste.parameters import FieldMapper, ParameterMapper from celeste.structured_outputs import StrictJsonSchemaGenerator +from celeste.tools import Tool from celeste.types import TextContent +from .tools import TOOL_MAPPERS + class TemperatureMapper(FieldMapper[TextContent]): """Map temperature to Responses temperature field.""" @@ -128,8 +131,8 @@ def map( return request -class WebSearchMapper(ParameterMapper[TextContent]): - """Map web_search to Responses tools array.""" +class ToolsMapper(ParameterMapper[TextContent]): + """Map tools list to Responses tools array.""" def map( self, @@ -137,58 +140,53 @@ def map( value: object, model: Model, ) -> dict[str, Any]: - """Transform web_search into provider request.""" + """Transform tools into provider request.""" validated_value = self._validate_value(value, model) if not validated_value: return request - request.setdefault("tools", []).append({"type": "web_search"}) - return request - - -class XSearchMapper(ParameterMapper[TextContent]): - """Map x_search to Responses tools array (search X/Twitter).""" + dispatch = {m.tool_type: m for m in TOOL_MAPPERS} + tools = request.setdefault("tools", []) + + for item in validated_value: + if isinstance(item, Tool): + mapper = dispatch.get(type(item)) + if mapper is None: + msg = f"Tool '{type(item).__name__}' is not supported by OpenResponses" + raise ValueError(msg) + tools.append(mapper.map_tool(item)) + elif isinstance(item, dict) and "name" in item: + tools.append(self._map_user_tool(item)) + elif isinstance(item, dict): + tools.append(item) - def map( - self, - request: dict[str, Any], - value: object, - model: Model, - ) -> dict[str, Any]: - """Transform x_search into provider request.""" - validated_value = self._validate_value(value, model) - if not validated_value: - return request - - request.setdefault("tools", []).append({"type": "x_search"}) return request + @staticmethod + def _map_user_tool(tool: dict[str, Any]) -> dict[str, Any]: + """Map a user-defined tool dict to OpenResponses wire format.""" + params = tool.get("parameters", {}) + if isinstance(params, type) and issubclass(params, BaseModel): + schema = TypeAdapter(params).json_schema( + schema_generator=StrictJsonSchemaGenerator, + mode="serialization", + ) + else: + schema = params -class CodeExecutionMapper(ParameterMapper[TextContent]): - """Map code_execution to Responses tools array.""" - - def map( - self, - request: dict[str, Any], - value: object, - model: Model, - ) -> dict[str, Any]: - """Transform code_execution into provider request.""" - validated_value = self._validate_value(value, model) - if not validated_value: - return request - - request.setdefault("tools", []).append({"type": "code_execution"}) - return request + result: dict[str, Any] = {"type": "function", "name": tool["name"]} + if "description" in tool: + result["description"] = tool["description"] + if schema: + result["parameters"] = schema + return result __all__ = [ - "CodeExecutionMapper", "MaxOutputTokensMapper", "ReasoningEffortMapper", "TemperatureMapper", "TextFormatMapper", + "ToolsMapper", "VerbosityMapper", - "WebSearchMapper", - "XSearchMapper", ] diff --git a/src/celeste/protocols/openresponses/tools.py b/src/celeste/protocols/openresponses/tools.py new file mode 100644 index 0000000..ea7a838 --- /dev/null +++ b/src/celeste/protocols/openresponses/tools.py @@ -0,0 +1,127 @@ +"""OpenResponses protocol tool mappers and shared parsing helpers.""" + +import json +from typing import Any + +from celeste.tools import ( + CodeExecution, + Tool, + ToolCall, + ToolMapper, + ToolResult, + WebSearch, + XSearch, +) +from celeste.types import Message + + +class WebSearchMapper(ToolMapper): + """Map WebSearch to OpenResponses web_search wire format.""" + + tool_type = WebSearch + + def map_tool(self, tool: Tool) -> dict[str, Any]: + assert isinstance(tool, WebSearch) + result: dict[str, Any] = {"type": "web_search"} + if tool.allowed_domains is not None: + result.setdefault("filters", {})["allowed_domains"] = tool.allowed_domains + return result + + +class XSearchMapper(ToolMapper): + """Map XSearch to OpenResponses x_search wire format.""" + + tool_type = XSearch + + def map_tool(self, tool: Tool) -> dict[str, Any]: + return {"type": "x_search"} + + +class CodeExecutionMapper(ToolMapper): + """Map CodeExecution to OpenResponses code_execution wire format.""" + + tool_type = CodeExecution + + def map_tool(self, tool: Tool) -> dict[str, Any]: + return {"type": "code_execution"} + + +TOOL_MAPPERS: list[ToolMapper] = [ + WebSearchMapper(), + XSearchMapper(), + CodeExecutionMapper(), +] + + +def parse_tool_calls(response_data: dict[str, Any]) -> list[ToolCall]: + """Parse tool calls from Responses API output.""" + tool_calls: list[ToolCall] = [] + for item in response_data.get("output", []): + if item.get("type") != "function_call": + continue + raw_args = item.get("arguments") + if isinstance(raw_args, str): + try: + arguments = json.loads(raw_args) + except (json.JSONDecodeError, ValueError): + arguments = {} + else: + arguments = raw_args if isinstance(raw_args, dict) else {} + tool_calls.append( + ToolCall( + id=item.get("call_id", item.get("id", "")), + name=item.get("name", ""), + arguments=arguments, + ) + ) + return tool_calls + + +def parse_content(output: list[dict[str, Any]]) -> str: + """Extract text from Responses API output items.""" + for item in output: + if item.get("type") == "message": + for part in item.get("content", []): + if part.get("type") == "output_text": + return part.get("text") or "" + return "" + + +def serialize_messages(messages: list[Message]) -> list[dict[str, Any]]: + """Serialize messages to Responses API input format.""" + items: list[dict[str, Any]] = [] + for msg in messages: + if isinstance(msg, ToolResult): + items.append( + { + "type": "function_call_output", + "call_id": msg.tool_call_id, + "output": str(msg.content), + } + ) + elif msg.role == "assistant" and msg.tool_calls: + for tc in msg.tool_calls: + items.append( + { + "type": "function_call", + "name": tc.name, + "arguments": json.dumps(tc.arguments), + "call_id": tc.id, + } + ) + else: + msg_dict = msg.model_dump(exclude_none=True) + msg_dict.pop("tool_calls", None) + items.append(msg_dict) + return items + + +__all__ = [ + "TOOL_MAPPERS", + "CodeExecutionMapper", + "WebSearchMapper", + "XSearchMapper", + "parse_content", + "parse_tool_calls", + "serialize_messages", +] diff --git a/src/celeste/providers/anthropic/messages/parameters.py b/src/celeste/providers/anthropic/messages/parameters.py index 7146653..191ee17 100644 --- a/src/celeste/providers/anthropic/messages/parameters.py +++ b/src/celeste/providers/anthropic/messages/parameters.py @@ -8,8 +8,11 @@ from celeste.models import Model from celeste.parameters import FieldMapper, ParameterMapper from celeste.structured_outputs import StrictJsonSchemaGenerator +from celeste.tools import Tool from celeste.types import TextContent +from .tools import TOOL_MAPPERS + class TemperatureMapper(FieldMapper[TextContent]): """Map temperature to Anthropic temperature field.""" @@ -67,8 +70,8 @@ def map( return request -class WebSearchMapper(ParameterMapper[TextContent]): - """Map web_search to Anthropic tools field.""" +class ToolsMapper(ParameterMapper[TextContent]): + """Map tools list to Anthropic tools field.""" def map( self, @@ -76,19 +79,45 @@ def map( value: object, model: Model, ) -> dict[str, Any]: - """Transform web_search into provider request.""" + """Transform tools into provider request.""" validated_value = self._validate_value(value, model) if not validated_value: return request - request.setdefault("tools", []).append( - { - "type": "web_search_20250305", - "name": "web_search", - } - ) + dispatch = {m.tool_type: m for m in TOOL_MAPPERS} + tools = request.setdefault("tools", []) + + for item in validated_value: + if isinstance(item, Tool): + mapper = dispatch.get(type(item)) + if mapper is None: + msg = f"Tool '{type(item).__name__}' is not supported by Anthropic" + raise ValueError(msg) + tools.append(mapper.map_tool(item)) + elif isinstance(item, dict) and "name" in item: + tools.append(self._map_user_tool(item)) + elif isinstance(item, dict): + tools.append(item) + return request + @staticmethod + def _map_user_tool(tool: dict[str, Any]) -> dict[str, Any]: + """Map a user-defined tool dict to Anthropic wire format.""" + params = tool.get("parameters", {}) + if isinstance(params, type) and issubclass(params, BaseModel): + input_schema = TypeAdapter(params).json_schema( + schema_generator=StrictJsonSchemaGenerator, + mode="serialization", + ) + else: + input_schema = params + + result: dict[str, Any] = {"name": tool["name"], "input_schema": input_schema} + if "description" in tool: + result["description"] = tool["description"] + return result + class OutputFormatMapper(ParameterMapper[TextContent]): """Map output_schema to Anthropic output_format field. @@ -161,7 +190,7 @@ def parse_output(self, content: TextContent, value: object | None) -> TextConten "StopSequencesMapper", "TemperatureMapper", "ThinkingMapper", + "ToolsMapper", "TopKMapper", "TopPMapper", - "WebSearchMapper", ] diff --git a/src/celeste/providers/anthropic/messages/tools.py b/src/celeste/providers/anthropic/messages/tools.py new file mode 100644 index 0000000..159dad9 --- /dev/null +++ b/src/celeste/providers/anthropic/messages/tools.py @@ -0,0 +1,27 @@ +"""Anthropic Messages API tool mappers.""" + +from typing import Any + +from celeste.tools import Tool, ToolMapper, WebSearch + + +class WebSearchMapper(ToolMapper): + """Map WebSearch to Anthropic web_search_20250305 wire format.""" + + tool_type = WebSearch + + def map_tool(self, tool: Tool) -> dict[str, Any]: + assert isinstance(tool, WebSearch) + result: dict[str, Any] = {"type": "web_search_20250305", "name": "web_search"} + if tool.allowed_domains is not None: + result["allowed_domains"] = tool.allowed_domains + if tool.blocked_domains is not None: + result["blocked_domains"] = tool.blocked_domains + if tool.max_uses is not None: + result["max_uses"] = tool.max_uses + return result + + +TOOL_MAPPERS: list[ToolMapper] = [WebSearchMapper()] + +__all__ = ["TOOL_MAPPERS", "WebSearchMapper"] diff --git a/src/celeste/providers/google/generate_content/parameters.py b/src/celeste/providers/google/generate_content/parameters.py index e7e979a..5a4d93e 100644 --- a/src/celeste/providers/google/generate_content/parameters.py +++ b/src/celeste/providers/google/generate_content/parameters.py @@ -10,8 +10,11 @@ from celeste.mime_types import ApplicationMimeType from celeste.models import Model from celeste.parameters import ParameterMapper +from celeste.tools import Tool from celeste.types import TextContent +from .tools import TOOL_MAPPERS + class TemperatureMapper[Content](ParameterMapper[Content]): """Map temperature to Google generationConfig.temperature field.""" @@ -181,8 +184,8 @@ def map( return request -class WebSearchMapper[Content](ParameterMapper[Content]): - """Map web_search to Google tools field.""" +class ToolsMapper[Content](ParameterMapper[Content]): + """Map tools list to Google tools field.""" def map( self, @@ -190,14 +193,68 @@ def map( value: object, model: Model, ) -> dict[str, Any]: - """Transform web_search into provider request.""" + """Transform tools into provider request.""" validated_value = self._validate_value(value, model) if not validated_value: return request - request["tools"] = [{"google_search": {}}] + dispatch = {m.tool_type: m for m in TOOL_MAPPERS} + tools = request.setdefault("tools", []) + fn_declarations: list[dict[str, Any]] = [] + + for item in validated_value: + if isinstance(item, Tool): + mapper = dispatch.get(type(item)) + if mapper is None: + msg = f"Tool '{type(item).__name__}' is not supported by Google" + raise ValueError(msg) + tools.append(mapper.map_tool(item)) + elif isinstance(item, dict) and "name" in item: + fn_declarations.append(self._map_user_tool(item)) + elif isinstance(item, dict): + tools.append(item) + + if fn_declarations: + tools.append({"functionDeclarations": fn_declarations}) + return request + @staticmethod + def _map_user_tool(tool: dict[str, Any]) -> dict[str, Any]: + """Map a user-defined tool dict to Google functionDeclaration format.""" + params = tool.get("parameters", {}) + if isinstance(params, type) and issubclass(params, BaseModel): + schema = params.model_json_schema() + # Remove unsupported 'title' fields + schema = ToolsMapper._remove_titles(schema) + else: + schema = params + + result: dict[str, Any] = {"name": tool["name"]} + if "description" in tool: + result["description"] = tool["description"] + if schema: + result["parameters"] = schema + return result + + @staticmethod + def _remove_titles(schema: dict[str, Any]) -> dict[str, Any]: + """Remove unsupported 'title' fields from schema for Google.""" + result: dict[str, Any] = {} + for key, value in schema.items(): + if key == "title": + continue + if isinstance(value, dict): + result[key] = ToolsMapper._remove_titles(value) + elif isinstance(value, list): + result[key] = [ + ToolsMapper._remove_titles(item) if isinstance(item, dict) else item + for item in value + ] + else: + result[key] = value + return result + class ResponseJsonSchemaMapper(ParameterMapper[TextContent]): """Map output_schema to Google generationConfig.responseJsonSchema field.""" @@ -297,5 +354,5 @@ def _remove_unsupported_fields(self, schema: dict[str, Any]) -> dict[str, Any]: "TemperatureMapper", "ThinkingBudgetMapper", "ThinkingLevelMapper", - "WebSearchMapper", + "ToolsMapper", ] diff --git a/src/celeste/providers/google/generate_content/streaming.py b/src/celeste/providers/google/generate_content/streaming.py index b1f9d63..f31ad92 100644 --- a/src/celeste/providers/google/generate_content/streaming.py +++ b/src/celeste/providers/google/generate_content/streaming.py @@ -30,8 +30,12 @@ def _parse_chunk_content(self, event_data: dict[str, Any]) -> str | None: content = candidate.get("content", {}) parts = content.get("parts", []) - if parts: - return parts[0].get("text") or None + for p in parts: + if p.get("thought"): + continue + text = p.get("text") + if text is not None: + return text return None diff --git a/src/celeste/providers/google/generate_content/tools.py b/src/celeste/providers/google/generate_content/tools.py new file mode 100644 index 0000000..911990d --- /dev/null +++ b/src/celeste/providers/google/generate_content/tools.py @@ -0,0 +1,32 @@ +"""Google GenerateContent API tool mappers.""" + +from typing import Any + +from celeste.tools import CodeExecution, Tool, ToolMapper, WebSearch + + +class WebSearchMapper(ToolMapper): + """Map WebSearch to Google google_search wire format.""" + + tool_type = WebSearch + + def map_tool(self, tool: Tool) -> dict[str, Any]: + assert isinstance(tool, WebSearch) + config: dict[str, Any] = {} + if tool.blocked_domains is not None: + config["exclude_domains"] = tool.blocked_domains + return {"google_search": config} + + +class CodeExecutionMapper(ToolMapper): + """Map CodeExecution to Google code_execution wire format.""" + + tool_type = CodeExecution + + def map_tool(self, tool: Tool) -> dict[str, Any]: + return {"code_execution": {}} + + +TOOL_MAPPERS: list[ToolMapper] = [WebSearchMapper(), CodeExecutionMapper()] + +__all__ = ["TOOL_MAPPERS", "CodeExecutionMapper", "WebSearchMapper"] diff --git a/src/celeste/providers/groq/chat/tools.py b/src/celeste/providers/groq/chat/tools.py new file mode 100644 index 0000000..1cdb319 --- /dev/null +++ b/src/celeste/providers/groq/chat/tools.py @@ -0,0 +1,28 @@ +"""Groq Chat Completions tool mappers.""" + +from typing import Any + +from celeste.tools import CodeExecution, Tool, ToolMapper, WebSearch + + +class WebSearchMapper(ToolMapper): + """Map WebSearch to Groq browser_search wire format.""" + + tool_type = WebSearch + + def map_tool(self, tool: Tool) -> dict[str, Any]: + return {"type": "browser_search"} + + +class CodeExecutionMapper(ToolMapper): + """Map CodeExecution to Groq code_interpreter wire format.""" + + tool_type = CodeExecution + + def map_tool(self, tool: Tool) -> dict[str, Any]: + return {"type": "code_interpreter"} + + +TOOL_MAPPERS: list[ToolMapper] = [WebSearchMapper(), CodeExecutionMapper()] + +__all__ = ["TOOL_MAPPERS", "CodeExecutionMapper", "WebSearchMapper"] diff --git a/src/celeste/providers/moonshot/chat/tools.py b/src/celeste/providers/moonshot/chat/tools.py new file mode 100644 index 0000000..10e4855 --- /dev/null +++ b/src/celeste/providers/moonshot/chat/tools.py @@ -0,0 +1,19 @@ +"""Moonshot Chat Completions tool mappers.""" + +from typing import Any + +from celeste.tools import Tool, ToolMapper, WebSearch + + +class WebSearchMapper(ToolMapper): + """Map WebSearch to Moonshot builtin_function wire format.""" + + tool_type = WebSearch + + def map_tool(self, tool: Tool) -> dict[str, Any]: + return {"type": "builtin_function", "function": {"name": "$web_search"}} + + +TOOL_MAPPERS: list[ToolMapper] = [WebSearchMapper()] + +__all__ = ["TOOL_MAPPERS", "WebSearchMapper"] diff --git a/src/celeste/streaming.py b/src/celeste/streaming.py index 0ab6b43..e023053 100644 --- a/src/celeste/streaming.py +++ b/src/celeste/streaming.py @@ -13,6 +13,7 @@ from celeste.io import Chunk as ChunkBase from celeste.io import FinishReason, Output, Usage from celeste.parameters import Parameters +from celeste.tools import ToolCall from celeste.types import RawUsage @@ -158,8 +159,15 @@ def _parse_output(self, chunks: list[Chunk], **parameters: Unpack[Params]) -> Ou usage=self._aggregate_usage(chunks), finish_reason=self._aggregate_finish_reason(chunks), metadata=self._build_stream_metadata(raw_events), + tool_calls=self._aggregate_tool_calls(chunks, raw_events), ) + def _aggregate_tool_calls( + self, chunks: list[Chunk], raw_events: list[dict[str, Any]] + ) -> list[ToolCall]: + """Aggregate tool calls from stream events. Override in providers that support tools.""" + return [] + def _parse_chunk_usage(self, event_data: dict[str, Any]) -> RawUsage | None: """Parse usage from chunk event. Override in provider mixin.""" return None diff --git a/src/celeste/tools.py b/src/celeste/tools.py new file mode 100644 index 0000000..3d599ae --- /dev/null +++ b/src/celeste/tools.py @@ -0,0 +1,74 @@ +"""Tool calling types for Celeste.""" + +from abc import ABC, abstractmethod +from typing import Any, ClassVar + +from pydantic import BaseModel, ConfigDict + +from celeste.types import Message, Role, ToolCall + + +class Tool(BaseModel): + """Base for configurable tools. Subclass per tool type. + + Provider-specific ToolMappers translate these to wire format. + """ + + model_config = ConfigDict(frozen=True) + + +class WebSearch(Tool): + """Web search tool with unified cross-provider configuration. + + Config mapping per provider: + - allowed_domains → Anthropic: allowed_domains, OpenAI: filters.allowed_domains + - blocked_domains → Anthropic: blocked_domains, Google: exclude_domains + - max_uses → Anthropic: max_uses + """ + + allowed_domains: list[str] | None = None + blocked_domains: list[str] | None = None + max_uses: int | None = None + + +class XSearch(Tool): + """X/Twitter search tool (OpenAI/xAI only).""" + + +class CodeExecution(Tool): + """Code execution/interpreter tool.""" + + +class ToolMapper(ABC): + """Maps a single Tool type to provider wire format. + + Parallel to FieldMapper for parameters. One per tool type per provider. + """ + + tool_type: ClassVar[type[Tool]] + + @abstractmethod + def map_tool(self, tool: Tool) -> dict[str, Any]: ... + + +type ToolDefinition = Tool | dict[str, Any] + + +class ToolResult(Message): + """A tool result for multi-turn tool use.""" + + role: Role = Role.USER + tool_call_id: str + name: str | None = None + + +__all__ = [ + "CodeExecution", + "Tool", + "ToolCall", + "ToolDefinition", + "ToolMapper", + "ToolResult", + "WebSearch", + "XSearch", +] diff --git a/src/celeste/types.py b/src/celeste/types.py index 7dadd71..7ff6917 100644 --- a/src/celeste/types.py +++ b/src/celeste/types.py @@ -17,7 +17,7 @@ type VideoContent = VideoArtifact | list[VideoArtifact] type EmbeddingsContent = list[float] | list[list[float]] -type Content = str | JsonValue | dict[str, Any] | list[JsonValue | dict[str, Any]] +type Content = TextContent | ImageContent | VideoContent | AudioContent type RawUsage = dict[str, int | float | None] @@ -31,6 +31,16 @@ class Role(StrEnum): DEVELOPER = "developer" +class ToolCall(BaseModel): + """A tool call returned by the model.""" + + model_config = ConfigDict(extra="allow") + + id: str + name: str + arguments: dict[str, Any] + + class Message(BaseModel): """A message in a conversation.""" @@ -38,6 +48,7 @@ class Message(BaseModel): role: Role content: Content + tool_calls: list[ToolCall] | None = None __all__ = [ @@ -50,5 +61,6 @@ class Message(BaseModel): "RawUsage", "Role", "TextContent", + "ToolCall", "VideoContent", ] diff --git a/templates/modalities/{modality_slug}/parameters.py.template b/templates/modalities/{modality_slug}/parameters.py.template index 12b0230..a97a721 100644 --- a/templates/modalities/{modality_slug}/parameters.py.template +++ b/templates/modalities/{modality_slug}/parameters.py.template @@ -23,7 +23,7 @@ class {Modality}Parameter(StrEnum): # THINKING_BUDGET = "thinking_budget" # THINKING_LEVEL = "thinking_level" # OUTPUT_SCHEMA = "output_schema" - # WEB_SEARCH = "web_search" + # TOOLS = "tools" # VERBOSITY = "verbosity" # Media input declarations (for optional_input_types) @@ -44,7 +44,7 @@ class {Modality}Parameters(Parameters): # thinking_budget: int | str # thinking_level: str # output_schema: type[BaseModel] - # web_search: bool + # tools: list[ToolDefinition] # verbosity: str diff --git a/templates/providers/{provider_slug}/{api_slug}/parameters.py.template b/templates/providers/{provider_slug}/{api_slug}/parameters.py.template index 95e1a23..85f4389 100644 --- a/templates/providers/{provider_slug}/{api_slug}/parameters.py.template +++ b/templates/providers/{provider_slug}/{api_slug}/parameters.py.template @@ -2,7 +2,7 @@ Naming convention: - Mapper class name MUST match the provider's API parameter name -- Example: API param "web_search" → class WebSearchMapper (not SearchMapper) +- Example: API param "tools" → class ToolsMapper (not ToolMapper) - Example: API param "aspectRatio" → class AspectRatioMapper - The request key should match the provider's expected field name exactly """ diff --git a/tests/integration_tests/text/test_tools.py b/tests/integration_tests/text/test_tools.py new file mode 100644 index 0000000..a25dfc1 --- /dev/null +++ b/tests/integration_tests/text/test_tools.py @@ -0,0 +1,189 @@ +"""Integration tests for tools= parameter - WebSearch, function tools, streaming.""" + +import warnings + +# Suppress deprecation warnings from legacy capability packages +warnings.filterwarnings( + "ignore", + message=".*capability parameter is deprecated.*", + category=DeprecationWarning, +) + +import pytest # noqa: E402 + +from celeste import Modality, create_client # noqa: E402 +from celeste.modalities.text import TextChunk, TextOutput # noqa: E402 +from celeste.tools import ToolResult, WebSearch, XSearch # noqa: E402 +from celeste.types import Message, Role # noqa: E402 + +# One cheap model per provider for server-side tools (WebSearch/XSearch) +# xAI: only grok-4+ supports server-side tools +SERVER_TOOL_MODELS = [ + ("anthropic", "claude-haiku-4-5"), + ("openai", "gpt-4o-mini"), + ("google", "gemini-2.5-flash"), + ("xai", "grok-4-fast-non-reasoning"), +] + +# One cheap model per provider for function tools (user-defined) +FUNCTION_TOOL_MODELS = [ + ("anthropic", "claude-haiku-4-5"), + ("openai", "gpt-4o-mini"), + ("google", "gemini-2.5-flash"), + ("xai", "grok-3-mini"), +] + +WEATHER_TOOL = { + "name": "get_weather", + "description": "Get current weather for a city. You MUST call this tool.", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, +} + +TEST_MAX_TOKENS = 500 + + +# -- WebSearch (non-streaming) -- + + +@pytest.mark.parametrize( + ("provider", "model_id"), + SERVER_TOOL_MODELS, + ids=[f"{p}-{m}" for p, m in SERVER_TOOL_MODELS], +) +@pytest.mark.integration +@pytest.mark.asyncio +async def test_web_search(provider: str, model_id: str) -> None: + """Test WebSearch tool produces a text response across all providers.""" + client = create_client(modality=Modality.TEXT, provider=provider, model=model_id) + + output = await client.generate( + prompt="What year was Python 3.12 released?", + tools=[WebSearch()], + max_tokens=TEST_MAX_TOKENS, + ) + + assert isinstance(output, TextOutput) + assert output.content + + +# -- WebSearch (streaming) -- + + +@pytest.mark.parametrize( + ("provider", "model_id"), + SERVER_TOOL_MODELS, + ids=[f"{p}-{m}" for p, m in SERVER_TOOL_MODELS], +) +@pytest.mark.integration +@pytest.mark.asyncio +async def test_stream_web_search(provider: str, model_id: str) -> None: + """Test streaming with WebSearch tool across all providers.""" + client = create_client(modality=Modality.TEXT, provider=provider, model=model_id) + + chunks: list[TextChunk] = [] + async for chunk in client.stream.generate( + prompt="What year was Python 3.12 released?", + tools=[WebSearch()], + max_tokens=TEST_MAX_TOKENS, + ): + chunks.append(chunk) + + assert chunks + assert all(isinstance(c, TextChunk) for c in chunks) + + +# -- User-defined function tool -> ToolCall parsing -- + + +@pytest.mark.parametrize( + ("provider", "model_id"), + FUNCTION_TOOL_MODELS, + ids=[f"{p}-{m}" for p, m in FUNCTION_TOOL_MODELS], +) +@pytest.mark.integration +@pytest.mark.asyncio +async def test_function_tool_call(provider: str, model_id: str) -> None: + """Test user-defined function tool returns parsed ToolCall objects.""" + client = create_client(modality=Modality.TEXT, provider=provider, model=model_id) + + output = await client.generate( + prompt="What is the weather in Paris right now? Use the get_weather tool.", + tools=[WEATHER_TOOL], + max_tokens=TEST_MAX_TOKENS, + ) + + assert isinstance(output, TextOutput) + assert len(output.tool_calls) > 0, ( + f"{provider}/{model_id} did not return tool_calls" + ) + tc = output.tool_calls[0] + assert tc.name == "get_weather" + assert "city" in tc.arguments + + +# -- xAI-specific: XSearch -- + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_xai_x_search() -> None: + """Test XSearch tool (xAI-only) produces a text response.""" + client = create_client( + modality=Modality.TEXT, provider="xai", model="grok-4-fast-non-reasoning" + ) + + output = await client.generate( + prompt="What is trending on X right now?", + tools=[XSearch()], + max_tokens=TEST_MAX_TOKENS, + ) + + assert isinstance(output, TextOutput) + assert output.content + + +# -- Multi-turn ToolResult round-trip -- + + +@pytest.mark.parametrize( + ("provider", "model_id"), + FUNCTION_TOOL_MODELS, + ids=[f"{p}-{m}" for p, m in FUNCTION_TOOL_MODELS], +) +@pytest.mark.integration +@pytest.mark.asyncio +async def test_tool_result_round_trip(provider: str, model_id: str) -> None: + """Test full round-trip: tool call -> tool result -> final answer.""" + client = create_client(modality=Modality.TEXT, provider=provider, model=model_id) + + # Step 1: Get tool call + output1 = await client.generate( + prompt="What is the weather in Paris right now? Use the get_weather tool.", + tools=[WEATHER_TOOL], + max_tokens=TEST_MAX_TOKENS, + ) + + assert output1.tool_calls, f"{provider}/{model_id} did not return tool_calls" + tc = output1.tool_calls[0] + assert tc.name == "get_weather" + + # Step 2: Send tool result back using output.message for round-trip + output2 = await client.generate( + messages=[ + Message( + role=Role.USER, + content="What is the weather in Paris right now? Use the get_weather tool.", + ), + output1.message, + ToolResult(content="18°C, sunny", tool_call_id=tc.id, name=tc.name), + ], + tools=[WEATHER_TOOL], + max_tokens=TEST_MAX_TOKENS, + ) + + assert isinstance(output2, TextOutput) + assert output2.content, f"{provider}/{model_id} did not return final text answer" diff --git a/tests/unit_tests/test_constraints.py b/tests/unit_tests/test_constraints.py index 42abd18..839be12 100644 --- a/tests/unit_tests/test_constraints.py +++ b/tests/unit_tests/test_constraints.py @@ -17,11 +17,13 @@ Range, Schema, Str, + ToolSupport, VideoConstraint, VideosConstraint, ) from celeste.exceptions import ConstraintViolationError from celeste.mime_types import AudioMimeType, ImageMimeType, VideoMimeType +from celeste.tools import WebSearch, XSearch class TestChoice: @@ -868,3 +870,47 @@ def test_accepts_all_valid_audios(self) -> None: result = constraint(artifacts) assert result == artifacts + + +class TestToolSupport: + """Test ToolSupport constraint validation.""" + + def test_valid_tool_passes(self) -> None: + """Supported Tool instance passes validation.""" + constraint = ToolSupport(tools=[WebSearch]) + + result = constraint([WebSearch()]) + + assert len(result) == 1 + assert isinstance(result[0], WebSearch) + + def test_invalid_tool_raises(self) -> None: + """Unsupported Tool instance raises ConstraintViolationError.""" + constraint = ToolSupport(tools=[WebSearch]) + + with pytest.raises(ConstraintViolationError, match="XSearch"): + constraint([XSearch()]) + + def test_dict_passes_unchecked(self) -> None: + """User-defined tool dicts pass through without validation.""" + constraint = ToolSupport(tools=[WebSearch]) + + result = constraint([{"name": "custom_fn", "parameters": {}}]) + + assert result == [{"name": "custom_fn", "parameters": {}}] + + def test_mixed_list(self) -> None: + """Mixed list of Tool instances and dicts validates correctly.""" + constraint = ToolSupport(tools=[WebSearch]) + + result = constraint([WebSearch(), {"name": "custom"}]) + + assert len(result) == 2 + + def test_empty_list(self) -> None: + """Empty list passes validation.""" + constraint = ToolSupport(tools=[WebSearch]) + + result = constraint([]) + + assert result == []