Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 128 additions & 47 deletions backend/app/controllers/knowledge_base_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,39 +12,61 @@
delete_knowledge_base_file,
get_file_chunks,
)
from app.utils.response_formatter import ResponseFormatter


async def create_collection_controller():
async def create_collection_controller() -> JSONResponse:
"""
Create the knowledge base collection if it doesn't exist.
Returns:
JSONResponse: Standardized response for collection creation.
"""
try:
result = create_knowledge_base_collection_if_not_exists()
if result["status"] == "success":
return JSONResponse(content=result, status_code=200)
return ResponseFormatter.success_response(
data=None,
message=result["message"]
)
else:
return JSONResponse(content=result, status_code=500)
return ResponseFormatter.error_response(
message=result["message"],
status_code=500
)
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Error creating collection: {str(e)}"
return ResponseFormatter.error_response(
message=f"Failed to create knowledge base collection: {str(e)}",
status_code=500
)


async def get_files_controller():
async def get_files_controller() -> JSONResponse:
"""
Get all files in the knowledge base.
Returns:
JSONResponse: Standardized response with knowledge base files.
"""
try:
result = get_knowledge_base_files()
if result["status"] == "success":
return JSONResponse(content=result, status_code=200)
files = result.get("files", [])
return ResponseFormatter.success_response(
data=files,
message=f"Retrieved {len(files)} files from knowledge base"
)
else:
return JSONResponse(content=result, status_code=500)
return ResponseFormatter.error_response(
message="Failed to retrieve knowledge base files",
status_code=500
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error getting files: {str(e)}")
return ResponseFormatter.error_response(
message=f"Failed to get knowledge base files: {str(e)}",
status_code=500
)


async def upload_pdf_controller(file: UploadFile = File(...)):
async def upload_pdf_controller(file: UploadFile = File(...)) -> JSONResponse:
"""
Upload a PDF file to the knowledge base.
"""
Expand All @@ -66,14 +88,26 @@ async def upload_pdf_controller(file: UploadFile = File(...)):
result = upload_pdf_file(temp_file_path, file_name)

if result["status"] == "success":
return JSONResponse(content=result, status_code=200)
return ResponseFormatter.success_response(
data={"file_name": file_name},
message=result["message"]
)
else:
return JSONResponse(content=result, status_code=500)
return ResponseFormatter.error_response(
message=result["message"],
status_code=500
)

except HTTPException:
raise
except HTTPException as e:
return ResponseFormatter.error_response(
message=str(e.detail),
status_code=e.status_code
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error uploading PDF: {str(e)}")
return ResponseFormatter.error_response(
message="Failed to upload PDF file",
status_code=500
)
finally:
# Clean up temporary file
if temp_file_path and os.path.exists(temp_file_path):
Expand All @@ -83,9 +117,11 @@ async def upload_pdf_controller(file: UploadFile = File(...)):
pass # Ignore cleanup errors


async def upload_text_controller(file: UploadFile = File(...)):
async def upload_text_controller(file: UploadFile = File(...)) -> JSONResponse:
"""
Upload a text file to the knowledge base.
Returns:
JSONResponse: Standardized response for text file upload.
"""
temp_file_path = None
try:
Expand All @@ -105,15 +141,25 @@ async def upload_text_controller(file: UploadFile = File(...)):
result = upload_text_file(temp_file_path, file_name)

if result["status"] == "success":
return JSONResponse(content=result, status_code=200)
return ResponseFormatter.success_response(
data={"file_name": file_name},
message=result["message"]
)
else:
return JSONResponse(content=result, status_code=500)
return ResponseFormatter.error_response(
message=result["message"],
status_code=500
)

except HTTPException:
raise
except HTTPException as e:
return ResponseFormatter.error_response(
message=str(e.detail),
status_code=e.status_code
)
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Error uploading text file: {str(e)}"
return ResponseFormatter.error_response(
message="Failed to upload text file",
status_code=500
)
finally:
# Clean up temporary file
Expand All @@ -124,67 +170,102 @@ async def upload_text_controller(file: UploadFile = File(...)):
pass # Ignore cleanup errors


async def delete_file_controller(file_name: str):
async def delete_file_controller(file_name: str) -> JSONResponse:
"""
Delete a file from the knowledge base.
Returns:
JSONResponse: Standardized response for file deletion.
"""
try:
if not file_name:
raise HTTPException(status_code=400, detail="File name is required")
return ResponseFormatter.error_response(
message="File name is required",
status_code=400
)

result = delete_knowledge_base_file(file_name)
if result["status"] == "success":
return JSONResponse(content=result, status_code=200)
return ResponseFormatter.success_response(
data={"file_name": file_name},
message=result["message"]
)
else:
return JSONResponse(content=result, status_code=500)
except HTTPException:
raise
return ResponseFormatter.error_response(
message=result["message"],
status_code=500
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error deleting file: {str(e)}")
return ResponseFormatter.error_response(
message="Failed to delete file from knowledge base",
status_code=500
)


async def get_similar_controller(query: str, limit: int = 10):
async def get_similar_controller(query: str, limit: int = 10) -> JSONResponse:
"""
Get similar chunks from the knowledge base.
Returns:
JSONResponse: Standardized response with similar chunks.
"""
try:
if not query:
raise HTTPException(status_code=400, detail="Query is required")
return ResponseFormatter.error_response(
message="Query is required",
status_code=400
)

if limit <= 0 or limit > 100:
raise HTTPException(
status_code=400, detail="Limit must be between 1 and 100"
return ResponseFormatter.error_response(
message="Limit must be between 1 and 100",
status_code=400
)

result = get_similar_chunks(query=query, limit=limit)
if result["status"] == "success":
return JSONResponse(content=result, status_code=200)
results = result.get("results", [])
return ResponseFormatter.success_response(
data=results,
message=f"Found {len(results)} similar chunks for query"
)
else:
return JSONResponse(content=result, status_code=500)
except HTTPException:
raise
return ResponseFormatter.error_response(
message=result["message"],
status_code=500
)
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Error getting similar chunks: {str(e)}"
return ResponseFormatter.error_response(
message="Failed to perform similarity search",
status_code=500
)


async def get_file_chunks_controller(file_name: str):
async def get_file_chunks_controller(file_name: str) -> JSONResponse:
"""
Get all chunks for a specific file.
Returns:
JSONResponse: Standardized response with file chunks.
"""
try:
if not file_name:
raise HTTPException(status_code=400, detail="File name is required")
return ResponseFormatter.error_response(
message="File name is required",
status_code=400
)

result = get_file_chunks(file_name)
if result["status"] == "success":
return JSONResponse(content=result, status_code=200)
chunks = result.get("chunks", [])
return ResponseFormatter.success_response(
data=chunks,
message=f"Retrieved {len(chunks)} chunks for file {file_name}"
)
else:
return JSONResponse(content=result, status_code=500)
except HTTPException:
raise
return ResponseFormatter.error_response(
message=result["message"],
status_code=500
)
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Error getting file chunks: {str(e)}"
return ResponseFormatter.error_response(
message="Failed to get file chunks",
status_code=500
)
Loading