|
18 | 18 |
|
19 | 19 | import json |
20 | 20 | import logging |
21 | | -from typing import Any, Dict, List, Optional |
| 21 | +from typing import Any, Dict, List, Literal, Optional |
22 | 22 |
|
23 | | -from litellm import Router |
| 23 | +from litellm import Router # type: ignore[attr-defined] |
24 | 24 | from litellm.utils import get_model_info |
25 | 25 |
|
26 | 26 | from dembrane.settings import LLMProviderConfig, get_settings |
|
34 | 34 | ROUTER_NUM_RETRIES = 3 |
35 | 35 | ROUTER_ALLOWED_FAILS = 3 # Failures per minute before cooldown |
36 | 36 | ROUTER_COOLDOWN_TIME = 60 # Seconds to cooldown a failed deployment |
37 | | -ROUTER_ROUTING_STRATEGY = "simple-shuffle" # Recommended for production |
| 37 | +ROUTER_ROUTING_STRATEGY: Literal["simple-shuffle"] = "simple-shuffle" # Recommended for production |
38 | 38 |
|
39 | 39 | # Global router instance (lazy initialized) |
40 | 40 | _router: Optional[Router] = None |
@@ -261,9 +261,11 @@ def get_min_context_length(model_group: str) -> int: |
261 | 261 | try: |
262 | 262 | resolved = config.resolve() |
263 | 263 | model_info = get_model_info(resolved.model) |
264 | | - if model_info and model_info.get("max_input_tokens"): |
265 | | - max_tokens = model_info["max_input_tokens"] |
266 | | - if min_tokens is None or max_tokens < min_tokens: |
| 264 | + max_tokens = model_info.get("max_input_tokens") if model_info else None |
| 265 | + if isinstance(max_tokens, int) and max_tokens > 0: |
| 266 | + if min_tokens is None: |
| 267 | + min_tokens = max_tokens |
| 268 | + elif max_tokens < min_tokens: |
267 | 269 | min_tokens = max_tokens |
268 | 270 | logger.debug(f" {model_group}[{suffix}] {resolved.model}: {max_tokens} tokens") |
269 | 271 | except Exception as e: |
|
0 commit comments