22FastAPI-based API interface for the MultiModelWrapper.
33"""
44
5+ import logging
56from fastapi import FastAPI , HTTPException
67from pydantic import BaseModel , Field
78from typing import List , Dict , Optional , Union
89import asyncio
10+ import json
11+ from typing import Tuple , Any
12+ from functools import lru_cache
913from ..models .factory import ModelFactory
1014from ..models .multi_model import MultiModelWrapper
1115
1216app = FastAPI (title = "Multi-Model API" )
17+ logger = logging .getLogger (__name__ )
18+
19+ # Reuse a single factory across requests to avoid re-loading env / re-allocating caches.
20+ _MODEL_FACTORY = ModelFactory ()
21+
22+ # Cache MultiModelWrapper instances by request parameters.
23+ # Note: wrapper init can be expensive because it initializes provider model instances.
24+ _WRAPPER_CACHE : Dict [Tuple [str , Tuple [str , ...], str ], MultiModelWrapper ] = {}
25+ _WRAPPER_LOCKS : Dict [Tuple [str , Tuple [str , ...], str ], asyncio .Lock ] = {}
26+
27+
28+ def _weights_key (model_weights : Optional [Dict [str , float ]]) -> str :
29+ # Stable string key for dict weights (used for caching).
30+ return json .dumps (model_weights or {}, sort_keys = True , default = str )
31+
32+
33+ async def _get_multi_model (
34+ * ,
35+ primary_model : str ,
36+ fallback_models : List [str ],
37+ model_weights : Optional [Dict [str , float ]],
38+ ) -> MultiModelWrapper :
39+ fallback_tuple = tuple (fallback_models or [])
40+ key = (primary_model , fallback_tuple , _weights_key (model_weights ))
41+
42+ if key in _WRAPPER_CACHE :
43+ return _WRAPPER_CACHE [key ]
44+
45+ lock = _WRAPPER_LOCKS .setdefault (key , asyncio .Lock ())
46+ async with lock :
47+ if key in _WRAPPER_CACHE :
48+ return _WRAPPER_CACHE [key ]
49+
50+ wrapper = MultiModelWrapper (
51+ model_factory = _MODEL_FACTORY ,
52+ primary_model = primary_model ,
53+ fallback_models = list (fallback_tuple ),
54+ model_weights = model_weights ,
55+ )
56+ _WRAPPER_CACHE [key ] = wrapper
57+ return wrapper
1358
1459class GenerateRequest (BaseModel ):
1560 prompt : str
@@ -37,12 +82,10 @@ class EmbeddingsRequest(BaseModel):
3782async def generate (request : GenerateRequest ):
3883 """Generate text using the multi-model wrapper."""
3984 try :
40- factory = ModelFactory ()
41- multi_model = MultiModelWrapper (
42- model_factory = factory ,
85+ multi_model = await _get_multi_model (
4386 primary_model = request .primary_model ,
4487 fallback_models = request .fallback_models ,
45- model_weights = request .model_weights
88+ model_weights = request .model_weights ,
4689 )
4790
4891 response = await multi_model .generate (
@@ -52,18 +95,17 @@ async def generate(request: GenerateRequest):
5295 )
5396 return {"response" : response }
5497 except Exception as e :
55- raise HTTPException (status_code = 500 , detail = str (e ))
98+ logger .exception ("Unhandled error in /generate" )
99+ raise HTTPException (status_code = 500 , detail = "Internal server error" )
56100
57101@app .post ("/chat" )
58102async def chat (request : ChatRequest ):
59103 """Generate chat completion using the multi-model wrapper."""
60104 try :
61- factory = ModelFactory ()
62- multi_model = MultiModelWrapper (
63- model_factory = factory ,
105+ multi_model = await _get_multi_model (
64106 primary_model = request .primary_model ,
65107 fallback_models = request .fallback_models ,
66- model_weights = request .model_weights
108+ model_weights = request .model_weights ,
67109 )
68110
69111 response = await multi_model .chat (
@@ -73,24 +115,24 @@ async def chat(request: ChatRequest):
73115 )
74116 return {"response" : response }
75117 except Exception as e :
76- raise HTTPException (status_code = 500 , detail = str (e ))
118+ logger .exception ("Unhandled error in /chat" )
119+ raise HTTPException (status_code = 500 , detail = "Internal server error" )
77120
78121@app .post ("/embeddings" )
79122async def embeddings (request : EmbeddingsRequest ):
80123 """Generate embeddings using the multi-model wrapper."""
81124 try :
82- factory = ModelFactory ()
83- multi_model = MultiModelWrapper (
84- model_factory = factory ,
125+ multi_model = await _get_multi_model (
85126 primary_model = request .primary_model ,
86127 fallback_models = request .fallback_models ,
87- model_weights = request .model_weights
128+ model_weights = request .model_weights ,
88129 )
89130
90131 embeddings = await multi_model .embeddings (request .text )
91132 return {"embeddings" : embeddings }
92133 except Exception as e :
93- raise HTTPException (status_code = 500 , detail = str (e ))
134+ logger .exception ("Unhandled error in /embeddings" )
135+ raise HTTPException (status_code = 500 , detail = "Internal server error" )
94136
95137@app .get ("/health" )
96138async def health_check ():
0 commit comments