From 83c4d1b7477791d32b2a30e087e89ccda004a61c Mon Sep 17 00:00:00 2001 From: Devasy Patel <110348311+Devasy23@users.noreply.github.com> Date: Sun, 7 Dec 2025 14:31:06 +0530 Subject: [PATCH 1/8] feat: optimize database queries for friends balance summary and enrich member data --- backend/app/expenses/service.py | 270 ++++++++++------- backend/app/groups/service.py | 139 +++++---- backend/migrations/001_create_indexes.py | 285 ++++++++++++++++++ .../migrations/002_fix_duplicate_emails.py | 271 +++++++++++++++++ 4 files changed, 799 insertions(+), 166 deletions(-) create mode 100644 backend/migrations/001_create_indexes.py create mode 100644 backend/migrations/002_fix_duplicate_emails.py diff --git a/backend/app/expenses/service.py b/backend/app/expenses/service.py index 80cbcf31..d76dccb0 100644 --- a/backend/app/expenses/service.py +++ b/backend/app/expenses/service.py @@ -953,141 +953,201 @@ async def get_user_balance_in_group( } async def get_friends_balance_summary(self, user_id: str) -> Dict[str, Any]: - """Get cross-group friend balances for a user""" + """ + Get cross-group friend balances using optimized aggregation pipeline. - # Get all groups user belongs to + Performance: Optimized to use single aggregation query instead of N×M queries. + Example: 20 friends × 5 groups = 3 queries total (vs 100+ with naive approach). + + Uses MongoDB aggregation to calculate all balances at once, then batch enriches + with user and group details for optimal performance. + """ + + # First, get all groups user belongs to (need this to filter friends properly) groups = await self.groups_collection.find({"members.userId": user_id}).to_list( - None + length=500 ) - friends_balance = [] - user_totals = {"totalOwedToYou": 0, "totalYouOwe": 0} + if not groups: + return { + "friendsBalance": [], + "summary": { + "totalOwedToYou": 0, + "totalYouOwe": 0, + "netBalance": 0, + "friendCount": 0, + "activeGroups": 0, + }, + } - # Get all unique friends across groups - friend_ids = set() + # Extract group IDs and friend IDs (only from user's groups) + group_ids = [str(g["_id"]) for g in groups] + friend_ids_in_groups = set() for group in groups: for member in group["members"]: if member["userId"] != user_id: - friend_ids.add(member["userId"]) + friend_ids_in_groups.add(member["userId"]) - # Get user names & images - users = await self.users_collection.find( - {"_id": {"$in": [ObjectId(uid) for uid in friend_ids]}} - ).to_list(None) - user_names = {str(user["_id"]): user.get("name", "Unknown") for user in users} - user_images = {str(user["_id"]): user.get("imageUrl") for user in users} + # OPTIMIZATION: Single aggregation to calculate all friend balances at once + # Only for friends in user's groups and groups user belongs to + pipeline = [ + # Step 1: Match settlements in user's groups involving the user + { + "$match": { + "groupId": {"$in": group_ids}, + "$or": [ + { + "payerId": user_id, + "payeeId": {"$in": list(friend_ids_in_groups)}, + }, + { + "payeeId": user_id, + "payerId": {"$in": list(friend_ids_in_groups)}, + }, + ], + } + }, + # Step 2: Calculate net balance per friend per group + { + "$group": { + "_id": { + "friendId": { + "$cond": [ + {"$eq": ["$payerId", user_id]}, + "$payeeId", + "$payerId", + ] + }, + "groupId": "$groupId", + }, + "balance": { + "$sum": { + "$cond": [ + # If user is payer, friend owes user (positive) + {"$eq": ["$payerId", user_id]}, + "$amount", + # If user is payee, user owes friend (negative) + {"$multiply": ["$amount", -1]}, + ] + } + }, + } + }, + # Step 3: Group by friend to get total balance across all groups + { + "$group": { + "_id": "$_id.friendId", + "totalBalance": {"$sum": "$balance"}, + "groups": { + "$push": {"groupId": "$_id.groupId", "balance": "$balance"} + }, + } + }, + # Step 4: Filter out friends with zero balance + {"$match": {"$expr": {"$gt": [{"$abs": "$totalBalance"}, 0.01]}}}, + ] - for friend_id in friend_ids: - friend_balance_data = { - "userId": friend_id, - "userName": user_names.get(friend_id, "Unknown"), - # Populate image directly from users collection to avoid extra client round-trips - "userImageUrl": user_images.get(friend_id), - "netBalance": 0, - "owesYou": False, - "breakdown": [], - "lastActivity": datetime.utcnow(), + # Execute aggregation - Single query for all friend balances + try: + results = await self.settlements_collection.aggregate(pipeline).to_list( + length=500 + ) + except Exception as e: + logger.error(f"Error in optimized friends balance aggregation: {e}") + results = [] + + if not results: + # No balances found + return { + "friendsBalance": [], + "summary": { + "totalOwedToYou": 0, + "totalYouOwe": 0, + "netBalance": 0, + "friendCount": 0, + "activeGroups": len(groups), + }, } - total_friend_balance = 0 + # Extract unique friend IDs for batch fetching + friend_ids = list(set(result["_id"] for result in results)) - # Calculate balance for each group - for group in groups: - group_id = str(group["_id"]) + # Build group map from groups we already fetched + groups_map = {str(g["_id"]): g.get("name", "Unknown Group") for g in groups} - # Check if friend is in this group - friend_in_group = any( - member["userId"] == friend_id for member in group["members"] - ) - if not friend_in_group: - continue + # OPTIMIZATION: Batch fetch all friend details in one query + try: + friends_cursor = self.users_collection.find( + {"_id": {"$in": [ObjectId(fid) for fid in friend_ids]}}, + {"_id": 1, "name": 1, "imageUrl": 1}, + ) + friends_list = await friends_cursor.to_list(length=500) + friends_map = {str(f["_id"]): f for f in friends_list} + except Exception as e: + logger.error(f"Error batch fetching friend details: {e}") + friends_map = {} - # Calculate net balance between user and friend in this group - pipeline = [ - { - "$match": { - "groupId": group_id, - "$or": [ - {"payerId": user_id, "payeeId": friend_id}, - {"payerId": friend_id, "payeeId": user_id}, - ], - } - }, - { - "$group": { - "_id": None, - "userOwes": { - "$sum": { - "$cond": [ - { - "$and": [ - {"$eq": ["$payerId", friend_id]}, - {"$eq": ["$payeeId", user_id]}, - ] - }, - "$amount", - 0, - ] - } - }, - "friendOwes": { - "$sum": { - "$cond": [ - { - "$and": [ - {"$eq": ["$payerId", user_id]}, - {"$eq": ["$payeeId", friend_id]}, - ] - }, - "$amount", - 0, - ] - } - }, - } - }, - ] + # Post-process results to build final response + friends_balance = [] + user_totals = {"totalOwedToYou": 0, "totalYouOwe": 0} - result = await self.settlements_collection.aggregate(pipeline).to_list( - None - ) - balance_data = result[0] if result else {"userOwes": 0, "friendOwes": 0} + for result in results: + friend_id = result["_id"] + total_balance = result["totalBalance"] - group_balance = balance_data["friendOwes"] - balance_data["userOwes"] - total_friend_balance += group_balance + # Get friend details from map + friend_details = friends_map.get(friend_id) - if ( - abs(group_balance) > 0.01 - ): # Only include if there's a significant balance - friend_balance_data["breakdown"].append( + # Build breakdown by group + breakdown = [] + for group_item in result.get("groups", []): + group_id = group_item["groupId"] + group_balance = group_item["balance"] + + # Only include groups with significant balance + if abs(group_balance) > 0.01: + breakdown.append( { "groupId": group_id, - "groupName": group["name"], - "balance": group_balance, + "groupName": groups_map.get(group_id, "Unknown Group"), + "balance": round(group_balance, 2), "owesYou": group_balance > 0, } ) - if ( - abs(total_friend_balance) > 0.01 - ): # Only include friends with non-zero balance - friend_balance_data["netBalance"] = total_friend_balance - friend_balance_data["owesYou"] = total_friend_balance > 0 + # Build friend balance object + friend_data = { + "userId": friend_id, + "userName": ( + friend_details.get("name", "Unknown") + if friend_details + else "Unknown" + ), + "userImageUrl": ( + friend_details.get("imageUrl") if friend_details else None + ), + "netBalance": round(total_balance, 2), + "owesYou": total_balance > 0, + "breakdown": breakdown, + "lastActivity": datetime.utcnow(), # TODO: Calculate actual last activity + } - if total_friend_balance > 0: - user_totals["totalOwedToYou"] += total_friend_balance - else: - user_totals["totalYouOwe"] += abs(total_friend_balance) + friends_balance.append(friend_data) - friends_balance.append(friend_balance_data) + # Update totals + if total_balance > 0: + user_totals["totalOwedToYou"] += total_balance + else: + user_totals["totalYouOwe"] += abs(total_balance) return { "friendsBalance": friends_balance, "summary": { - "totalOwedToYou": user_totals["totalOwedToYou"], - "totalYouOwe": user_totals["totalYouOwe"], - "netBalance": user_totals["totalOwedToYou"] - - user_totals["totalYouOwe"], + "totalOwedToYou": round(user_totals["totalOwedToYou"], 2), + "totalYouOwe": round(user_totals["totalYouOwe"], 2), + "netBalance": round( + user_totals["totalOwedToYou"] - user_totals["totalYouOwe"], 2 + ), "friendCount": len(friends_balance), "activeGroups": len(groups), }, diff --git a/backend/app/groups/service.py b/backend/app/groups/service.py index c70086e9..821ae3ff 100644 --- a/backend/app/groups/service.py +++ b/backend/app/groups/service.py @@ -24,82 +24,99 @@ def generate_join_code(self, length: int = 6) -> str: async def _enrich_members_with_user_details( self, members: List[dict] ) -> List[dict]: - """Private method to enrich member data with user details from users collection""" + """ + Enrich member data with user details from the users collection. + Uses batch fetching for optimal performance (single query for all members). + + Performance: O(1) database queries regardless of member count. + Example: 10 members = 1 query instead of 10 separate queries. + """ + if not members: + return [] + db = self.get_db() enriched_members = [] + # Extract all unique user IDs + user_ids = [] for member in members: member_user_id = member.get("userId") if member_user_id: try: - # Fetch user details from users collection - user_obj_id = ObjectId(member_user_id) - user = await db.users.find_one({"_id": user_obj_id}) - - # Create enriched member object - enriched_member = { - "userId": member_user_id, - "role": member.get("role", "member"), - "joinedAt": member.get("joinedAt"), - "user": ( - { - "name": ( - user.get("name", f"User {member_user_id[-4:]}") - if user - else f"User {member_user_id[-4:]}" - ), - "email": ( - user.get("email", f"{member_user_id}@example.com") - if user - else f"{member_user_id}@example.com" - ), - "imageUrl": (user.get("imageUrl") if user else None), - } - if user - else { - "name": f"User {member_user_id[-4:]}", - "email": f"{member_user_id}@example.com", - "imageUrl": None, - } - ), - } - enriched_members.append(enriched_member) - except errors.InvalidId: # exception for invalid ObjectId + user_ids.append(ObjectId(member_user_id)) + except errors.InvalidId: logger.warning(f"Invalid ObjectId for userId: {member_user_id}") - enriched_members.append( - { - "userId": member_user_id, - "role": member.get("role", "member"), - "joinedAt": member.get("joinedAt"), - "user": { - "name": f"User {member_user_id[-4:]}", - "email": f"{member_user_id}@example.com", - "avatar": None, - }, - } - ) - except Exception as e: - logger.error(f"Error enriching userId {member_user_id}: {e}") - # If user lookup fails, add member with basic info - enriched_members.append( - { - "userId": member_user_id, - "role": member.get("role", "member"), - "joinedAt": member.get("joinedAt"), - "user": { - "name": f"User {member_user_id[-4:]}", - "email": f"{member_user_id}@example.com", - "imageUrl": None, - }, - } - ) + if not user_ids: + # No valid user IDs, return members with fallback data + return [self._create_fallback_member(m) for m in members] + + # OPTIMIZATION: Single query to fetch ALL users at once using $in + try: + users_cursor = db.users.find( + {"_id": {"$in": user_ids}}, + { + "_id": 1, + "name": 1, + "email": 1, + "imageUrl": 1, + }, # Project only needed fields + ) + users_list = await users_cursor.to_list(length=100) + + # Create fast lookup dictionary: O(1) access per member + users_map = {str(user["_id"]): user for user in users_list} + + except Exception as e: + logger.error(f"Error batch fetching users: {e}") + # Fallback to empty map if query fails + users_map = {} + + # Enrich members using the lookup map + for member in members: + member_user_id = member.get("userId") + if member_user_id: + user = users_map.get(member_user_id) + + enriched_member = { + "userId": member_user_id, + "role": member.get("role", "member"), + "joinedAt": member.get("joinedAt"), + "user": { + "name": ( + user.get("name", f"User {member_user_id[-4:]}") + if user + else f"User {member_user_id[-4:]}" + ), + "email": ( + user.get("email", f"{member_user_id}@example.com") + if user + else f"{member_user_id}@example.com" + ), + "imageUrl": user.get("imageUrl") if user else None, + }, + } + enriched_members.append(enriched_member) else: # Add member without user details if userId is missing enriched_members.append(member) return enriched_members + def _create_fallback_member(self, member: dict) -> dict: + """Helper to create fallback member data when user lookup fails""" + user_id = member.get("userId", "unknown") + return { + "userId": user_id, + "role": member.get("role", "member"), + "joinedAt": member.get("joinedAt"), + "user": { + "name": f"User {user_id[-4:] if len(user_id) >= 4 else user_id}", + "email": f"{user_id}@example.com", + "imageUrl": None, + }, + } + def transform_group_document(self, group: dict) -> dict: """Transform MongoDB group document to API response format""" if not group: diff --git a/backend/migrations/001_create_indexes.py b/backend/migrations/001_create_indexes.py new file mode 100644 index 00000000..6ae11762 --- /dev/null +++ b/backend/migrations/001_create_indexes.py @@ -0,0 +1,285 @@ +""" +Database Index Migration +======================== + +This script creates all necessary indexes for optimal query performance +in the Splitwiser application. + +Run this script once to set up indexes: + python -m migrations.001_create_indexes + +Expected runtime: < 30 seconds +""" + +import asyncio +import sys +from pathlib import Path + +from app.config import logger, settings +from motor.motor_asyncio import AsyncIOMotorClient + +# Add parent directory to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent)) + + +async def create_indexes(): + """Create all database indexes for optimal performance""" + + logger.info("=" * 60) + logger.info("DATABASE INDEX MIGRATION") + logger.info("=" * 60) + + # Connect to MongoDB + client = AsyncIOMotorClient(settings.mongodb_url) + db = client[settings.database_name] + + try: + logger.info(f"Connected to database: {settings.database_name}") + logger.info("") + + # ========================================== + # USERS COLLECTION INDEXES + # ========================================== + logger.info("📋 Creating indexes for 'users' collection...") + + # Email index (unique) - Critical for login performance + try: + await db.users.create_index("email", unique=True) + logger.info(" ✓ Created unique index on 'email'") + except Exception as e: + if "duplicate key" in str(e).lower() or "11000" in str(e): + logger.warning( + f" ⚠ Cannot create unique email index - duplicate emails exist in database" + ) + logger.warning( + f" → Creating non-unique index instead for performance" + ) + # Create non-unique index for performance even with duplicates + try: + await db.users.create_index("email", unique=False) + logger.info(" ✓ Created non-unique index on 'email'") + except Exception as e2: + logger.error(f" ✗ Failed to create email index: {e2}") + else: + logger.error(f" ✗ Failed to create email index: {e}") + raise + + # Firebase UID index (sparse) - For Google OAuth users + await db.users.create_index("firebase_uid", sparse=True) + logger.info(" ✓ Created sparse index on 'firebase_uid'") + + logger.info("") + + # ========================================== + # GROUPS COLLECTION INDEXES + # ========================================== + logger.info("📋 Creating indexes for 'groups' collection...") + + # Join code index (unique) - For group joining + await db.groups.create_index("joinCode", unique=True) + logger.info(" ✓ Created unique index on 'joinCode'") + + # Member userId index - Critical for finding user's groups + await db.groups.create_index([("members.userId", 1)]) + logger.info(" ✓ Created index on 'members.userId'") + + # Created by index - For user's created groups + await db.groups.create_index("createdBy") + logger.info(" ✓ Created index on 'createdBy'") + + logger.info("") + + # ========================================== + # EXPENSES COLLECTION INDEXES + # ========================================== + logger.info("📋 Creating indexes for 'expenses' collection...") + + # Compound index: groupId + createdAt - For listing group expenses + await db.expenses.create_index([("groupId", 1), ("createdAt", -1)]) + logger.info(" ✓ Created compound index on 'groupId' + 'createdAt'") + + # Compound index: groupId + splits.userId - For finding user's expenses + await db.expenses.create_index([("groupId", 1), ("splits.userId", 1)]) + logger.info(" ✓ Created compound index on 'groupId' + 'splits.userId'") + + # Compound index: groupId + tags - For filtering by tags + await db.expenses.create_index([("groupId", 1), ("tags", 1)]) + logger.info(" ✓ Created compound index on 'groupId' + 'tags'") + + # Compound index: createdBy + createdAt - For user's created expenses + await db.expenses.create_index([("createdBy", 1), ("createdAt", -1)]) + logger.info(" ✓ Created compound index on 'createdBy' + 'createdAt'") + + logger.info("") + + # ========================================== + # SETTLEMENTS COLLECTION INDEXES + # ========================================== + logger.info("📋 Creating indexes for 'settlements' collection...") + + # Compound index: groupId + status - Critical for settlement queries + await db.settlements.create_index([("groupId", 1), ("status", 1)]) + logger.info(" ✓ Created compound index on 'groupId' + 'status'") + + # Compound index: groupId + payerId + payeeId - For balance calculations + await db.settlements.create_index( + [("groupId", 1), ("payerId", 1), ("payeeId", 1)] + ) + logger.info(" ✓ Created compound index on 'groupId' + 'payerId' + 'payeeId'") + + # Expense ID index - For finding settlements by expense + await db.settlements.create_index("expenseId") + logger.info(" ✓ Created index on 'expenseId'") + + # Payer ID index - For finding all settlements where user paid + await db.settlements.create_index("payerId") + logger.info(" ✓ Created index on 'payerId'") + + # Payee ID index - For finding all settlements where user owes + await db.settlements.create_index("payeeId") + logger.info(" ✓ Created index on 'payeeId'") + + # Compound index for balance queries - user in either role + await db.settlements.create_index([("payerId", 1), ("payeeId", 1)]) + logger.info(" ✓ Created compound index on 'payerId' + 'payeeId'") + + logger.info("") + + # ========================================== + # REFRESH_TOKENS COLLECTION INDEXES + # ========================================== + logger.info("📋 Creating indexes for 'refresh_tokens' collection...") + + # Token index (unique) - For token lookup during refresh + await db.refresh_tokens.create_index("token", unique=True) + logger.info(" ✓ Created unique index on 'token'") + + # Compound index: user_id + revoked - For finding user's active tokens + await db.refresh_tokens.create_index([("user_id", 1), ("revoked", 1)]) + logger.info(" ✓ Created compound index on 'user_id' + 'revoked'") + + # TTL index on expires_at - Auto-delete expired tokens + await db.refresh_tokens.create_index( + "expires_at", expireAfterSeconds=0 # Delete immediately when expired + ) + logger.info( + " ✓ Created TTL index on 'expires_at' (auto-cleanup expired tokens)" + ) + + logger.info("") + + # ========================================== + # PASSWORD_RESETS COLLECTION INDEXES + # ========================================== + logger.info("📋 Creating indexes for 'password_resets' collection...") + + # Token index (unique) - For reset token lookup + await db.password_resets.create_index("token", unique=True) + logger.info(" ✓ Created unique index on 'token'") + + # Compound index: user_id + used - For finding user's reset history + await db.password_resets.create_index([("user_id", 1), ("used", 1)]) + logger.info(" ✓ Created compound index on 'user_id' + 'used'") + + # TTL index on expires_at - Auto-delete expired reset tokens + await db.password_resets.create_index( + "expires_at", expireAfterSeconds=0 # Delete immediately when expired + ) + logger.info( + " ✓ Created TTL index on 'expires_at' (auto-cleanup expired resets)" + ) + + logger.info("") + + # ========================================== + # VERIFY INDEXES CREATED + # ========================================== + logger.info("=" * 60) + logger.info("VERIFICATION") + logger.info("=" * 60) + + collections_to_verify = [ + "users", + "groups", + "expenses", + "settlements", + "refresh_tokens", + "password_resets", + ] + + for collection_name in collections_to_verify: + indexes = await db[collection_name].index_information() + index_count = len(indexes) + logger.info(f"✓ {collection_name}: {index_count} indexes") + for index_name, index_info in indexes.items(): + if index_name != "_id_": # Skip default _id index + key = index_info.get("key", []) + unique = " (unique)" if index_info.get("unique") else "" + ttl = ( + f" (TTL: {index_info.get('expireAfterSeconds')}s)" + if "expireAfterSeconds" in index_info + else "" + ) + logger.info(f" - {index_name}: {key}{unique}{ttl}") + + logger.info("") + logger.info("=" * 60) + logger.info("✅ ALL INDEXES CREATED SUCCESSFULLY!") + logger.info("=" * 60) + logger.info("") + logger.info("Next steps:") + logger.info("1. Monitor query performance in logs") + logger.info("2. Run performance benchmarks") + logger.info("3. Check index usage with db.collection.explain()") + logger.info("") + + except Exception as e: + logger.error(f"❌ Error creating indexes: {e}") + raise + finally: + client.close() + logger.info("Database connection closed.") + + +async def drop_indexes_if_needed(): + """ + CAUTION: This drops all custom indexes. + Only use this if you need to recreate indexes from scratch. + """ + logger.warning("⚠️ DROP INDEXES MODE - THIS WILL DELETE ALL CUSTOM INDEXES") + logger.warning("Are you sure you want to continue? (yes/no)") + + # In automated scripts, you might want to pass a flag instead + # For now, we'll just skip this function unless explicitly called + + client = AsyncIOMotorClient(settings.mongodb_url) + db = client[settings.database_name] + + try: + collections = ["users", "groups", "expenses", "settlements", "refresh_tokens"] + + for collection_name in collections: + # Get all indexes + indexes = await db[collection_name].index_information() + + # Drop all except _id_ index + for index_name in indexes.keys(): + if index_name != "_id_": + await db[collection_name].drop_index(index_name) + logger.info( + f" Dropped index '{index_name}' from {collection_name}" + ) + + logger.info("✓ All custom indexes dropped") + + except Exception as e: + logger.error(f"Error dropping indexes: {e}") + raise + finally: + client.close() + + +if __name__ == "__main__": + logger.info("Starting index migration...") + asyncio.run(create_indexes()) + logger.info("Migration complete!") diff --git a/backend/migrations/002_fix_duplicate_emails.py b/backend/migrations/002_fix_duplicate_emails.py new file mode 100644 index 00000000..063fd272 --- /dev/null +++ b/backend/migrations/002_fix_duplicate_emails.py @@ -0,0 +1,271 @@ +""" +Find and Fix Duplicate Emails +============================== + +This script helps identify and resolve duplicate email addresses in the users collection. + +Usage: + python -m migrations.002_fix_duplicate_emails + +Options: + --dry-run : Show duplicates without fixing (default) + --fix : Automatically fix duplicates by keeping the oldest account + --interactive: Interactively choose which account to keep +""" + +import asyncio +import sys +from datetime import datetime +from pathlib import Path + +from app.config import logger, settings +from bson import ObjectId +from motor.motor_asyncio import AsyncIOMotorClient + +# Add parent directory to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent)) + + +async def find_duplicate_emails(db): + """Find all duplicate email addresses in the users collection""" + + logger.info("Searching for duplicate emails...") + + pipeline = [ + { + "$group": { + "_id": "$email", + "count": {"$sum": 1}, + "users": { + "$push": { + "id": "$_id", + "name": "$name", + "created_at": "$created_at", + } + }, + } + }, + {"$match": {"count": {"$gt": 1}}}, + {"$sort": {"count": -1}}, + ] + + duplicates = await db.users.aggregate(pipeline).to_list(length=1000) + + if not duplicates: + logger.info("✓ No duplicate emails found!") + return [] + + logger.info(f"⚠ Found {len(duplicates)} email(s) with duplicates:") + logger.info("") + + for dup in duplicates: + email = dup["_id"] + count = dup["count"] + users = dup["users"] + + logger.info(f"Email: {email}") + logger.info(f" Duplicate count: {count}") + logger.info(f" Accounts:") + + for i, user in enumerate(users, 1): + created = user.get("created_at", "Unknown") + if isinstance(created, datetime): + created = created.strftime("%Y-%m-%d %H:%M:%S") + logger.info( + f" {i}. ID: {user['id']}, Name: {user.get('name', 'Unknown')}, Created: {created}" + ) + + logger.info("") + + return duplicates + + +async def fix_duplicates_auto(db, duplicates): + """Automatically fix duplicates by keeping the oldest account and deleting others""" + + logger.info("=" * 60) + logger.info("AUTO-FIX MODE: Keeping oldest account for each email") + logger.info("=" * 60) + logger.info("") + + total_deleted = 0 + + for dup in duplicates: + email = dup["_id"] + users = dup["users"] + + # Sort by created_at (oldest first) + sorted_users = sorted(users, key=lambda x: x.get("created_at", datetime.min)) + + # Keep the oldest, delete the rest + keep_user = sorted_users[0] + delete_users = sorted_users[1:] + + logger.info(f"Email: {email}") + logger.info( + f" ✓ Keeping: ID {keep_user['id']} ({keep_user.get('name', 'Unknown')})" + ) + + for user in delete_users: + user_id = user["id"] + logger.info(f" ✗ Deleting: ID {user_id} ({user.get('name', 'Unknown')})") + + # Delete the user + result = await db.users.delete_one({"_id": user_id}) + if result.deleted_count > 0: + total_deleted += 1 + logger.info(f" → Deleted successfully") + else: + logger.warning(f" → Failed to delete") + + logger.info("") + + logger.info(f"Total accounts deleted: {total_deleted}") + return total_deleted + + +async def fix_duplicates_interactive(db, duplicates): + """Interactively ask user which account to keep for each duplicate""" + + logger.info("=" * 60) + logger.info("INTERACTIVE MODE: You choose which account to keep") + logger.info("=" * 60) + logger.info("") + + total_deleted = 0 + + for dup in duplicates: + email = dup["_id"] + users = dup["users"] + + logger.info(f"Email: {email}") + logger.info(f"Which account do you want to KEEP?") + + for i, user in enumerate(users, 1): + created = user.get("created_at", "Unknown") + if isinstance(created, datetime): + created = created.strftime("%Y-%m-%d %H:%M:%S") + logger.info( + f" {i}. Name: {user.get('name', 'Unknown')}, Created: {created}, ID: {user['id']}" + ) + + while True: + try: + choice = ( + input(f"\nEnter number to keep (1-{len(users)}) or 's' to skip: ") + .strip() + .lower() + ) + + if choice == "s": + logger.info("Skipped") + break + + choice_num = int(choice) + if 1 <= choice_num <= len(users): + keep_idx = choice_num - 1 + keep_user = users[keep_idx] + + logger.info( + f"✓ Keeping: {keep_user.get('name', 'Unknown')} (ID: {keep_user['id']})" + ) + + # Delete all others + for i, user in enumerate(users): + if i != keep_idx: + user_id = user["id"] + logger.info( + f"✗ Deleting: {user.get('name', 'Unknown')} (ID: {user_id})" + ) + + result = await db.users.delete_one({"_id": user_id}) + if result.deleted_count > 0: + total_deleted += 1 + break + else: + logger.warning(f"Invalid choice. Enter 1-{len(users)} or 's'") + except ValueError: + logger.warning(f"Invalid input. Enter a number 1-{len(users)} or 's'") + + logger.info("") + + logger.info(f"Total accounts deleted: {total_deleted}") + return total_deleted + + +async def main(): + """Main function""" + + import argparse + + parser = argparse.ArgumentParser(description="Find and fix duplicate emails") + parser.add_argument( + "--fix", action="store_true", help="Automatically fix by keeping oldest account" + ) + parser.add_argument( + "--interactive", + action="store_true", + help="Interactively choose which account to keep", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Only show duplicates without fixing (default)", + default=True, + ) + args = parser.parse_args() + + # Connect to MongoDB + client = AsyncIOMotorClient(settings.mongodb_url) + db = client[settings.database_name] + + try: + logger.info("=" * 60) + logger.info("DUPLICATE EMAIL FINDER AND FIXER") + logger.info("=" * 60) + logger.info(f"Database: {settings.database_name}") + logger.info("") + + # Find duplicates + duplicates = await find_duplicate_emails(db) + + if not duplicates: + logger.info("No action needed. Database is clean! ✓") + return + + # Determine mode + if args.fix: + logger.info("⚠ WARNING: This will DELETE duplicate accounts automatically!") + confirm = input("Type 'yes' to proceed: ").strip().lower() + if confirm == "yes": + await fix_duplicates_auto(db, duplicates) + else: + logger.info("Cancelled.") + + elif args.interactive: + await fix_duplicates_interactive(db, duplicates) + + else: + logger.info("=" * 60) + logger.info("DRY RUN MODE - No changes made") + logger.info("=" * 60) + logger.info("") + logger.info("To fix duplicates automatically (keep oldest):") + logger.info(" python -m migrations.002_fix_duplicate_emails --fix") + logger.info("") + logger.info("To fix duplicates interactively (you choose):") + logger.info(" python -m migrations.002_fix_duplicate_emails --interactive") + + logger.info("") + logger.info("=" * 60) + logger.info("DONE") + logger.info("=" * 60) + + except Exception as e: + logger.error(f"Error: {e}", exc_info=True) + finally: + client.close() + + +if __name__ == "__main__": + asyncio.run(main()) From 171147e49f315eb03494dbc3204628c0c95ba55d Mon Sep 17 00:00:00 2001 From: Devasy Patel <110348311+Devasy23@users.noreply.github.com> Date: Sun, 7 Dec 2025 14:35:59 +0530 Subject: [PATCH 2/8] feat: optimize settlement aggregation for friends' balance summary --- .../tests/expenses/test_expense_service.py | 84 +++++++++---------- 1 file changed, 38 insertions(+), 46 deletions(-) diff --git a/backend/tests/expenses/test_expense_service.py b/backend/tests/expenses/test_expense_service.py index 670d18a8..2b2bc44d 100644 --- a/backend/tests/expenses/test_expense_service.py +++ b/backend/tests/expenses/test_expense_service.py @@ -1602,53 +1602,45 @@ async def test_get_friends_balance_summary_success(expense_service): }, ] - # Mocking settlement aggregations for each friend in each group + # Mocking the OPTIMIZED settlement aggregation + # The new optimized version makes ONE aggregation call that returns all friends' balances # Friend 1: - # Group Alpha: Main owes Friend1 50 (net -50 for Main) - # Group Beta: Friend1 owes Main 30 (net +30 for Main) - # Total for Friend1: Main is owed 50, owes 30. Net: Main is owed 20 by Friend1. + # Group Alpha: Main owes Friend1 50 (balance: -50 for Main) + # Group Beta: Friend1 owes Main 30 (balance: +30 for Main) + # Total for Friend1: -50 + 30 = -20 (Main owes Friend1 20) # Friend 2: - # Group Beta: Main owes Friend2 70 (net -70 for Main) - # Total for Friend2: Main owes 70 to Friend2. + # Group Beta: Main owes Friend2 70 (balance: -70 for Main) + # Total for Friend2: -70 (Main owes Friend2 70) - # This is the side_effect for the .aggregate() call. It must be a sync function - # that returns a cursor mock (AsyncMock). def sync_mock_settlements_aggregate_cursor_factory(pipeline, *args, **kwargs): - match_clause = pipeline[0]["$match"] - group_id_pipeline = match_clause["groupId"] - or_conditions = match_clause["$or"] - - # Determine which friend is being processed based on payer/payee in OR condition - # This is a simplification; real queries are more complex - pipeline_friend_id = None - for cond in or_conditions: - if cond["payerId"] == user_id_str and cond["payeeId"] != user_id_str: - pipeline_friend_id = cond["payeeId"] - break - elif cond["payeeId"] == user_id_str and cond["payerId"] != user_id_str: - pipeline_friend_id = cond["payerId"] - break - + # The optimized version returns aggregated results for all friends in one go mock_agg_cursor = AsyncMock() - if group_id_pipeline == group1_id and pipeline_friend_id == friend1_id_str: - # Main owes Friend1 50 in Group Alpha - mock_agg_cursor.to_list.return_value = [ - {"_id": None, "userOwes": 50.0, "friendOwes": 0.0} - ] - elif group_id_pipeline == group2_id and pipeline_friend_id == friend1_id_str: - # Friend1 owes Main 30 in Group Beta - mock_agg_cursor.to_list.return_value = [ - {"_id": None, "userOwes": 0.0, "friendOwes": 30.0} - ] - elif group_id_pipeline == group2_id and pipeline_friend_id == friend2_id_str: - # Main owes Friend2 70 in Group Beta - mock_agg_cursor.to_list.return_value = [ - {"_id": None, "userOwes": 70.0, "friendOwes": 0.0} - ] - else: - mock_agg_cursor.to_list.return_value = [ - {"_id": None, "userOwes": 0.0, "friendOwes": 0.0} - ] # Default empty + mock_agg_cursor.to_list.return_value = [ + { + "_id": friend1_id_str, # Friend 1 + "totalBalance": -20.0, # Main owes Friend1 20 (net: -50 from G1, +30 from G2) + "groups": [ + { + "groupId": group1_id, + "balance": -50.0, + }, # Main owes 50 in Group Alpha + { + "groupId": group2_id, + "balance": 30.0, + }, # Friend1 owes 30 in Group Beta + ], + }, + { + "_id": friend2_id_str, # Friend 2 + "totalBalance": -70.0, # Main owes Friend2 70 + "groups": [ + { + "groupId": group2_id, + "balance": -70.0, + }, # Main owes 70 in Group Beta + ], + }, + ] return mock_agg_cursor with patch("app.expenses.service.mongodb") as mock_mongodb: @@ -1678,7 +1670,7 @@ def mock_user_find_cursor_side_effect(query, *args, **kwargs): mock_db.users.find = MagicMock(side_effect=mock_user_find_cursor_side_effect) - # Mock settlement aggregation logic + # Mock the optimized settlement aggregation logic # .aggregate() is sync, returns an async cursor. mock_db.settlements.aggregate = MagicMock( side_effect=sync_mock_settlements_aggregate_cursor_factory @@ -1739,9 +1731,9 @@ def mock_user_find_cursor_side_effect(query, *args, **kwargs): # Verify mocks mock_db.groups.find.assert_called_once_with({"members.userId": user_id_str}) - # settlements.aggregate is called for each friend in each group they share with user_id_str - # Friend1 is in 2 groups with user_id_str, Friend2 is in 1 group with user_id_str. Total 3 calls. - assert mock_db.settlements.aggregate.call_count == 3 + # OPTIMIZED: settlements.aggregate is called ONCE (not per friend/group) + # The optimized version uses a single aggregation pipeline to get all friends' balances + assert mock_db.settlements.aggregate.call_count == 1 @pytest.mark.asyncio From 57d2253e8b27d17189609772df7eb4a66c42443a Mon Sep 17 00:00:00 2001 From: Devasy Patel <110348311+Devasy23@users.noreply.github.com> Date: Sun, 7 Dec 2025 14:39:49 +0530 Subject: [PATCH 3/8] Update backend/app/expenses/service.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- backend/app/expenses/service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/app/expenses/service.py b/backend/app/expenses/service.py index d76dccb0..746f093a 100644 --- a/backend/app/expenses/service.py +++ b/backend/app/expenses/service.py @@ -1070,7 +1070,7 @@ async def get_friends_balance_summary(self, user_id: str) -> Dict[str, Any]: } # Extract unique friend IDs for batch fetching - friend_ids = list(set(result["_id"] for result in results)) + friend_ids = list({result["_id"] for result in results}) # Build group map from groups we already fetched groups_map = {str(g["_id"]): g.get("name", "Unknown Group") for g in groups} From 9cdc3171a398905396d2c0f978c5fbcd1db70240 Mon Sep 17 00:00:00 2001 From: Devasy Patel <110348311+Devasy23@users.noreply.github.com> Date: Sun, 7 Dec 2025 14:41:40 +0530 Subject: [PATCH 4/8] Update backend/app/groups/service.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- backend/app/groups/service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/app/groups/service.py b/backend/app/groups/service.py index 821ae3ff..f112192a 100644 --- a/backend/app/groups/service.py +++ b/backend/app/groups/service.py @@ -62,7 +62,7 @@ async def _enrich_members_with_user_details( "imageUrl": 1, }, # Project only needed fields ) - users_list = await users_cursor.to_list(length=100) + users_list = await users_cursor.to_list(length=len(user_ids)) # Create fast lookup dictionary: O(1) access per member users_map = {str(user["_id"]): user for user in users_list} From da9383e55721287a652134c1369cc756637a2d29 Mon Sep 17 00:00:00 2001 From: Devasy Patel <110348311+Devasy23@users.noreply.github.com> Date: Sun, 7 Dec 2025 14:46:58 +0530 Subject: [PATCH 5/8] feat: enhance expense service with timezone-aware datetime and improve balance retrieval --- backend/app/expenses/service.py | 8 +++++--- backend/tests/expenses/test_expense_service.py | 5 ++++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/backend/app/expenses/service.py b/backend/app/expenses/service.py index 746f093a..d01f2675 100644 --- a/backend/app/expenses/service.py +++ b/backend/app/expenses/service.py @@ -1,5 +1,5 @@ from collections import defaultdict -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from typing import Any, Dict, List, Optional from app.config import logger @@ -1093,7 +1093,7 @@ async def get_friends_balance_summary(self, user_id: str) -> Dict[str, Any]: for result in results: friend_id = result["_id"] - total_balance = result["totalBalance"] + total_balance = result.get("totalBalance", 0) # Get friend details from map friend_details = friends_map.get(friend_id) @@ -1129,7 +1129,9 @@ async def get_friends_balance_summary(self, user_id: str) -> Dict[str, Any]: "netBalance": round(total_balance, 2), "owesYou": total_balance > 0, "breakdown": breakdown, - "lastActivity": datetime.utcnow(), # TODO: Calculate actual last activity + "lastActivity": datetime.now( + timezone.utc + ), # TODO: Calculate actual last activity } friends_balance.append(friend_data) diff --git a/backend/tests/expenses/test_expense_service.py b/backend/tests/expenses/test_expense_service.py index 2b2bc44d..4177a58f 100644 --- a/backend/tests/expenses/test_expense_service.py +++ b/backend/tests/expenses/test_expense_service.py @@ -1,5 +1,6 @@ import asyncio from datetime import datetime, timedelta, timezone +from typing import Any from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -1612,7 +1613,9 @@ async def test_get_friends_balance_summary_success(expense_service): # Group Beta: Main owes Friend2 70 (balance: -70 for Main) # Total for Friend2: -70 (Main owes Friend2 70) - def sync_mock_settlements_aggregate_cursor_factory(pipeline, *args, **kwargs): + def sync_mock_settlements_aggregate_cursor_factory( + _pipeline: Any, *_args: Any, **_kwargs: Any + ) -> AsyncMock: # The optimized version returns aggregated results for all friends in one go mock_agg_cursor = AsyncMock() mock_agg_cursor.to_list.return_value = [ From 338722aab26412f292642f314824b3ee552a79a8 Mon Sep 17 00:00:00 2001 From: Devasy Patel <110348311+Devasy23@users.noreply.github.com> Date: Sun, 7 Dec 2025 14:52:53 +0530 Subject: [PATCH 6/8] feat: add tests for optimized member enrichment function in group service --- .../tests/expenses/test_expense_service.py | 117 ++++++++++ backend/tests/groups/test_enrich_members.py | 215 ++++++++++++++++++ 2 files changed, 332 insertions(+) create mode 100644 backend/tests/groups/test_enrich_members.py diff --git a/backend/tests/expenses/test_expense_service.py b/backend/tests/expenses/test_expense_service.py index 4177a58f..5271339a 100644 --- a/backend/tests/expenses/test_expense_service.py +++ b/backend/tests/expenses/test_expense_service.py @@ -2109,5 +2109,122 @@ async def test_get_group_analytics_group_not_found(expense_service): mock_db.users.find_one.assert_not_called() +@pytest.mark.asyncio +async def test_get_friends_balance_summary_aggregation_error(expense_service): + """Test friends balance summary when aggregation fails""" + user_id_str = str(ObjectId()) + + with patch("app.expenses.service.mongodb") as mock_mongodb: + mock_db = MagicMock() + mock_mongodb.database = mock_db + + # Mock groups + mock_groups = [ + { + "_id": ObjectId(), + "name": "Test Group", + "members": [{"userId": user_id_str}, {"userId": str(ObjectId())}], + } + ] + mock_groups_cursor = AsyncMock() + mock_groups_cursor.to_list.return_value = mock_groups + mock_db.groups.find.return_value = mock_groups_cursor + + # Mock aggregation failure + mock_agg_cursor = AsyncMock() + mock_agg_cursor.to_list.side_effect = Exception("Aggregation failed") + mock_db.settlements.aggregate.return_value = mock_agg_cursor + + result = await expense_service.get_friends_balance_summary(user_id_str) + + # Should return empty results on error + assert len(result["friendsBalance"]) == 0 + assert result["summary"]["totalOwedToYou"] == 0 + assert result["summary"]["totalYouOwe"] == 0 + assert result["summary"]["friendCount"] == 0 + + +@pytest.mark.asyncio +async def test_get_friends_balance_summary_user_fetch_error(expense_service): + """Test friends balance summary when fetching user details fails""" + user_id_str = str(ObjectId()) + friend_id_str = str(ObjectId()) + + with patch("app.expenses.service.mongodb") as mock_mongodb: + mock_db = MagicMock() + mock_mongodb.database = mock_db + + # Mock groups + mock_groups = [ + { + "_id": ObjectId(), + "name": "Test Group", + "members": [{"userId": user_id_str}, {"userId": friend_id_str}], + } + ] + mock_groups_cursor = AsyncMock() + mock_groups_cursor.to_list.return_value = mock_groups + mock_db.groups.find.return_value = mock_groups_cursor + + # Mock aggregation success + mock_agg_cursor = AsyncMock() + mock_agg_cursor.to_list.return_value = [ + { + "_id": friend_id_str, + "totalBalance": 50.0, + "groups": [{"groupId": str(mock_groups[0]["_id"]), "balance": 50.0}], + } + ] + mock_db.settlements.aggregate.return_value = mock_agg_cursor + + # Mock user fetch failure + mock_users_cursor = AsyncMock() + mock_users_cursor.to_list.side_effect = Exception("User fetch failed") + mock_db.users.find.return_value = mock_users_cursor + + result = await expense_service.get_friends_balance_summary(user_id_str) + + # Should still return results but with "Unknown" for user names + assert len(result["friendsBalance"]) == 1 + assert result["friendsBalance"][0]["userName"] == "Unknown" + assert result["friendsBalance"][0]["netBalance"] == 50.0 + + +@pytest.mark.asyncio +async def test_get_friends_balance_summary_zero_balance_filtering(expense_service): + """Test that friends with zero balance are filtered out""" + user_id_str = str(ObjectId()) + + with patch("app.expenses.service.mongodb") as mock_mongodb: + mock_db = MagicMock() + mock_mongodb.database = mock_db + + # Mock groups + mock_groups = [ + { + "_id": ObjectId(), + "name": "Test Group", + "members": [{"userId": user_id_str}], + } + ] + mock_groups_cursor = AsyncMock() + mock_groups_cursor.to_list.return_value = mock_groups + mock_db.groups.find.return_value = mock_groups_cursor + + # Mock aggregation returns no results (all filtered by zero balance) + mock_agg_cursor = AsyncMock() + mock_agg_cursor.to_list.return_value = [] + mock_db.settlements.aggregate.return_value = mock_agg_cursor + + result = await expense_service.get_friends_balance_summary(user_id_str) + + # Should return empty friend balance + assert len(result["friendsBalance"]) == 0 + assert result["summary"]["totalOwedToYou"] == 0 + assert result["summary"]["totalYouOwe"] == 0 + assert result["summary"]["friendCount"] == 0 + assert result["summary"]["activeGroups"] == 1 + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/backend/tests/groups/test_enrich_members.py b/backend/tests/groups/test_enrich_members.py new file mode 100644 index 00000000..ef252d3e --- /dev/null +++ b/backend/tests/groups/test_enrich_members.py @@ -0,0 +1,215 @@ +"""Tests for the optimized _enrich_members_with_user_details function""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from app.groups.service import GroupService +from bson import ObjectId + + +class TestEnrichMembersOptimized: + """Test cases for _enrich_members_with_user_details optimized function""" + + def setup_method(self): + """Setup for each test method""" + self.service = GroupService() + + @pytest.mark.asyncio + async def test_enrich_members_with_user_details_success(self): + """Test successful enrichment of members with user details""" + user_id_1 = str(ObjectId()) + user_id_2 = str(ObjectId()) + user_id_3 = str(ObjectId()) + + members = [ + {"userId": user_id_1, "role": "admin", "joinedAt": "2023-01-01"}, + {"userId": user_id_2, "role": "member", "joinedAt": "2023-01-02"}, + {"userId": user_id_3, "role": "member", "joinedAt": "2023-01-03"}, + ] + + mock_users = [ + {"_id": ObjectId(user_id_1), "name": "Admin User", "imageUrl": "admin.jpg"}, + { + "_id": ObjectId(user_id_2), + "name": "Member One", + "imageUrl": "member1.jpg", + }, + {"_id": ObjectId(user_id_3), "name": "Member Two", "imageUrl": None}, + ] + + mock_db = MagicMock() + mock_users_collection = MagicMock() + mock_db.users = mock_users_collection + + # Mock the find operation + mock_cursor = AsyncMock() + mock_cursor.to_list.return_value = mock_users + mock_users_collection.find.return_value = mock_cursor + + with patch.object(self.service, "get_db", return_value=mock_db): + enriched = await self.service._enrich_members_with_user_details(members) + + assert len(enriched) == 3 + assert enriched[0]["userId"] == user_id_1 + assert enriched[0]["user"]["name"] == "Admin User" + assert enriched[0]["user"]["imageUrl"] == "admin.jpg" + assert enriched[0]["role"] == "admin" + + assert enriched[1]["userId"] == user_id_2 + assert enriched[1]["user"]["name"] == "Member One" + assert enriched[1]["user"]["imageUrl"] == "member1.jpg" + + assert enriched[2]["userId"] == user_id_3 + assert enriched[2]["user"]["name"] == "Member Two" + assert enriched[2]["user"]["imageUrl"] is None + + # Verify the query was made correctly with $in operator + mock_users_collection.find.assert_called_once() + call_args = mock_users_collection.find.call_args + assert "_id" in call_args[0][0] + assert "$in" in call_args[0][0]["_id"] + + @pytest.mark.asyncio + async def test_enrich_members_empty_list(self): + """Test enrichment with empty members list""" + mock_db = MagicMock() + + with patch.object(self.service, "get_db", return_value=mock_db): + enriched = await self.service._enrich_members_with_user_details([]) + + assert enriched == [] + # Verify no database call was made + mock_db.users.find.assert_not_called() + + @pytest.mark.asyncio + async def test_enrich_members_missing_user_data(self): + """Test enrichment when some users are not found in database""" + user_id_1 = str(ObjectId()) + user_id_2 = str(ObjectId()) + + members = [ + {"userId": user_id_1, "role": "admin", "joinedAt": "2023-01-01"}, + {"userId": user_id_2, "role": "member", "joinedAt": "2023-01-02"}, + ] + + # Only return data for user_id_1, not user_id_2 + mock_users = [ + {"_id": ObjectId(user_id_1), "name": "Admin User", "imageUrl": "admin.jpg"}, + ] + + mock_db = MagicMock() + mock_users_collection = MagicMock() + mock_db.users = mock_users_collection + + mock_cursor = AsyncMock() + mock_cursor.to_list.return_value = mock_users + mock_users_collection.find.return_value = mock_cursor + + with patch.object(self.service, "get_db", return_value=mock_db): + enriched = await self.service._enrich_members_with_user_details(members) + + assert len(enriched) == 2 + assert enriched[0]["user"]["name"] == "Admin User" + # Missing user should have fallback name + assert "User" in enriched[1]["user"]["name"] # Will be "User " + + @pytest.mark.asyncio + async def test_enrich_members_database_error(self): + """Test enrichment when database query fails""" + user_id_1 = str(ObjectId()) + + members = [ + {"userId": user_id_1, "role": "admin", "joinedAt": "2023-01-01"}, + ] + + mock_db = MagicMock() + mock_users_collection = MagicMock() + mock_db.users = mock_users_collection + + # Simulate database error + mock_cursor = AsyncMock() + mock_cursor.to_list.side_effect = Exception("Database connection error") + mock_users_collection.find.return_value = mock_cursor + + with patch.object(self.service, "get_db", return_value=mock_db): + enriched = await self.service._enrich_members_with_user_details(members) + + # Should still return members with fallback user data + assert len(enriched) == 1 + assert "User" in enriched[0]["user"]["name"] # Fallback name + assert enriched[0]["user"]["imageUrl"] is None + + @pytest.mark.asyncio + async def test_enrich_members_preserves_member_fields(self): + """Test that enrichment preserves all original member fields""" + user_id_1 = str(ObjectId()) + + members = [ + { + "userId": user_id_1, + "role": "admin", + "joinedAt": "2023-01-01", + "customField": "custom_value", + }, + ] + + mock_users = [ + {"_id": ObjectId(user_id_1), "name": "Admin User", "imageUrl": "admin.jpg"}, + ] + + mock_db = MagicMock() + mock_users_collection = MagicMock() + mock_db.users = mock_users_collection + + mock_cursor = AsyncMock() + mock_cursor.to_list.return_value = mock_users + mock_users_collection.find.return_value = mock_cursor + + with patch.object(self.service, "get_db", return_value=mock_db): + enriched = await self.service._enrich_members_with_user_details(members) + + # Verify all fields are preserved + assert enriched[0]["userId"] == user_id_1 + assert enriched[0]["role"] == "admin" + assert enriched[0]["joinedAt"] == "2023-01-01" + # Note: customField won't be in the output as the function creates a new structure + # It only preserves userId, role, and joinedAt + assert enriched[0]["user"]["name"] == "Admin User" + assert enriched[0]["user"]["imageUrl"] == "admin.jpg" + + @pytest.mark.asyncio + async def test_enrich_members_batch_query_optimization(self): + """Test that the function uses a single batch query instead of N queries""" + # Create 10 members + members = [] + user_ids = [] + for i in range(10): + user_id = str(ObjectId()) + user_ids.append(user_id) + members.append( + {"userId": user_id, "role": "member", "joinedAt": f"2023-01-{i+1:02d}"} + ) + + mock_users = [ + {"_id": ObjectId(uid), "name": f"User {i}", "imageUrl": None} + for i, uid in enumerate(user_ids) + ] + + mock_db = MagicMock() + mock_users_collection = MagicMock() + mock_db.users = mock_users_collection + + mock_cursor = AsyncMock() + mock_cursor.to_list.return_value = mock_users + mock_users_collection.find.return_value = mock_cursor + + with patch.object(self.service, "get_db", return_value=mock_db): + enriched = await self.service._enrich_members_with_user_details(members) + + # Verify only ONE database call was made (batch query) + assert mock_users_collection.find.call_count == 1 + + # Verify all 10 members were enriched + assert len(enriched) == 10 + for i, member in enumerate(enriched): + assert member["user"]["name"] == f"User {i}" From c16925e3c3ec26e32367bec86310b52099211e00 Mon Sep 17 00:00:00 2001 From: Devasy Patel <110348311+Devasy23@users.noreply.github.com> Date: Sun, 7 Dec 2025 14:57:13 +0530 Subject: [PATCH 7/8] feat: add tests for handling zero and negative balance scenarios in friends' balance summary --- .../tests/expenses/test_expense_service.py | 51 ++++++++++++++++++- backend/tests/groups/test_enrich_members.py | 51 ++++++++++++++++++- 2 files changed, 100 insertions(+), 2 deletions(-) diff --git a/backend/tests/expenses/test_expense_service.py b/backend/tests/expenses/test_expense_service.py index 5271339a..32976071 100644 --- a/backend/tests/expenses/test_expense_service.py +++ b/backend/tests/expenses/test_expense_service.py @@ -2192,7 +2192,7 @@ async def test_get_friends_balance_summary_user_fetch_error(expense_service): @pytest.mark.asyncio async def test_get_friends_balance_summary_zero_balance_filtering(expense_service): - """Test that friends with zero balance are filtered out""" + """Test that friends with zero balance are filtered out - covers line 1061""" user_id_str = str(ObjectId()) with patch("app.expenses.service.mongodb") as mock_mongodb: @@ -2226,5 +2226,54 @@ async def test_get_friends_balance_summary_zero_balance_filtering(expense_servic assert result["summary"]["activeGroups"] == 1 +@pytest.mark.asyncio +async def test_get_friends_balance_summary_negative_balance(expense_service): + """Test friends balance with negative balance (user owes) - covers line 1141""" + user_id_str = str(ObjectId()) + friend_id_str = str(ObjectId()) + group_id = str(ObjectId()) + + with patch("app.expenses.service.mongodb") as mock_mongodb: + mock_db = MagicMock() + mock_mongodb.database = mock_db + + # Mock groups + mock_groups = [ + { + "_id": ObjectId(group_id), + "name": "Test Group", + "members": [{"userId": user_id_str}, {"userId": friend_id_str}], + } + ] + mock_groups_cursor = AsyncMock() + mock_groups_cursor.to_list.return_value = mock_groups + mock_db.groups.find.return_value = mock_groups_cursor + + # Mock aggregation with NEGATIVE balance (user owes friend) + mock_agg_cursor = AsyncMock() + mock_agg_cursor.to_list.return_value = [ + { + "_id": friend_id_str, + "totalBalance": -100.0, # Negative = user owes friend + "groups": [{"groupId": group_id, "balance": -100.0}], + } + ] + mock_db.settlements.aggregate.return_value = mock_agg_cursor + + # Mock user fetch + mock_users_cursor = AsyncMock() + mock_users_cursor.to_list.return_value = [ + {"_id": ObjectId(friend_id_str), "name": "Friend", "imageUrl": None} + ] + mock_db.users.find.return_value = mock_users_cursor + + result = await expense_service.get_friends_balance_summary(user_id_str) + + # Should have totalYouOwe = 100 (covers line 1141 - else branch) + assert result["summary"]["totalOwedToYou"] == 0 + assert result["summary"]["totalYouOwe"] == 100.0 + assert result["summary"]["netBalance"] == -100.0 + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/backend/tests/groups/test_enrich_members.py b/backend/tests/groups/test_enrich_members.py index ef252d3e..8b675f44 100644 --- a/backend/tests/groups/test_enrich_members.py +++ b/backend/tests/groups/test_enrich_members.py @@ -71,7 +71,7 @@ async def test_enrich_members_with_user_details_success(self): @pytest.mark.asyncio async def test_enrich_members_empty_list(self): - """Test enrichment with empty members list""" + """Test enrichment with empty members list - covers line 35""" mock_db = MagicMock() with patch.object(self.service, "get_db", return_value=mock_db): @@ -81,6 +81,55 @@ async def test_enrich_members_empty_list(self): # Verify no database call was made mock_db.users.find.assert_not_called() + @pytest.mark.asyncio + async def test_enrich_members_invalid_object_ids(self): + """Test enrichment with invalid ObjectIds - covers lines 46-47""" + members = [ + {"userId": "invalid_id_123", "role": "admin", "joinedAt": "2023-01-01"}, + {"userId": "also_invalid", "role": "member", "joinedAt": "2023-01-02"}, + ] + + mock_db = MagicMock() + + with patch.object(self.service, "get_db", return_value=mock_db): + enriched = await self.service._enrich_members_with_user_details(members) + + # Should return fallback members since no valid ObjectIds - covers line 52 + assert len(enriched) == 2 + assert "User" in enriched[0]["user"]["name"] + assert enriched[0]["role"] == "admin" + + @pytest.mark.asyncio + async def test_enrich_members_member_without_userId(self): + """Test enrichment when member has no userId - covers line 99""" + user_id_1 = str(ObjectId()) + + members = [ + {"userId": user_id_1, "role": "admin", "joinedAt": "2023-01-01"}, + {"role": "member", "joinedAt": "2023-01-02"}, # No userId + ] + + mock_users = [ + {"_id": ObjectId(user_id_1), "name": "Admin User", "imageUrl": "admin.jpg"}, + ] + + mock_db = MagicMock() + mock_users_collection = MagicMock() + mock_db.users = mock_users_collection + + mock_cursor = AsyncMock() + mock_cursor.to_list.return_value = mock_users + mock_users_collection.find.return_value = mock_cursor + + with patch.object(self.service, "get_db", return_value=mock_db): + enriched = await self.service._enrich_members_with_user_details(members) + + assert len(enriched) == 2 + assert enriched[0]["user"]["name"] == "Admin User" + # Second member should be returned as-is since no userId + assert enriched[1]["role"] == "member" + assert "user" not in enriched[1] or enriched[1] == members[1] + @pytest.mark.asyncio async def test_enrich_members_missing_user_data(self): """Test enrichment when some users are not found in database""" From 790ae8aedd2861903aa83979e05726dbbcb1e79d Mon Sep 17 00:00:00 2001 From: Devasy Patel <110348311+Devasy23@users.noreply.github.com> Date: Sun, 7 Dec 2025 15:03:38 +0530 Subject: [PATCH 8/8] feat: enhance tests for member enrichment function with improved coverage and logging --- backend/tests/groups/test_enrich_members.py | 52 +++++++++++++++++++-- 1 file changed, 49 insertions(+), 3 deletions(-) diff --git a/backend/tests/groups/test_enrich_members.py b/backend/tests/groups/test_enrich_members.py index 8b675f44..772a1eb1 100644 --- a/backend/tests/groups/test_enrich_members.py +++ b/backend/tests/groups/test_enrich_members.py @@ -83,7 +83,7 @@ async def test_enrich_members_empty_list(self): @pytest.mark.asyncio async def test_enrich_members_invalid_object_ids(self): - """Test enrichment with invalid ObjectIds - covers lines 46-47""" + """Test enrichment with invalid ObjectIds - covers lines 47, 52, 108-109""" members = [ {"userId": "invalid_id_123", "role": "admin", "joinedAt": "2023-01-01"}, {"userId": "also_invalid", "role": "member", "joinedAt": "2023-01-02"}, @@ -92,12 +92,20 @@ async def test_enrich_members_invalid_object_ids(self): mock_db = MagicMock() with patch.object(self.service, "get_db", return_value=mock_db): - enriched = await self.service._enrich_members_with_user_details(members) + # Patch logger to verify warning is called (covers line 47) + with patch("app.groups.service.logger") as mock_logger: + enriched = await self.service._enrich_members_with_user_details(members) + # Verify logger.warning was called for invalid ObjectIds (line 47) + assert mock_logger.warning.call_count == 2 - # Should return fallback members since no valid ObjectIds - covers line 52 + # Should return fallback members since no valid ObjectIds (line 52) assert len(enriched) == 2 + # Verify _create_fallback_member output (lines 108-109) + assert enriched[0]["userId"] == "invalid_id_123" assert "User" in enriched[0]["user"]["name"] assert enriched[0]["role"] == "admin" + assert enriched[0]["user"]["email"] == "invalid_id_123@example.com" + assert enriched[0]["joinedAt"] == "2023-01-01" @pytest.mark.asyncio async def test_enrich_members_member_without_userId(self): @@ -262,3 +270,41 @@ async def test_enrich_members_batch_query_optimization(self): assert len(enriched) == 10 for i, member in enumerate(enriched): assert member["user"]["name"] == f"User {i}" + + def test_create_fallback_member_short_user_id(self): + """Test _create_fallback_member with short user ID - covers lines 108-109""" + member = {"userId": "ab", "role": "member", "joinedAt": "2023-01-01"} + + result = self.service._create_fallback_member(member) + + # Verify fallback member structure (lines 108-109) + assert result["userId"] == "ab" + assert result["role"] == "member" + assert result["joinedAt"] == "2023-01-01" + assert result["user"]["name"] == "User ab" # Uses the short ID since len < 4 + assert result["user"]["email"] == "ab@example.com" + assert result["user"]["imageUrl"] is None + + def test_create_fallback_member_long_user_id(self): + """Test _create_fallback_member with long user ID - covers lines 108-109""" + long_id = "abcdefghijklmnop" + member = {"userId": long_id, "role": "admin", "joinedAt": "2023-01-01"} + + result = self.service._create_fallback_member(member) + + # Verify fallback member structure with last 4 chars + assert result["userId"] == long_id + assert result["role"] == "admin" + assert result["user"]["name"] == "User mnop" # Uses last 4 chars + assert result["user"]["email"] == "abcdefghijklmnop@example.com" + assert result["user"]["imageUrl"] is None + + def test_create_fallback_member_no_user_id(self): + """Test _create_fallback_member when userId is missing""" + member = {"role": "member", "joinedAt": "2023-01-01"} # No userId + + result = self.service._create_fallback_member(member) + + # Should use "unknown" as fallback + assert result["userId"] == "unknown" + assert result["user"]["name"] == "User nown" # Last 4 chars of "unknown"