|
5 | 5 | Rewritten to use LiteLLM for unified retry logic, caching, and provider support. |
6 | 6 | """ |
7 | 7 |
|
8 | | -import asyncio |
9 | | -import json |
10 | 8 | import logging |
11 | 9 | import os |
12 | | -from abc import ABC, abstractmethod |
13 | | -from typing import Any, Dict, List, Literal, Optional, Tuple, Union |
| 10 | +from typing import Any, Dict, List, Literal, Optional |
14 | 11 |
|
15 | 12 | import litellm |
16 | | -from litellm import acompletion, completion |
| 13 | +from litellm import acompletion |
| 14 | +from litellm.types.utils import ModelResponse |
| 15 | +from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper |
17 | 16 | from litellm.caching.caching import Cache |
18 | 17 | from litellm.caching.dual_cache import DualCache |
19 | 18 | from litellm.caching.in_memory_cache import InMemoryCache |
@@ -194,7 +193,20 @@ async def _make_llm_call(self, messages: List[Dict[str, Any]], tools: List[Dict[ |
194 | 193 | request_params["tools"] = tools |
195 | 194 |
|
196 | 195 | try: |
197 | | - response = await acompletion(model=self.model_id, **request_params) |
| 196 | + if request_params.get("stream") is True: |
| 197 | + chunks = [] |
| 198 | + stream = await acompletion(model=self.model_id, **request_params) |
| 199 | + |
| 200 | + assert isinstance(stream, CustomStreamWrapper), "Stream should be a CustomStreamWrapper" |
| 201 | + |
| 202 | + async for chunk in stream: # pyright: ignore[reportGeneralTypeIssues] |
| 203 | + chunks.append(chunk) |
| 204 | + response = litellm.stream_chunk_builder(chunks, messages) |
| 205 | + else: |
| 206 | + response = await acompletion(model=self.model_id, **request_params) |
| 207 | + |
| 208 | + assert response is not None, "Response is None" |
| 209 | + assert isinstance(response, ModelResponse), "Response should be ModelResponse" |
198 | 210 |
|
199 | 211 | # Log cache hit/miss for monitoring |
200 | 212 | hidden = getattr(response, "_hidden_params", {}) |
|
0 commit comments