Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/run-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ jobs:
python -m pip install --upgrade pip
cd backend
pip install -r requirements.txt
pip install -r requirements-dev.txt

- name: Run Backend Tests with Coverage
run: |
Expand Down
2 changes: 0 additions & 2 deletions backend/app/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@
from .routes import router
from .schemas import UserResponse
from .security import create_access_token, verify_token
from .service import auth_service

__all__ = [
"router",
"auth_service",
"verify_token",
"create_access_token",
"UserResponse",
Expand Down
38 changes: 28 additions & 10 deletions backend/app/auth/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
TokenVerifyRequest,
UserResponse,
)
from app.auth.security import create_access_token, oauth2_scheme # Import oauth2_scheme
from app.auth.service import auth_service
from app.auth.security import create_access_token, oauth2_scheme
from app.auth.service import AuthService
from app.config import settings
from app.dependencies import get_auth_service
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import ( # Import OAuth2PasswordRequestForm
OAuth2PasswordRequestForm,
Expand All @@ -28,7 +29,10 @@
@router.post(
"/token", response_model=TokenResponse, include_in_schema=False
) # include_in_schema=False to hide from docs if desired, or True to show
async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
async def login_for_access_token(
form_data: OAuth2PasswordRequestForm = Depends(),
auth_service: AuthService = Depends(get_auth_service),
):
"""
OAuth2 compatible token login, get an access token for future requests.
This endpoint is used by Swagger UI for authorization.
Expand Down Expand Up @@ -59,7 +63,9 @@ async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(


@router.post("/signup/email", response_model=AuthResponse)
async def signup_with_email(request: EmailSignupRequest):
async def signup_with_email(
request: EmailSignupRequest, auth_service: AuthService = Depends(get_auth_service)
):
"""
Registers a new user using email, password, and name, and returns authentication tokens and user information.

Expand Down Expand Up @@ -101,7 +107,9 @@ async def signup_with_email(request: EmailSignupRequest):


@router.post("/login/email", response_model=AuthResponse)
async def login_with_email(request: EmailLoginRequest):
async def login_with_email(
request: EmailLoginRequest, auth_service: AuthService = Depends(get_auth_service)
):
"""
Authenticates a user using email and password credentials.

Expand Down Expand Up @@ -136,7 +144,9 @@ async def login_with_email(request: EmailLoginRequest):


@router.post("/login/google", response_model=AuthResponse)
async def login_with_google(request: GoogleLoginRequest):
async def login_with_google(
request: GoogleLoginRequest, auth_service: AuthService = Depends(get_auth_service)
):
"""
Authenticates or registers a user using a Google OAuth ID token.

Expand Down Expand Up @@ -169,7 +179,9 @@ async def login_with_google(request: GoogleLoginRequest):


@router.post("/refresh", response_model=TokenResponse)
async def refresh_token(request: RefreshTokenRequest):
async def refresh_token(
request: RefreshTokenRequest, auth_service: AuthService = Depends(get_auth_service)
):
"""
Refreshes JWT tokens using a valid refresh token.

Expand Down Expand Up @@ -213,7 +225,9 @@ async def refresh_token(request: RefreshTokenRequest):


@router.post("/token/verify", response_model=UserResponse)
async def verify_token(request: TokenVerifyRequest):
async def verify_token(
request: TokenVerifyRequest, auth_service: AuthService = Depends(get_auth_service)
):
"""
Verifies an access token and returns the associated user information.

Expand All @@ -236,7 +250,9 @@ async def verify_token(request: TokenVerifyRequest):


@router.post("/password/reset/request", response_model=SuccessResponse)
async def request_password_reset(request: PasswordResetRequest):
async def request_password_reset(
request: PasswordResetRequest, auth_service: AuthService = Depends(get_auth_service)
):
"""
Initiates a password reset process by sending a reset link to the provided email address.

Expand All @@ -256,7 +272,9 @@ async def request_password_reset(request: PasswordResetRequest):


@router.post("/password/reset/confirm", response_model=SuccessResponse)
async def confirm_password_reset(request: PasswordResetConfirm):
async def confirm_password_reset(
request: PasswordResetConfirm, auth_service: AuthService = Depends(get_auth_service)
):
"""
Resets a user's password using a valid password reset token.

Expand Down
77 changes: 23 additions & 54 deletions backend/app/auth/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,30 +55,16 @@
},
)
logger.info("Firebase initialized with credentials from environment variables")
# Fall back to service account JSON file if env vars are not available
elif os.path.exists(settings.firebase_service_account_path):
cred = credentials.Certificate(settings.firebase_service_account_path)
firebase_admin.initialize_app(
cred,
{
"projectId": settings.firebase_project_id,
},
)
logger.info("Firebase initialized with service account file")
else:
logger.warning("Firebase service account not found. Google auth will not work.")


