forked from bayuzen19/dtsense-api
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
95 lines (81 loc) · 3 KB
/
app.py
File metadata and controls
95 lines (81 loc) · 3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import os
import logging
from typing import Optional
from src.document_pipeline import DocumentPipeline
# Set up logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
# Define request/response models with Pydantic
class QueryRequest(BaseModel):
question: str = Field(..., min_length=1, description="The question to ask the medical chatbot")
class QueryResponse(BaseModel):
answer: str
class ErrorResponse(BaseModel):
error: str
detail: Optional[str] = None
# Initialize FastAPI app
app = FastAPI(
title="Medical Chatbot API",
description="A FastAPI application serving a medical chatbot interface",
version="1.0.0"
)
# Enable CORS for all origins (adjust in production)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, replace with specific origins
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Mount static files
app.mount("/static", StaticFiles(directory="static"), name="static")
# Templates directory
templates = Jinja2Templates(directory="template")
# Initialize DocumentPipeline
try:
PDF_DIR = os.getenv("PDF_DIR", "Data")
INDEX_NAME = os.getenv("PINECONE_INDEX_NAME", "medicalbot")
logger.info(f"Initializing DocumentPipeline with PDF_DIR={PDF_DIR}, INDEX_NAME={INDEX_NAME}")
pipeline = DocumentPipeline(pdf_dir=PDF_DIR, index_name=INDEX_NAME)
pipeline.load_vectore_store()
pipeline.create_retrieval_chain()
logger.info("DocumentPipeline initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize DocumentPipeline: {str(e)}")
raise
@app.get("/", response_class=HTMLResponse)
async def get_chat(request: Request):
"""Serve the chat UI page."""
return templates.TemplateResponse("chat.html", {"request": request})
@app.post("/query", response_model=QueryResponse, responses={
400: {"model": ErrorResponse},
500: {"model": ErrorResponse}
})
async def query(request: QueryRequest):
"""
Process a question and return the answer from the medical chatbot.
"""
try:
logger.info(f"Received query: {request.question}")
answer = pipeline.query(request.question)
logger.info(f"Query processed successfully")
return QueryResponse(answer=answer)
except Exception as e:
logger.error(f"Error processing query: {str(e)}")
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {"status": "healthy"}
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)