Skip to content

Commit 10133e6

Browse files
feat: add extra_headers support to all client methods (#193)
* feat: add extra_headers support to all client methods (#178) Add `extra_headers: dict[str, str] | None = None` alongside `extra_body` on all client methods, enabling users to pass provider-specific HTTP headers (e.g., Anthropic's 1M context beta header). - Extend `_json_headers()` to accept and merge extra headers - Add `_merge_headers()` static helper for edge cases (WebSocket, multipart) - Thread `extra_headers` through `_predict()` → `_make_request()` and `_stream()` → `_make_stream_request()` across all providers - Update all 3 templates and the template contract test Closes #178 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: improve extra_headers consistency across all providers - Make _merge_headers return a new dict instead of mutating input - Add extra_headers to base _make_stream_request stub signature - Fix OpenAI videos multipart sending wrong Content-Type - Add extra_headers to Google _get_interaction and download_content - Simplify BFL header construction to use _json_headers(extra_headers) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent eb0771d commit 10133e6

File tree

33 files changed

+254
-81
lines changed

33 files changed

+254
-81
lines changed

src/celeste/client.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,27 @@ def http_client(self) -> HTTPClient:
5353
"""HTTP client with connection pooling for this provider."""
5454
...
5555

56-
def _json_headers(self) -> dict[str, str]:
56+
def _json_headers(
57+
self, extra_headers: dict[str, str] | None = None
58+
) -> dict[str, str]:
5759
"""Build standard JSON request headers with auth."""
58-
return {**self.auth.get_headers(), "Content-Type": ApplicationMimeType.JSON}
60+
headers = {**self.auth.get_headers(), "Content-Type": ApplicationMimeType.JSON}
61+
if extra_headers:
62+
headers.update(extra_headers)
63+
return headers
64+
65+
@staticmethod
66+
def _merge_headers(
67+
headers: dict[str, str],
68+
extra_headers: dict[str, str] | None = None,
69+
) -> dict[str, str]:
70+
"""Merge user-provided extra headers into provider headers.
71+
72+
User-provided headers take precedence over provider defaults.
73+
"""
74+
if extra_headers:
75+
return {**headers, **extra_headers}
76+
return headers
5977

6078
@staticmethod
6179
def _deep_merge(
@@ -160,6 +178,7 @@ async def _predict(
160178
*,
161179
endpoint: str | None = None,
162180
extra_body: dict[str, Any] | None = None,
181+
extra_headers: dict[str, str] | None = None,
163182
**parameters: Unpack[Params], # type: ignore[misc]
164183
) -> Out:
165184
"""Generic prediction - called by operation methods.
@@ -168,6 +187,7 @@ async def _predict(
168187
inputs: Operation-specific input object.
169188
endpoint: Optional endpoint path (e.g., "/generations").
170189
extra_body: Additional parameters to merge into the request body.
190+
extra_headers: Additional headers to merge into the request headers.
171191
**parameters: Operation-specific keyword arguments.
172192
173193
Returns:
@@ -176,7 +196,7 @@ async def _predict(
176196
inputs, parameters = self._validate_artifacts(inputs, **parameters)
177197
request_body = self._build_request(inputs, extra_body=extra_body, **parameters)
178198
response_data = await self._make_request(
179-
request_body, endpoint=endpoint, **parameters
199+
request_body, endpoint=endpoint, extra_headers=extra_headers, **parameters
180200
)
181201
content = self._parse_content(response_data, **parameters)
182202
content = self._transform_output(content, **parameters)
@@ -195,6 +215,7 @@ def _stream(
195215
endpoint: str | None = None,
196216
base_url: str | None = None,
197217
extra_body: dict[str, Any] | None = None,
218+
extra_headers: dict[str, str] | None = None,
198219
**parameters: Unpack[Params], # type: ignore[misc]
199220
) -> Stream[Out, Params, Chunk]:
200221
"""Generic streaming - called by operation methods.
@@ -206,6 +227,7 @@ def _stream(
206227
inputs: Operation-specific input object.
207228
stream_class: The Stream class to instantiate.
208229
extra_body: Additional parameters to merge into the request body.
230+
extra_headers: Additional headers to merge into the request headers.
209231
**parameters: Operation-specific keyword arguments.
210232
211233
Returns:
@@ -222,7 +244,11 @@ def _stream(
222244
inputs, extra_body=extra_body, streaming=True, **parameters
223245
)
224246
sse_iterator = self._make_stream_request(
225-
request_body, endpoint=endpoint, base_url=base_url, **parameters
247+
request_body,
248+
endpoint=endpoint,
249+
base_url=base_url,
250+
extra_headers=extra_headers,
251+
**parameters,
226252
)
227253
return stream_class(
228254
sse_iterator,
@@ -287,6 +313,7 @@ async def _make_request(
287313
request_body: dict[str, Any],
288314
*,
289315
endpoint: str | None = None,
316+
extra_headers: dict[str, str] | None = None,
290317
**parameters: Unpack[Params], # type: ignore[misc]
291318
) -> dict[str, Any]:
292319
"""Make HTTP request(s) and return response data."""
@@ -295,6 +322,9 @@ async def _make_request(
295322
def _make_stream_request(
296323
self,
297324
request_body: dict[str, Any],
325+
*,
326+
endpoint: str | None = None,
327+
extra_headers: dict[str, str] | None = None,
298328
**parameters: Unpack[Params], # type: ignore[misc]
299329
) -> AsyncIterator[dict[str, Any]]:
300330
"""Make HTTP streaming request and return async iterator of events."""

src/celeste/modalities/audio/client.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def speak(
4949
text: str,
5050
*,
5151
extra_body: dict[str, Any] | None = None,
52+
extra_headers: dict[str, str] | None = None,
5253
**parameters: Unpack[AudioParameters],
5354
) -> AudioStream:
5455
"""Stream speech generation."""
@@ -57,6 +58,7 @@ def speak(
5758
inputs,
5859
stream_class=self._client._stream_class(),
5960
extra_body=extra_body,
61+
extra_headers=extra_headers,
6062
**parameters,
6163
)
6264

@@ -72,12 +74,13 @@ def speak(
7274
text: str,
7375
*,
7476
extra_body: dict[str, Any] | None = None,
77+
extra_headers: dict[str, str] | None = None,
7578
**parameters: Unpack[AudioParameters],
7679
) -> AudioOutput:
7780
"""Blocking speech generation."""
7881
inputs = AudioInput(text=text)
7982
return async_to_sync(self._client._predict)(
80-
inputs, extra_body=extra_body, **parameters
83+
inputs, extra_body=extra_body, extra_headers=extra_headers, **parameters
8184
)
8285

8386
@property
@@ -97,6 +100,7 @@ def speak(
97100
text: str,
98101
*,
99102
extra_body: dict[str, Any] | None = None,
103+
extra_headers: dict[str, str] | None = None,
100104
**parameters: Unpack[AudioParameters],
101105
) -> AudioStream:
102106
"""Sync streaming speech generation.
@@ -110,7 +114,9 @@ def speak(
110114
stream.output.content.save("output.mp3")
111115
"""
112116
# Return same stream as async version - __iter__/__next__ handle sync iteration
113-
return self._client.stream.speak(text, extra_body=extra_body, **parameters)
117+
return self._client.stream.speak(
118+
text, extra_body=extra_body, extra_headers=extra_headers, **parameters
119+
)
114120

115121

116122
__all__ = [

src/celeste/modalities/embeddings/client.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,15 @@ async def embed(
3838
text: str | list[str],
3939
*,
4040
extra_body: dict[str, Any] | None = None,
41+
extra_headers: dict[str, str] | None = None,
4142
**parameters: Unpack[EmbeddingsParameters],
4243
) -> EmbeddingsOutput:
4344
"""Generate embeddings from text.
4445
4546
Args:
4647
text: Text to embed. Single string or list of strings.
4748
extra_body: Additional provider-specific fields to merge into request.
49+
extra_headers: Additional HTTP headers to include in the request.
4850
**parameters: Embedding parameters (e.g., dimensions).
4951
5052
Returns:
@@ -53,7 +55,9 @@ async def embed(
5355
- list[list[float]] if text was a list
5456
"""
5557
inputs = EmbeddingsInput(text=text)
56-
output = await self._predict(inputs, extra_body=extra_body, **parameters)
58+
output = await self._predict(
59+
inputs, extra_body=extra_body, extra_headers=extra_headers, **parameters
60+
)
5761

5862
# If single text input, unwrap from batch format to single embedding
5963
if (
@@ -83,11 +87,12 @@ def embed(
8387
text: str | list[str],
8488
*,
8589
extra_body: dict[str, Any] | None = None,
90+
extra_headers: dict[str, str] | None = None,
8691
**parameters: Unpack[EmbeddingsParameters],
8792
) -> EmbeddingsOutput:
8893
"""Blocking embeddings generation."""
8994
return async_to_sync(self._client.embed)(
90-
text, extra_body=extra_body, **parameters
95+
text, extra_body=extra_body, extra_headers=extra_headers, **parameters
9196
)
9297

9398

src/celeste/modalities/images/client.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def generate(
5353
prompt: str,
5454
*,
5555
extra_body: dict[str, Any] | None = None,
56+
extra_headers: dict[str, str] | None = None,
5657
**parameters: Unpack[ImageParameters],
5758
) -> ImagesStream:
5859
"""Stream image generation."""
@@ -61,6 +62,7 @@ def generate(
6162
inputs,
6263
stream_class=self._client._stream_class(),
6364
extra_body=extra_body,
65+
extra_headers=extra_headers,
6466
**parameters,
6567
)
6668

@@ -70,6 +72,7 @@ def edit(
7072
prompt: str,
7173
*,
7274
extra_body: dict[str, Any] | None = None,
75+
extra_headers: dict[str, str] | None = None,
7376
**parameters: Unpack[ImageParameters],
7477
) -> ImagesStream:
7578
"""Stream image editing."""
@@ -78,6 +81,7 @@ def edit(
7881
inputs,
7982
stream_class=self._client._stream_class(),
8083
extra_body=extra_body,
84+
extra_headers=extra_headers,
8185
**parameters,
8286
)
8387

@@ -96,6 +100,7 @@ def generate(
96100
prompt: str,
97101
*,
98102
extra_body: dict[str, Any] | None = None,
103+
extra_headers: dict[str, str] | None = None,
99104
**parameters: Unpack[ImageParameters],
100105
) -> ImageOutput:
101106
"""Blocking image generation.
@@ -106,7 +111,7 @@ def generate(
106111
"""
107112
inputs = ImageInput(prompt=prompt)
108113
return async_to_sync(self._client._predict)(
109-
inputs, extra_body=extra_body, **parameters
114+
inputs, extra_body=extra_body, extra_headers=extra_headers, **parameters
110115
)
111116

112117
def edit(
@@ -115,6 +120,7 @@ def edit(
115120
prompt: str,
116121
*,
117122
extra_body: dict[str, Any] | None = None,
123+
extra_headers: dict[str, str] | None = None,
118124
**parameters: Unpack[ImageParameters],
119125
) -> ImageOutput:
120126
"""Blocking image edit.
@@ -125,7 +131,7 @@ def edit(
125131
"""
126132
inputs = ImageInput(prompt=prompt, image=image)
127133
return async_to_sync(self._client._predict)(
128-
inputs, extra_body=extra_body, **parameters
134+
inputs, extra_body=extra_body, extra_headers=extra_headers, **parameters
129135
)
130136

131137
@property
@@ -145,6 +151,7 @@ def generate(
145151
prompt: str,
146152
*,
147153
extra_body: dict[str, Any] | None = None,
154+
extra_headers: dict[str, str] | None = None,
148155
**parameters: Unpack[ImageParameters],
149156
) -> ImagesStream:
150157
"""Sync streaming image generation.
@@ -158,14 +165,17 @@ def generate(
158165
print(stream.output.usage)
159166
"""
160167
# Return same stream as async version - __iter__/__next__ handle sync iteration
161-
return self._client.stream.generate(prompt, extra_body=extra_body, **parameters)
168+
return self._client.stream.generate(
169+
prompt, extra_body=extra_body, extra_headers=extra_headers, **parameters
170+
)
162171

163172
def edit(
164173
self,
165174
image: ImageArtifact,
166175
prompt: str,
167176
*,
168177
extra_body: dict[str, Any] | None = None,
178+
extra_headers: dict[str, str] | None = None,
169179
**parameters: Unpack[ImageParameters],
170180
) -> ImagesStream:
171181
"""Sync streaming image editing.
@@ -179,7 +189,11 @@ def edit(
179189
print(stream.output.usage)
180190
"""
181191
return self._client.stream.edit(
182-
image, prompt, extra_body=extra_body, **parameters
192+
image,
193+
prompt,
194+
extra_body=extra_body,
195+
extra_headers=extra_headers,
196+
**parameters,
183197
)
184198

185199

src/celeste/modalities/images/providers/byteplus/client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ async def _make_request(
149149
request_body: dict[str, Any],
150150
*,
151151
endpoint: str | None = None,
152+
extra_headers: dict[str, str] | None = None,
152153
**parameters: Unpack[ImageParameters],
153154
) -> dict[str, Any]:
154155
"""Make HTTP request with parameter validation."""
@@ -164,7 +165,7 @@ async def _make_request(
164165
raise ConstraintViolationError(msg)
165166

166167
return await super()._make_request(
167-
request_body, endpoint=endpoint, **parameters
168+
request_body, endpoint=endpoint, extra_headers=extra_headers, **parameters
168169
)
169170

170171
def _stream_class(self) -> type[ImagesStream]:

src/celeste/modalities/images/providers/google/client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,13 @@ async def _make_request(
101101
request_body: dict[str, Any],
102102
*,
103103
endpoint: str | None = None,
104+
extra_headers: dict[str, str] | None = None,
104105
**parameters: Unpack[ImageParameters],
105106
) -> dict[str, Any]:
106107
return await self._strategy._make_request( # type: ignore[union-attr]
107108
request_body,
108109
endpoint=endpoint,
110+
extra_headers=extra_headers,
109111
**parameters,
110112
)
111113

0 commit comments

Comments
 (0)