diff --git a/backend/main.py b/backend/main.py index 6931e08..0b6ad90 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,3 +1,16 @@ +from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Depends, BackgroundTasks, Request, Query +from fastapi.responses import JSONResponse +from fastapi.middleware.cors import CORSMiddleware +from fastapi.middleware.gzip import GZipMiddleware +from sqlalchemy.orm import Session +from sqlalchemy import func +from contextlib import asynccontextmanager +from functools import lru_cache +from typing import List, Optional, Any +from PIL import Image +import mimetypes + +import json import os import sys from pathlib import Path @@ -18,9 +31,32 @@ from contextlib import asynccontextmanager import httpx import logging -import asyncio +import time +from pywebpush import webpush, WebPushException +try: + import magic +except ImportError: + magic = None +import httpx +from async_lru import alru_cache -from backend.database import Base, engine +from backend.cache import recent_issues_cache, user_upload_cache +from backend.database import engine, Base, SessionLocal, get_db +from backend.models import Issue, PushSubscription, Grievance, EscalationAudit, Jurisdiction +from backend.schemas import ( + IssueResponse, IssueSummaryResponse, IssueCreateRequest, IssueCreateResponse, ChatRequest, ChatResponse, + VoteRequest, VoteResponse, DetectionResponse, UrgencyAnalysisRequest, + UrgencyAnalysisResponse, HealthResponse, MLStatusResponse, ResponsibilityMapResponse, + ErrorResponse, SuccessResponse, StatsResponse, IssueCategory, IssueStatus, + IssueStatusUpdateRequest, IssueStatusUpdateResponse, + PushSubscriptionRequest, PushSubscriptionResponse, + NearbyIssueResponse, DeduplicationCheckResponse, IssueCreateWithDeduplicationResponse, + LeaderboardResponse, LeaderboardEntry, + EscalationAuditResponse, GrievanceSummaryResponse, EscalationStatsResponse +) +from backend.exceptions import EXCEPTION_HANDLERS +from backend.database import Base, engine, get_db, SessionLocal +from backend.models import Issue from backend.ai_factory import create_all_ai_services from backend.ai_interfaces import initialize_ai_services from backend.bot import start_bot_thread, stop_bot_thread @@ -38,30 +74,183 @@ ) logger = logging.getLogger(__name__) -async def background_initialization(app: FastAPI): - """Perform non-critical startup tasks in background to speed up app availability""" +if magic is None: + logger.warning( + "python-magic is not available; falling back to content_type and " + "mimetypes-based detection for uploads." + ) + +# Shared HTTP Client for cached functions +SHARED_HTTP_CLIENT: Optional[httpx.AsyncClient] = None + +# File upload validation constants +MAX_FILE_SIZE = 20 * 1024 * 1024 # 20MB (increased for better user experience) +ALLOWED_MIME_TYPES = { + 'image/jpeg', + 'image/png', + 'image/gif', + 'image/webp', + 'image/bmp', + 'image/tiff' +} + +# User upload limits +UPLOAD_LIMIT_PER_USER = 5 # max uploads per user per hour +UPLOAD_LIMIT_PER_IP = 10 # max uploads per IP per hour + +# Image processing cache to avoid duplicate API calls +# Replaced custom cache with async_lru for better performance and memory management + +def check_upload_limits(identifier: str, limit: int) -> None: + """ + Check if the user/IP has exceeded upload limits using thread-safe cache. + """ + current_uploads = user_upload_cache.get(identifier) or [] + now = datetime.now() + one_hour_ago = now - timedelta(hours=1) + + # Filter out old timestamps (older than 1 hour) + recent_uploads = [ts for ts in current_uploads if ts > one_hour_ago] + + if len(recent_uploads) >= limit: + raise HTTPException( + status_code=429, + detail=f"Upload limit exceeded. Maximum {limit} uploads per hour allowed." + ) + + # Add current timestamp and update cache atomically + recent_uploads.append(now) + user_upload_cache.set(recent_uploads, identifier) + +def _validate_uploaded_file_sync(file: UploadFile) -> None: + """ + Synchronous validation logic to be run in a threadpool. + + Security measures: + - File size limits + - MIME type validation using content detection + - Image content validation using PIL + - TODO: Add virus/malware scanning (consider integrating ClamAV or similar) + """ + # Check file size + file.file.seek(0, 2) # Seek to end + file_size = file.file.tell() + file.file.seek(0) # Reset to beginning + + if file_size > MAX_FILE_SIZE: + raise HTTPException( + status_code=413, + detail=f"File too large. Maximum size allowed is {MAX_FILE_SIZE // (1024*1024)}MB" + ) + + # Check MIME type from content using python-magic when available + detected_mime: Optional[str] = None + try: + if magic is not None: + # Read first 1024 bytes for MIME detection + file_content = file.file.read(1024) + file.file.seek(0) # Reset file pointer + detected_mime = magic.from_buffer(file_content, mime=True) + except Exception as mime_error: + logger.warning( + f"MIME detection via python-magic failed for {file.filename}: {mime_error}. " + "Falling back to content_type/mimetypes.", + exc_info=True, + ) + file.file.seek(0) + + if not detected_mime: + # Fallback: trust FastAPI's content_type header or guess from filename + detected_mime = file.content_type or mimetypes.guess_type(file.filename or "")[0] + + if not detected_mime: + raise HTTPException( + status_code=400, + detail="Unable to detect file type. Only image files are allowed." + ) + + if detected_mime not in ALLOWED_MIME_TYPES: + raise HTTPException( + status_code=400, + detail=f"Invalid file type. Only image files are allowed. Detected: {detected_mime}" + ) + + # Additional content validation: Try to open with PIL to ensure it's a valid image try: - # 1. AI Services initialization - # These can take a few seconds due to imports and configuration - action_plan_service, chat_service, mla_summary_service = await run_in_threadpool(create_all_ai_services) - - initialize_ai_services( - action_plan_service=action_plan_service, - chat_service=chat_service, - mla_summary_service=mla_summary_service + img = Image.open(file.file) + img.verify() # Verify the image is not corrupted + file.file.seek(0) # Reset after PIL operations + + # Resize large images for better performance + img = Image.open(file.file) + if img.width > 1024 or img.height > 1024: + # Calculate new size maintaining aspect ratio + ratio = min(1024 / img.width, 1024 / img.height) + new_width = int(img.width * ratio) + new_height = int(img.height * ratio) + + img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) + + # Save resized image back to file + output = io.BytesIO() + img.save(output, format=img.format or 'JPEG', quality=85) + + # Record the correct size before rewinding buffer + resized_size = output.tell() + output.seek(0) + + # Replace file content + file.file = output + file.size = resized_size + + except Exception as pil_error: + logger.error(f"PIL validation failed for {file.filename}: {pil_error}") + raise HTTPException( + status_code=400, + detail="Invalid image file. The file appears to be corrupted or not a valid image." ) - logger.info("AI services initialized successfully.") - # 2. Static data pre-loading (loads large JSONs into memory) - await run_in_threadpool(load_maharashtra_pincode_data) - await run_in_threadpool(load_maharashtra_mla_data) - logger.info("Maharashtra data pre-loaded successfully.") +async def validate_uploaded_file(file: UploadFile) -> None: + """ + Validate uploaded file for security and safety (async wrapper). + + Args: + file: The uploaded file to validate + + Raises: + HTTPException: If validation fails + """ + await run_in_threadpool(_validate_uploaded_file_sync, file) + +async def process_and_detect(image: UploadFile, detection_func) -> DetectionResponse: + """ + Helper to process uploaded image and run detection. + Handles validation, loading, and error handling. + """ + # Validate uploaded file + await validate_uploaded_file(image) - # 3. Start Telegram Bot in separate thread - await run_in_threadpool(start_bot_thread) - logger.info("Telegram bot started in separate thread.") + # Convert to PIL Image directly from file object to save memory + try: + pil_image = await run_in_threadpool(Image.open, image.file) + # Validate image for processing + await run_in_threadpool(validate_image_for_processing, pil_image) + except HTTPException: + raise # Re-raise HTTP exceptions from validation except Exception as e: - logger.error(f"Error during background initialization: {e}", exc_info=True) + logger.error(f"Invalid image file during processing: {e}", exc_info=True) + raise HTTPException(status_code=400, detail="Invalid image file") + + # Run detection + try: + detections = await detection_func(pil_image) + return DetectionResponse(detections=detections) + except Exception as e: + logger.error(f"Detection error: {e}", exc_info=True) + raise HTTPException(status_code=500, detail="Detection service temporarily unavailable") + +# Create tables if they don't exist +Base.metadata.create_all(bind=engine) @asynccontextmanager async def lifespan(app: FastAPI):