-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
87 lines (67 loc) · 2.5 KB
/
main.py
File metadata and controls
87 lines (67 loc) · 2.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import uvicorn
from fastapi import FastAPI, HTTPException, File, UploadFile, Form
from fastapi.middleware.cors import CORSMiddleware
from models.prompt import PromptRequest, PromptResponse
from models.rag import RAGRequest, RAGResponse
from handlers.model_type_handler import *
from rag.llm_setup import save_embeddings, prepare_prompt
from loguru import logger
logger.add("logs/debug_logs.log")
app = FastAPI()
# Middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def process_model_response(model_type: ModelType, prompt: str) -> str:
if model_type == ModelType.llama3:
return process_with_llama3(prompt)
elif model_type == ModelType.phi3:
return process_with_phi3(prompt)
elif model_type == ModelType.gemini:
return process_with_gemini(prompt)
elif model_type == ModelType.gemma2b:
return process_with_gemma2b(prompt)
elif model_type == ModelType.mixtral:
return process_with_mistral(prompt)
else:
raise HTTPException(status_code=404, detail="Invalid Model Type")
@app.post("/api/v1/prompt", response_model=PromptResponse)
async def process_prompt(request: PromptRequest):
if not request.prompt:
raise HTTPException(status_code=404, detail="Prompt cannot be empty")
model_type = request.model_type
logger.debug(f"Model Type {model_type}")
prompt = request.prompt
response = process_model_response(model_type, prompt)
return PromptResponse(response=response)
@app.post("/api/v1/rag", response_model=RAGResponse)
async def process_rag(
model_type: ModelType = Form(...),
query: str = Form(...),
file: UploadFile = File(...)
):
if not query:
raise HTTPException(status_code=404, detail="Query cannot be empty")
filepath = f"tmp/{file.filename}"
os.makedirs("tmp", exist_ok=True)
with open(filepath, "wb") as f:
f.write(await file.read())
file_type = file.filename.split(".")[-1]
try:
save_embeddings(filepath, file_type)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
finally:
os.remove(filepath)
try:
prompt = prepare_prompt(query)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
response = process_model_response(model_type, prompt)
return RAGResponse(response=response)
if __name__ == "__main__":
uvicorn.run(app, host='0.0.0.0', port=8000)