|
12 | 12 | from celeste.core import Modality, Provider |
13 | 13 | from celeste.exceptions import StreamingNotSupportedError |
14 | 14 | from celeste.http import HTTPClient, get_http_client |
15 | | -from celeste.io import Chunk, FinishReason, Input, Output, Usage |
| 15 | +from celeste.io import Chunk as ChunkBase |
| 16 | +from celeste.io import FinishReason, Input, Output, Usage |
16 | 17 | from celeste.mime_types import ApplicationMimeType |
17 | 18 | from celeste.models import Model |
18 | 19 | from celeste.parameters import ParameterMapper, Parameters |
@@ -130,15 +131,19 @@ def _handle_error_response(self, response: httpx.Response) -> None: |
130 | 131 | super()._handle_error_response(response) # type: ignore[misc] |
131 | 132 |
|
132 | 133 |
|
133 | | -class ModalityClient[In: Input, Out: Output, Params: Parameters, Content]( |
134 | | - APIMixin, BaseModel |
135 | | -): |
| 134 | +class ModalityClient[ |
| 135 | + In: Input, |
| 136 | + Out: Output, |
| 137 | + Params: Parameters, |
| 138 | + Content, |
| 139 | + Chunk: ChunkBase, |
| 140 | +](APIMixin, BaseModel): |
136 | 141 | """Base class for unified modality clients. |
137 | 142 |
|
138 | 143 | Operation methods in subclasses delegate to _predict(). |
139 | 144 |
|
140 | 145 | Example: |
141 | | - class ImagesClient(ModalityClient[ImagesInput, ImagesOutput, ImagesParameters, ImageContent]): |
| 146 | + class ImagesClient(ModalityClient[ImagesInput, ImagesOutput, ImagesParameters, ImageContent, ImageChunk]): |
142 | 147 | modality = Modality.IMAGES |
143 | 148 |
|
144 | 149 | async def generate(self, prompt: str, **parameters) -> ImageGenerationOutput: |
@@ -198,7 +203,7 @@ async def _predict( |
198 | 203 | response_data = await self._make_request( |
199 | 204 | request_body, endpoint=endpoint, extra_headers=extra_headers, **parameters |
200 | 205 | ) |
201 | | - content = self._parse_content(response_data, **parameters) |
| 206 | + content = self._parse_content(response_data) |
202 | 207 | content = self._transform_output(content, **parameters) |
203 | 208 | return self._output_class()( |
204 | 209 | content=content, |
@@ -277,7 +282,6 @@ def _parse_usage(self, response_data: dict[str, Any]) -> RawUsage: |
277 | 282 | def _parse_content( |
278 | 283 | self, |
279 | 284 | response_data: dict[str, Any], |
280 | | - **parameters: Unpack[Params], # type: ignore[misc] |
281 | 285 | ) -> Content: |
282 | 286 | """Parse content from provider response.""" |
283 | 287 | ... |
@@ -384,8 +388,7 @@ def _transform_output( |
384 | 388 | """Transform content using parameter mapper output transformations.""" |
385 | 389 | for mapper in self.parameter_mappers(): |
386 | 390 | value = parameters.get(mapper.name) |
387 | | - if value is not None: |
388 | | - content = mapper.parse_output(content, value) |
| 391 | + content = mapper.parse_output(content, value) |
389 | 392 | return content |
390 | 393 |
|
391 | 394 | @abstractmethod |
|
0 commit comments