Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 34 additions & 4 deletions src/celeste/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,27 @@ def http_client(self) -> HTTPClient:
"""HTTP client with connection pooling for this provider."""
...

def _json_headers(self) -> dict[str, str]:
def _json_headers(
self, extra_headers: dict[str, str] | None = None
) -> dict[str, str]:
"""Build standard JSON request headers with auth."""
return {**self.auth.get_headers(), "Content-Type": ApplicationMimeType.JSON}
headers = {**self.auth.get_headers(), "Content-Type": ApplicationMimeType.JSON}
if extra_headers:
headers.update(extra_headers)
return headers

@staticmethod
def _merge_headers(
headers: dict[str, str],
extra_headers: dict[str, str] | None = None,
) -> dict[str, str]:
"""Merge user-provided extra headers into provider headers.

User-provided headers take precedence over provider defaults.
"""
if extra_headers:
return {**headers, **extra_headers}
return headers

@staticmethod
def _deep_merge(
Expand Down Expand Up @@ -160,6 +178,7 @@ async def _predict(
*,
endpoint: str | None = None,
extra_body: dict[str, Any] | None = None,
extra_headers: dict[str, str] | None = None,
**parameters: Unpack[Params], # type: ignore[misc]
) -> Out:
"""Generic prediction - called by operation methods.
Expand All @@ -168,6 +187,7 @@ async def _predict(
inputs: Operation-specific input object.
endpoint: Optional endpoint path (e.g., "/generations").
extra_body: Additional parameters to merge into the request body.
extra_headers: Additional headers to merge into the request headers.
**parameters: Operation-specific keyword arguments.

Returns:
Expand All @@ -176,7 +196,7 @@ async def _predict(
inputs, parameters = self._validate_artifacts(inputs, **parameters)
request_body = self._build_request(inputs, extra_body=extra_body, **parameters)
response_data = await self._make_request(
request_body, endpoint=endpoint, **parameters
request_body, endpoint=endpoint, extra_headers=extra_headers, **parameters
)
content = self._parse_content(response_data, **parameters)
content = self._transform_output(content, **parameters)
Expand All @@ -195,6 +215,7 @@ def _stream(
endpoint: str | None = None,
base_url: str | None = None,
extra_body: dict[str, Any] | None = None,
extra_headers: dict[str, str] | None = None,
**parameters: Unpack[Params], # type: ignore[misc]
) -> Stream[Out, Params, Chunk]:
"""Generic streaming - called by operation methods.
Expand All @@ -206,6 +227,7 @@ def _stream(
inputs: Operation-specific input object.
stream_class: The Stream class to instantiate.
extra_body: Additional parameters to merge into the request body.
extra_headers: Additional headers to merge into the request headers.
**parameters: Operation-specific keyword arguments.

Returns:
Expand All @@ -222,7 +244,11 @@ def _stream(
inputs, extra_body=extra_body, streaming=True, **parameters
)
sse_iterator = self._make_stream_request(
request_body, endpoint=endpoint, base_url=base_url, **parameters
request_body,
endpoint=endpoint,
base_url=base_url,
extra_headers=extra_headers,
**parameters,
)
return stream_class(
sse_iterator,
Expand Down Expand Up @@ -287,6 +313,7 @@ async def _make_request(
request_body: dict[str, Any],
*,
endpoint: str | None = None,
extra_headers: dict[str, str] | None = None,
**parameters: Unpack[Params], # type: ignore[misc]
) -> dict[str, Any]:
"""Make HTTP request(s) and return response data."""
Expand All @@ -295,6 +322,9 @@ async def _make_request(
def _make_stream_request(
self,
request_body: dict[str, Any],
*,
endpoint: str | None = None,
extra_headers: dict[str, str] | None = None,
**parameters: Unpack[Params], # type: ignore[misc]
) -> AsyncIterator[dict[str, Any]]:
"""Make HTTP streaming request and return async iterator of events."""
Expand Down
10 changes: 8 additions & 2 deletions src/celeste/modalities/audio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def speak(
text: str,
*,
extra_body: dict[str, Any] | None = None,
extra_headers: dict[str, str] | None = None,
**parameters: Unpack[AudioParameters],
) -> AudioStream:
"""Stream speech generation."""
Expand All @@ -57,6 +58,7 @@ def speak(
inputs,
stream_class=self._client._stream_class(),
extra_body=extra_body,
extra_headers=extra_headers,
**parameters,
)

Expand All @@ -72,12 +74,13 @@ def speak(
text: str,
*,
extra_body: dict[str, Any] | None = None,
extra_headers: dict[str, str] | None = None,
**parameters: Unpack[AudioParameters],
) -> AudioOutput:
"""Blocking speech generation."""
inputs = AudioInput(text=text)
return async_to_sync(self._client._predict)(
inputs, extra_body=extra_body, **parameters
inputs, extra_body=extra_body, extra_headers=extra_headers, **parameters
)

@property
Expand All @@ -97,6 +100,7 @@ def speak(
text: str,
*,
extra_body: dict[str, Any] | None = None,
extra_headers: dict[str, str] | None = None,
**parameters: Unpack[AudioParameters],
) -> AudioStream:
"""Sync streaming speech generation.
Expand All @@ -110,7 +114,9 @@ def speak(
stream.output.content.save("output.mp3")
"""
# Return same stream as async version - __iter__/__next__ handle sync iteration
return self._client.stream.speak(text, extra_body=extra_body, **parameters)
return self._client.stream.speak(
text, extra_body=extra_body, extra_headers=extra_headers, **parameters
)


__all__ = [
Expand Down
9 changes: 7 additions & 2 deletions src/celeste/modalities/embeddings/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,15 @@ async def embed(
text: str | list[str],
*,
extra_body: dict[str, Any] | None = None,
extra_headers: dict[str, str] | None = None,
**parameters: Unpack[EmbeddingsParameters],
) -> EmbeddingsOutput:
"""Generate embeddings from text.

Args:
text: Text to embed. Single string or list of strings.
extra_body: Additional provider-specific fields to merge into request.
extra_headers: Additional HTTP headers to include in the request.
**parameters: Embedding parameters (e.g., dimensions).

Returns:
Expand All @@ -53,7 +55,9 @@ async def embed(
- list[list[float]] if text was a list
"""
inputs = EmbeddingsInput(text=text)
output = await self._predict(inputs, extra_body=extra_body, **parameters)
output = await self._predict(
inputs, extra_body=extra_body, extra_headers=extra_headers, **parameters
)

# If single text input, unwrap from batch format to single embedding
if (
Expand Down Expand Up @@ -83,11 +87,12 @@ def embed(
text: str | list[str],
*,
extra_body: dict[str, Any] | None = None,
extra_headers: dict[str, str] | None = None,
**parameters: Unpack[EmbeddingsParameters],
) -> EmbeddingsOutput:
"""Blocking embeddings generation."""
return async_to_sync(self._client.embed)(
text, extra_body=extra_body, **parameters
text, extra_body=extra_body, extra_headers=extra_headers, **parameters
)


Expand Down
22 changes: 18 additions & 4 deletions src/celeste/modalities/images/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def generate(
prompt: str,
*,
extra_body: dict[str, Any] | None = None,
extra_headers: dict[str, str] | None = None,
**parameters: Unpack[ImageParameters],
) -> ImagesStream:
"""Stream image generation."""
Expand All @@ -61,6 +62,7 @@ def generate(
inputs,
stream_class=self._client._stream_class(),
extra_body=extra_body,
extra_headers=extra_headers,
**parameters,
)

Expand All @@ -70,6 +72,7 @@ def edit(
prompt: str,
*,
extra_body: dict[str, Any] | None = None,
extra_headers: dict[str, str] | None = None,
**parameters: Unpack[ImageParameters],
) -> ImagesStream:
"""Stream image editing."""
Expand All @@ -78,6 +81,7 @@ def edit(
inputs,
stream_class=self._client._stream_class(),
extra_body=extra_body,
extra_headers=extra_headers,
**parameters,
)

Expand All @@ -96,6 +100,7 @@ def generate(
prompt: str,
*,
extra_body: dict[str, Any] | None = None,
extra_headers: dict[str, str] | None = None,
**parameters: Unpack[ImageParameters],
) -> ImageOutput:
"""Blocking image generation.
Expand All @@ -106,7 +111,7 @@ def generate(
"""
inputs = ImageInput(prompt=prompt)
return async_to_sync(self._client._predict)(
inputs, extra_body=extra_body, **parameters
inputs, extra_body=extra_body, extra_headers=extra_headers, **parameters
)

def edit(
Expand All @@ -115,6 +120,7 @@ def edit(
prompt: str,
*,
extra_body: dict[str, Any] | None = None,
extra_headers: dict[str, str] | None = None,
**parameters: Unpack[ImageParameters],
) -> ImageOutput:
"""Blocking image edit.
Expand All @@ -125,7 +131,7 @@ def edit(
"""
inputs = ImageInput(prompt=prompt, image=image)
return async_to_sync(self._client._predict)(
inputs, extra_body=extra_body, **parameters
inputs, extra_body=extra_body, extra_headers=extra_headers, **parameters
)

@property
Expand All @@ -145,6 +151,7 @@ def generate(
prompt: str,
*,
extra_body: dict[str, Any] | None = None,
extra_headers: dict[str, str] | None = None,
**parameters: Unpack[ImageParameters],
) -> ImagesStream:
"""Sync streaming image generation.
Expand All @@ -158,14 +165,17 @@ def generate(
print(stream.output.usage)
"""
# Return same stream as async version - __iter__/__next__ handle sync iteration
return self._client.stream.generate(prompt, extra_body=extra_body, **parameters)
return self._client.stream.generate(
prompt, extra_body=extra_body, extra_headers=extra_headers, **parameters
)

def edit(
self,
image: ImageArtifact,
prompt: str,
*,
extra_body: dict[str, Any] | None = None,
extra_headers: dict[str, str] | None = None,
**parameters: Unpack[ImageParameters],
) -> ImagesStream:
"""Sync streaming image editing.
Expand All @@ -179,7 +189,11 @@ def edit(
print(stream.output.usage)
"""
return self._client.stream.edit(
image, prompt, extra_body=extra_body, **parameters
image,
prompt,
extra_body=extra_body,
extra_headers=extra_headers,
**parameters,
)


Expand Down
3 changes: 2 additions & 1 deletion src/celeste/modalities/images/providers/byteplus/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ async def _make_request(
request_body: dict[str, Any],
*,
endpoint: str | None = None,
extra_headers: dict[str, str] | None = None,
**parameters: Unpack[ImageParameters],
) -> dict[str, Any]:
"""Make HTTP request with parameter validation."""
Expand All @@ -164,7 +165,7 @@ async def _make_request(
raise ConstraintViolationError(msg)

return await super()._make_request(
request_body, endpoint=endpoint, **parameters
request_body, endpoint=endpoint, extra_headers=extra_headers, **parameters
)

def _stream_class(self) -> type[ImagesStream]:
Expand Down
2 changes: 2 additions & 0 deletions src/celeste/modalities/images/providers/google/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,13 @@ async def _make_request(
request_body: dict[str, Any],
*,
endpoint: str | None = None,
extra_headers: dict[str, str] | None = None,
**parameters: Unpack[ImageParameters],
) -> dict[str, Any]:
return await self._strategy._make_request( # type: ignore[union-attr]
request_body,
endpoint=endpoint,
extra_headers=extra_headers,
**parameters,
)

Expand Down
Loading
Loading