class AuthService:
def __init__(self):
# Initializes the AuthService instance.
pass

def get_db(self):
"""
Returns a database connection instance from the application's database module.
"""
return get_database()
self.db = get_database()
self.users_collection = self.db["users"]
self.refresh_tokens_collection = self.db["refresh_tokens"]
self.password_resets_collection = self.db["password_resets"]

async def create_user_with_email(
self, email: str, password: str, name: str
Expand All @@ -99,10 +85,8 @@ async def create_user_with_email(
Raises:
HTTPException: If a user with the given email already exists.
"""
db = self.get_db()

# Check if user already exists
existing_user = await db.users.find_one({"email": email})
existing_user = await self.users_collection.find_one({"email": email})
if existing_user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
Expand All @@ -122,7 +106,7 @@ async def create_user_with_email(
}

try:
result = await db.users.insert_one(user_doc)
result = await self.users_collection.insert_one(user_doc)
user_doc["_id"] = str(result.inserted_id)

# Create refresh token
Expand Down Expand Up @@ -154,9 +138,8 @@ async def authenticate_user_with_email(
Returns:
A dictionary containing the authenticated user and a new refresh token.
"""
db = self.get_db()
try:
user = await db.users.find_one({"email": email})
user = await self.users_collection.find_one({"email": email})
except PyMongoError as e:
logger.error(f"Database error during user lookup: {e}")
raise HTTPException(
Expand Down Expand Up @@ -215,11 +198,9 @@ async def authenticate_with_google(self, id_token: str) -> Dict[str, Any]:
detail="Email not provided by Google",
)

db = self.get_db()

# Check if user exists
try:
user = await db.users.find_one(
user = await self.users_collection.find_one(
{"$or": [{"email": email}, {"firebase_uid": firebase_uid}]}
)
except PyMongoError as e:
Expand All @@ -238,7 +219,7 @@ async def authenticate_with_google(self, id_token: str) -> Dict[str, Any]:

if update_data:
try:
await db.users.update_one(
await self.users_collection.update_one(
{"_id": user["_id"]}, {"$set": update_data}
)
user.update(update_data)
Expand All @@ -257,7 +238,7 @@ async def authenticate_with_google(self, id_token: str) -> Dict[str, Any]:
"hashed_password": None,
}
try:
result = await db.users.insert_one(user_doc)
result = await self.users_collection.insert_one(user_doc)
user_doc["_id"] = result.inserted_id
user = user_doc
except PyMongoError as e:
Expand Down Expand Up @@ -303,11 +284,9 @@ async def refresh_access_token(self, refresh_token: str) -> str:
Returns:
A new refresh token string.
"""
db = self.get_db()

# Find and validate refresh token
try:
token_record = await db.refresh_tokens.find_one(
token_record = await self.refresh_tokens_collection.find_one(
{
"token": refresh_token,
"revoked": False,
Expand All @@ -329,7 +308,9 @@ async def refresh_access_token(self, refresh_token: str) -> str:

# Get user
try:
user = await db.users.find_one({"_id": token_record["user_id"]})
user = await self.users_collection.find_one(
{"_id": token_record["user_id"]}
)
except PyMongoError as e:
logger.error("Error while fetching user: %s", str(e))
raise HTTPException(
Expand All @@ -355,7 +336,7 @@ async def refresh_access_token(self, refresh_token: str) -> str:

# Revoke old token
try:
await db.refresh_tokens.update_one(
await self.refresh_tokens_collection.update_one(
{"_id": token_record["_id"]}, {"$set": {"revoked": True}}
)
except PyMongoError as e:
Expand Down Expand Up @@ -393,10 +374,8 @@ async def verify_access_token(self, token: str) -> Dict[str, Any]:
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
)

db = self.get_db()

try:
user = await db.users.find_one({"_id": user_id})
user = await self.users_collection.find_one({"_id": user_id})
except Exception as e:
logger.error("Error while verifying token: %s", str(e))
raise HTTPException(
Expand All @@ -417,10 +396,8 @@ async def request_password_reset(self, email: str) -> bool:

If the user exists, generates a password reset token with a 1-hour expiration and stores it in the database. The reset token and link are logged for development purposes. Always returns True to avoid revealing whether the email is registered.
"""
db = self.get_db()

try:
user = await db.users.find_one({"email": email})
user = await self.users_collection.find_one({"email": email})
except PyMongoError as e:
logger.error(
f"Database error while fetching user by email {email}: {str(e)}"
Expand All @@ -439,7 +416,7 @@ async def request_password_reset(self, email: str) -> bool:

try:
# Store reset token
await db.password_resets.insert_one(
await self.password_resets_collection.insert_one(
{
"user_id": user["_id"],
"token": reset_token,
Expand Down Expand Up @@ -481,11 +458,9 @@ async def confirm_password_reset(self, reset_token: str, new_password: str) -> b
Raises:
HTTPException: If the reset token is invalid or expired.
"""
db = self.get_db()

try:
# Find and validate reset token
reset_record = await db.password_resets.find_one(
reset_record = await self.password_resets_collection.find_one(
{
"token": reset_token,
"used": False,
Expand All @@ -502,18 +477,18 @@ async def confirm_password_reset(self, reset_token: str, new_password: str) -> b

# Update user password
new_hash = get_password_hash(new_password)
await db.users.update_one(
await self.users_collection.update_one(
{"_id": reset_record["user_id"]},
{"$set": {"hashed_password": new_hash}},
)

# Mark token as used
await db.password_resets.update_one(
await self.password_resets_collection.update_one(
{"_id": reset_record["_id"]}, {"$set": {"used": True}}
)

# Revoke all refresh tokens for this user (force re-login)
await db.refresh_tokens.update_many(
await self.refresh_tokens_collection.update_many(
{"user_id": reset_record["user_id"]}, {"$set": {"revoked": True}}
)
logger.info(
Expand Down Expand Up @@ -542,15 +517,13 @@ async def _create_refresh_token_record(self, user_id: str) -> str:
Returns:
The generated refresh token string.
"""
db = self.get_db()

refresh_token = create_refresh_token()
expires_at = datetime.now(timezone.utc) + timedelta(
days=settings.refresh_token_expire_days
)

try:
await db.refresh_tokens.insert_one(
await self.refresh_tokens_collection.insert_one(
{
"token": refresh_token,
"user_id": (
Expand All @@ -571,7 +544,3 @@ async def _create_refresh_token_record(self, user_id: str) -> str:
)

return refresh_token


# Create service instance
auth_service = AuthService()
15 changes: 9 additions & 6 deletions backend/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import time
from logging.config import dictConfig
from typing import Optional
from typing import List, Optional

from pydantic_settings import BaseSettings
from starlette.middleware.base import BaseHTTPMiddleware
Expand All @@ -15,13 +15,13 @@ class Settings(BaseSettings):
database_name: str = "splitwiser"

# JWT
secret_key: str = "your-super-secret-jwt-key-change-this-in-production"
secret_key: str
algorithm: str = "HS256"
access_token_expire_minutes: int = 15
refresh_token_expire_days: int = 30

# Firebase
firebase_project_id: Optional[str] = None
firebase_service_account_path: str = "./firebase-service-account.json"
# Firebase service account credentials as environment variables
firebase_type: Optional[str] = None
firebase_private_key_id: Optional[str] = None
Expand All @@ -37,9 +37,12 @@ class Settings(BaseSettings):
debug: bool = False

# CORS - Add your frontend domain here for production
allowed_origins: str = (
"http://localhost:3000,http://localhost:5173,http://127.0.0.1:3000,http://localhost:8081"
)
allowed_origins: List[str] = [
"http://localhost:3000",
"http://localhost:5173",
"http://127.0.0.1:3000",
"http://localhost:8081",
]
allow_all_origins: bool = False

class Config:
Expand Down
Loading
Loading