Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions src/celeste/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from celeste.core import Modality, Provider
from celeste.exceptions import StreamingNotSupportedError
from celeste.http import HTTPClient, get_http_client
from celeste.io import Chunk, FinishReason, Input, Output, Usage
from celeste.io import Chunk as ChunkBase
from celeste.io import FinishReason, Input, Output, Usage
from celeste.mime_types import ApplicationMimeType
from celeste.models import Model
from celeste.parameters import ParameterMapper, Parameters
Expand Down Expand Up @@ -130,15 +131,19 @@ def _handle_error_response(self, response: httpx.Response) -> None:
super()._handle_error_response(response) # type: ignore[misc]


class ModalityClient[In: Input, Out: Output, Params: Parameters, Content](
APIMixin, BaseModel
):
class ModalityClient[
In: Input,
Out: Output,
Params: Parameters,
Content,
Chunk: ChunkBase,
](APIMixin, BaseModel):
"""Base class for unified modality clients.

Operation methods in subclasses delegate to _predict().

Example:
class ImagesClient(ModalityClient[ImagesInput, ImagesOutput, ImagesParameters, ImageContent]):
class ImagesClient(ModalityClient[ImagesInput, ImagesOutput, ImagesParameters, ImageContent, ImageChunk]):
modality = Modality.IMAGES

async def generate(self, prompt: str, **parameters) -> ImageGenerationOutput:
Expand Down Expand Up @@ -198,7 +203,7 @@ async def _predict(
response_data = await self._make_request(
request_body, endpoint=endpoint, extra_headers=extra_headers, **parameters
)
content = self._parse_content(response_data, **parameters)
content = self._parse_content(response_data)
content = self._transform_output(content, **parameters)
return self._output_class()(
content=content,
Expand Down Expand Up @@ -277,7 +282,6 @@ def _parse_usage(self, response_data: dict[str, Any]) -> RawUsage:
def _parse_content(
self,
response_data: dict[str, Any],
**parameters: Unpack[Params], # type: ignore[misc]
) -> Content:
"""Parse content from provider response."""
...
Expand Down Expand Up @@ -384,8 +388,7 @@ def _transform_output(
"""Transform content using parameter mapper output transformations."""
for mapper in self.parameter_mappers():
value = parameters.get(mapper.name)
if value is not None:
content = mapper.parse_output(content, value)
content = mapper.parse_output(content, value)
return content

@abstractmethod
Expand Down
4 changes: 2 additions & 2 deletions src/celeste/modalities/audio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
from celeste.core import Modality
from celeste.types import AudioContent

from .io import AudioFinishReason, AudioInput, AudioOutput, AudioUsage
from .io import AudioChunk, AudioFinishReason, AudioInput, AudioOutput, AudioUsage
from .parameters import AudioParameters
from .streaming import AudioStream


class AudioClient(
ModalityClient[AudioInput, AudioOutput, AudioParameters, AudioContent]
ModalityClient[AudioInput, AudioOutput, AudioParameters, AudioContent, AudioChunk]
):
"""Base audio client. Providers implement speak() method."""

Expand Down
6 changes: 1 addition & 5 deletions src/celeste/modalities/audio/providers/elevenlabs/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,14 @@ def _init_request(self, inputs: AudioInput) -> dict[str, Any]:
def _parse_content(
self,
response_data: dict[str, Any],
**parameters: Unpack[AudioParameters],
) -> AudioArtifact:
"""Extract audio bytes from response."""
audio_bytes = response_data.get("audio_bytes")
if not audio_bytes:
msg = "No audio data in response"
raise ValueError(msg)

output_format = parameters.get("output_format")
mime_type = self._map_output_format_to_mime_type(output_format)

return AudioArtifact(data=audio_bytes, mime_type=mime_type)
return AudioArtifact(data=audio_bytes)

def _stream_class(self) -> type[AudioStream]:
"""Return the Stream class for this provider."""
Expand Down
10 changes: 2 additions & 8 deletions src/celeste/modalities/audio/providers/google/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Any, Unpack

from celeste.artifacts import AudioArtifact
from celeste.mime_types import AudioMimeType
from celeste.parameters import ParameterMapper
from celeste.providers.google.cloud_tts import config
from celeste.providers.google.cloud_tts.client import (
Expand All @@ -16,7 +15,7 @@
AudioInput,
AudioOutput,
)
from ...parameters import AudioParameter, AudioParameters
from ...parameters import AudioParameters
from .parameters import GOOGLE_PARAMETER_MAPPERS


Expand Down Expand Up @@ -51,15 +50,10 @@ def _init_request(self, inputs: AudioInput) -> dict[str, Any]:
def _parse_content(
self,
response_data: dict[str, Any],
**parameters: Unpack[AudioParameters],
) -> AudioArtifact:
"""Extract audio bytes from response."""
audio_b64 = super()._parse_content(response_data)

output_format = parameters.get(AudioParameter.OUTPUT_FORMAT)
mime_type = AudioMimeType(output_format) if output_format else AudioMimeType.MP3

return AudioArtifact(data=audio_b64, mime_type=mime_type)
return AudioArtifact(data=audio_b64)


__all__ = ["GoogleAudioClient"]
6 changes: 1 addition & 5 deletions src/celeste/modalities/audio/providers/gradium/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,14 @@ def _init_request(self, inputs: AudioInput) -> dict[str, Any]:
def _parse_content(
self,
response_data: dict[str, Any],
**parameters: Unpack[AudioParameters],
) -> AudioArtifact:
"""Extract audio bytes from response."""
audio_bytes = response_data.get("audio_bytes")
if not audio_bytes:
msg = "No audio data in response"
raise ValueError(msg)

output_format = parameters.get("output_format")
mime_type = self._map_output_format_to_mime_type(output_format)

return AudioArtifact(data=audio_bytes, mime_type=mime_type)
return AudioArtifact(data=audio_bytes)

def _stream_class(self) -> type[AudioStream]:
"""Return the Stream class for this provider."""
Expand Down
7 changes: 1 addition & 6 deletions src/celeste/modalities/audio/providers/openai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,14 @@ def _init_request(self, inputs: AudioInput) -> dict[str, Any]:
def _parse_content(
self,
response_data: dict[str, Any],
**parameters: Unpack[AudioParameters],
) -> AudioArtifact:
"""Extract audio bytes from response."""
audio_bytes = response_data.get("audio_bytes")
if not audio_bytes:
msg = "No audio data in response"
raise ValueError(msg)

# Use mixin helper to determine MIME type from output_format
output_format = parameters.get("output_format")
mime_type = self._map_response_format_to_mime_type(output_format)

return AudioArtifact(data=audio_bytes, mime_type=mime_type)
return AudioArtifact(data=audio_bytes)

def _parse_finish_reason(self, response_data: dict[str, Any]) -> AudioFinishReason:
"""OpenAI TTS doesn't provide finish reasons."""
Expand Down
7 changes: 6 additions & 1 deletion src/celeste/modalities/embeddings/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from celeste.types import EmbeddingsContent

