diff --git a/src/celeste/client.py b/src/celeste/client.py index 7cd5aea..e79ee9b 100644 --- a/src/celeste/client.py +++ b/src/celeste/client.py @@ -53,9 +53,27 @@ def http_client(self) -> HTTPClient: """HTTP client with connection pooling for this provider.""" ... - def _json_headers(self) -> dict[str, str]: + def _json_headers( + self, extra_headers: dict[str, str] | None = None + ) -> dict[str, str]: """Build standard JSON request headers with auth.""" - return {**self.auth.get_headers(), "Content-Type": ApplicationMimeType.JSON} + headers = {**self.auth.get_headers(), "Content-Type": ApplicationMimeType.JSON} + if extra_headers: + headers.update(extra_headers) + return headers + + @staticmethod + def _merge_headers( + headers: dict[str, str], + extra_headers: dict[str, str] | None = None, + ) -> dict[str, str]: + """Merge user-provided extra headers into provider headers. + + User-provided headers take precedence over provider defaults. + """ + if extra_headers: + return {**headers, **extra_headers} + return headers @staticmethod def _deep_merge( @@ -160,6 +178,7 @@ async def _predict( *, endpoint: str | None = None, extra_body: dict[str, Any] | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Unpack[Params], # type: ignore[misc] ) -> Out: """Generic prediction - called by operation methods. @@ -168,6 +187,7 @@ async def _predict( inputs: Operation-specific input object. endpoint: Optional endpoint path (e.g., "/generations"). extra_body: Additional parameters to merge into the request body. + extra_headers: Additional headers to merge into the request headers. **parameters: Operation-specific keyword arguments. Returns: @@ -176,7 +196,7 @@ async def _predict( inputs, parameters = self._validate_artifacts(inputs, **parameters) request_body = self._build_request(inputs, extra_body=extra_body, **parameters) response_data = await self._make_request( - request_body, endpoint=endpoint, **parameters + request_body, endpoint=endpoint, extra_headers=extra_headers, **parameters ) content = self._parse_content(response_data, **parameters) content = self._transform_output(content, **parameters) @@ -195,6 +215,7 @@ def _stream( endpoint: str | None = None, base_url: str | None = None, extra_body: dict[str, Any] | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Unpack[Params], # type: ignore[misc] ) -> Stream[Out, Params, Chunk]: """Generic streaming - called by operation methods. @@ -206,6 +227,7 @@ def _stream( inputs: Operation-specific input object. stream_class: The Stream class to instantiate. extra_body: Additional parameters to merge into the request body. + extra_headers: Additional headers to merge into the request headers. **parameters: Operation-specific keyword arguments. Returns: @@ -222,7 +244,11 @@ def _stream( inputs, extra_body=extra_body, streaming=True, **parameters ) sse_iterator = self._make_stream_request( - request_body, endpoint=endpoint, base_url=base_url, **parameters + request_body, + endpoint=endpoint, + base_url=base_url, + extra_headers=extra_headers, + **parameters, ) return stream_class( sse_iterator, @@ -287,6 +313,7 @@ async def _make_request( request_body: dict[str, Any], *, endpoint: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Unpack[Params], # type: ignore[misc] ) -> dict[str, Any]: """Make HTTP request(s) and return response data.""" @@ -295,6 +322,9 @@ async def _make_request( def _make_stream_request( self, request_body: dict[str, Any], + *, + endpoint: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Unpack[Params], # type: ignore[misc] ) -> AsyncIterator[dict[str, Any]]: """Make HTTP streaming request and return async iterator of events.""" diff --git a/src/celeste/modalities/audio/client.py b/src/celeste/modalities/audio/client.py index 2b3bf50..6e0e281 100644 --- a/src/celeste/modalities/audio/client.py +++ b/src/celeste/modalities/audio/client.py @@ -49,6 +49,7 @@ def speak( text: str, *, extra_body: dict[str, Any] | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Unpack[AudioParameters], ) -> AudioStream: """Stream speech generation.""" @@ -57,6 +58,7 @@ def speak( inputs, stream_class=self._client._stream_class(), extra_body=extra_body, + extra_headers=extra_headers, **parameters, ) @@ -72,12 +74,13 @@ def speak( text: str, *, extra_body: dict[str, Any] | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Unpack[AudioParameters], ) -> AudioOutput: """Blocking speech generation.""" inputs = AudioInput(text=text) return async_to_sync(self._client._predict)( - inputs, extra_body=extra_body, **parameters + inputs, extra_body=extra_body, extra_headers=extra_headers, **parameters ) @property @@ -97,6 +100,7 @@ def speak( text: str, *, extra_body: dict[str, Any] | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Unpack[AudioParameters], ) -> AudioStream: """Sync streaming speech generation. @@ -110,7 +114,9 @@ def speak( stream.output.content.save("output.mp3") """ # Return same stream as async version - __iter__/__next__ handle sync iteration - return self._client.stream.speak(text, extra_body=extra_body, **parameters) + return self._client.stream.speak( + text, extra_body=extra_body, extra_headers=extra_headers, **parameters + ) __all__ = [ diff --git a/src/celeste/modalities/embeddings/client.py b/src/celeste/modalities/embeddings/client.py index 2b6e7a3..d5ee4d9 100644 --- a/src/celeste/modalities/embeddings/client.py +++ b/src/celeste/modalities/embeddings/client.py @@ -38,6 +38,7 @@ async def embed( text: str | list[str], *, extra_body: dict[str, Any] | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Unpack[EmbeddingsParameters], ) -> EmbeddingsOutput: """Generate embeddings from text. @@ -45,6 +46,7 @@ async def embed( Args: text: Text to embed. Single string or list of strings. extra_body: Additional provider-specific fields to merge into request. + extra_headers: Additional HTTP headers to include in the request. **parameters: Embedding parameters (e.g., dimensions). Returns: @@ -53,7 +55,9 @@ async def embed( - list[list[float]] if text was a list """ inputs = EmbeddingsInput(text=text) - output = await self._predict(inputs, extra_body=extra_body, **parameters) + output = await self._predict( + inputs, extra_body=extra_body, extra_headers=extra_headers, **parameters + ) # If single text input, unwrap from batch format to single embedding if ( @@ -83,11 +87,12 @@ def embed( text: str | list[str], *, extra_body: dict[str, Any] | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Unpack[EmbeddingsParameters], ) -> EmbeddingsOutput: """Blocking embeddings generation.""" return async_to_sync(self._client.embed)( - text, extra_body=extra_body, **parameters + text, extra_body=extra_body, extra_headers=extra_headers, **parameters ) diff --git a/src/celeste/modalities/images/client.py b/src/celeste/modalities/images/client.py index ac01237..ed471e8 100644 --- a/src/celeste/modalities/images/client.py +++ b/src/celeste/modalities/images/client.py @@ -53,6 +53,7 @@ def generate( prompt: str, *, extra_body: dict[str, Any] | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Unpack[ImageParameters], ) -> ImagesStream: """Stream image generation.""" @@ -61,6 +62,7 @@ def generate( inputs, stream_class=self._client._stream_class(), extra_body=extra_body, + extra_headers=extra_headers, **parameters, ) @@ -70,6 +72,7 @@ def edit( prompt: str, *, extra_body: dict[str, Any] | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Unpack[ImageParameters], ) -> ImagesStream: """Stream image editing.""" @@ -78,6 +81,7 @@ def edit( inputs, stream_class=self._client._stream_class(), extra_body=extra_body, + extra_headers=extra_headers, **parameters, ) @@ -96,6 +100,7 @@ def generate( prompt: str, *, extra_body: dict[str, Any] | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Unpack[ImageParameters], ) -> ImageOutput: """Blocking image generation. @@ -106,7 +111,7 @@ def generate( """ inputs = ImageInput(prompt=prompt) return async_to_sync(self._client._predict)( - inputs, extra_body=extra_body, **parameters + inputs, extra_body=extra_body, extra_headers=extra_headers, **parameters ) def edit( @@ -115,6 +120,7 @@ def edit( prompt: str, *, extra_body: dict[str, Any] | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Unpack[ImageParameters], ) -> ImageOutput: """Blocking image edit. @@ -125,7 +131,7 @@ def edit( """ inputs = ImageInput(prompt=prompt, image=image) return async_to_sync(self._client._predict)( - inputs, extra_body=extra_body, **parameters + inputs, extra_body=extra_body, extra_headers=extra_headers, **parameters ) @property @@ -145,6 +151,7 @@ def generate( prompt: str, *, extra_body: dict[str, Any] | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Unpack[ImageParameters], ) -> ImagesStream: """Sync streaming image generation. @@ -158,7 +165,9 @@ def generate( print(stream.output.usage) """ # Return same stream as async version - __iter__/__next__ handle sync iteration - return self._client.stream.generate(prompt, extra_body=extra_body, **parameters) + return self._client.stream.generate( + prompt, extra_body=extra_body, extra_headers=extra_headers, **parameters + ) def edit( self, @@ -166,6 +175,7 @@ def edit( prompt: str, *, extra_body: dict[str, Any] | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Unpack[ImageParameters], ) -> ImagesStream: """Sync streaming image editing. @@ -179,7 +189,11 @@ def edit( print(stream.output.usage) """ return self._client.stream.edit( - image, prompt, extra_body=extra_body, **parameters + image, + prompt, + extra_body=extra_body, + extra_headers=extra_headers, + **parameters, ) diff --git a/src/celeste/modalities/images/providers/byteplus/client.py b/src/celeste/modalities/images/providers/byteplus/client.py index 97ec1dd..b5829fc 100644 --- a/src/celeste/modalities/images/providers/byteplus/client.py +++ b/src/celeste/modalities/images/providers/byteplus/client.py @@ -149,6 +149,7 @@ async def _make_request( request_body: dict[str, Any], *, endpoint: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Unpack[ImageParameters], ) -> dict[str, Any]: """Make HTTP request with parameter validation.""" @@ -164,7 +165,7 @@ async def _make_request( raise ConstraintViolationError(msg) return await super()._make_request( - request_body, endpoint=endpoint, **parameters + request_body, endpoint=endpoint, extra_headers=extra_headers, **parameters ) def _stream_class(self) -> type[ImagesStream]: diff --git a/src/celeste/modalities/images/providers/google/client.py b/src/celeste/modalities/images/providers/google/client.py index bf16beb..265c8dd 100644 --- a/src/celeste/modalities/images/providers/google/client.py +++ b/src/celeste/modalities/images/providers/google/client.py @@ -101,11 +101,13 @@ async def _make_request( request_body: dict[str, Any], *, endpoint: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Unpack[ImageParameters], ) -> dict[str, Any]: return await self._strategy._make_request( # type: ignore[union-attr] request_body, endpoint=endpoint, + extra_headers=extra_headers, **parameters, ) diff --git a/src/celeste/modalities/text/client.py b/src/celeste/modalities/text/client.py index dd59013..50bfc02 100644 --- a/src/celeste/modalities/text/client.py +++ b/src/celeste/modalities/text/client.py @@ -76,6 +76,7 @@ def generate( messages: list[Message] | None = None, base_url: str | None = None, extra_body: dict[str, Any] | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Unpack[TextParameters], ) -> TextStream: """Stream text generation. @@ -90,6 +91,7 @@ def generate( stream_class=self._client._stream_class(), base_url=base_url, extra_body=extra_body, + extra_headers=extra_headers, **parameters, ) @@ -103,6 +105,7 @@ def analyze( audio: AudioContent | None = None, base_url: str | None = None, extra_body: dict[str, Any] | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Unpack[TextParameters], ) -> TextStream: """Stream media analysis (image, video, or audio). @@ -127,6 +130,7 @@ def analyze( stream_class=self._client._stream_class(), base_url=base_url, extra_body=extra_body, + extra_headers=extra_headers, **parameters, ) @@ -147,6 +151,7 @@ def generate( messages: list[Message] | None = None, base_url: str | None = None, extra_body: dict[str, Any] | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Unpack[TextParameters], ) -> TextOutput: """Blocking text generation. @@ -157,7 +162,11 @@ def generate( """ inputs = TextInput(prompt=prompt, messages=messages) return async_to_sync(self._client._predict)( - inputs, base_url=base_url, extra_body=extra_body, **parameters + inputs, + base_url=base_url, + extra_body=extra_body, + extra_headers=extra_headers, + **parameters, ) def analyze( @@ -170,6 +179,7 @@ def analyze( audio: AudioContent | None = None, base_url: str | None = None, extra_body: dict[str, Any] | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Unpack[TextParameters], ) -> TextOutput: """Blocking media analysis (image, video, or audio). @@ -190,7 +200,11 @@ def analyze( prompt=prompt, messages=messages, image=image, video=video, audio=audio ) return async_to_sync(self._client._predict)( - inputs, base_url=base_url, extra_body=extra_body, **parameters + inputs, + base_url=base_url, + extra_body=extra_body, + extra_headers=extra_headers, + **parameters, ) @property @@ -212,6 +226,7 @@ def generate( messages: list[Message] | None = None, base_url: str | None = None, extra_body: dict[str, Any] | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Unpack[TextParameters], ) -> TextStream: """Sync streaming text generation. @@ -230,6 +245,7 @@ def generate( messages=messages, base_url=base_url, extra_body=extra_body, + extra_headers=extra_headers, **parameters, ) @@ -243,6 +259,7 @@ def analyze( audio: AudioContent | None = None, base_url: str | None = None, extra_body: dict[str, Any] | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Unpack[TextParameters], ) -> TextStream: """Sync streaming media analysis (image, video, or audio). @@ -274,6 +291,7 @@ def analyze( audio=audio, base_url=base_url, extra_body=extra_body, + extra_headers=extra_headers, **parameters, ) diff --git a/src/celeste/modalities/videos/client.py b/src/celeste/modalities/videos/client.py index 58955fe..1522058 100644 --- a/src/celeste/modalities/videos/client.py +++ b/src/celeste/modalities/videos/client.py @@ -46,6 +46,7 @@ def generate( prompt: str, *, extra_body: dict[str, Any] | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Unpack[VideoParameters], ) -> VideoOutput: """Blocking video generation. @@ -56,7 +57,7 @@ def generate( """ inputs = VideoInput(prompt=prompt) return async_to_sync(self._client._predict)( - inputs, extra_body=extra_body, **parameters + inputs, extra_body=extra_body, extra_headers=extra_headers, **parameters ) diff --git a/src/celeste/protocols/chatcompletions/client.py b/src/celeste/protocols/chatcompletions/client.py index bd2171f..8f8378c 100644 --- a/src/celeste/protocols/chatcompletions/client.py +++ b/src/celeste/protocols/chatcompletions/client.py @@ -72,13 +72,14 @@ async def _make_request( request_body: dict[str, Any], *, endpoint: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Any, ) -> dict[str, Any]: """Make HTTP request to Chat Completions API endpoint.""" if endpoint is None: endpoint = self._default_endpoint - headers = self._json_headers() + headers = self._json_headers(extra_headers) response = await self.http_client.post( self._build_url(endpoint), @@ -94,13 +95,14 @@ def _make_stream_request( request_body: dict[str, Any], *, endpoint: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Any, ) -> AsyncIterator[dict[str, Any]]: """Make streaming request to Chat Completions API endpoint.""" if endpoint is None: endpoint = self._default_endpoint - headers = self._json_headers() + headers = self._json_headers(extra_headers) return self.http_client.stream_post( self._build_url(endpoint, streaming=True), diff --git a/src/celeste/protocols/openresponses/client.py b/src/celeste/protocols/openresponses/client.py index fbc153b..d3cba75 100644 --- a/src/celeste/protocols/openresponses/client.py +++ b/src/celeste/protocols/openresponses/client.py @@ -72,13 +72,14 @@ async def _make_request( request_body: dict[str, Any], *, endpoint: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Any, ) -> dict[str, Any]: """Make HTTP request to Responses API endpoint.""" if endpoint is None: endpoint = self._default_endpoint - headers = self._json_headers() + headers = self._json_headers(extra_headers) response = await self.http_client.post( self._build_url(endpoint), @@ -94,13 +95,14 @@ def _make_stream_request( request_body: dict[str, Any], *, endpoint: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Any, ) -> AsyncIterator[dict[str, Any]]: """Make streaming request to Responses API endpoint.""" if endpoint is None: endpoint = self._default_endpoint - headers = self._json_headers() + headers = self._json_headers(extra_headers) return self.http_client.stream_post( self._build_url(endpoint, streaming=True), diff --git a/src/celeste/providers/anthropic/messages/client.py b/src/celeste/providers/anthropic/messages/client.py index bf7298d..bad241a 100644 --- a/src/celeste/providers/anthropic/messages/client.py +++ b/src/celeste/providers/anthropic/messages/client.py @@ -55,7 +55,11 @@ def _build_url(self, endpoint: str, streaming: bool = False) -> str: ) return f"{config.BASE_URL}{endpoint}" - def _build_headers(self, beta_features: list[str] | None = None) -> dict[str, str]: + def _build_headers( + self, + beta_features: list[str] | None = None, + extra_headers: dict[str, str] | None = None, + ) -> dict[str, str]: """Build Anthropic request headers.""" headers: dict[str, str] = { **self._json_headers(), @@ -67,6 +71,8 @@ def _build_headers(self, beta_features: list[str] | None = None) -> dict[str, st for f in beta_features ] headers[config.HEADER_ANTHROPIC_BETA] = ",".join(beta_values) + if extra_headers: + headers.update(extra_headers) return headers def _build_request( @@ -90,6 +96,7 @@ async def _make_request( request_body: dict[str, Any], *, endpoint: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Any, ) -> dict[str, Any]: """Make HTTP request to Anthropic Messages API endpoint.""" @@ -98,7 +105,9 @@ async def _make_request( request_body["max_tokens"] = config.DEFAULT_MAX_TOKENS beta_features: list[str] = request_body.pop("_beta_features", []) - headers = self._build_headers(beta_features=beta_features) + headers = self._build_headers( + beta_features=beta_features, extra_headers=extra_headers + ) if endpoint is None: endpoint = config.AnthropicMessagesEndpoint.CREATE_MESSAGE @@ -117,6 +126,7 @@ def _make_stream_request( request_body: dict[str, Any], *, endpoint: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Any, ) -> AsyncIterator[dict[str, Any]]: """Make streaming request to Anthropic Messages API endpoint.""" @@ -125,7 +135,9 @@ def _make_stream_request( request_body["max_tokens"] = config.DEFAULT_MAX_TOKENS beta_features: list[str] = request_body.pop("_beta_features", []) - headers = self._build_headers(beta_features=beta_features) + headers = self._build_headers( + beta_features=beta_features, extra_headers=extra_headers + ) if endpoint is None: endpoint = config.AnthropicMessagesEndpoint.CREATE_MESSAGE diff --git a/src/celeste/providers/bfl/images/client.py b/src/celeste/providers/bfl/images/client.py index d1c3635..1b16565 100644 --- a/src/celeste/providers/bfl/images/client.py +++ b/src/celeste/providers/bfl/images/client.py @@ -40,6 +40,7 @@ async def _make_request( request_body: dict[str, Any], *, endpoint: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Any, ) -> dict[str, Any]: """Make HTTP request with async polling for BFL image generation. @@ -49,7 +50,10 @@ async def _make_request( 2. Poll polling_url until Ready/Failed 3. Return response with _submit_metadata for usage parsing """ - headers = {**self._json_headers(), "Accept": ApplicationMimeType.JSON} + headers = { + **self._json_headers(extra_headers), + "Accept": ApplicationMimeType.JSON, + } if endpoint is None: endpoint = config.BFLImagesEndpoint.CREATE_IMAGE @@ -72,7 +76,10 @@ async def _make_request( # Phase 2: Poll for completion start_time = time.monotonic() - poll_headers = {**self.auth.get_headers(), "Accept": ApplicationMimeType.JSON} + poll_headers = self._merge_headers( + {**self.auth.get_headers(), "Accept": ApplicationMimeType.JSON}, + extra_headers, + ) while True: elapsed = time.monotonic() - start_time @@ -107,6 +114,7 @@ def _make_stream_request( request_body: dict[str, Any], *, endpoint: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Any, ) -> AsyncIterator[dict[str, Any]]: """BFL Images API does not support SSE streaming in this client.""" diff --git a/src/celeste/providers/byteplus/images/client.py b/src/celeste/providers/byteplus/images/client.py index 81cbf0f..611595b 100644 --- a/src/celeste/providers/byteplus/images/client.py +++ b/src/celeste/providers/byteplus/images/client.py @@ -50,13 +50,14 @@ async def _make_request( request_body: dict[str, Any], *, endpoint: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Any, ) -> dict[str, Any]: """Make HTTP request to BytePlus Images API endpoint.""" if endpoint is None: endpoint = config.BytePlusImagesEndpoint.CREATE_IMAGE - headers = self._json_headers() + headers = self._json_headers(extra_headers) response = await self.http_client.post( f"{config.BASE_URL}{endpoint}", @@ -72,13 +73,14 @@ def _make_stream_request( request_body: dict[str, Any], *, endpoint: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Any, ) -> AsyncIterator[dict[str, Any]]: """Make streaming request to BytePlus Images API endpoint.""" if endpoint is None: endpoint = config.BytePlusImagesEndpoint.CREATE_IMAGE - headers = self._json_headers() + headers = self._json_headers(extra_headers) return self.http_client.stream_post( f"{config.BASE_URL}{endpoint}", diff --git a/src/celeste/providers/byteplus/videos/client.py b/src/celeste/providers/byteplus/videos/client.py index cc3493c..14200d7 100644 --- a/src/celeste/providers/byteplus/videos/client.py +++ b/src/celeste/providers/byteplus/videos/client.py @@ -57,6 +57,7 @@ async def _make_request( request_body: dict[str, Any], *, endpoint: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Any, ) -> dict[str, Any]: """Make HTTP request with async polling for BytePlus video generation. @@ -66,7 +67,7 @@ async def _make_request( 2. Poll CONTENT_STATUS endpoint until succeeded/failed/canceled 3. Return response with final status data """ - headers = self._json_headers() + headers = self._json_headers(extra_headers) if endpoint is None: endpoint = config.BytePlusVideosEndpoint.CREATE_VIDEO @@ -130,6 +131,7 @@ def _make_stream_request( request_body: dict[str, Any], *, endpoint: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Any, ) -> AsyncIterator[dict[str, Any]]: """BytePlus Videos API does not support SSE streaming in this client.""" diff --git a/src/celeste/providers/cohere/chat/client.py b/src/celeste/providers/cohere/chat/client.py index 7f8d08a..a5891fa 100644 --- a/src/celeste/providers/cohere/chat/client.py +++ b/src/celeste/providers/cohere/chat/client.py @@ -52,13 +52,14 @@ async def _make_request( request_body: dict[str, Any], *, endpoint: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Any, ) -> dict[str, Any]: """Make HTTP request to Cohere Chat API endpoint.""" if endpoint is None: endpoint = config.CohereChatEndpoint.CREATE_CHAT - headers = self._json_headers() + headers = self._json_headers(extra_headers) response = await self.http_client.post( f"{config.BASE_URL}{endpoint}", @@ -74,13 +75,14 @@ def _make_stream_request( request_body: dict[str, Any], *, endpoint: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Any, ) -> AsyncIterator[dict[str, Any]]: """Make streaming request to Cohere Chat API endpoint.""" if endpoint is None: endpoint = config.CohereChatEndpoint.CREATE_CHAT - headers = self._json_headers() + headers = self._json_headers(extra_headers) return self.http_client.stream_post( f"{config.BASE_URL}{endpoint}", diff --git a/src/celeste/providers/elevenlabs/text_to_speech/client.py b/src/celeste/providers/elevenlabs/text_to_speech/client.py index 294ad33..457e13e 100644 --- a/src/celeste/providers/elevenlabs/text_to_speech/client.py +++ b/src/celeste/providers/elevenlabs/text_to_speech/client.py @@ -35,6 +35,7 @@ async def _make_request( request_body: dict[str, Any], *, endpoint: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Any, ) -> dict[str, Any]: """Make HTTP request to ElevenLabs TTS endpoint. @@ -55,7 +56,7 @@ async def _make_request( endpoint = config.ElevenLabsTextToSpeechEndpoint.CREATE_SPEECH endpoint = endpoint.format(voice_id=voice_id) - headers = self._json_headers() + headers = self._json_headers(extra_headers) response = await self.http_client.post( f"{config.BASE_URL}{endpoint}", @@ -73,6 +74,7 @@ def _make_stream_request( request_body: dict[str, Any], *, endpoint: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Any, ) -> AsyncIterator[dict[str, Any]]: """Make HTTP streaming request returning binary audio chunks. @@ -93,7 +95,7 @@ def _make_stream_request( endpoint = config.ElevenLabsTextToSpeechEndpoint.STREAM_SPEECH endpoint = endpoint.format(voice_id=voice_id) - headers = self._json_headers() + headers = self._json_headers(extra_headers) return self._stream_binary_audio( f"{config.BASE_URL}{endpoint}", diff --git a/src/celeste/providers/google/cloud_tts/client.py b/src/celeste/providers/google/cloud_tts/client.py index 8fad1a4..c61414b 100644 --- a/src/celeste/providers/google/cloud_tts/client.py +++ b/src/celeste/providers/google/cloud_tts/client.py @@ -36,6 +36,7 @@ def _make_stream_request( request_body: dict[str, Any], *, endpoint: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Any, ) -> AsyncIterator[dict[str, Any]]: """Cloud TTS does not support SSE streaming in this client.""" @@ -60,13 +61,14 @@ async def _make_request( request_body: dict[str, Any], *, endpoint: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Any, ) -> dict[str, Any]: """Make HTTP request to Cloud TTS synthesize endpoint.""" if endpoint is None: endpoint = config.GoogleCloudTTSEndpoint.CREATE_SPEECH - headers = self._json_headers() + headers = self._json_headers(extra_headers) response = await self.http_client.post( f"{config.BASE_URL}{endpoint}", diff --git a/src/celeste/providers/google/embeddings/client.py b/src/celeste/providers/google/embeddings/client.py index 482f8fa..235ac1b 100644 --- a/src/celeste/providers/google/embeddings/client.py +++ b/src/celeste/providers/google/embeddings/client.py @@ -35,6 +35,7 @@ def _make_stream_request( request_body: dict[str, Any], *, endpoint: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Any, ) -> AsyncIterator[dict[str, Any]]: """Embeddings API does not support SSE streaming in this client.""" @@ -72,6 +73,7 @@ async def _make_request( request_body: dict[str, Any], *, endpoint: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Any, ) -> dict[str, Any]: """Make HTTP request to embeddings endpoint.""" @@ -95,7 +97,7 @@ async def _make_request( if endpoint is None: endpoint = endpoint_template - headers = self._json_headers() + headers = self._json_headers(extra_headers) response = await self.http_client.post( self._build_url(endpoint), diff --git a/src/celeste/providers/google/generate_content/client.py b/src/celeste/providers/google/generate_content/client.py index 734c373..c1402dc 100644 --- a/src/celeste/providers/google/generate_content/client.py +++ b/src/celeste/providers/google/generate_content/client.py @@ -63,13 +63,14 @@ async def _make_request( request_body: dict[str, Any], *, endpoint: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Any, ) -> dict[str, Any]: """Make HTTP request to generateContent endpoint.""" if endpoint is None: endpoint = config.GoogleGenerateContentEndpoint.GENERATE_CONTENT - headers = self._json_headers() + headers = self._json_headers(extra_headers) response = await self.http_client.post( url=self._build_url(endpoint), headers=headers, @@ -84,13 +85,14 @@ def _make_stream_request( request_body: dict[str, Any], *, endpoint: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Any, ) -> AsyncIterator[dict[str, Any]]: """Make streaming request to streamGenerateContent endpoint.""" if endpoint is None: endpoint = config.GoogleGenerateContentEndpoint.STREAM_GENERATE_CONTENT - headers = self._json_headers() + headers = self._json_headers(extra_headers) return self.http_client.stream_post( url=self._build_url(endpoint), headers=headers, diff --git a/src/celeste/providers/google/imagen/client.py b/src/celeste/providers/google/imagen/client.py index cff9217..e91fe6c 100644 --- a/src/celeste/providers/google/imagen/client.py +++ b/src/celeste/providers/google/imagen/client.py @@ -60,13 +60,14 @@ async def _make_request( request_body: dict[str, Any], *, endpoint: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Any, ) -> dict[str, Any]: """Make HTTP request to Imagen :predict endpoint.""" if endpoint is None: endpoint = config.GoogleImagenEndpoint.CREATE_IMAGE - headers = self._json_headers() + headers = self._json_headers(extra_headers) response = await self.http_client.post( self._build_url(endpoint), headers=headers, @@ -81,6 +82,7 @@ def _make_stream_request( request_body: dict[str, Any], *, endpoint: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Any, ) -> AsyncIterator[dict[str, Any]]: """Imagen API does not support SSE streaming in this client.""" diff --git a/src/celeste/providers/google/interactions/client.py b/src/celeste/providers/google/interactions/client.py index d490ef6..dc9d9d7 100644 --- a/src/celeste/providers/google/interactions/client.py +++ b/src/celeste/providers/google/interactions/client.py @@ -59,13 +59,14 @@ async def _make_request( request_body: dict[str, Any], *, endpoint: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Any, ) -> dict[str, Any]: """Make HTTP request to interactions endpoint.""" if endpoint is None: endpoint = config.GoogleInteractionsEndpoint.CREATE_INTERACTION - headers = self._json_headers() + headers = self._json_headers(extra_headers) response = await self.http_client.post( f"{config.BASE_URL}{endpoint}", headers=headers, @@ -80,13 +81,14 @@ def _make_stream_request( request_body: dict[str, Any], *, endpoint: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Any, ) -> AsyncIterator[dict[str, Any]]: """Make streaming request to interactions endpoint.""" if endpoint is None: endpoint = config.GoogleInteractionsEndpoint.STREAM_INTERACTION - headers = self._json_headers() + headers = self._json_headers(extra_headers) return self.http_client.stream_post( f"{config.BASE_URL}{endpoint}", headers=headers, @@ -96,6 +98,7 @@ def _make_stream_request( async def _get_interaction( self, interaction_id: str, + extra_headers: dict[str, str] | None = None, ) -> httpx.Response: """Get an existing interaction by ID. @@ -104,7 +107,7 @@ async def _get_interaction( endpoint = config.GoogleInteractionsEndpoint.GET_INTERACTION.format( interaction_id=interaction_id ) - headers = self.auth.get_headers() + headers = self._merge_headers(self.auth.get_headers(), extra_headers) return await self.http_client.get( f"{config.BASE_URL}{endpoint}", diff --git a/src/celeste/providers/google/veo/client.py b/src/celeste/providers/google/veo/client.py index 7817c96..f456b5b 100644 --- a/src/celeste/providers/google/veo/client.py +++ b/src/celeste/providers/google/veo/client.py @@ -65,13 +65,15 @@ def _build_poll_url(self, operation_name: str) -> str: ) return f"{config.BASE_URL}{poll_path}" - async def _make_poll_request(self, operation_name: str) -> dict[str, Any]: + async def _make_poll_request( + self, operation_name: str, extra_headers: dict[str, str] | None = None + ) -> dict[str, Any]: """Poll a long-running operation. Vertex AI uses POST to fetchPredictOperation with operationName in body. AI Studio uses GET to /v1beta/{operation_name}. """ - headers = self._json_headers() + headers = self._json_headers(extra_headers) poll_url = self._build_poll_url(operation_name) if isinstance(self.auth, GoogleADC): @@ -97,6 +99,7 @@ def _make_stream_request( request_body: dict[str, Any], *, endpoint: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Any, ) -> AsyncIterator[dict[str, Any]]: """Veo API does not support SSE streaming in this client.""" @@ -107,13 +110,14 @@ async def _make_request( request_body: dict[str, Any], *, endpoint: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Any, ) -> dict[str, Any]: """Make HTTP request with async polling for Veo video generation.""" if endpoint is None: endpoint = config.GoogleVeoEndpoint.CREATE_VIDEO - headers = self._json_headers() + headers = self._json_headers(extra_headers) logger.info(f"Initiating video generation with model {self.model.id}") response = await self.http_client.post( @@ -133,7 +137,9 @@ async def _make_request( await asyncio.sleep(config.POLL_INTERVAL) logger.debug(f"Polling operation status: {operation_name}") - operation_data = await self._make_poll_request(operation_name) + operation_data = await self._make_poll_request( + operation_name, extra_headers=extra_headers + ) if operation_data.get("done"): if "error" in operation_data: @@ -193,13 +199,16 @@ def _parse_finish_reason(self, response_data: dict[str, Any]) -> FinishReason: """Veo API doesn't provide finish reasons.""" return FinishReason(reason=None) - async def download_content(self, url: str) -> bytes: + async def download_content( + self, url: str, extra_headers: dict[str, str] | None = None + ) -> bytes: """Download video content from GCS URL. Returns raw bytes that capability clients wrap in VideoArtifact. Args: url: GCS URL (gs://) or HTTPS URL to download from. + extra_headers: Optional extra HTTP headers to include. Returns: Raw video bytes. @@ -210,7 +219,7 @@ async def download_content(self, url: str) -> bytes: logger.info(f"Downloading video from: {download_url}") - headers = self.auth.get_headers() + headers = self._merge_headers(self.auth.get_headers(), extra_headers) response = await self.http_client.get( download_url, diff --git a/src/celeste/providers/gradium/text_to_speech/client.py b/src/celeste/providers/gradium/text_to_speech/client.py index a409054..d10e963 100644 --- a/src/celeste/providers/gradium/text_to_speech/client.py +++ b/src/celeste/providers/gradium/text_to_speech/client.py @@ -36,6 +36,7 @@ async def _make_stream_request( request_body: dict[str, Any], *, endpoint: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Any, ) -> AsyncIterator[dict[str, Any]]: """Execute WebSocket TTS flow as async generator. @@ -63,7 +64,7 @@ async def _make_stream_request( if endpoint is None: endpoint = config.GradiumTextToSpeechEndpoint.CREATE_SPEECH url = f"{config.BASE_URL}{endpoint}" - headers = self.auth.get_headers() + headers = self._merge_headers(self.auth.get_headers(), extra_headers) async with ws_connect(url, additional_headers=headers) as ws: # 1. Send setup message @@ -143,6 +144,7 @@ async def _make_request( request_body: dict[str, Any], *, endpoint: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Any, ) -> dict[str, Any]: """Collect audio from WebSocket stream. @@ -153,7 +155,7 @@ async def _make_request( output_format = request_body.get("output_format", "wav") async for event in self._make_stream_request( - request_body, endpoint=endpoint, **parameters + request_body, endpoint=endpoint, extra_headers=extra_headers, **parameters ): if "data" in event: audio_chunks.append(event["data"]) diff --git a/src/celeste/providers/ollama/generate/client.py b/src/celeste/providers/ollama/generate/client.py index b07c370..62a4bef 100644 --- a/src/celeste/providers/ollama/generate/client.py +++ b/src/celeste/providers/ollama/generate/client.py @@ -47,6 +47,7 @@ async def _make_request( *, endpoint: str | None = None, base_url: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Any, ) -> dict[str, Any]: """Make HTTP request to Ollama Generate API.""" @@ -55,7 +56,7 @@ async def _make_request( if base_url is None: base_url = config.DEFAULT_BASE_URL - headers = self._json_headers() + headers = self._json_headers(extra_headers) response = await self.http_client.post( f"{base_url}{endpoint}", @@ -72,6 +73,7 @@ def _make_stream_request( *, endpoint: str | None = None, base_url: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Any, ) -> AsyncIterator[dict[str, Any]]: """Make NDJSON streaming request to Ollama Generate API.""" @@ -80,7 +82,7 @@ def _make_stream_request( if base_url is None: base_url = config.DEFAULT_BASE_URL - headers = self._json_headers() + headers = self._json_headers(extra_headers) return self.http_client.stream_post_ndjson( f"{base_url}{endpoint}", diff --git a/src/celeste/providers/openai/audio/client.py b/src/celeste/providers/openai/audio/client.py index 28d7739..e2a3c08 100644 --- a/src/celeste/providers/openai/audio/client.py +++ b/src/celeste/providers/openai/audio/client.py @@ -49,6 +49,7 @@ def _make_stream_request( request_body: dict[str, Any], *, endpoint: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Any, ) -> AsyncIterator[dict[str, Any]]: """OpenAI Audio API speech endpoint does not support SSE streaming in this client.""" @@ -59,6 +60,7 @@ async def _make_request( request_body: dict[str, Any], *, endpoint: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Any, ) -> dict[str, Any]: """Make HTTP request to OpenAI Audio API speech endpoint. @@ -68,7 +70,7 @@ async def _make_request( if endpoint is None: endpoint = config.OpenAIAudioEndpoint.CREATE_SPEECH - headers = self._json_headers() + headers = self._json_headers(extra_headers) response = await self.http_client.post( f"{config.BASE_URL}{endpoint}", diff --git a/src/celeste/providers/openai/images/client.py b/src/celeste/providers/openai/images/client.py index d8f084c..2cdce0d 100644 --- a/src/celeste/providers/openai/images/client.py +++ b/src/celeste/providers/openai/images/client.py @@ -51,6 +51,7 @@ async def _make_request( request_body: dict[str, Any], *, endpoint: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Any, ) -> dict[str, Any]: """Make HTTP request to OpenAI Images API.""" @@ -59,22 +60,28 @@ async def _make_request( # Edit endpoint requires multipart/form-data if endpoint == config.OpenAIImagesEndpoint.CREATE_EDIT: - return await self._make_multipart_request(request_body, endpoint) + return await self._make_multipart_request( + request_body, endpoint, extra_headers=extra_headers + ) # Generate uses JSON - return await self._make_json_request(request_body, endpoint) + return await self._make_json_request( + request_body, endpoint, extra_headers=extra_headers + ) async def _make_json_request( self, request_body: dict[str, Any], endpoint: str, + *, + extra_headers: dict[str, str] | None = None, ) -> dict[str, Any]: """Make JSON request for generate operations.""" # DALL-E 2/3 need b64_json response format if self.model.id in ("dall-e-2", "dall-e-3"): request_body.setdefault("response_format", "b64_json") - headers = self._json_headers() + headers = self._json_headers(extra_headers) response = await self.http_client.post( f"{config.BASE_URL}{endpoint}", @@ -89,6 +96,8 @@ async def _make_multipart_request( self, request_body: dict[str, Any], endpoint: str, + *, + extra_headers: dict[str, str] | None = None, ) -> dict[str, Any]: """Make multipart request for edit operations.""" image_artifact = request_body.pop("image") @@ -112,7 +121,7 @@ async def _make_multipart_request( response = await self.http_client.post_multipart( f"{config.BASE_URL}{endpoint}", - headers=self.auth.get_headers(), + headers=self._merge_headers(self.auth.get_headers(), extra_headers), files=files, data=data, ) @@ -125,6 +134,7 @@ def _make_stream_request( request_body: dict[str, Any], *, endpoint: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Any, ) -> AsyncIterator[dict[str, Any]]: """Make streaming request to OpenAI Images API. @@ -144,7 +154,7 @@ def _make_stream_request( request_body["images"] = [{"image_url": build_image_data_url(artifact)}] endpoint = config.OpenAIImagesEndpoint.CREATE_EDIT - headers = self._json_headers() + headers = self._json_headers(extra_headers) return self.http_client.stream_post( f"{config.BASE_URL}{endpoint}", diff --git a/src/celeste/providers/openai/videos/client.py b/src/celeste/providers/openai/videos/client.py index f4f0841..ffb4601 100644 --- a/src/celeste/providers/openai/videos/client.py +++ b/src/celeste/providers/openai/videos/client.py @@ -63,6 +63,7 @@ def _make_stream_request( request_body: dict[str, Any], *, endpoint: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Any, ) -> AsyncIterator[dict[str, Any]]: """OpenAI Videos API does not support SSE streaming in this client.""" @@ -73,6 +74,7 @@ async def _make_request( request_body: dict[str, Any], *, endpoint: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Any, ) -> dict[str, Any]: """Make HTTP request with async polling for OpenAI video generation. @@ -85,15 +87,13 @@ async def _make_request( if endpoint is None: endpoint = config.OpenAIVideosEndpoint.CREATE_VIDEO - headers = self._json_headers() - files, data = await self._prepare_multipart_request(request_body.copy()) if files: logger.info("Sending multipart request to OpenAI with input_reference") response = await self.http_client.post_multipart( f"{config.BASE_URL}{endpoint}", - headers=headers, + headers=self._merge_headers(self.auth.get_headers(), extra_headers), files=files, data=data, ) @@ -101,7 +101,7 @@ async def _make_request( logger.info(f"Sending request to OpenAI: {request_body}") response = await self.http_client.post( f"{config.BASE_URL}{endpoint}", - headers=headers, + headers=self._json_headers(extra_headers), json_body=request_body, ) @@ -112,10 +112,11 @@ async def _make_request( logger.info(f"Created video job: {video_id}") # Poll for completion + poll_headers = self._json_headers(extra_headers) for _ in range(config.MAX_POLLS): status_response = await self.http_client.get( f"{config.BASE_URL}{endpoint}/{video_id}", - headers=headers, + headers=poll_headers, ) self._handle_error_response(status_response) video_obj = status_response.json() @@ -142,7 +143,7 @@ async def _make_request( # Fetch video content content_response = await self.http_client.get( f"{config.BASE_URL}{endpoint}/{video_id}{config.CONTENT_ENDPOINT_SUFFIX}", - headers=headers, + headers=poll_headers, ) self._handle_error_response(content_response) video_data = content_response.content diff --git a/src/celeste/providers/xai/images/client.py b/src/celeste/providers/xai/images/client.py index 9d5a398..2003f7b 100644 --- a/src/celeste/providers/xai/images/client.py +++ b/src/celeste/providers/xai/images/client.py @@ -47,13 +47,14 @@ async def _make_request( request_body: dict[str, Any], *, endpoint: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Any, ) -> dict[str, Any]: """Make HTTP request to xAI Images API.""" if endpoint is None: endpoint = config.XAIImagesEndpoint.CREATE_IMAGE - headers = self._json_headers() + headers = self._json_headers(extra_headers) response = await self.http_client.post( f"{config.BASE_URL}{endpoint}", @@ -69,6 +70,7 @@ def _make_stream_request( request_body: dict[str, Any], *, endpoint: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Any, ) -> AsyncIterator[dict[str, Any]]: """xAI Images does not support SSE streaming.""" diff --git a/src/celeste/providers/xai/videos/client.py b/src/celeste/providers/xai/videos/client.py index 42429a1..beb5c3c 100644 --- a/src/celeste/providers/xai/videos/client.py +++ b/src/celeste/providers/xai/videos/client.py @@ -49,6 +49,7 @@ def _make_stream_request( request_body: dict[str, Any], *, endpoint: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Any, ) -> AsyncIterator[dict[str, Any]]: """xAI Videos API does not support SSE streaming.""" @@ -59,13 +60,14 @@ async def _make_request( request_body: dict[str, Any], *, endpoint: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Any, ) -> dict[str, Any]: """Make HTTP request with async polling for xAI video generation.""" if endpoint is None: endpoint = config.XAIVideosEndpoint.CREATE_VIDEO - headers = self._json_headers() + headers = self._json_headers(extra_headers) # Submit video generation request response = await self.http_client.post( diff --git a/templates/modalities/{modality_slug}/client.py.template b/templates/modalities/{modality_slug}/client.py.template index cb74625..134fe61 100644 --- a/templates/modalities/{modality_slug}/client.py.template +++ b/templates/modalities/{modality_slug}/client.py.template @@ -72,6 +72,7 @@ class {Modality}StreamNamespace: prompt: str, *, extra_body: dict[str, Any] | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Unpack[{Modality}Parameters], ) -> {Modality}Stream: """Stream {modality} generation. @@ -85,6 +86,7 @@ class {Modality}StreamNamespace: inputs, stream_class=self._client._stream_class(), extra_body=extra_body, + extra_headers=extra_headers, **parameters, ) @@ -96,6 +98,7 @@ class {Modality}StreamNamespace: video: VideoContent | None = None, audio: AudioContent | None = None, extra_body: dict[str, Any] | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Unpack[{Modality}Parameters], ) -> {Modality}Stream: """Stream media analysis (image, video, or audio). @@ -110,6 +113,7 @@ class {Modality}StreamNamespace: inputs, stream_class=self._client._stream_class(), extra_body=extra_body, + extra_headers=extra_headers, **parameters, ) @@ -128,6 +132,7 @@ class {Modality}SyncNamespace: prompt: str, *, extra_body: dict[str, Any] | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Unpack[{Modality}Parameters], ) -> {Modality}Output: """Blocking {modality} generation. @@ -137,7 +142,7 @@ class {Modality}SyncNamespace: print(result.content) """ inputs = {Modality}Input(prompt=prompt) - return async_to_sync(self._client._predict)(inputs, extra_body=extra_body, **parameters) + return async_to_sync(self._client._predict)(inputs, extra_body=extra_body, extra_headers=extra_headers, **parameters) def analyze( self, @@ -147,6 +152,7 @@ class {Modality}SyncNamespace: video: VideoContent | None = None, audio: AudioContent | None = None, extra_body: dict[str, Any] | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Unpack[{Modality}Parameters], ) -> {Modality}Output: """Blocking media analysis (image, video, or audio). @@ -157,7 +163,7 @@ class {Modality}SyncNamespace: """ self._client._check_media_support(image=image, video=video, audio=audio) inputs = {Modality}Input(prompt=prompt, image=image, video=video, audio=audio) - return async_to_sync(self._client._predict)(inputs, extra_body=extra_body, **parameters) + return async_to_sync(self._client._predict)(inputs, extra_body=extra_body, extra_headers=extra_headers, **parameters) @property def stream(self) -> "{Modality}SyncStreamNamespace": @@ -176,6 +182,7 @@ class {Modality}SyncStreamNamespace: prompt: str, *, extra_body: dict[str, Any] | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Unpack[{Modality}Parameters], ) -> {Modality}Stream: """Sync streaming {modality} generation. @@ -189,7 +196,7 @@ class {Modality}SyncStreamNamespace: print(stream.output.usage) """ # Return same stream as async version - __iter__/__next__ handle sync iteration - return self._client.stream.generate(prompt, extra_body=extra_body, **parameters) + return self._client.stream.generate(prompt, extra_body=extra_body, extra_headers=extra_headers, **parameters) def analyze( self, @@ -199,6 +206,7 @@ class {Modality}SyncStreamNamespace: video: VideoContent | None = None, audio: AudioContent | None = None, extra_body: dict[str, Any] | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Unpack[{Modality}Parameters], ) -> {Modality}Stream: """Sync streaming media analysis (image, video, or audio). @@ -213,7 +221,7 @@ class {Modality}SyncStreamNamespace: """ # Return same stream as async version - __iter__/__next__ handle sync iteration return self._client.stream.analyze( - prompt, image=image, video=video, audio=audio, extra_body=extra_body, **parameters + prompt, image=image, video=video, audio=audio, extra_body=extra_body, extra_headers=extra_headers, **parameters ) diff --git a/templates/protocols/{protocol_slug}/client.py.template b/templates/protocols/{protocol_slug}/client.py.template index 1d1dea8..ccde4c7 100644 --- a/templates/protocols/{protocol_slug}/client.py.template +++ b/templates/protocols/{protocol_slug}/client.py.template @@ -6,8 +6,6 @@ from typing import Any, ClassVar from celeste.client import APIMixin from celeste.core import UsageField from celeste.io import FinishReason -from celeste.mime_types import ApplicationMimeType - from . import config @@ -73,16 +71,14 @@ class {Protocol}Client(APIMixin): request_body: dict[str, Any], *, endpoint: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Any, ) -> dict[str, Any]: """Make HTTP request to {Protocol} API endpoint.""" if endpoint is None: endpoint = self._default_endpoint - headers = { - **self.auth.get_headers(), - "Content-Type": ApplicationMimeType.JSON, - } + headers = self._json_headers(extra_headers) response = await self.http_client.post( self._build_url(endpoint), @@ -98,16 +94,14 @@ class {Protocol}Client(APIMixin): request_body: dict[str, Any], *, endpoint: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Any, ) -> AsyncIterator[dict[str, Any]]: """Make streaming request to {Protocol} API endpoint.""" if endpoint is None: endpoint = self._default_endpoint - headers = { - **self.auth.get_headers(), - "Content-Type": ApplicationMimeType.JSON, - } + headers = self._json_headers(extra_headers) return self.http_client.stream_post( self._build_url(endpoint, streaming=True), diff --git a/templates/providers/{provider_slug}/{api_slug}/client.py.template b/templates/providers/{provider_slug}/{api_slug}/client.py.template index 9047f59..c73fd21 100644 --- a/templates/providers/{provider_slug}/{api_slug}/client.py.template +++ b/templates/providers/{provider_slug}/{api_slug}/client.py.template @@ -78,13 +78,14 @@ class {Provider}{Api}Client(APIMixin): request_body: dict[str, Any], *, endpoint: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Any, ) -> dict[str, Any]: """Make HTTP request to {Provider} {Api} API.""" if endpoint is None: endpoint = config.{Provider}{Api}Endpoint.CREATE_... - headers = self._json_headers() + headers = self._json_headers(extra_headers) response = await self.http_client.post( f"{config.BASE_URL}{endpoint}", @@ -100,6 +101,7 @@ class {Provider}{Api}Client(APIMixin): request_body: dict[str, Any], *, endpoint: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Any, ) -> AsyncIterator[dict[str, Any]]: """Make streaming request to {Provider} {Api} API. @@ -110,6 +112,7 @@ class {Provider}{Api}Client(APIMixin): request_body: dict[str, Any], *, endpoint: str | None = None, + extra_headers: dict[str, str] | None = None, **parameters: Any, ) -> AsyncIterator[dict[str, Any]]: \"\"\"{Provider} {Api} does not support SSE streaming in this client.\"\"\" @@ -118,7 +121,7 @@ class {Provider}{Api}Client(APIMixin): if endpoint is None: endpoint = config.{Provider}{Api}Endpoint.CREATE_... - headers = self._json_headers() + headers = self._json_headers(extra_headers) return self.http_client.stream_post( f"{config.BASE_URL}{endpoint}", @@ -129,13 +132,14 @@ class {Provider}{Api}Client(APIMixin): async def _make_poll_request( self, operation_name: str, + extra_headers: dict[str, str] | None = None, ) -> dict[str, Any]: """Poll a long-running operation. If this API does not use long-running operations, remove this method. Override for Vertex AI support (POST to fetchPredictOperation). """ - headers = self.auth.get_headers() + headers = self._merge_headers(self.auth.get_headers(), extra_headers) poll_url = f"{config.BASE_URL}{config.{Provider}{Api}Endpoint.GET_OPERATION.format(operation_name=operation_name)}" response = await self.http_client.get( diff --git a/tests/unit_tests/test_provider_api_templates.py b/tests/unit_tests/test_provider_api_templates.py index cc675c3..818f630 100644 --- a/tests/unit_tests/test_provider_api_templates.py +++ b/tests/unit_tests/test_provider_api_templates.py @@ -59,6 +59,7 @@ def _extract_template_expectations(template_text: str) -> TemplateExpectations: # Also ensure the endpoint routing contract is present in the template. assert "async def _make_request" in template_text assert "endpoint: str | None = None" in template_text + assert "extra_headers: dict[str, str] | None = None" in template_text assert "def _make_stream_request" in template_text return TemplateExpectations( @@ -202,6 +203,22 @@ def test_all_provider_api_mixins_match_template_contract() -> None: and endpoint_default.value is None ), f"{client_path}: {fn_name} endpoint default must be None" + kw_headers = _kwonly_arg(fn, "extra_headers") + assert kw_headers is not None, ( + f"{client_path}: {fn_name} missing kw-only extra_headers" + ) + headers_arg, headers_default = kw_headers + assert headers_arg.annotation is not None, ( + f"{client_path}: {fn_name} extra_headers missing annotation" + ) + assert ( + ast.unparse(headers_arg.annotation).strip() == "dict[str, str] | None" + ), f"{client_path}: {fn_name} extra_headers annotation mismatch" + assert ( + isinstance(headers_default, ast.Constant) + and headers_default.value is None + ), f"{client_path}: {fn_name} extra_headers default must be None" + # Usage typing parity (matches template) map_usage_fields = methods["map_usage_fields"] assert _has_staticmethod_decorator(map_usage_fields), (