Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/celeste/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
StreamNotExhaustedError,
UnsupportedCapabilityError,
UnsupportedParameterError,
UnsupportedParameterWarning,
UnsupportedProviderError,
ValidationError,
)
Expand Down Expand Up @@ -272,6 +273,7 @@ def create_client(
"StrictRefResolvingJsonSchemaGenerator",
"UnsupportedCapabilityError",
"UnsupportedParameterError",
"UnsupportedParameterWarning",
"UnsupportedProviderError",
"Usage",
"UsageField",
Expand Down
14 changes: 12 additions & 2 deletions src/celeste/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Base client for modality-specific AI operations."""

import warnings
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator
from json import JSONDecodeError
Expand All @@ -10,7 +11,7 @@

from celeste.auth import Authentication
from celeste.core import Modality, Provider
from celeste.exceptions import StreamingNotSupportedError
from celeste.exceptions import StreamingNotSupportedError, UnsupportedParameterWarning
from celeste.http import HTTPClient, get_http_client
from celeste.io import Chunk as ChunkBase
from celeste.io import FinishReason, Input, Output, Usage
Expand Down Expand Up @@ -408,9 +409,18 @@ def _build_request(
request = self._init_request(inputs)

for mapper in self.parameter_mappers():
value = parameters.get(mapper.name)
value = parameters.pop(mapper.name, None)
request = mapper.map(request, value, self.model)

for name, value in parameters.items():
if value is not None:
warnings.warn(
f"Parameter '{name}' is not supported by model "
f"'{self.model.id}' and will be ignored.",
UnsupportedParameterWarning,
stacklevel=4,
)

if extra_body:
self._deep_merge(request, extra_body)

Expand Down
5 changes: 5 additions & 0 deletions src/celeste/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,10 @@ def __init__(self, parameter: str, model_id: str) -> None:
)


class UnsupportedParameterWarning(UserWarning):
"""Emitted when a parameter is not supported by a provider and will be ignored."""


__all__ = [
"ClientNotFoundError",
"ConstraintViolationError",
Expand All @@ -259,5 +263,6 @@ def __init__(self, parameter: str, model_id: str) -> None:
"StreamingNotSupportedError",
"UnsupportedCapabilityError",
"UnsupportedParameterError",
"UnsupportedParameterWarning",
"UnsupportedProviderError",
]
67 changes: 66 additions & 1 deletion tests/unit_tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from celeste.auth import APIKey
from celeste.client import ModalityClient
from celeste.core import Modality, Provider
from celeste.exceptions import StreamingNotSupportedError
from celeste.exceptions import StreamingNotSupportedError, UnsupportedParameterWarning
from celeste.io import Chunk, Input, Output, Usage
from celeste.models import Model, Operation
from celeste.parameters import ParameterMapper, Parameters
Expand Down Expand Up @@ -222,6 +222,71 @@ def parameter_mappers(cls) -> list[ParameterMapper[str]]:
assert request["first_param"] == "first"
assert request["second_param"] == "second"

def test_build_request_warns_on_unsupported_parameter(
self, text_model: Model, api_key: str
) -> None:
"""_build_request emits UnsupportedParameterWarning for unmapped parameters."""

class ClientWithOneMapper(ConcreteModalityClient):
@classmethod
def parameter_mappers(cls) -> list[ParameterMapper[str]]:
return [_create_test_mapper(ParamEnum.FIRST_PARAM)]

client = ClientWithOneMapper(
modality=Modality.TEXT,
model=text_model,
provider=text_model.provider,
auth=APIKey(secret=SecretStr(api_key)),
)

inputs = _TestInput(prompt="test")

with pytest.warns(UnsupportedParameterWarning, match="second_param.*gpt-4"):
client._build_request(inputs, first_param="ok", second_param="unsupported")

def test_build_request_no_warning_for_supported_parameters(
self, text_model: Model, api_key: str
) -> None:
"""_build_request does not warn when all parameters have mappers."""
import warnings

class ClientWithMapper(ConcreteModalityClient):
@classmethod
def parameter_mappers(cls) -> list[ParameterMapper[str]]:
return [_create_test_mapper(ParamEnum.TEST_PARAM)]

client = ClientWithMapper(
modality=Modality.TEXT,
model=text_model,
provider=text_model.provider,
auth=APIKey(secret=SecretStr(api_key)),
)

inputs = _TestInput(prompt="test")

with warnings.catch_warnings():
warnings.simplefilter("error", UnsupportedParameterWarning)
client._build_request(inputs, test_param="supported")

def test_build_request_no_warning_for_none_unsupported_parameter(
self, text_model: Model, api_key: str
) -> None:
"""_build_request does not warn when unsupported parameter value is None."""
import warnings

client = ConcreteModalityClient(
modality=Modality.TEXT,
model=text_model,
provider=text_model.provider,
auth=APIKey(secret=SecretStr(api_key)),
)

inputs = _TestInput(prompt="test")

with warnings.catch_warnings():
warnings.simplefilter("error", UnsupportedParameterWarning)
client._build_request(inputs, test_param=None)

@pytest.mark.parametrize(
"param_value,expected_output",
[
Expand Down
Loading