diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index 254312af..f132e2b2 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -25,6 +25,7 @@ jobs: python -m pip install --upgrade pip cd backend pip install -r requirements.txt + pip install -r requirements-dev.txt - name: Run Backend Tests with Coverage run: | diff --git a/backend/app/auth/__init__.py b/backend/app/auth/__init__.py index 6f00fa6e..d1a732da 100644 --- a/backend/app/auth/__init__.py +++ b/backend/app/auth/__init__.py @@ -2,11 +2,9 @@ from .routes import router from .schemas import UserResponse from .security import create_access_token, verify_token -from .service import auth_service __all__ = [ "router", - "auth_service", "verify_token", "create_access_token", "UserResponse", diff --git a/backend/app/auth/routes.py b/backend/app/auth/routes.py index 066e16f3..4ac9f086 100644 --- a/backend/app/auth/routes.py +++ b/backend/app/auth/routes.py @@ -14,9 +14,10 @@ TokenVerifyRequest, UserResponse, ) -from app.auth.security import create_access_token, oauth2_scheme # Import oauth2_scheme -from app.auth.service import auth_service +from app.auth.security import create_access_token, oauth2_scheme +from app.auth.service import AuthService from app.config import settings +from app.dependencies import get_auth_service from fastapi import APIRouter, Depends, HTTPException, status from fastapi.security import ( # Import OAuth2PasswordRequestForm OAuth2PasswordRequestForm, @@ -28,7 +29,10 @@ @router.post( "/token", response_model=TokenResponse, include_in_schema=False ) # include_in_schema=False to hide from docs if desired, or True to show -async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()): +async def login_for_access_token( + form_data: OAuth2PasswordRequestForm = Depends(), + auth_service: AuthService = Depends(get_auth_service), +): """ OAuth2 compatible token login, get an access token for future requests. This endpoint is used by Swagger UI for authorization. @@ -59,7 +63,9 @@ async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends( @router.post("/signup/email", response_model=AuthResponse) -async def signup_with_email(request: EmailSignupRequest): +async def signup_with_email( + request: EmailSignupRequest, auth_service: AuthService = Depends(get_auth_service) +): """ Registers a new user using email, password, and name, and returns authentication tokens and user information. @@ -101,7 +107,9 @@ async def signup_with_email(request: EmailSignupRequest): @router.post("/login/email", response_model=AuthResponse) -async def login_with_email(request: EmailLoginRequest): +async def login_with_email( + request: EmailLoginRequest, auth_service: AuthService = Depends(get_auth_service) +): """ Authenticates a user using email and password credentials. @@ -136,7 +144,9 @@ async def login_with_email(request: EmailLoginRequest): @router.post("/login/google", response_model=AuthResponse) -async def login_with_google(request: GoogleLoginRequest): +async def login_with_google( + request: GoogleLoginRequest, auth_service: AuthService = Depends(get_auth_service) +): """ Authenticates or registers a user using a Google OAuth ID token. @@ -169,7 +179,9 @@ async def login_with_google(request: GoogleLoginRequest): @router.post("/refresh", response_model=TokenResponse) -async def refresh_token(request: RefreshTokenRequest): +async def refresh_token( + request: RefreshTokenRequest, auth_service: AuthService = Depends(get_auth_service) +): """ Refreshes JWT tokens using a valid refresh token. @@ -213,7 +225,9 @@ async def refresh_token(request: RefreshTokenRequest): @router.post("/token/verify", response_model=UserResponse) -async def verify_token(request: TokenVerifyRequest): +async def verify_token( + request: TokenVerifyRequest, auth_service: AuthService = Depends(get_auth_service) +): """ Verifies an access token and returns the associated user information. @@ -236,7 +250,9 @@ async def verify_token(request: TokenVerifyRequest): @router.post("/password/reset/request", response_model=SuccessResponse) -async def request_password_reset(request: PasswordResetRequest): +async def request_password_reset( + request: PasswordResetRequest, auth_service: AuthService = Depends(get_auth_service) +): """ Initiates a password reset process by sending a reset link to the provided email address. @@ -256,7 +272,9 @@ async def request_password_reset(request: PasswordResetRequest): @router.post("/password/reset/confirm", response_model=SuccessResponse) -async def confirm_password_reset(request: PasswordResetConfirm): +async def confirm_password_reset( + request: PasswordResetConfirm, auth_service: AuthService = Depends(get_auth_service) +): """ Resets a user's password using a valid password reset token. diff --git a/backend/app/auth/service.py b/backend/app/auth/service.py index eebe63d6..d3c112d5 100644 --- a/backend/app/auth/service.py +++ b/backend/app/auth/service.py @@ -55,30 +55,16 @@ }, ) logger.info("Firebase initialized with credentials from environment variables") - # Fall back to service account JSON file if env vars are not available - elif os.path.exists(settings.firebase_service_account_path): - cred = credentials.Certificate(settings.firebase_service_account_path) - firebase_admin.initialize_app( - cred, - { - "projectId": settings.firebase_project_id, - }, - ) - logger.info("Firebase initialized with service account file") else: logger.warning("Firebase service account not found. Google auth will not work.") class AuthService: def __init__(self): - # Initializes the AuthService instance. - pass - - def get_db(self): - """ - Returns a database connection instance from the application's database module. - """ - return get_database() + self.db = get_database() + self.users_collection = self.db["users"] + self.refresh_tokens_collection = self.db["refresh_tokens"] + self.password_resets_collection = self.db["password_resets"] async def create_user_with_email( self, email: str, password: str, name: str @@ -99,10 +85,8 @@ async def create_user_with_email( Raises: HTTPException: If a user with the given email already exists. """ - db = self.get_db() - # Check if user already exists - existing_user = await db.users.find_one({"email": email}) + existing_user = await self.users_collection.find_one({"email": email}) if existing_user: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -122,7 +106,7 @@ async def create_user_with_email( } try: - result = await db.users.insert_one(user_doc) + result = await self.users_collection.insert_one(user_doc) user_doc["_id"] = str(result.inserted_id) # Create refresh token @@ -154,9 +138,8 @@ async def authenticate_user_with_email( Returns: A dictionary containing the authenticated user and a new refresh token. """ - db = self.get_db() try: - user = await db.users.find_one({"email": email}) + user = await self.users_collection.find_one({"email": email}) except PyMongoError as e: logger.error(f"Database error during user lookup: {e}") raise HTTPException( @@ -215,11 +198,9 @@ async def authenticate_with_google(self, id_token: str) -> Dict[str, Any]: detail="Email not provided by Google", ) - db = self.get_db() - # Check if user exists try: - user = await db.users.find_one( + user = await self.users_collection.find_one( {"$or": [{"email": email}, {"firebase_uid": firebase_uid}]} ) except PyMongoError as e: @@ -238,7 +219,7 @@ async def authenticate_with_google(self, id_token: str) -> Dict[str, Any]: if update_data: try: - await db.users.update_one( + await self.users_collection.update_one( {"_id": user["_id"]}, {"$set": update_data} ) user.update(update_data) @@ -257,7 +238,7 @@ async def authenticate_with_google(self, id_token: str) -> Dict[str, Any]: "hashed_password": None, } try: - result = await db.users.insert_one(user_doc) + result = await self.users_collection.insert_one(user_doc) user_doc["_id"] = result.inserted_id user = user_doc except PyMongoError as e: @@ -303,11 +284,9 @@ async def refresh_access_token(self, refresh_token: str) -> str: Returns: A new refresh token string. """ - db = self.get_db() - # Find and validate refresh token try: - token_record = await db.refresh_tokens.find_one( + token_record = await self.refresh_tokens_collection.find_one( { "token": refresh_token, "revoked": False, @@ -329,7 +308,9 @@ async def refresh_access_token(self, refresh_token: str) -> str: # Get user try: - user = await db.users.find_one({"_id": token_record["user_id"]}) + user = await self.users_collection.find_one( + {"_id": token_record["user_id"]} + ) except PyMongoError as e: logger.error("Error while fetching user: %s", str(e)) raise HTTPException( @@ -355,7 +336,7 @@ async def refresh_access_token(self, refresh_token: str) -> str: # Revoke old token try: - await db.refresh_tokens.update_one( + await self.refresh_tokens_collection.update_one( {"_id": token_record["_id"]}, {"$set": {"revoked": True}} ) except PyMongoError as e: @@ -393,10 +374,8 @@ async def verify_access_token(self, token: str) -> Dict[str, Any]: status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token" ) - db = self.get_db() - try: - user = await db.users.find_one({"_id": user_id}) + user = await self.users_collection.find_one({"_id": user_id}) except Exception as e: logger.error("Error while verifying token: %s", str(e)) raise HTTPException( @@ -417,10 +396,8 @@ async def request_password_reset(self, email: str) -> bool: If the user exists, generates a password reset token with a 1-hour expiration and stores it in the database. The reset token and link are logged for development purposes. Always returns True to avoid revealing whether the email is registered. """ - db = self.get_db() - try: - user = await db.users.find_one({"email": email}) + user = await self.users_collection.find_one({"email": email}) except PyMongoError as e: logger.error( f"Database error while fetching user by email {email}: {str(e)}" @@ -439,7 +416,7 @@ async def request_password_reset(self, email: str) -> bool: try: # Store reset token - await db.password_resets.insert_one( + await self.password_resets_collection.insert_one( { "user_id": user["_id"], "token": reset_token, @@ -481,11 +458,9 @@ async def confirm_password_reset(self, reset_token: str, new_password: str) -> b Raises: HTTPException: If the reset token is invalid or expired. """ - db = self.get_db() - try: # Find and validate reset token - reset_record = await db.password_resets.find_one( + reset_record = await self.password_resets_collection.find_one( { "token": reset_token, "used": False, @@ -502,18 +477,18 @@ async def confirm_password_reset(self, reset_token: str, new_password: str) -> b # Update user password new_hash = get_password_hash(new_password) - await db.users.update_one( + await self.users_collection.update_one( {"_id": reset_record["user_id"]}, {"$set": {"hashed_password": new_hash}}, ) # Mark token as used - await db.password_resets.update_one( + await self.password_resets_collection.update_one( {"_id": reset_record["_id"]}, {"$set": {"used": True}} ) # Revoke all refresh tokens for this user (force re-login) - await db.refresh_tokens.update_many( + await self.refresh_tokens_collection.update_many( {"user_id": reset_record["user_id"]}, {"$set": {"revoked": True}} ) logger.info( @@ -542,15 +517,13 @@ async def _create_refresh_token_record(self, user_id: str) -> str: Returns: The generated refresh token string. """ - db = self.get_db() - refresh_token = create_refresh_token() expires_at = datetime.now(timezone.utc) + timedelta( days=settings.refresh_token_expire_days ) try: - await db.refresh_tokens.insert_one( + await self.refresh_tokens_collection.insert_one( { "token": refresh_token, "user_id": ( @@ -571,7 +544,3 @@ async def _create_refresh_token_record(self, user_id: str) -> str: ) return refresh_token - - -# Create service instance -auth_service = AuthService() diff --git a/backend/app/config.py b/backend/app/config.py index 3ee6f809..5a522014 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -2,7 +2,7 @@ import os import time from logging.config import dictConfig -from typing import Optional +from typing import List, Optional from pydantic_settings import BaseSettings from starlette.middleware.base import BaseHTTPMiddleware @@ -15,13 +15,13 @@ class Settings(BaseSettings): database_name: str = "splitwiser" # JWT - secret_key: str = "your-super-secret-jwt-key-change-this-in-production" + secret_key: str algorithm: str = "HS256" access_token_expire_minutes: int = 15 refresh_token_expire_days: int = 30 + # Firebase firebase_project_id: Optional[str] = None - firebase_service_account_path: str = "./firebase-service-account.json" # Firebase service account credentials as environment variables firebase_type: Optional[str] = None firebase_private_key_id: Optional[str] = None @@ -37,9 +37,12 @@ class Settings(BaseSettings): debug: bool = False # CORS - Add your frontend domain here for production - allowed_origins: str = ( - "http://localhost:3000,http://localhost:5173,http://127.0.0.1:3000,http://localhost:8081" - ) + allowed_origins: List[str] = [ + "http://localhost:3000", + "http://localhost:5173", + "http://127.0.0.1:3000", + "http://localhost:8081", + ] allow_all_origins: bool = False class Config: diff --git a/backend/app/dependencies.py b/backend/app/dependencies.py index 38877af6..4851c853 100644 --- a/backend/app/dependencies.py +++ b/backend/app/dependencies.py @@ -57,3 +57,25 @@ async def get_current_user( detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) + + +from app.auth.service import AuthService +from app.expenses.service import ExpenseService +from app.groups.service import GroupService +from app.user.service import UserService + + +def get_user_service() -> UserService: + return UserService() + + +def get_expense_service() -> ExpenseService: + return ExpenseService() + + +def get_group_service() -> GroupService: + return GroupService() + + +def get_auth_service() -> AuthService: + return AuthService() diff --git a/backend/app/expenses/routes.py b/backend/app/expenses/routes.py index a21e4ca9..f548ce85 100644 --- a/backend/app/expenses/routes.py +++ b/backend/app/expenses/routes.py @@ -5,6 +5,7 @@ from app.auth.security import get_current_user from app.config import logger +from app.dependencies import get_expense_service from app.expenses.schemas import ( AttachmentUploadResponse, BalanceSummaryResponse, @@ -22,7 +23,7 @@ SettlementUpdateRequest, UserBalance, ) -from app.expenses.service import expense_service +from app.expenses.service import ExpenseService from fastapi import ( APIRouter, Depends, @@ -49,6 +50,7 @@ async def create_expense( group_id: str, expense_data: ExpenseCreateRequest, current_user: Dict[str, Any] = Depends(get_current_user), + expense_service: ExpenseService = Depends(get_expense_service), ): """Create a new expense within a group""" try: @@ -71,6 +73,7 @@ async def list_group_expenses( to_date: Optional[datetime] = Query(None, alias="to"), tags: Optional[str] = Query(None), current_user: Dict[str, Any] = Depends(get_current_user), + expense_service: ExpenseService = Depends(get_expense_service), ): """List all expenses for a group with pagination and filtering""" try: @@ -90,6 +93,7 @@ async def get_single_expense( group_id: str, expense_id: str, current_user: Dict[str, Any] = Depends(get_current_user), + expense_service: ExpenseService = Depends(get_expense_service), ): """Retrieve details for a single expense""" try: @@ -109,6 +113,7 @@ async def update_expense( expense_id: str, updates: ExpenseUpdateRequest, current_user: Dict[str, Any] = Depends(get_current_user), + expense_service: ExpenseService = Depends(get_expense_service), ): """Update an existing expense""" try: @@ -130,6 +135,7 @@ async def delete_expense( group_id: str, expense_id: str, current_user: Dict[str, Any] = Depends(get_current_user), + expense_service: ExpenseService = Depends(get_expense_service), ): """Delete an expense""" try: @@ -158,6 +164,7 @@ async def upload_attachment_for_expense( expense_id: str, file: UploadFile = File(...), current_user: Dict[str, Any] = Depends(get_current_user), + expense_service: ExpenseService = Depends(get_expense_service), ): """Upload attachment for an expense""" try: @@ -190,6 +197,7 @@ async def get_attachment( expense_id: str, key: str, current_user: Dict[str, Any] = Depends(get_current_user), + expense_service: ExpenseService = Depends(get_expense_service), ): """Get/download an attachment""" try: @@ -219,6 +227,7 @@ async def manually_record_payment( group_id: str, settlement_data: SettlementCreateRequest, current_user: Dict[str, Any] = Depends(get_current_user), + expense_service: ExpenseService = Depends(get_expense_service), ): """Manually record a payment settlement between users in a group""" try: @@ -242,6 +251,7 @@ async def get_group_settlements( "advanced", description="Settlement algorithm: 'normal' or 'advanced'" ), current_user: Dict[str, Any] = Depends(get_current_user), + expense_service: ExpenseService = Depends(get_expense_service), ): """Retrieve pending and optimized settlements for a group""" try: @@ -295,6 +305,7 @@ async def get_single_settlement( group_id: str, settlement_id: str, current_user: Dict[str, Any] = Depends(get_current_user), + expense_service: ExpenseService = Depends(get_expense_service), ): """Retrieve details for a single settlement""" try: @@ -314,6 +325,7 @@ async def mark_settlement_as_paid( settlement_id: str, updates: SettlementUpdateRequest, current_user: Dict[str, Any] = Depends(get_current_user), + expense_service: ExpenseService = Depends(get_expense_service), ): """Mark a settlement as paid""" try: @@ -332,6 +344,7 @@ async def delete_settlement( group_id: str, settlement_id: str, current_user: Dict[str, Any] = Depends(get_current_user), + expense_service: ExpenseService = Depends(get_expense_service), ): """Delete/undo a recorded settlement""" try: @@ -355,6 +368,7 @@ async def calculate_optimized_settlements( "advanced", description="Settlement algorithm: 'normal' or 'advanced'" ), current_user: Dict[str, Any] = Depends(get_current_user), + expense_service: ExpenseService = Depends(get_expense_service), ): """Calculate and return optimized (simplified) settlements for a group""" try: @@ -399,6 +413,7 @@ async def calculate_optimized_settlements( @balance_router.get("/friends-balance", response_model=FriendsBalanceResponse) async def get_cross_group_friend_balances( current_user: Dict[str, Any] = Depends(get_current_user), + expense_service: ExpenseService = Depends(get_expense_service), ): """Retrieve the current user's aggregated balances with all friends""" try: @@ -411,6 +426,7 @@ async def get_cross_group_friend_balances( @balance_router.get("/balance-summary", response_model=BalanceSummaryResponse) async def get_overall_user_balance_summary( current_user: Dict[str, Any] = Depends(get_current_user), + expense_service: ExpenseService = Depends(get_expense_service), ): """Retrieve an overall balance summary for the current user""" try: @@ -426,6 +442,7 @@ async def get_user_balance_in_specific_group( group_id: str, user_id: str, current_user: Dict[str, Any] = Depends(get_current_user), + expense_service: ExpenseService = Depends(get_expense_service), ): """Get a specific user's balance within a particular group""" try: @@ -449,6 +466,7 @@ async def group_expense_analytics( year: int = Query(...), month: Optional[int] = Query(None), current_user: Dict[str, Any] = Depends(get_current_user), + expense_service: ExpenseService = Depends(get_expense_service), ): """Provide expense analytics for a group""" try: diff --git a/backend/app/expenses/service.py b/backend/app/expenses/service.py index 80cbcf31..d0e5d7fb 100644 --- a/backend/app/expenses/service.py +++ b/backend/app/expenses/service.py @@ -3,7 +3,6 @@ from typing import Any, Dict, List, Optional from app.config import logger -from app.database import mongodb from app.expenses.schemas import ( ExpenseCreateRequest, ExpenseResponse, @@ -14,29 +13,17 @@ SettlementStatus, SplitType, ) +from app.services.base import BaseService from bson import ObjectId, errors from fastapi import HTTPException -class ExpenseService: +class ExpenseService(BaseService): def __init__(self): - pass - - @property - def expenses_collection(self): - return mongodb.database.expenses - - @property - def settlements_collection(self): - return mongodb.database.settlements - - @property - def groups_collection(self): - return mongodb.database.groups - - @property - def users_collection(self): - return mongodb.database.users + super().__init__("expenses") + self.settlements_collection = self.db["settlements"] + self.groups_collection = self.db["groups"] + self.users_collection = self.db["users"] async def create_expense( self, group_id: str, expense_data: ExpenseCreateRequest, user_id: str @@ -92,7 +79,7 @@ async def create_expense( } # Insert expense - await self.expenses_collection.insert_one(expense_doc) + await self.collection.insert_one(expense_doc) # Create settlements settlements = await self._create_settlements_for_expense( @@ -188,15 +175,12 @@ async def list_group_expenses( query["tags"] = {"$in": tags} # Get total count - total = await self.expenses_collection.count_documents(query) + total = await self.collection.count_documents(query) # Get expenses with pagination skip = (page - 1) * limit expenses_cursor = ( - self.expenses_collection.find(query) - .sort("createdAt", -1) - .skip(skip) - .limit(limit) + self.collection.find(query).sort("createdAt", -1).skip(skip).limit(limit) ) expenses_docs = await expenses_cursor.to_list(None) @@ -217,9 +201,7 @@ async def list_group_expenses( } }, ] - summary_result = await self.expenses_collection.aggregate(pipeline).to_list( - None - ) + summary_result = await self.collection.aggregate(pipeline).to_list(None) summary = ( summary_result[0] if summary_result @@ -269,7 +251,7 @@ async def get_expense_by_id( status_code=403, detail="You are not a member of this group" ) - expense_doc = await self.expenses_collection.find_one( + expense_doc = await self.collection.find_one( {"_id": expense_obj_id, "groupId": group_id} ) if not expense_doc: # Expense not found @@ -307,7 +289,7 @@ async def update_expense( raise HTTPException(status_code=400, detail="Invalid expense ID format") # Verify user access and that they created the expense - expense_doc = await self.expenses_collection.find_one( + expense_doc = await self.collection.find_one( {"_id": expense_obj_id, "groupId": group_id, "createdBy": user_id} ) if not expense_doc: # Expense not found or user not authorized @@ -379,7 +361,7 @@ async def update_expense( } # Update expense with both $set and $push operations - result = await self.expenses_collection.update_one( + result = await self.collection.update_one( {"_id": expense_obj_id}, {"$set": update_doc, "$push": {"history": history_entry}}, ) @@ -390,7 +372,7 @@ async def update_expense( ) else: # No actual changes, just update the timestamp - result = await self.expenses_collection.update_one( + result = await self.collection.update_one( {"_id": expense_obj_id}, {"$set": update_doc} ) @@ -408,7 +390,7 @@ async def update_expense( ) # Get updated expense - updated_expense = await self.expenses_collection.find_one( + updated_expense = await self.collection.find_one( {"_id": expense_obj_id} ) @@ -424,9 +406,7 @@ async def update_expense( # Continue anyway, as the expense update succeeded # Return updated expense - updated_expense = await self.expenses_collection.find_one( - {"_id": expense_obj_id} - ) + updated_expense = await self.collection.find_one({"_id": expense_obj_id}) if not updated_expense: raise HTTPException( status_code=500, detail="Failed to retrieve updated expense" @@ -456,7 +436,7 @@ async def delete_expense( """Delete an expense""" # Verify user access and that they created the expense - expense_doc = await self.expenses_collection.find_one( + expense_doc = await self.collection.find_one( {"_id": ObjectId(expense_id), "groupId": group_id, "createdBy": user_id} ) if not expense_doc: @@ -472,9 +452,7 @@ async def delete_expense( await self.settlements_collection.delete_many({"expenseId": expense_id}) # Delete the expense - result = await self.expenses_collection.delete_one( - {"_id": ObjectId(expense_id)} - ) + result = await self.collection.delete_one({"_id": ObjectId(expense_id)}) return result.deleted_count > 0 async def calculate_optimized_settlements( @@ -690,9 +668,7 @@ async def _get_group_summary( } }, ] - expense_result = await self.expenses_collection.aggregate(pipeline).to_list( - None - ) + expense_result = await self.collection.aggregate(pipeline).to_list(None) expense_stats = ( expense_result[0] if expense_result @@ -909,7 +885,7 @@ async def get_user_balance_in_group( # Get recent expenses where user was involved recent_expenses = ( - await self.expenses_collection.find( + await self.collection.find( { "groupId": group_id, "$or": [ @@ -1204,7 +1180,7 @@ async def get_group_analytics( period_str = f"{now.year}-{now.month:02d}" # Get expenses in the period - expenses = await self.expenses_collection.find( + expenses = await self.collection.find( {"groupId": group_id, "createdAt": {"$gte": start_date, "$lt": end_date}} ).to_list(None) @@ -1296,7 +1272,3 @@ async def get_group_analytics( "memberContributions": member_contributions, "expenseTrends": expense_trends, } - - -# Create service instance -expense_service = ExpenseService() diff --git a/backend/app/groups/routes.py b/backend/app/groups/routes.py index 4c36707b..ed421e20 100644 --- a/backend/app/groups/routes.py +++ b/backend/app/groups/routes.py @@ -1,6 +1,7 @@ from typing import Any, Dict, List from app.auth.security import get_current_user +from app.dependencies import get_group_service from app.groups.schemas import ( DeleteGroupResponse, GroupCreateRequest, @@ -14,7 +15,7 @@ MemberRoleUpdateRequest, RemoveMemberResponse, ) -from app.groups.service import group_service +from app.groups.service import GroupService from fastapi import APIRouter, Depends, HTTPException, status router = APIRouter(prefix="/groups", tags=["Groups"]) @@ -24,6 +25,7 @@ async def create_group( group_data: GroupCreateRequest, current_user: Dict[str, Any] = Depends(get_current_user), + group_service: GroupService = Depends(get_group_service), ): """Create a new group""" group = await group_service.create_group( @@ -35,7 +37,10 @@ async def create_group( @router.get("", response_model=GroupListResponse) -async def list_user_groups(current_user: Dict[str, Any] = Depends(get_current_user)): +async def list_user_groups( + current_user: Dict[str, Any] = Depends(get_current_user), + group_service: GroupService = Depends(get_group_service), +): """List all groups the current user belongs to""" groups = await group_service.get_user_groups(current_user["_id"]) return {"groups": groups} @@ -43,7 +48,9 @@ async def list_user_groups(current_user: Dict[str, Any] = Depends(get_current_us @router.get("/{group_id}", response_model=GroupResponse) async def get_group_details( - group_id: str, current_user: Dict[str, Any] = Depends(get_current_user) + group_id: str, + current_user: Dict[str, Any] = Depends(get_current_user), + group_service: GroupService = Depends(get_group_service), ): """Get group details including members""" group = await group_service.get_group_by_id(group_id, current_user["_id"]) @@ -57,6 +64,7 @@ async def update_group_metadata( group_id: str, updates: GroupUpdateRequest, current_user: Dict[str, Any] = Depends(get_current_user), + group_service: GroupService = Depends(get_group_service), ): """Update group metadata (admin only)""" update_data = updates.model_dump(exclude_unset=True) @@ -73,7 +81,9 @@ async def update_group_metadata( @router.delete("/{group_id}", response_model=DeleteGroupResponse) async def delete_group( - group_id: str, current_user: Dict[str, Any] = Depends(get_current_user) + group_id: str, + current_user: Dict[str, Any] = Depends(get_current_user), + group_service: GroupService = Depends(get_group_service), ): """Delete a group (admin only)""" deleted = await group_service.delete_group(group_id, current_user["_id"]) @@ -86,6 +96,7 @@ async def delete_group( async def join_group_by_code( join_data: JoinGroupRequest, current_user: Dict[str, Any] = Depends(get_current_user), + group_service: GroupService = Depends(get_group_service), ): """Join a group using a join code""" group = await group_service.join_group_by_code( @@ -98,7 +109,9 @@ async def join_group_by_code( @router.post("/{group_id}/leave", response_model=LeaveGroupResponse) async def leave_group( - group_id: str, current_user: Dict[str, Any] = Depends(get_current_user) + group_id: str, + current_user: Dict[str, Any] = Depends(get_current_user), + group_service: GroupService = Depends(get_group_service), ): """Leave a group (only if no outstanding balances)""" left = await group_service.leave_group(group_id, current_user["_id"]) @@ -109,7 +122,9 @@ async def leave_group( @router.get("/{group_id}/members", response_model=List[GroupMemberWithDetails]) async def get_group_members( - group_id: str, current_user: Dict[str, Any] = Depends(get_current_user) + group_id: str, + current_user: Dict[str, Any] = Depends(get_current_user), + group_service: GroupService = Depends(get_group_service), ): """Get list of group members with detailed user information""" members = await group_service.get_group_members(group_id, current_user["_id"]) @@ -122,6 +137,7 @@ async def update_member_role( member_id: str, role_update: MemberRoleUpdateRequest, current_user: Dict[str, Any] = Depends(get_current_user), + group_service: GroupService = Depends(get_group_service), ): """Change member role (admin only)""" updated = await group_service.update_member_role( @@ -137,6 +153,7 @@ async def remove_group_member( group_id: str, member_id: str, current_user: Dict[str, Any] = Depends(get_current_user), + group_service: GroupService = Depends(get_group_service), ): """Remove a member from the group (admin only)""" removed = await group_service.remove_member( diff --git a/backend/app/groups/service.py b/backend/app/groups/service.py index c70086e9..4bcef7fc 100644 --- a/backend/app/groups/service.py +++ b/backend/app/groups/service.py @@ -4,17 +4,16 @@ from typing import Any, Dict, List, Optional from app.config import logger -from app.database import get_database +from app.services.base import BaseService from bson import ObjectId, errors from fastapi import HTTPException -class GroupService: +class GroupService(BaseService): def __init__(self): - pass - - def get_db(self): - return get_database() + super().__init__("groups") + self.users_collection = self.db["users"] + self.settlements_collection = self.db["settlements"] def generate_join_code(self, length: int = 6) -> str: """Generate a random alphanumeric join code""" @@ -25,7 +24,6 @@ async def _enrich_members_with_user_details( self, members: List[dict] ) -> List[dict]: """Private method to enrich member data with user details from users collection""" - db = self.get_db() enriched_members = [] for member in members: @@ -34,7 +32,7 @@ async def _enrich_members_with_user_details( try: # Fetch user details from users collection user_obj_id = ObjectId(member_user_id) - user = await db.users.find_one({"_id": user_obj_id}) + user = await self.users_collection.find_one({"_id": user_obj_id}) # Create enriched member object enriched_member = { @@ -123,13 +121,11 @@ def transform_group_document(self, group: dict) -> dict: async def create_group(self, group_data: dict, user_id: str) -> dict: """Create a new group with the user as admin""" - db = self.get_db() - # Generate unique join code join_code = None for _ in range(10): # Try up to 10 times to generate unique code join_code = self.generate_join_code() - existing = await db.groups.find_one({"joinCode": join_code}) + existing = await self.collection.find_one({"joinCode": join_code}) if not existing: break @@ -149,14 +145,13 @@ async def create_group(self, group_data: dict, user_id: str) -> dict: "members": [{"userId": user_id, "role": "admin", "joinedAt": now}], } - result = await db.groups.insert_one(group_doc) - created_group = await db.groups.find_one({"_id": result.inserted_id}) + result = await self.collection.insert_one(group_doc) + created_group = await self.collection.find_one({"_id": result.inserted_id}) return self.transform_group_document(created_group) async def get_user_groups(self, user_id: str) -> List[dict]: """Get all groups where user is a member""" - db = self.get_db() - cursor = db.groups.find({"members.userId": user_id}) + cursor = self.collection.find({"members.userId": user_id}) groups = [] async for group in cursor: transformed = self.transform_group_document(group) @@ -166,7 +161,6 @@ async def get_user_groups(self, user_id: str) -> List[dict]: async def get_group_by_id(self, group_id: str, user_id: str) -> Optional[dict]: """Get group details by ID with enriched member information, only if user is a member""" - db = self.get_db() try: obj_id = ObjectId(group_id) except errors.InvalidId: @@ -176,7 +170,9 @@ async def get_group_by_id(self, group_id: str, user_id: str) -> Optional[dict]: logger.error(f"Unexpected error converting group_id to ObjectId: {e}") return None - group = await db.groups.find_one({"_id": obj_id, "members.userId": user_id}) + group = await self.collection.find_one( + {"_id": obj_id, "members.userId": user_id} + ) if not group: return None @@ -197,7 +193,6 @@ async def update_group( self, group_id: str, updates: dict, user_id: str ) -> Optional[dict]: """Update group metadata (admin only)""" - db = self.get_db() try: obj_id = ObjectId(group_id) except errors.InvalidId: @@ -208,7 +203,7 @@ async def update_group( return None # Check if user is admin - group = await db.groups.find_one( + group = await self.collection.find_one( { "_id": obj_id, "members": {"$elemMatch": {"userId": user_id, "role": "admin"}}, @@ -219,14 +214,13 @@ async def update_group( status_code=403, detail="Only group admins can update group details" ) - result = await db.groups.find_one_and_update( + result = await self.collection.find_one_and_update( {"_id": obj_id}, {"$set": updates}, return_document=True ) return self.transform_group_document(result) async def delete_group(self, group_id: str, user_id: str) -> bool: """Delete group (admin only)""" - db = self.get_db() try: obj_id = ObjectId(group_id) except errors.InvalidId: @@ -237,7 +231,7 @@ async def delete_group(self, group_id: str, user_id: str) -> bool: return False # Check if user is admin - group = await db.groups.find_one( + group = await self.collection.find_one( { "_id": obj_id, "members": {"$elemMatch": {"userId": user_id, "role": "admin"}}, @@ -248,15 +242,13 @@ async def delete_group(self, group_id: str, user_id: str) -> bool: status_code=403, detail="Only group admins can delete groups" ) - result = await db.groups.delete_one({"_id": obj_id}) + result = await self.collection.delete_one({"_id": obj_id}) return result.deleted_count == 1 async def join_group_by_code(self, join_code: str, user_id: str) -> Optional[dict]: """Join a group using join code""" - db = self.get_db() - # Find group by join code - group = await db.groups.find_one({"joinCode": join_code.upper()}) + group = await self.collection.find_one({"joinCode": join_code.upper()}) if not group: raise HTTPException(status_code=404, detail="Invalid join code") @@ -276,7 +268,7 @@ async def join_group_by_code(self, join_code: str, user_id: str) -> Optional[dic "joinedAt": datetime.now(timezone.utc), } - result = await db.groups.find_one_and_update( + result = await self.collection.find_one_and_update( {"_id": group["_id"]}, {"$push": {"members": new_member}}, return_document=True, @@ -285,14 +277,15 @@ async def join_group_by_code(self, join_code: str, user_id: str) -> Optional[dic async def leave_group(self, group_id: str, user_id: str) -> bool: """Leave a group (only if user has no outstanding balances)""" - db = self.get_db() try: obj_id = ObjectId(group_id) except Exception: return False # Check if user is a member - group = await db.groups.find_one({"_id": obj_id, "members.userId": user_id}) + group = await self.collection.find_one( + {"_id": obj_id, "members.userId": user_id} + ) if not group: raise HTTPException( status_code=404, detail="Group not found or you are not a member" @@ -314,7 +307,7 @@ async def leave_group(self, group_id: str, user_id: str) -> bool: # Block leaving when there are unsettled balances involving this user try: - pending = await db.settlements.find_one( + pending = await self.settlements_collection.find_one( { "groupId": group_id, # settlements store string groupId "status": "pending", @@ -336,20 +329,21 @@ async def leave_group(self, group_id: str, user_id: str) -> bool: detail="Cannot leave group with unsettled balances. Please settle up first.", ) - result = await db.groups.update_one( + result = await self.collection.update_one( {"_id": obj_id}, {"$pull": {"members": {"userId": user_id}}} ) return result.modified_count == 1 async def get_group_members(self, group_id: str, user_id: str) -> List[dict]: """Get list of group members with detailed user information""" - db = self.get_db() try: obj_id = ObjectId(group_id) except Exception: return [] - group = await db.groups.find_one({"_id": obj_id, "members.userId": user_id}) + group = await self.collection.find_one( + {"_id": obj_id, "members.userId": user_id} + ) if not group: return [] @@ -364,14 +358,13 @@ async def update_member_role( self, group_id: str, member_id: str, new_role: str, user_id: str ) -> bool: """Update member role (admin only)""" - db = self.get_db() try: obj_id = ObjectId(group_id) except Exception: return False # Check if user is admin - group = await db.groups.find_one( + group = await self.collection.find_one( { "_id": obj_id, "members": {"$elemMatch": {"userId": user_id, "role": "admin"}}, @@ -400,7 +393,7 @@ async def update_member_role( detail="Cannot demote yourself when you are the only admin. Promote another member to admin first.", ) - result = await db.groups.update_one( + result = await self.collection.update_one( {"_id": obj_id, "members.userId": member_id}, {"$set": {"members.$.role": new_role}}, ) @@ -408,14 +401,13 @@ async def update_member_role( async def remove_member(self, group_id: str, member_id: str, user_id: str) -> bool: """Remove a member from group (admin only)""" - db = self.get_db() try: obj_id = ObjectId(group_id) except Exception: return False # Check if group exists and user is admin - group = await db.groups.find_one( + group = await self.collection.find_one( { "_id": obj_id, "members": {"$elemMatch": {"userId": user_id, "role": "admin"}}, @@ -423,7 +415,7 @@ async def remove_member(self, group_id: str, member_id: str, user_id: str) -> bo ) if not group: # Check if group exists at all - group_exists = await db.groups.find_one({"_id": obj_id}) + group_exists = await self.collection.find_one({"_id": obj_id}) if not group_exists: raise HTTPException(status_code=404, detail="Group not found") else: @@ -446,7 +438,7 @@ async def remove_member(self, group_id: str, member_id: str, user_id: str) -> bo # Block removal when there are unsettled balances involving the target member try: - pending = await db.settlements.find_one( + pending = await self.settlements_collection.find_one( { "groupId": group_id, # settlements store string groupId "status": "pending", @@ -468,10 +460,7 @@ async def remove_member(self, group_id: str, member_id: str, user_id: str) -> bo detail="Cannot remove member with unsettled balances. Please settle up first.", ) - result = await db.groups.update_one( + result = await self.collection.update_one( {"_id": obj_id}, {"$pull": {"members": {"userId": member_id}}} ) return result.modified_count == 1 - - -group_service = GroupService() diff --git a/backend/app/services/__init__.py b/backend/app/services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/app/services/base.py b/backend/app/services/base.py new file mode 100644 index 00000000..abb1a6d1 --- /dev/null +++ b/backend/app/services/base.py @@ -0,0 +1,45 @@ +from typing import Any, Dict, Optional + +from app.database import get_database +from bson import ObjectId, errors + + +class BaseService: + def __init__(self, collection_name: str): + self.collection_name = collection_name + self._db = None + self._collection = None + + @property + def db(self): + if self._db is None: + self._db = get_database() + return self._db + + @property + def collection(self): + if self._collection is None: + self._collection = self.db[self.collection_name] + return self._collection + + def _to_json(self, item: Dict[str, Any]) -> Dict[str, Any]: + """Converts a MongoDB document to a JSON-serializable dictionary.""" + if item and "_id" in item: + item["id"] = str(item.pop("_id")) + return item + + async def get_by_id(self, item_id: str) -> Optional[Dict[str, Any]]: + try: + obj_id = ObjectId(item_id) + except errors.InvalidId: + return None + item = await self.collection.find_one({"_id": obj_id}) + return self._to_json(item) + + async def delete(self, item_id: str) -> bool: + try: + obj_id = ObjectId(item_id) + except errors.InvalidId: + return False + result = await self.collection.delete_one({"_id": obj_id}) + return result.deleted_count > 0 diff --git a/backend/app/user/routes.py b/backend/app/user/routes.py index 6f0b283f..25ddef23 100644 --- a/backend/app/user/routes.py +++ b/backend/app/user/routes.py @@ -1,12 +1,13 @@ from typing import Any, Dict from app.auth.security import get_current_user +from app.dependencies import get_user_service from app.user.schemas import ( DeleteUserResponse, UserProfileResponse, UserProfileUpdateRequest, ) -from app.user.service import user_service +from app.user.service import UserService from fastapi import APIRouter, Depends, HTTPException, status router = APIRouter(prefix="/users", tags=["User"]) @@ -15,6 +16,7 @@ @router.get("/me", response_model=UserProfileResponse) async def get_current_user_profile( current_user: Dict[str, Any] = Depends(get_current_user), + user_service: UserService = Depends(get_user_service), ): user = await user_service.get_user_by_id(current_user["_id"]) if not user: @@ -28,6 +30,7 @@ async def get_current_user_profile( async def update_user_profile( updates: UserProfileUpdateRequest, current_user: Dict[str, Any] = Depends(get_current_user), + user_service: UserService = Depends(get_user_service), ): update_data = updates.model_dump(exclude_unset=True) if not update_data: @@ -46,7 +49,10 @@ async def update_user_profile( @router.delete("/me", response_model=DeleteUserResponse) -async def delete_user_account(current_user: Dict[str, Any] = Depends(get_current_user)): +async def delete_user_account( + current_user: Dict[str, Any] = Depends(get_current_user), + user_service: UserService = Depends(get_user_service), +): deleted = await user_service.delete_user(current_user["_id"]) if not deleted: raise HTTPException( diff --git a/backend/app/user/service.py b/backend/app/user/service.py index 8ce83dd5..e4d9c078 100644 --- a/backend/app/user/service.py +++ b/backend/app/user/service.py @@ -2,16 +2,13 @@ from typing import Any, Dict, Optional from app.config import logger -from app.database import get_database +from app.services.base import BaseService from bson import ObjectId, errors -class UserService: +class UserService(BaseService): def __init__(self): - pass - - def get_db(self): - return get_database() + super().__init__("users") def transform_user_document(self, user: dict) -> dict: if not user: @@ -36,11 +33,7 @@ def iso(dt): ) # Logging failed datetime transformation return str(dt) - try: - user_id = str(user["_id"]) - except (KeyError, TypeError) as e: - logger.error(f"Invalid user document format: {e}") - return None # Handle invalid ObjectId gracefully + user_id = str(user.get("id") or user.get("_id")) return { "id": user_id, "name": user.get("name"), @@ -52,45 +45,23 @@ def iso(dt): } async def get_user_by_id(self, user_id: str) -> Optional[dict]: - db = self.get_db() - try: - obj_id = ObjectId(user_id) - except errors.InvalidId as e: - # Invalid ObjectId format - logger.warning(f"Invalid User ID format: {e}") - return None # Handle invalid ObjectId gracefully - user = await db.users.find_one({"_id": obj_id}) + user = await self.get_by_id(user_id) return self.transform_user_document(user) async def update_user_profile(self, user_id: str, updates: dict) -> Optional[dict]: - db = self.get_db() try: obj_id = ObjectId(user_id) except errors.InvalidId as e: - logger.warning( - f"Invalid User ID format: {e}" - ) # Invalid ObjectId format for profile update - return None # Handle invalid ObjectId gracefully + logger.warning(f"Invalid User ID format: {e}") + return None # Only allow certain fields allowed = {"name", "imageUrl", "currency"} updates = {k: v for k, v in updates.items() if k in allowed} updates["updated_at"] = datetime.now(timezone.utc) - result = await db.users.find_one_and_update( + result = await self.collection.find_one_and_update( {"_id": obj_id}, {"$set": updates}, return_document=True ) return self.transform_user_document(result) async def delete_user(self, user_id: str) -> bool: - db = self.get_db() - try: - obj_id = ObjectId(user_id) - except errors.InvalidId as e: - logger.warning( - f"Invalid User ID format: {e}" - ) # Invalid ObjectId format for deletion - return False # Handle invalid ObjectId gracefully - result = await db.users.delete_one({"_id": obj_id}) - return result.deleted_count > 0 - - -user_service = UserService() + return await self.delete(user_id) diff --git a/backend/main.py b/backend/main.py index 3372ffb8..1d223970 100644 --- a/backend/main.py +++ b/backend/main.py @@ -34,82 +34,28 @@ async def lifespan(app: FastAPI): lifespan=lifespan, ) -# CORS middleware - Enhanced configuration for production -allowed_origins = [] +# CORS middleware if settings.allow_all_origins: - # Allow all origins in development mode - allowed_origins = ["*"] - logger.debug("Development mode: CORS configured to allow all origins") -elif settings.allowed_origins: - # Use specified origins in production mode - allowed_origins = [ - origin.strip() - for origin in settings.allowed_origins.split(",") - if origin.strip() - ] + logger.info("Allowing all origins for CORS") + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) else: - # Fallback to allow all origins if not specified (not recommended for production) - allowed_origins = ["*"] - -logger.info(f"Allowed CORS origins: {allowed_origins}") + logger.info(f"Allowing specific origins for CORS: {settings.allowed_origins}") + app.add_middleware( + CORSMiddleware, + allow_origins=settings.allowed_origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) app.add_middleware(RequestResponseLoggingMiddleware) -app.add_middleware( - CORSMiddleware, - allow_origins=allowed_origins, - allow_credentials=True, - allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"], - allow_headers=[ - "Accept", - "Accept-Language", - "Content-Language", - "Content-Type", - "Authorization", - "X-Requested-With", - "Origin", - "Cache-Control", - "Pragma", - "X-CSRFToken", - ], - expose_headers=["*"], - max_age=3600, # Cache preflight responses for 1 hour -) - - -# Add a catch-all OPTIONS handler that should work for any path -@app.options("/{path:path}") -async def options_handler(request: Request, path: str): - """Handle all OPTIONS requests""" - logger.info(f"OPTIONS request received for path: /{path}") - logger.info(f"Origin: {request.headers.get('origin', 'No origin header')}") - - response = Response(status_code=200) - - # Manually set CORS headers for debugging - origin = request.headers.get("origin") - if origin and (origin in allowed_origins or "*" in allowed_origins): - response.headers["Access-Control-Allow-Origin"] = origin - response.headers["Access-Control-Allow-Methods"] = ( - "GET, POST, PUT, DELETE, OPTIONS, HEAD, PATCH" - ) - response.headers["Access-Control-Allow-Headers"] = ( - "Accept, Accept-Language, Content-Language, Content-Type, Authorization, X-Requested-With, Origin, Cache-Control, Pragma, X-CSRFToken" - ) - response.headers["Access-Control-Allow-Credentials"] = "true" - response.headers["Access-Control-Max-Age"] = "3600" - elif "*" in allowed_origins: - response.headers["Access-Control-Allow-Origin"] = "*" - response.headers["Access-Control-Allow-Methods"] = ( - "GET, POST, PUT, DELETE, OPTIONS, HEAD, PATCH" - ) - response.headers["Access-Control-Allow-Headers"] = ( - "Accept, Accept-Language, Content-Language, Content-Type, Authorization, X-Requested-With, Origin, Cache-Control, Pragma, X-CSRFToken" - ) - response.headers["Access-Control-Max-Age"] = "3600" - - return response - # Health check @app.get("/health") diff --git a/backend/requirements-dev.txt b/backend/requirements-dev.txt new file mode 100644 index 00000000..7545764e --- /dev/null +++ b/backend/requirements-dev.txt @@ -0,0 +1,7 @@ +pytest +pytest-asyncio +httpx +mongomock-motor +pytest-env +pytest-cov +pytest-mock diff --git a/backend/requirements.txt b/backend/requirements.txt index 14825b14..0d66227f 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -11,10 +11,3 @@ firebase-admin==6.9.0 python-dotenv==1.0.0 bcrypt==4.0.1 email-validator==2.2.0 -pytest -pytest-asyncio -httpx -mongomock-motor -pytest-env -pytest-cov -pytest-mock diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 18cc0247..1dd65df0 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -73,9 +73,7 @@ async def mock_db(): # Patch get_database for all services that use it patches = [ - patch("app.auth.service.get_database", return_value=mock_database_instance), - patch("app.user.service.get_database", return_value=mock_database_instance), - patch("app.groups.service.get_database", return_value=mock_database_instance), + patch("app.database.get_database", return_value=mock_database_instance), ] # Start all patches