Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/celeste/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
StrictJsonSchemaGenerator,
StrictRefResolvingJsonSchemaGenerator,
)
from celeste.tools import CodeExecution, Tool, ToolCall, ToolResult, WebSearch, XSearch
from celeste.types import Content, JsonValue, Message, Role
from celeste.websocket import WebSocketClient, WebSocketConnection, close_all_ws_clients

Expand Down Expand Up @@ -246,6 +247,7 @@ def create_client(
"Authentication",
"Capability",
"ClientNotFoundError",
"CodeExecution",
"ConstraintViolationError",
"Content",
"Error",
Expand All @@ -271,15 +273,20 @@ def create_client(
"StreamingNotSupportedError",
"StrictJsonSchemaGenerator",
"StrictRefResolvingJsonSchemaGenerator",
"Tool",
"ToolCall",
"ToolResult",
"UnsupportedCapabilityError",
"UnsupportedParameterError",
"UnsupportedParameterWarning",
"UnsupportedProviderError",
"Usage",
"UsageField",
"ValidationError",
"WebSearch",
"WebSocketClient",
"WebSocketConnection",
"XSearch",
"audio",
"close_all_http_clients",
"close_all_ws_clients",
Expand Down
7 changes: 7 additions & 0 deletions src/celeste/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from celeste.models import Model
from celeste.parameters import ParameterMapper, Parameters
from celeste.streaming import Stream, enrich_stream_errors
from celeste.tools import ToolCall
from celeste.types import RawUsage


Expand Down Expand Up @@ -206,13 +207,19 @@ async def _predict(
)
content = self._parse_content(response_data)
content = self._transform_output(content, **parameters)
tool_calls = self._parse_tool_calls(response_data)
return self._output_class()(
content=content,
usage=self._get_usage(response_data),
finish_reason=self._get_finish_reason(response_data),
metadata=self._build_metadata(response_data),
tool_calls=tool_calls,
)

def _parse_tool_calls(self, response_data: dict[str, Any]) -> list[ToolCall]:
"""Parse tool calls from response. Override in providers that support tools."""
return []

def _stream(
self,
inputs: In,
Expand Down
24 changes: 23 additions & 1 deletion src/celeste/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
from abc import ABC, abstractmethod
from typing import Any, ClassVar, get_args, get_origin

from pydantic import BaseModel, Field, computed_field
from pydantic import BaseModel, Field, computed_field, field_serializer

from celeste.artifacts import AudioArtifact, ImageArtifact, VideoArtifact
from celeste.exceptions import ConstraintViolationError
from celeste.mime_types import AudioMimeType, ImageMimeType, MimeType, VideoMimeType
from celeste.tools import Tool


class Constraint(BaseModel, ABC):
Expand Down Expand Up @@ -367,6 +368,26 @@ class AudiosConstraint(_MediaListConstraint[AudioMimeType]):
_media_label = "audio"


class ToolSupport(Constraint):
"""Tool support constraint - validates Tool instances are supported by the model."""

tools: list[type[Tool]]

@field_serializer("tools")
@classmethod
def _serialize_tools(cls, v: list[type[Tool]]) -> list[str]:
return [t.__name__ for t in v]

def __call__(self, value: list) -> list:
"""Validate tools list against supported tools."""
for item in value:
if isinstance(item, Tool) and type(item) not in self.tools:
supported = [t.__name__ for t in self.tools]
msg = f"Tool '{type(item).__name__}' not supported. Supported: {supported}"
raise ConstraintViolationError(msg)
return value


__all__ = [
"AudioConstraint",
"AudiosConstraint",
Expand All @@ -382,6 +403,7 @@ class AudiosConstraint(_MediaListConstraint[AudioMimeType]):
"Range",
"Schema",
"Str",
"ToolSupport",
"VideoConstraint",
"VideosConstraint",
]
2 changes: 2 additions & 0 deletions src/celeste/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from celeste.artifacts import AudioArtifact, ImageArtifact, VideoArtifact
from celeste.constraints import Constraint
from celeste.core import InputType
from celeste.tools import ToolCall


class Input(BaseModel):
Expand Down Expand Up @@ -38,6 +39,7 @@ class Output[Content](BaseModel):
usage: Usage = Field(default_factory=Usage)
finish_reason: FinishReason | None = None
metadata: dict[str, Any] = Field(default_factory=dict)
tool_calls: list[ToolCall] = Field(default_factory=list)


class Chunk[Content](BaseModel):
Expand Down
38 changes: 37 additions & 1 deletion src/celeste/modalities/text/client.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""Text modality client."""

from typing import Any, Unpack
import warnings
from typing import Any, ClassVar, Unpack

from asgiref.sync import async_to_sync

from celeste.client import ModalityClient
from celeste.core import InputType, Modality
from celeste.tools import CodeExecution, WebSearch, XSearch
from celeste.types import AudioContent, ImageContent, Message, TextContent, VideoContent

from .io import TextChunk, TextFinishReason, TextInput, TextOutput, TextUsage
Expand All @@ -25,11 +27,45 @@ class TextClient(
_usage_class = TextUsage
_finish_reason_class = TextFinishReason

# Deprecated param → Tool class mapping.
# TODO(deprecation): Remove on 2026-06-07.
_DEPRECATED_TOOL_PARAMS: ClassVar[dict[str, type]] = {
"web_search": WebSearch,
"x_search": XSearch,
"code_execution": CodeExecution,
}

@classmethod
def _output_class(cls) -> type[TextOutput]:
"""Return the Output class for text modality."""
return TextOutput

def _build_request(
self,
inputs: TextInput,
extra_body: dict[str, Any] | None = None,
streaming: bool = False,
**parameters: Unpack[TextParameters],
) -> dict[str, Any]:
"""Build request, migrating deprecated boolean tool params first.

TODO(deprecation): Remove this override on 2026-06-07.
"""
for old_param, tool_cls in self._DEPRECATED_TOOL_PARAMS.items():
value = parameters.pop(old_param, None) # type: ignore[misc]
if value:
warnings.warn(
f"'{old_param}=True' is deprecated, "
f"use tools=[{tool_cls.__name__}()] instead. "
"Will be removed on 2026-06-07.",
DeprecationWarning,
stacklevel=4,
)
parameters.setdefault("tools", []).append(tool_cls())
return super()._build_request(
inputs, extra_body=extra_body, streaming=streaming, **parameters
)

def _check_media_support(
self,
image: ImageContent | None,
Expand Down
21 changes: 19 additions & 2 deletions src/celeste/modalities/text/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,22 @@
from pydantic import Field

from celeste.io import Chunk, FinishReason, Input, Output, Usage
from celeste.types import AudioContent, ImageContent, Message, TextContent, VideoContent
from celeste.tools import ToolResult
from celeste.types import (
AudioContent,
ImageContent,
Message,
Role,
TextContent,
VideoContent,
)


class TextInput(Input):
"""Input for text operations."""

prompt: str | None = None
messages: list[Message] | None = None
messages: list[ToolResult | Message] | None = None
text: str | list[str] | None = None
image: ImageContent | None = None
video: VideoContent | None = None
Expand Down Expand Up @@ -46,6 +54,15 @@ class TextOutput(Output[TextContent]):
usage: TextUsage = Field(default_factory=TextUsage)
finish_reason: TextFinishReason | None = None

@property
def message(self) -> Message:
"""The assistant message for multi-turn conversations."""
return Message(
role=Role.ASSISTANT,
content=self.content,
tool_calls=self.tool_calls if self.tool_calls else None,
)


class TextChunk(Chunk[str]):
"""Chunk for text streaming."""
Expand Down
13 changes: 11 additions & 2 deletions src/celeste/modalities/text/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pydantic import BaseModel

from celeste.parameters import Parameters
from celeste.tools import ToolDefinition


class TextParameter(StrEnum):
Expand All @@ -23,8 +24,12 @@ class TextParameter(StrEnum):
THINKING_BUDGET = "thinking_budget"
THINKING_LEVEL = "thinking_level"
OUTPUT_SCHEMA = "output_schema"
WEB_SEARCH = "web_search"
TOOLS = "tools"
VERBOSITY = "verbosity"

# Deprecated: use tools=[WebSearch()], tools=[XSearch()], tools=[CodeExecution()] instead.
# TODO(deprecation): Remove on 2026-06-07.
WEB_SEARCH = "web_search"
X_SEARCH = "x_search"
CODE_EXECUTION = "code_execution"

Expand All @@ -46,8 +51,12 @@ class TextParameters(Parameters):
thinking_budget: int | str
thinking_level: str
output_schema: type[BaseModel]
web_search: bool
tools: list[ToolDefinition]
verbosity: str

# Deprecated: use tools=[WebSearch()], tools=[XSearch()], tools=[CodeExecution()] instead.
# TODO(deprecation): Remove on 2026-06-07.
web_search: bool
x_search: bool
code_execution: bool

Expand Down
1 change: 1 addition & 0 deletions src/celeste/modalities/text/protocols/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Text modality protocol implementations."""
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Chat Completions protocol for text modality."""

from .client import ChatCompletionsTextClient, ChatCompletionsTextStream

__all__ = ["ChatCompletionsTextClient", "ChatCompletionsTextStream"]
101 changes: 101 additions & 0 deletions src/celeste/modalities/text/protocols/chatcompletions/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""Chat Completions text client."""

from typing import Any, Unpack

from celeste.parameters import ParameterMapper
from celeste.protocols.chatcompletions.client import (
ChatCompletionsClient as ChatCompletionsMixin,
)
from celeste.protocols.chatcompletions.streaming import (
ChatCompletionsStream as _ChatCompletionsStream,
)
from celeste.protocols.chatcompletions.tools import (
parse_tool_calls,
serialize_messages,
)
from celeste.tools import ToolCall
from celeste.types import ImageContent, Message, TextContent, VideoContent
from celeste.utils import build_image_data_url

from ...client import TextClient
from ...io import (
TextInput,
TextOutput,
)
from ...parameters import TextParameters
from ...streaming import TextStream
from .parameters import CHATCOMPLETIONS_PARAMETER_MAPPERS


class ChatCompletionsTextStream(_ChatCompletionsStream, TextStream):
"""Chat Completions streaming for text modality."""


class ChatCompletionsTextClient(ChatCompletionsMixin, TextClient):
"""Chat Completions text client."""

@classmethod
def parameter_mappers(cls) -> list[ParameterMapper[TextContent]]:
return CHATCOMPLETIONS_PARAMETER_MAPPERS

async def generate(
self,
prompt: str | None = None,
*,
messages: list[Message] | None = None,
**parameters: Unpack[TextParameters],
) -> TextOutput:
"""Generate text from prompt."""
inputs = TextInput(prompt=prompt, messages=messages)
return await self._predict(inputs, **parameters)

async def analyze(
self,
prompt: str | None = None,
*,
messages: list[Message] | None = None,
image: ImageContent | None = None,
video: VideoContent | None = None,
**parameters: Unpack[TextParameters],
) -> TextOutput:
"""Analyze image(s) or video(s) with prompt or messages."""
inputs = TextInput(prompt=prompt, messages=messages, image=image, video=video)
return await self._predict(inputs, **parameters)

def _init_request(self, inputs: TextInput) -> dict[str, Any]:
"""Initialize request with Chat Completions message format."""
if inputs.messages is not None:
return {"messages": serialize_messages(inputs.messages)}

if inputs.image is None:
content: str | list[dict[str, Any]] = inputs.prompt or ""
else:
images = inputs.image if isinstance(inputs.image, list) else [inputs.image]
content = [
{"type": "image_url", "image_url": {"url": build_image_data_url(img)}}
for img in images
]
content.append({"type": "text", "text": inputs.prompt or ""})

return {"messages": [{"role": "user", "content": content}]}

def _parse_content(
self,
response_data: dict[str, Any],
) -> TextContent:
"""Parse text content from response."""
choices = super()._parse_content(response_data)
message = choices[0].get("message", {})
content = message.get("content") or ""
return content

def _parse_tool_calls(self, response_data: dict[str, Any]) -> list[ToolCall]:
"""Parse tool calls from Chat Completions response."""
return parse_tool_calls(response_data)

def _stream_class(self) -> type[TextStream]:
"""Return the Stream class for this provider."""
return ChatCompletionsTextStream


__all__ = ["ChatCompletionsTextClient", "ChatCompletionsTextStream"]
Loading
Loading