Skip to content

Commit 4381e8a

Browse files
feat: expose extra_body parameter on all modalities (#126)
* feat: expose extra_body parameter on all modalities Add extra_body parameter to all public methods across images, audio, videos, and embeddings modalities. This allows users to pass provider-specific request fields (e.g., Google's generationConfig, imageConfig) without resorting to private methods. Updated methods: - images: generate, edit (stream, sync, sync.stream) - audio: speak (stream, sync, sync.stream) - videos: generate (sync) - embeddings: embed (async, sync) Also updated the modality client template for future modalities. Fixes #124 https://claude.ai/code/session_01KYduqFZTvWMNMBW9b1nLXF * style: format with ruff https://claude.ai/code/session_01KYduqFZTvWMNMBW9b1nLXF --------- Co-authored-by: Claude <noreply@anthropic.com>
1 parent 6d7f9d7 commit 4381e8a

6 files changed

Lines changed: 70 additions & 19 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "celeste-ai"
3-
version = "0.9.4"
3+
version = "0.9.5"
44
description = "Open source, type-safe primitives for multi-modal AI. All capabilities, all providers, one interface"
55
authors = [{name = "Kamilbenkirane", email = "kamil@withceleste.ai"}]
66
readme = "README.md"

src/celeste/modalities/audio/client.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Audio modality client."""
22

3-
from typing import Unpack
3+
from typing import Any, Unpack
44

55
from asgiref.sync import async_to_sync
66

@@ -45,13 +45,16 @@ def __init__(self, client: AudioClient) -> None:
4545
def speak(
4646
self,
4747
text: str,
48+
*,
49+
extra_body: dict[str, Any] | None = None,
4850
**parameters: Unpack[AudioParameters],
4951
) -> AudioStream:
5052
"""Stream speech generation."""
5153
inputs = AudioInput(text=text)
5254
return self._client._stream(
5355
inputs,
5456
stream_class=self._client._stream_class(),
57+
extra_body=extra_body,
5558
**parameters,
5659
)
5760

@@ -65,11 +68,15 @@ def __init__(self, client: AudioClient) -> None:
6568
def speak(
6669
self,
6770
text: str,
71+
*,
72+
extra_body: dict[str, Any] | None = None,
6873
**parameters: Unpack[AudioParameters],
6974
) -> AudioOutput:
7075
"""Blocking speech generation."""
7176
inputs = AudioInput(text=text)
72-
return async_to_sync(self._client._predict)(inputs, **parameters)
77+
return async_to_sync(self._client._predict)(
78+
inputs, extra_body=extra_body, **parameters
79+
)
7380

7481
@property
7582
def stream(self) -> "AudioSyncStreamNamespace":
@@ -86,6 +93,8 @@ def __init__(self, client: AudioClient) -> None:
8693
def speak(
8794
self,
8895
text: str,
96+
*,
97+
extra_body: dict[str, Any] | None = None,
8998
**parameters: Unpack[AudioParameters],
9099
) -> AudioStream:
91100
"""Sync streaming speech generation.
@@ -99,7 +108,7 @@ def speak(
99108
stream.output.content.save("output.mp3")
100109
"""
101110
# Return same stream as async version - __iter__/__next__ handle sync iteration
102-
return self._client.stream.speak(text, **parameters)
111+
return self._client.stream.speak(text, extra_body=extra_body, **parameters)
103112

104113

105114
__all__ = [

src/celeste/modalities/embeddings/client.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Embeddings modality client."""
22

3-
from typing import Unpack
3+
from typing import Any, Unpack
44

55
from asgiref.sync import async_to_sync
66

@@ -29,12 +29,15 @@ def _output_class(cls) -> type[EmbeddingsOutput]:
2929
async def embed(
3030
self,
3131
text: str | list[str],
32+
*,
33+
extra_body: dict[str, Any] | None = None,
3234
**parameters: Unpack[EmbeddingsParameters],
3335
) -> EmbeddingsOutput:
3436
"""Generate embeddings from text.
3537
3638
Args:
3739
text: Text to embed. Single string or list of strings.
40+
extra_body: Additional provider-specific fields to merge into request.
3841
**parameters: Embedding parameters (e.g., dimensions).
3942
4043
Returns:
@@ -43,7 +46,7 @@ async def embed(
4346
- list[list[float]] if text was a list
4447
"""
4548
inputs = EmbeddingsInput(text=text)
46-
output = await self._predict(inputs, **parameters)
49+
output = await self._predict(inputs, extra_body=extra_body, **parameters)
4750

4851
# If single text input, unwrap from batch format to single embedding
4952
if (
@@ -71,10 +74,14 @@ def __init__(self, client: EmbeddingsClient) -> None:
7174
def embed(
7275
self,
7376
text: str | list[str],
77+
*,
78+
extra_body: dict[str, Any] | None = None,
7479
**parameters: Unpack[EmbeddingsParameters],
7580
) -> EmbeddingsOutput:
7681
"""Blocking embeddings generation."""
77-
return async_to_sync(self._client.embed)(text, **parameters)
82+
return async_to_sync(self._client.embed)(
83+
text, extra_body=extra_body, **parameters
84+
)
7885

7986

8087
__all__ = [

src/celeste/modalities/images/client.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Images modality client."""
22

3-
from typing import Unpack
3+
from typing import Any, Unpack
44

55
from asgiref.sync import async_to_sync
66

@@ -49,27 +49,33 @@ def __init__(self, client: ImagesClient) -> None:
4949
def generate(
5050
self,
5151
prompt: str,
52+
*,
53+
extra_body: dict[str, Any] | None = None,
5254
**parameters: Unpack[ImageParameters],
5355
) -> ImagesStream:
5456
"""Stream image generation."""
5557
inputs = ImageInput(prompt=prompt)
5658
return self._client._stream(
5759
inputs,
5860
stream_class=self._client._stream_class(),
61+
extra_body=extra_body,
5962
**parameters,
6063
)
6164

6265
def edit(
6366
self,
6467
image: ImageArtifact,
6568
prompt: str,
69+
*,
70+
extra_body: dict[str, Any] | None = None,
6671
**parameters: Unpack[ImageParameters],
6772
) -> ImagesStream:
6873
"""Stream image editing."""
6974
inputs = ImageInput(prompt=prompt, image=image)
7075
return self._client._stream(
7176
inputs,
7277
stream_class=self._client._stream_class(),
78+
extra_body=extra_body,
7379
**parameters,
7480
)
7581

@@ -86,6 +92,8 @@ def __init__(self, client: ImagesClient) -> None:
8692
def generate(
8793
self,
8894
prompt: str,
95+
*,
96+
extra_body: dict[str, Any] | None = None,
8997
**parameters: Unpack[ImageParameters],
9098
) -> ImageOutput:
9199
"""Blocking image generation.
@@ -95,12 +103,16 @@ def generate(
95103
result.content.show()
96104
"""
97105
inputs = ImageInput(prompt=prompt)
98-
return async_to_sync(self._client._predict)(inputs, **parameters)
106+
return async_to_sync(self._client._predict)(
107+
inputs, extra_body=extra_body, **parameters
108+
)
99109

100110
def edit(
101111
self,
102112
image: ImageArtifact,
103113
prompt: str,
114+
*,
115+
extra_body: dict[str, Any] | None = None,
104116
**parameters: Unpack[ImageParameters],
105117
) -> ImageOutput:
106118
"""Blocking image edit.
@@ -110,7 +122,9 @@ def edit(
110122
result.content.show()
111123
"""
112124
inputs = ImageInput(prompt=prompt, image=image)
113-
return async_to_sync(self._client._predict)(inputs, **parameters)
125+
return async_to_sync(self._client._predict)(
126+
inputs, extra_body=extra_body, **parameters
127+
)
114128

115129
@property
116130
def stream(self) -> "ImagesSyncStreamNamespace":
@@ -127,6 +141,8 @@ def __init__(self, client: ImagesClient) -> None:
127141
def generate(
128142
self,
129143
prompt: str,
144+
*,
145+
extra_body: dict[str, Any] | None = None,
130146
**parameters: Unpack[ImageParameters],
131147
) -> ImagesStream:
132148
"""Sync streaming image generation.
@@ -140,12 +156,14 @@ def generate(
140156
print(stream.output.usage)
141157
"""
142158
# Return same stream as async version - __iter__/__next__ handle sync iteration
143-
return self._client.stream.generate(prompt, **parameters)
159+
return self._client.stream.generate(prompt, extra_body=extra_body, **parameters)
144160

145161
def edit(
146162
self,
147163
image: ImageArtifact,
148164
prompt: str,
165+
*,
166+
extra_body: dict[str, Any] | None = None,
149167
**parameters: Unpack[ImageParameters],
150168
) -> ImagesStream:
151169
"""Sync streaming image editing.
@@ -158,7 +176,9 @@ def edit(
158176
print(chunk.content)
159177
print(stream.output.usage)
160178
"""
161-
return self._client.stream.edit(image, prompt, **parameters)
179+
return self._client.stream.edit(
180+
image, prompt, extra_body=extra_body, **parameters
181+
)
162182

163183

164184
__all__ = [

src/celeste/modalities/videos/client.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Videos modality client."""
22

3-
from typing import Unpack
3+
from typing import Any, Unpack
44

55
from asgiref.sync import async_to_sync
66

@@ -42,6 +42,8 @@ def __init__(self, client: VideosClient) -> None:
4242
def generate(
4343
self,
4444
prompt: str,
45+
*,
46+
extra_body: dict[str, Any] | None = None,
4547
**parameters: Unpack[VideoParameters],
4648
) -> VideoOutput:
4749
"""Blocking video generation.
@@ -51,7 +53,9 @@ def generate(
5153
result.content.save("video.mp4")
5254
"""
5355
inputs = VideoInput(prompt=prompt)
54-
return async_to_sync(self._client._predict)(inputs, **parameters)
56+
return async_to_sync(self._client._predict)(
57+
inputs, extra_body=extra_body, **parameters
58+
)
5559

5660

5761
__all__ = [

templates/modalities/{modality_slug}/src/celeste_{modality_slug}/client.py.template

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""{Modality} modality client."""
22

3-
from typing import Unpack
3+
from typing import Any, Unpack
44

55
from asgiref.sync import async_to_sync
66

@@ -70,6 +70,8 @@ class {Modality}StreamNamespace:
7070
def generate(
7171
self,
7272
prompt: str,
73+
*,
74+
extra_body: dict[str, Any] | None = None,
7375
**parameters: Unpack[{Modality}Parameters],
7476
) -> {Modality}Stream:
7577
"""Stream {modality} generation.
@@ -82,6 +84,7 @@ class {Modality}StreamNamespace:
8284
return self._client._stream(
8385
inputs,
8486
stream_class=self._client._stream_class(),
87+
extra_body=extra_body,
8588
**parameters,
8689
)
8790

@@ -92,6 +95,7 @@ class {Modality}StreamNamespace:
9295
image: ImageContent | None = None,
9396
video: VideoContent | None = None,
9497
audio: AudioContent | None = None,
98+
extra_body: dict[str, Any] | None = None,
9599
**parameters: Unpack[{Modality}Parameters],
96100
) -> {Modality}Stream:
97101
"""Stream media analysis (image, video, or audio).
@@ -105,6 +109,7 @@ class {Modality}StreamNamespace:
105109
return self._client._stream(
106110
inputs,
107111
stream_class=self._client._stream_class(),
112+
extra_body=extra_body,
108113
**parameters,
109114
)
110115

@@ -121,6 +126,8 @@ class {Modality}SyncNamespace:
121126
def generate(
122127
self,
123128
prompt: str,
129+
*,
130+
extra_body: dict[str, Any] | None = None,
124131
**parameters: Unpack[{Modality}Parameters],
125132
) -> {Modality}Output:
126133
"""Blocking {modality} generation.
@@ -130,7 +137,7 @@ class {Modality}SyncNamespace:
130137
print(result.content)
131138
"""
132139
inputs = {Modality}Input(prompt=prompt)
133-
return async_to_sync(self._client._predict)(inputs, **parameters)
140+
return async_to_sync(self._client._predict)(inputs, extra_body=extra_body, **parameters)
134141

135142
def analyze(
136143
self,
@@ -139,6 +146,7 @@ class {Modality}SyncNamespace:
139146
image: ImageContent | None = None,
140147
video: VideoContent | None = None,
141148
audio: AudioContent | None = None,
149+
extra_body: dict[str, Any] | None = None,
142150
**parameters: Unpack[{Modality}Parameters],
143151
) -> {Modality}Output:
144152
"""Blocking media analysis (image, video, or audio).
@@ -149,7 +157,7 @@ class {Modality}SyncNamespace:
149157
"""
150158
self._client._check_media_support(image=image, video=video, audio=audio)
151159
inputs = {Modality}Input(prompt=prompt, image=image, video=video, audio=audio)
152-
return async_to_sync(self._client._predict)(inputs, **parameters)
160+
return async_to_sync(self._client._predict)(inputs, extra_body=extra_body, **parameters)
153161

154162
@property
155163
def stream(self) -> "{Modality}SyncStreamNamespace":
@@ -166,6 +174,8 @@ class {Modality}SyncStreamNamespace:
166174
def generate(
167175
self,
168176
prompt: str,
177+
*,
178+
extra_body: dict[str, Any] | None = None,
169179
**parameters: Unpack[{Modality}Parameters],
170180
) -> {Modality}Stream:
171181
"""Sync streaming {modality} generation.
@@ -179,7 +189,7 @@ class {Modality}SyncStreamNamespace:
179189
print(stream.output.usage)
180190
"""
181191
# Return same stream as async version - __iter__/__next__ handle sync iteration
182-
return self._client.stream.generate(prompt, **parameters)
192+
return self._client.stream.generate(prompt, extra_body=extra_body, **parameters)
183193

184194
def analyze(
185195
self,
@@ -188,6 +198,7 @@ class {Modality}SyncStreamNamespace:
188198
image: ImageContent | None = None,
189199
video: VideoContent | None = None,
190200
audio: AudioContent | None = None,
201+
extra_body: dict[str, Any] | None = None,
191202
**parameters: Unpack[{Modality}Parameters],
192203
) -> {Modality}Stream:
193204
"""Sync streaming media analysis (image, video, or audio).
@@ -202,7 +213,7 @@ class {Modality}SyncStreamNamespace:
202213
"""
203214
# Return same stream as async version - __iter__/__next__ handle sync iteration
204215
return self._client.stream.analyze(
205-
prompt, image=image, video=video, audio=audio, **parameters
216+
prompt, image=image, video=video, audio=audio, extra_body=extra_body, **parameters
206217
)
207218

208219

0 commit comments

Comments
 (0)