Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions client/src/api/generated/schemas.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ export const ForgotPasswordRequestSchema = {
properties: {
email: {
type: 'string',
maxLength: 255,
format: 'email',
title: 'Email'
}
Expand Down Expand Up @@ -182,11 +183,31 @@ export const HTTPValidationErrorSchema = {
export const LoginRequestSchema = {
properties: {
username: {
type: 'string',
maxLength: 50,
minLength: 3,
anyOf: [
{
type: 'string',
maxLength: 50,
minLength: 3
},
{
type: 'null'
}
],
title: 'Username'
},
email: {
anyOf: [
{
type: 'string',
maxLength: 255,
format: 'email'
},
{
type: 'null'
}
],
title: 'Email'
},
password: {
type: 'string',
maxLength: 64,
Expand All @@ -196,7 +217,6 @@ export const LoginRequestSchema = {
},
type: 'object',
required: [
'username',
'password'
],
title: 'LoginRequest'
Expand Down Expand Up @@ -236,6 +256,7 @@ export const RequestAccessRequestSchema = {
properties: {
email: {
type: 'string',
maxLength: 255,
format: 'email',
title: 'Email'
},
Expand Down
6 changes: 5 additions & 1 deletion client/src/api/generated/types.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,11 @@ export type LoginRequest = {
/**
* Username
*/
username: string;
username?: string | null;
/**
* Email
*/
email?: string | null;
/**
* Password
*/
Expand Down
7 changes: 4 additions & 3 deletions client/src/api/generated/zod.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,15 @@ export const zCreateFeedbackRequest = z.object({
* ForgotPasswordRequest
*/
export const zForgotPasswordRequest = z.object({
email: z.email()
email: z.email().max(255)
});

/**
* LoginRequest
*/
export const zLoginRequest = z.object({
username: z.string().min(3).max(50),
username: z.string().min(3).max(50).nullish(),
email: z.email().max(255).nullish(),
password: z.string().min(8).max(64)
});

Expand All @@ -63,7 +64,7 @@ export const zRegisterRequest = z.object({
* RequestAccessRequest
*/
export const zRequestAccessRequest = z.object({
email: z.email(),
email: z.email().max(255),
first_name: z.string().min(1).max(50),
last_name: z.string().min(1).max(50)
});
Expand Down
46 changes: 36 additions & 10 deletions client/src/pages/Login.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,22 @@ import { useLocation, useNavigate } from 'react-router'
import { Link } from 'react-router-dom'
import { z } from 'zod'

type LoginForm = z.infer<typeof zLoginRequest>
const loginFormSchema = z.object({
identifier: z
.string()
.trim()
.refine((value) => {
const isEmailValid =
zLoginRequest.shape.email.safeParse(value).success
const isUsernameValid =
zLoginRequest.shape.username.safeParse(value).success
return isEmailValid || isUsernameValid
}, 'Enter a valid username or email'),
password: zLoginRequest.shape.password,
})

type LoginForm = z.infer<typeof loginFormSchema>
type LoginRequestBody = z.infer<typeof zLoginRequest>

export function Login() {
const { refresh } = useSession()
Expand All @@ -36,13 +51,20 @@ export function Login() {
formState: { errors, isSubmitting },
reset,
} = useForm({
resolver: zodResolver(zLoginRequest),
resolver: zodResolver(loginFormSchema),
mode: 'onSubmit',
reValidateMode: 'onChange',
})

const onSubmit = async (form: LoginForm) => {
const { error } = await AuthService.login({ body: form })
const identifier = form.identifier.trim()
const isEmail = z.email().safeParse(identifier).success
const body: LoginRequestBody = zLoginRequest.parse(
isEmail
? { email: identifier, password: form.password }
: { username: identifier, password: form.password }
)
const { error } = await AuthService.login({ body })
if (error) {
await handleApiError(error, {
fallbackMessage: 'Failed to log in',
Expand Down Expand Up @@ -72,19 +94,23 @@ export function Login() {
}}
>
<div className="space-y-1">
<Label htmlFor="username">Username</Label>
<Label htmlFor="identifier">
Username or Email
</Label>
<Input
id="username"
id="identifier"
autoComplete="username"
aria-invalid={!!errors.username}
aria-invalid={!!errors.identifier}
className={
errors.username ? 'border-destructive' : ''
errors.identifier
? 'border-destructive'
: ''
}
{...register('username')}
{...register('identifier')}
/>
{errors.username && (
{errors.identifier && (
<p className="text-sm text-destructive">
{errors.username.message}
{errors.identifier.message}
</p>
)}
</div>
Expand Down
5 changes: 4 additions & 1 deletion server/app/api/endpoints/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,10 @@ async def login_endpoint(
res: Response,
):
result = await login(
username=req.username, password=req.password, db=db, settings=settings
identifier=req.identifier,
password=req.password,
db=db,
settings=settings,
)
res.set_cookie(
key=ACCESS_JWT_KEY,
Expand Down
6 changes: 3 additions & 3 deletions server/app/core/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from app.models.errors import InvalidCredentials
from app.models.schemas.user import JWTData
from app.services.token import expire_tokens, get_tokens_by_prefix
from app.services.user import get_user_by_username
from app.services.user import get_user_by_identifier

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -136,11 +136,11 @@ def create_password_reset_token(


async def authenticate_user(
username: str,
identifier: str,
password: str,
db: AsyncSession,
) -> User | None:
user = await get_user_by_username(username, db)
user = await get_user_by_identifier(identifier, db)
if not user or not verify_secret(password, user.password_hash):
return None
return user
Expand Down
57 changes: 45 additions & 12 deletions server/app/models/schemas/auth.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,60 @@
from pydantic import BaseModel, EmailStr, Field
from typing import Self

from pydantic import (
BaseModel,
field_validator,
model_validator,
)

from app.models.schemas.types import (
Email,
Name,
Password,
Token,
Username,
is_email_identifier,
)


class RequestAccessRequest(BaseModel):
email: EmailStr
first_name: str = Field(min_length=1, max_length=50)
last_name: str = Field(min_length=1, max_length=50)
email: Email
first_name: Name
last_name: Name


class RegisterRequest(BaseModel):
token: str = Field(min_length=1, max_length=64)
username: str = Field(min_length=3, max_length=50)
password: str = Field(min_length=8, max_length=64)
token: Token
username: Username
password: Password

@field_validator("username")
@classmethod
def validate_username_not_email(cls, value: str) -> str:
if is_email_identifier(value):
raise ValueError("Username cannot be an email address")
return value


class ForgotPasswordRequest(BaseModel):
email: EmailStr
email: Email


class ResetPasswordRequest(BaseModel):
token: str = Field(min_length=1, max_length=64)
password: str = Field(min_length=8, max_length=64)
token: Token
password: Password


class LoginRequest(BaseModel):
username: str = Field(min_length=3, max_length=50)
password: str = Field(min_length=8, max_length=64)
username: Username | None = None
email: Email | None = None
password: Password

@model_validator(mode="after")
def validate_login_identifier(self) -> Self:
if not self.username and not self.email:
raise ValueError("At least one of username or email must be provided")
return self

@property
def identifier(self) -> str:
return self.username or str(self.email)
19 changes: 19 additions & 0 deletions server/app/models/schemas/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from typing import Annotated

from pydantic import EmailStr, Field, StringConstraints, TypeAdapter, ValidationError

Name = Annotated[str, StringConstraints(min_length=1, max_length=50)]
Username = Annotated[str, StringConstraints(min_length=3, max_length=50)]
Password = Annotated[str, StringConstraints(min_length=8, max_length=64)]
Token = Annotated[str, StringConstraints(min_length=1, max_length=64)]
Email = Annotated[EmailStr, Field(max_length=255)]

EMAIL_TYPE_ADAPTER: TypeAdapter[EmailStr] = TypeAdapter(EmailStr)


def is_email_identifier(identifier: str) -> bool:
try:
EMAIL_TYPE_ADAPTER.validate_python(identifier)
return True
except ValidationError:
return False
23 changes: 16 additions & 7 deletions server/app/services/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ async def request_access(
"""Returns True if access was already approved, False otherwise"""
logger.info(f"Requesting access for email: {email}")

existing_user = await get_user_by_email(email, db)
if existing_user:
existing_user_by_email = await get_user_by_email(email, db)
existing_user_by_username = await get_user_by_username(email, db)
if existing_user_by_email or existing_user_by_username:
raise EmailAlreadyRegistered()

existing_request = await get_latest_access_request_by_email(email, db)
Expand Down Expand Up @@ -115,8 +116,9 @@ async def register(
if access_request.status != AccessRequestStatus.APPROVED:
raise InvalidToken()

existing_user = await get_user_by_username(username, db)
if existing_user:
existing_user_by_username = await get_user_by_username(username, db)
existing_user_by_email = await get_user_by_email(username, db)
if existing_user_by_username or existing_user_by_email:
raise UsernameAlreadyRegistered()

token.used_at = datetime.now(UTC)
Expand All @@ -142,6 +144,10 @@ async def request_password_reset(
) -> None:
logger.info(f"Requesting password reset for email: {email}")

if email == settings.admin.email:
logger.warning("Password reset requested for admin email, ignoring")
return

user = await get_user_by_email(email, db)
if not user:
logger.info(f"Password reset requested for unregistered email: {email}")
Expand Down Expand Up @@ -181,11 +187,14 @@ async def reset_password(


async def login(
username: str, password: str, db: AsyncSession, settings: Settings
identifier: str,
password: str,
db: AsyncSession,
settings: Settings,
) -> LoginResult:
logger.info(f"Logging in user {username}")
logger.info(f"Logging in user with identifier {identifier}")

user = await authenticate_user(username, password, db)
user = await authenticate_user(identifier, password, db)
if not user:
raise InvalidCredentials()

Expand Down
14 changes: 14 additions & 0 deletions server/app/services/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from sqlalchemy.ext.asyncio import AsyncSession

from app.models.database.user import User
from app.models.schemas.types import is_email_identifier


async def get_admin_users(db: AsyncSession) -> Sequence[User]:
Expand All @@ -27,6 +28,19 @@ async def get_user_by_email(
return result.scalar_one_or_none()


async def get_user_by_identifier(
identifier: str,
db: AsyncSession,
) -> User | None:
if is_email_identifier(identifier):
user = await get_user_by_email(identifier, db)
if user:
return user
return await get_user_by_username(identifier, db)

return await get_user_by_username(identifier, db)


async def get_users_ordered_by_username(db: AsyncSession) -> Sequence[User]:
result = await db.execute(select(User).order_by(User.username.asc()))
return result.scalars().all()
Loading