Skip to content

Commit 2632fc1

Browse files
refactor: make _parse_content a pure parser, add Chunk type param (#203)
* refactor: make _parse_content a pure parser, add Chunk type param - Remove **parameters from _parse_content across ~28 provider clients, base abstract method, docstrings, templates, and tests - Add 5th type parameter (Chunk) to ModalityClient and all 5 modality base clients for correct mypy typing of _stream_class overrides - Eliminate 4 duplicated _transform_output overrides in audio providers by implementing parse_output() on their OutputFormatMapper mappers - Remove if-value-is-not-None guard in base _transform_output so parse_output() always runs (safe: default is no-op) Closes #202 * fix: imagen _parse_content regression for empty images When predictions exist but none contain valid image data, _parse_content now returns [] instead of ImageArtifact() sentinel, preventing _transform_output from wrapping it into [ImageArtifact()]. * fix: use dict.get pattern for audio MIME type mapping - Google Cloud TTS: replace try/except + AudioMimeType(value) with _mime_map dict.get, consistent with all other audio providers - OpenAI: remove "wav"/"pcm" from _mime_map since map() cannot send these formats to the API
1 parent 7d10721 commit 2632fc1

File tree

52 files changed

+133
-94
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+133
-94
lines changed

src/celeste/client.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
from celeste.core import Modality, Provider
1313
from celeste.exceptions import StreamingNotSupportedError
1414
from celeste.http import HTTPClient, get_http_client
15-
from celeste.io import Chunk, FinishReason, Input, Output, Usage
15+
from celeste.io import Chunk as ChunkBase
16+
from celeste.io import FinishReason, Input, Output, Usage
1617
from celeste.mime_types import ApplicationMimeType
1718
from celeste.models import Model
1819
from celeste.parameters import ParameterMapper, Parameters
@@ -130,15 +131,19 @@ def _handle_error_response(self, response: httpx.Response) -> None:
130131
super()._handle_error_response(response) # type: ignore[misc]
131132

132133

133-
class ModalityClient[In: Input, Out: Output, Params: Parameters, Content](
134-
APIMixin, BaseModel
135-
):
134+
class ModalityClient[
135+
In: Input,
136+
Out: Output,
137+
Params: Parameters,
138+
Content,
139+
Chunk: ChunkBase,
140+
](APIMixin, BaseModel):
136141
"""Base class for unified modality clients.
137142
138143
Operation methods in subclasses delegate to _predict().
139144
140145
Example:
141-
class ImagesClient(ModalityClient[ImagesInput, ImagesOutput, ImagesParameters, ImageContent]):
146+
class ImagesClient(ModalityClient[ImagesInput, ImagesOutput, ImagesParameters, ImageContent, ImageChunk]):
142147
modality = Modality.IMAGES
143148
144149
async def generate(self, prompt: str, **parameters) -> ImageGenerationOutput:
@@ -198,7 +203,7 @@ async def _predict(
198203
response_data = await self._make_request(
199204
request_body, endpoint=endpoint, extra_headers=extra_headers, **parameters
200205
)
201-
content = self._parse_content(response_data, **parameters)
206+
content = self._parse_content(response_data)
202207
content = self._transform_output(content, **parameters)
203208
return self._output_class()(
204209
content=content,
@@ -277,7 +282,6 @@ def _parse_usage(self, response_data: dict[str, Any]) -> RawUsage:
277282
def _parse_content(
278283
self,
279284
response_data: dict[str, Any],
280-
**parameters: Unpack[Params], # type: ignore[misc]
281285
) -> Content:
282286
"""Parse content from provider response."""
283287
...
@@ -384,8 +388,7 @@ def _transform_output(
384388
"""Transform content using parameter mapper output transformations."""
385389
for mapper in self.parameter_mappers():
386390
value = parameters.get(mapper.name)
387-
if value is not None:
388-
content = mapper.parse_output(content, value)
391+
content = mapper.parse_output(content, value)
389392
return content
390393

391394
@abstractmethod

src/celeste/modalities/audio/client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88
from celeste.core import Modality
99
from celeste.types import AudioContent
1010

11-
from .io import AudioFinishReason, AudioInput, AudioOutput, AudioUsage
11+
from .io import AudioChunk, AudioFinishReason, AudioInput, AudioOutput, AudioUsage
1212
from .parameters import AudioParameters
1313
from .streaming import AudioStream
1414

1515

1616
class AudioClient(
17-
ModalityClient[AudioInput, AudioOutput, AudioParameters, AudioContent]
17+
ModalityClient[AudioInput, AudioOutput, AudioParameters, AudioContent, AudioChunk]
1818
):
1919
"""Base audio client. Providers implement speak() method."""
2020

src/celeste/modalities/audio/providers/elevenlabs/client.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,18 +64,14 @@ def _init_request(self, inputs: AudioInput) -> dict[str, Any]:
6464
def _parse_content(
6565
self,
6666
response_data: dict[str, Any],
67-
**parameters: Unpack[AudioParameters],
6867
) -> AudioArtifact:
6968
"""Extract audio bytes from response."""
7069
audio_bytes = response_data.get("audio_bytes")
7170
if not audio_bytes:
7271
msg = "No audio data in response"
7372
raise ValueError(msg)
7473

75-
output_format = parameters.get("output_format")
76-
mime_type = self._map_output_format_to_mime_type(output_format)
77-
78-
return AudioArtifact(data=audio_bytes, mime_type=mime_type)
74+
return AudioArtifact(data=audio_bytes)
7975

8076
def _stream_class(self) -> type[AudioStream]:
8177
"""Return the Stream class for this provider."""

src/celeste/modalities/audio/providers/google/client.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from typing import Any, Unpack
44

55
from celeste.artifacts import AudioArtifact
6-
from celeste.mime_types import AudioMimeType
76
from celeste.parameters import ParameterMapper
87
from celeste.providers.google.cloud_tts import config
98
from celeste.providers.google.cloud_tts.client import (
@@ -16,7 +15,7 @@
1615
AudioInput,
1716
AudioOutput,
1817
)
19-
from ...parameters import AudioParameter, AudioParameters
18+
from ...parameters import AudioParameters
2019
from .parameters import GOOGLE_PARAMETER_MAPPERS
2120

2221

@@ -51,15 +50,10 @@ def _init_request(self, inputs: AudioInput) -> dict[str, Any]:
5150
def _parse_content(
5251
self,
5352
response_data: dict[str, Any],
54-
**parameters: Unpack[AudioParameters],
5553
) -> AudioArtifact:
5654
"""Extract audio bytes from response."""
5755
audio_b64 = super()._parse_content(response_data)
58-
59-
output_format = parameters.get(AudioParameter.OUTPUT_FORMAT)
60-
mime_type = AudioMimeType(output_format) if output_format else AudioMimeType.MP3
61-
62-
return AudioArtifact(data=audio_b64, mime_type=mime_type)
56+
return AudioArtifact(data=audio_b64)
6357

6458

6559
__all__ = ["GoogleAudioClient"]

src/celeste/modalities/audio/providers/gradium/client.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,18 +64,14 @@ def _init_request(self, inputs: AudioInput) -> dict[str, Any]:
6464
def _parse_content(
6565
self,
6666
response_data: dict[str, Any],
67-
**parameters: Unpack[AudioParameters],
6867
) -> AudioArtifact:
6968
"""Extract audio bytes from response."""
7069
audio_bytes = response_data.get("audio_bytes")
7170
if not audio_bytes:
7271
msg = "No audio data in response"
7372
raise ValueError(msg)
7473

75-
output_format = parameters.get("output_format")
76-
mime_type = self._map_output_format_to_mime_type(output_format)
77-
78-
return AudioArtifact(data=audio_bytes, mime_type=mime_type)
74+
return AudioArtifact(data=audio_bytes)
7975

8076
def _stream_class(self) -> type[AudioStream]:
8177
"""Return the Stream class for this provider."""

src/celeste/modalities/audio/providers/openai/client.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,19 +41,14 @@ def _init_request(self, inputs: AudioInput) -> dict[str, Any]:
4141
def _parse_content(
4242
self,
4343
response_data: dict[str, Any],
44-
**parameters: Unpack[AudioParameters],
4544
) -> AudioArtifact:
4645
"""Extract audio bytes from response."""
4746
audio_bytes = response_data.get("audio_bytes")
4847
if not audio_bytes:
4948
msg = "No audio data in response"
5049
raise ValueError(msg)
5150

52-
# Use mixin helper to determine MIME type from output_format
53-
output_format = parameters.get("output_format")
54-
mime_type = self._map_response_format_to_mime_type(output_format)
55-
56-
return AudioArtifact(data=audio_bytes, mime_type=mime_type)
51+
return AudioArtifact(data=audio_bytes)
5752

5853
def _parse_finish_reason(self, response_data: dict[str, Any]) -> AudioFinishReason:
5954
"""OpenAI TTS doesn't provide finish reasons."""

src/celeste/modalities/embeddings/client.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from celeste.types import EmbeddingsContent
1010

1111
from .io import (
12+
EmbeddingsChunk,
1213
EmbeddingsFinishReason,
1314
EmbeddingsInput,
1415
EmbeddingsOutput,
@@ -19,7 +20,11 @@
1920

2021
class EmbeddingsClient(
2122
ModalityClient[
22-
EmbeddingsInput, EmbeddingsOutput, EmbeddingsParameters, EmbeddingsContent
23+
EmbeddingsInput,
24+
EmbeddingsOutput,
25+
EmbeddingsParameters,
26+
EmbeddingsContent,
27+
EmbeddingsChunk,
2328
]
2429
):
2530
"""Base embeddings client. Providers implement operation methods."""

src/celeste/modalities/embeddings/providers/google/client.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Google embeddings client."""
22

3-
from typing import Any, Unpack
3+
from typing import Any
44

55
from celeste.parameters import ParameterMapper
66
from celeste.providers.google.embeddings.client import (
@@ -10,7 +10,6 @@
1010

1111
from ...client import EmbeddingsClient
1212
from ...io import EmbeddingsInput
13-
from ...parameters import EmbeddingsParameters
1413
from .parameters import GOOGLE_PARAMETER_MAPPERS
1514

1615

@@ -42,7 +41,6 @@ def _init_request(self, inputs: EmbeddingsInput) -> dict[str, Any]:
4241
def _parse_content(
4342
self,
4443
response_data: dict[str, Any],
45-
**parameters: Unpack[EmbeddingsParameters],
4644
) -> EmbeddingsContent:
4745
"""Parse embedding vectors from response."""
4846
return super()._parse_content(response_data)

src/celeste/modalities/images/client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@
99
from celeste.core import Modality
1010
from celeste.types import ImageContent
1111

12-
from .io import ImageFinishReason, ImageInput, ImageOutput, ImageUsage
12+
from .io import ImageChunk, ImageFinishReason, ImageInput, ImageOutput, ImageUsage
1313
from .parameters import ImageParameters
1414
from .streaming import ImagesStream
1515

1616

1717
class ImagesClient(
18-
ModalityClient[ImageInput, ImageOutput, ImageParameters, ImageContent]
18+
ModalityClient[ImageInput, ImageOutput, ImageParameters, ImageContent, ImageChunk]
1919
):
2020
"""Base images client. Providers implement generate/edit methods."""
2121

src/celeste/modalities/images/providers/bfl/client.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def _init_request(self, inputs: ImageInput) -> dict[str, Any]:
5959
def _parse_content(
6060
self,
6161
response_data: dict[str, Any],
62-
**parameters: Unpack[ImageParameters],
6362
) -> ImageArtifact:
6463
"""Parse content from response."""
6564
result = super()._parse_content(response_data)

0 commit comments

Comments
 (0)