from .io import (
EmbeddingsChunk,
EmbeddingsFinishReason,
EmbeddingsInput,
EmbeddingsOutput,
Expand All @@ -19,7 +20,11 @@

class EmbeddingsClient(
ModalityClient[
EmbeddingsInput, EmbeddingsOutput, EmbeddingsParameters, EmbeddingsContent
EmbeddingsInput,
EmbeddingsOutput,
EmbeddingsParameters,
EmbeddingsContent,
EmbeddingsChunk,
]
):
"""Base embeddings client. Providers implement operation methods."""
Expand Down
4 changes: 1 addition & 3 deletions src/celeste/modalities/embeddings/providers/google/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Google embeddings client."""

from typing import Any, Unpack
from typing import Any

from celeste.parameters import ParameterMapper
from celeste.providers.google.embeddings.client import (
Expand All @@ -10,7 +10,6 @@

from ...client import EmbeddingsClient
from ...io import EmbeddingsInput
from ...parameters import EmbeddingsParameters
from .parameters import GOOGLE_PARAMETER_MAPPERS


Expand Down Expand Up @@ -42,7 +41,6 @@ def _init_request(self, inputs: EmbeddingsInput) -> dict[str, Any]:
def _parse_content(
self,
response_data: dict[str, Any],
**parameters: Unpack[EmbeddingsParameters],
) -> EmbeddingsContent:
"""Parse embedding vectors from response."""
return super()._parse_content(response_data)
Expand Down
4 changes: 2 additions & 2 deletions src/celeste/modalities/images/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
from celeste.core import Modality
from celeste.types import ImageContent

from .io import ImageFinishReason, ImageInput, ImageOutput, ImageUsage
from .io import ImageChunk, ImageFinishReason, ImageInput, ImageOutput, ImageUsage
from .parameters import ImageParameters
from .streaming import ImagesStream


class ImagesClient(
ModalityClient[ImageInput, ImageOutput, ImageParameters, ImageContent]
ModalityClient[ImageInput, ImageOutput, ImageParameters, ImageContent, ImageChunk]
):
"""Base images client. Providers implement generate/edit methods."""

Expand Down
1 change: 0 additions & 1 deletion src/celeste/modalities/images/providers/bfl/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def _init_request(self, inputs: ImageInput) -> dict[str, Any]:
def _parse_content(
self,
response_data: dict[str, Any],
**parameters: Unpack[ImageParameters],
) -> ImageArtifact:
"""Parse content from response."""
result = super()._parse_content(response_data)
Expand Down
1 change: 0 additions & 1 deletion src/celeste/modalities/images/providers/byteplus/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ def _init_request(self, inputs: ImageInput) -> dict[str, Any]:
def _parse_content(
self,
response_data: dict[str, Any],
**parameters: Unpack[ImageParameters],
) -> ImageArtifact:
"""Parse content from response."""
content = super()._parse_content(response_data)
Expand Down
3 changes: 1 addition & 2 deletions src/celeste/modalities/images/providers/google/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,8 @@ def _parse_usage(
def _parse_content(
self,
response_data: dict[str, Any],
**parameters: Unpack[ImageParameters],
) -> ImageContent:
return self._strategy._parse_content(response_data, **parameters) # type: ignore[union-attr]
return self._strategy._parse_content(response_data) # type: ignore[union-attr]

def _parse_finish_reason(self, response_data: dict[str, Any]) -> ImageFinishReason:
return self._strategy._parse_finish_reason(response_data) # type: ignore[union-attr]
Expand Down
1 change: 0 additions & 1 deletion src/celeste/modalities/images/providers/google/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ def _parse_usage(
def _parse_content(
self,
response_data: dict[str, Any],
**parameters: Unpack[ImageParameters],
) -> ImageContent:
"""Parse image artifacts from Gemini candidates."""
candidates = super()._parse_content(response_data)
Expand Down
26 changes: 19 additions & 7 deletions src/celeste/modalities/images/providers/google/imagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def _init_request(self, inputs: ImageInput) -> dict[str, Any]:
def _parse_content(
self,
response_data: dict[str, Any],
**parameters: Unpack[ImageParameters],
) -> ImageContent:
"""Parse image artifacts from Imagen predictions."""
predictions = super()._parse_content(response_data)
Expand All @@ -57,14 +56,27 @@ def _parse_content(
mime_type = ImageMimeType(prediction.get("mimeType", "image/png"))
images.append(ImageArtifact(data=base64_data, mime_type=mime_type))

num_images_requested = parameters.get("num_images")
if num_images_requested == 1:
return images[0] if images else ImageArtifact()
if num_images_requested is not None and num_images_requested > 1:
return images if images else []
if len(images) == 1:
return images[0]
return images if images else ImageArtifact()
return images

def _transform_output(
self,
content: ImageContent,
**parameters: Unpack[ImageParameters],
) -> ImageContent:
"""Singularize/pluralize based on num_images parameter."""
content = super()._transform_output(content, **parameters)
num_images_requested = parameters.get("num_images")
if num_images_requested == 1 and isinstance(content, list):
return content[0] if content else ImageArtifact()
if (
num_images_requested is not None
and num_images_requested > 1
and not isinstance(content, list)
):
return [content]
return content


__all__ = ["ImagenImagesClient"]
1 change: 0 additions & 1 deletion src/celeste/modalities/images/providers/ollama/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ def _parse_usage(
def _parse_content(
self,
response_data: dict[str, Any],
**parameters: Unpack[ImageParameters],
) -> ImageArtifact:
"""Parse content from response.

Expand Down
1 change: 0 additions & 1 deletion src/celeste/modalities/images/providers/openai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ async def edit(
def _parse_content(
self,
response_data: dict[str, Any],
**parameters: Unpack[ImageParameters],
) -> ImageArtifact:
"""Parse content from response."""
data = super()._parse_content(response_data)
Expand Down
1 change: 0 additions & 1 deletion src/celeste/modalities/images/providers/xai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ async def edit(
def _parse_content(
self,
response_data: dict[str, Any],
**parameters: Unpack[ImageParameters],
) -> ImageArtifact:
"""Parse content from response."""
data = super()._parse_content(response_data)
Expand Down
6 changes: 4 additions & 2 deletions src/celeste/modalities/text/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
from celeste.core import InputType, Modality
from celeste.types import AudioContent, ImageContent, Message, TextContent, VideoContent

from .io import TextFinishReason, TextInput, TextOutput, TextUsage
from .io import TextChunk, TextFinishReason, TextInput, TextOutput, TextUsage
from .parameters import TextParameters
from .streaming import TextStream


class TextClient(ModalityClient[TextInput, TextOutput, TextParameters, TextContent]):
class TextClient(
ModalityClient[TextInput, TextOutput, TextParameters, TextContent, TextChunk]
):
"""Base text client.

Providers implement operation methods (generate, analyze).
Expand Down
1 change: 0 additions & 1 deletion src/celeste/modalities/text/providers/anthropic/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ def _build_image_source(self, img: ImageArtifact) -> dict[str, Any]:
def _parse_content(
self,
response_data: dict[str, Any],
**parameters: Unpack[TextParameters],
) -> TextContent:
"""Parse content from response."""
content = super()._parse_content(response_data)
Expand Down
1 change: 0 additions & 1 deletion src/celeste/modalities/text/providers/cohere/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def _init_request(self, inputs: TextInput) -> dict[str, Any]:
def _parse_content(
self,
response_data: dict[str, Any],
**parameters: Unpack[TextParameters],
) -> TextContent:
"""Parse content from response."""
content_array = super()._parse_content(response_data)
Expand Down
1 change: 0 additions & 1 deletion src/celeste/modalities/text/providers/deepseek/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def _init_request(self, inputs: TextInput) -> dict[str, Any]:
def _parse_content(
self,
response_data: dict[str, Any],
**parameters: Unpack[TextParameters],
) -> TextContent:
"""Parse content from response."""
choices = super()._parse_content(response_data)
Expand Down
1 change: 0 additions & 1 deletion src/celeste/modalities/text/providers/google/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ def _build_audio_part(self, audio: AudioArtifact) -> dict[str, Any]:
def _parse_content(
self,
response_data: dict[str, Any],
**parameters: Unpack[TextParameters],
) -> TextContent:
"""Parse content from response."""
candidates = super()._parse_content(response_data)
Expand Down
1 change: 0 additions & 1 deletion src/celeste/modalities/text/providers/groq/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def _init_request(self, inputs: TextInput) -> dict[str, Any]:
def _parse_content(
self,
response_data: dict[str, Any],
**parameters: Unpack[TextParameters],
) -> TextContent:
"""Parse content from response."""
choices = super()._parse_content(response_data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def _init_request(self, inputs: TextInput) -> dict[str, Any]:
def _parse_content(
self,
response_data: dict[str, Any],
**parameters: Unpack[TextParameters],
) -> TextContent:
"""Parse content from response."""
choices = super()._parse_content(response_data)
Expand Down
Loading
Loading