diff --git a/.flake8 b/.flake8 new file mode 100644 index 00000000..1fdc63a8 --- /dev/null +++ b/.flake8 @@ -0,0 +1,4 @@ +[flake8] +max-line-length = 120 +exclude = .git,__pycache__,venv,tests/* +ignore = F541 diff --git a/.github/workflows/cd.yml b/.github/workflows/cd.yml index b801c2d3..fec7aa53 100644 --- a/.github/workflows/cd.yml +++ b/.github/workflows/cd.yml @@ -40,7 +40,7 @@ jobs: - name: Run Docker container run: | docker run -d -p 8000:8000 common-assessment-tool - sleep 10 # Wait for container to start + sleep 30 # Wait for container to start - name: Test Docker container run: | diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 30c81bdb..6f0caecf 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -10,6 +10,12 @@ jobs: test: runs-on: ubuntu-latest # Use the latest Ubuntu runner + services: + docker: + image: docker:stable-dind + options: --privileged + ports: + - 2375:2375 steps: - name: Checkout Code uses: actions/checkout@v4 # Checkout the repository @@ -24,12 +30,29 @@ jobs: python -m pip install --upgrade pip # Upgrade pip to the latest version pip install setuptools wheel pip install -r requirements.txt # Install dependencies from requirements.txt - pip install pylint pytest + pip install -r requirements-dev.txt + pip install pytest + + - name: Run Flake8 + run: flake8 . + - name: Run Black + run: black . --check + + - name: Run MyPy + run: mypy . - name: Run Tests run: | python -m pytest tests/ + - name: Build Docker Image + run: docker build -t my-app-name:latest . + + - name: Run Docker Container + run: | + docker run --name my-app-container -d -p 8000:8000 my-app-name:latest + # Optionally add commands to check the running container + - name: Print Success Message run: | echo "CI Pipeline completed successfully!" @@ -39,4 +62,5 @@ jobs: echo "✓ Dependencies installed" echo "✓ Tests executed" echo "✓ Linting completed" + echo "✓ Docker Image Built" echo "========================" diff --git a/.gitignore b/.gitignore index 371e45c1..ae97a7e6 100644 --- a/.gitignore +++ b/.gitignore @@ -31,4 +31,8 @@ build/ htmlcov/ .tox/ .coverage -.coverage.* \ No newline at end of file +.coverage.* + +# Reports +*_report.txt +reports/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..dfb1f748 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,17 @@ +repos: + - repo: https://github.com/psf/black + rev: 25.1.0 + hooks: + - id: black + language_version: python3.12 + + - repo: https://github.com/PyCQA/flake8 + rev: '7.1.2' + hooks: + - id: flake8 + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: 'v1.15.0' + hooks: + - id: mypy + diff --git a/Dockerfile b/Dockerfile index a3ff66eb..1bf95a00 100644 --- a/Dockerfile +++ b/Dockerfile @@ -12,9 +12,13 @@ RUN pip install --no-cache-dir -r /code/requirements.txt # Copy the rest of your application COPY . /code/ +COPY entrypoint.sh /code/ + +# Make the entrypoint script executable +RUN chmod +x /code/entrypoint.sh # Expose the port your app runs on EXPOSE 8000 -# Command to run the application -CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"] \ No newline at end of file +# Set the entrypoint script +ENTRYPOINT ["/code/entrypoint.sh"] \ No newline at end of file diff --git a/app/__init__.py b/app/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/app/auth/router.py b/app/auth/router.py index 229ee71d..b8ac8afc 100644 --- a/app/auth/router.py +++ b/app/auth/router.py @@ -1,28 +1,34 @@ from datetime import datetime, timedelta from typing import Optional + from fastapi import APIRouter, Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm -from jose import JWTError, jwt +from pydantic import BaseModel, Field, validator +from jose import jwt, JWTError from sqlalchemy.orm import Session -from app.database import get_db -from app.models import User, UserRole from passlib.context import CryptContext -from pydantic import BaseModel, Field, validator + +from app.database import get_db +from app.models import User +from app.enums import UserRole router = APIRouter(prefix="/auth", tags=["authentication"]) + +# Schemas class UserCreate(BaseModel): username: str = Field(..., min_length=3, max_length=50) email: str password: str role: UserRole - @validator('role') + @validator("role") def validate_role(cls, v): if v not in [UserRole.admin, UserRole.case_worker]: - raise ValueError('Role must be either admin or case_worker') + raise ValueError("Role must be either admin or case_worker") return v + class UserResponse(BaseModel): username: str email: str @@ -31,121 +37,146 @@ class UserResponse(BaseModel): class Config: from_attributes = True -# Configuration -SECRET_KEY = "your-secret-key-here" -ALGORITHM = "HS256" -ACCESS_TOKEN_EXPIRE_MINUTES = 30 -pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") -oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/token") +# Security Service +class SecurityService: + pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") -def verify_password(plain_password: str, hashed_password: str) -> bool: - return pwd_context.verify(plain_password, hashed_password) + def verify_password(self, plain_password: str, hashed_password: str) -> bool: + return self.pwd_context.verify(plain_password, hashed_password) -def get_password_hash(password: str) -> str: - return pwd_context.hash(password) + def get_password_hash(self, password: str) -> str: + return self.pwd_context.hash(password) -def authenticate_user(db: Session, username: str, password: str) -> Optional[User]: - user = db.query(User).filter(User.username == username).first() - if not user or not verify_password(password, user.hashed_password): - return None - return user -def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): - to_encode = data.copy() - if expires_delta: - expire = datetime.utcnow() + expires_delta - else: - expire = datetime.utcnow() + timedelta(minutes=15) - to_encode.update({"exp": expire}) - encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) - return encoded_jwt +security = SecurityService() + + +# Token Service +class TokenService: + SECRET_KEY = "your-secret-key-here" + ALGORITHM = "HS256" + ACCESS_TOKEN_EXPIRE_MINUTES = 30 + + def create_access_token( + self, data: dict, expires_delta: Optional[timedelta] = None + ): + to_encode = data.copy() + expire = datetime.utcnow() + (expires_delta or timedelta(minutes=15)) + to_encode.update({"exp": expire}) + return jwt.encode(to_encode, self.SECRET_KEY, algorithm=self.ALGORITHM) + + def decode_token(self, token: str): + try: + return jwt.decode(token, self.SECRET_KEY, algorithms=[self.ALGORITHM]) + except JWTError: + return None + + +token_service = TokenService() + + +# User Service +class UserService: + def authenticate_user( + self, db: Session, username: str, password: str + ) -> Optional[User]: + user = db.query(User).filter(User.username == username).first() + if not user or not security.verify_password(password, user.hashed_password): # type: ignore + return None + return user + + def create_user(self, db: Session, user_data: UserCreate) -> User: + if db.query(User).filter(User.username == user_data.username).first(): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Username already registered", + ) + + if db.query(User).filter(User.email == user_data.email).first(): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Email already registered", + ) + + db_user = User( + username=user_data.username, + email=user_data.email, + hashed_password=security.get_password_hash(user_data.password), + role=user_data.role, + ) + + try: + db.add(db_user) + db.commit() + db.refresh(db_user) + return db_user + except Exception as e: + db.rollback() + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=str(e), + ) + + +user_service = UserService() + + +# Auth Dependencies +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/token") + async def get_current_user( - token: str = Depends(oauth2_scheme), - db: Session = Depends(get_db) + token: str = Depends(oauth2_scheme), db: Session = Depends(get_db) ) -> User: credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) - try: - payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) - username: str = payload.get("sub") - if username is None: - raise credentials_exception - except JWTError: + payload = token_service.decode_token(token) + username = payload.get("sub") if payload else None + if not username: raise credentials_exception - + user = db.query(User).filter(User.username == username).first() if user is None: raise credentials_exception return user + def get_admin_user(current_user: User = Depends(get_current_user)): if current_user.role != UserRole.admin: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail="Only admin users can perform this operation" + detail="Only admin users can perform this operation", ) return current_user + +# Routes @router.post("/token") async def login_for_access_token( - form_data: OAuth2PasswordRequestForm = Depends(), - db: Session = Depends(get_db) + form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db) ): - user = authenticate_user(db, form_data.username, form_data.password) + user = user_service.authenticate_user(db, form_data.username, form_data.password) if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect username or password", headers={"WWW-Authenticate": "Bearer"}, ) - access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) - access_token = create_access_token( + access_token_expires = timedelta(minutes=token_service.ACCESS_TOKEN_EXPIRE_MINUTES) + access_token = token_service.create_access_token( data={"sub": user.username}, expires_delta=access_token_expires ) return {"access_token": access_token, "token_type": "bearer"} + @router.post("/users", response_model=UserResponse) async def create_user( user_data: UserCreate, current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): - """Create a new user (admin only)""" - # Check if username exists - if db.query(User).filter(User.username == user_data.username).first(): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Username already registered" - ) - - # Check if email exists - if db.query(User).filter(User.email == user_data.email).first(): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Email already registered" - ) - - # Create new user - db_user = User( - username=user_data.username, - email=user_data.email, - hashed_password=get_password_hash(user_data.password), - role=user_data.role - ) - - try: - db.add(db_user) - db.commit() - db.refresh(db_user) - return db_user - except Exception as e: - db.rollback() - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=str(e) - ) + return user_service.create_user(db, user_data) diff --git a/app/clients/router.py b/app/clients/router.py index 4ecc83e4..4444a9ef 100644 --- a/app/clients/router.py +++ b/app/clients/router.py @@ -3,41 +3,59 @@ Handles all HTTP requests for client operations including create, read, update, and delete. """ -from fastapi import APIRouter, Depends, HTTPException, status, Query -from sqlalchemy.orm import Session from typing import List, Optional -from app.auth.router import get_current_user, get_admin_user -from app.models import User, UserRole +from fastapi import APIRouter, Depends, status, Query +from sqlalchemy.orm import Session + +from app.auth.router import get_current_user, get_admin_user from app.database import get_db -from app.clients.service.client_service import ClientService +from app.models import User +from app.clients.service.client_service import ClientQueryService, ClientMutationService + from app.clients.schema import ( - ClientResponse, - ClientUpdate, + ClientResponse, + ClientUpdate, ClientListResponse, ServiceResponse, - ServiceUpdate + ServiceUpdate, ) + router = APIRouter(prefix="/clients", tags=["clients"]) + @router.get("/", response_model=ClientListResponse) async def get_clients( - current_user: User = Depends(get_admin_user), + current_user: User = Depends(get_admin_user), skip: int = Query(default=0, ge=0, description="Number of records to skip"), - limit: int = Query(default=50, ge=1, le=150, description="Maximum number of records to return"), - db: Session = Depends(get_db) + limit: int = Query( + default=50, ge=1, le=150, description="Maximum number of records to return" + ), + db: Session = Depends(get_db), ): - return ClientService.get_clients(db, skip, limit) + """ + Retrieve a list of clients with optional pagination. + Parameters: + current_user: User object of the currently authenticated user. + skip: The number of items to skip (for pagination). + limit: The maximum number of items to return. + db: Database session dependency. + Returns: + A list of clients according to the specified pagination rules. + """ + return ClientQueryService.get_clients(db, skip, limit) + @router.get("/{client_id}", response_model=ClientResponse) async def get_client( client_id: int, current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """Get a specific client by ID""" - return ClientService.get_client(db, client_id) + return ClientQueryService.get_client(db, client_id) + @router.get("/search/by-criteria", response_model=List[ClientResponse]) async def get_clients_by_criteria( @@ -66,10 +84,10 @@ async def get_clients_by_criteria( time_unemployed: Optional[int] = Query(None, ge=0), need_mental_health_support_bool: Optional[bool] = None, current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """Search clients by any combination of criteria""" - return ClientService.get_clients_by_criteria( + return ClientQueryService.get_clients_by_criteria( db, employment_status=employment_status, education_level=education_level, @@ -94,9 +112,10 @@ async def get_clients_by_criteria( attending_school=attending_school, substance_use=substance_use, time_unemployed=time_unemployed, - need_mental_health_support_bool=need_mental_health_support_bool + need_mental_health_support_bool=need_mental_health_support_bool, ) + @router.get("/search/by-services", response_model=List[ClientResponse]) async def get_clients_by_services( employment_assistance: Optional[bool] = None, @@ -107,10 +126,10 @@ async def get_clients_by_services( employer_financial_supports: Optional[bool] = None, enhanced_referrals: Optional[bool] = None, current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """Get clients filtered by multiple service statuses""" - return ClientService.get_clients_by_services( + return ClientQueryService.get_clients_by_services( db, employment_assistance=employment_assistance, life_stabilization=life_stabilization, @@ -118,71 +137,82 @@ async def get_clients_by_services( specialized_services=specialized_services, employment_related_financial_supports=employment_related_financial_supports, employer_financial_supports=employer_financial_supports, - enhanced_referrals=enhanced_referrals + enhanced_referrals=enhanced_referrals, ) + @router.get("/{client_id}/services", response_model=List[ServiceResponse]) async def get_client_services( client_id: int, current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """Get all services and their status for a specific client, including case worker info""" - return ClientService.get_client_services(db, client_id) + return ClientQueryService.get_client_services(db, client_id) + @router.get("/search/success-rate", response_model=List[ClientResponse]) async def get_clients_by_success_rate( - min_rate: int = Query(70, ge=0, le=100, description="Minimum success rate percentage"), + min_rate: int = Query( + 70, ge=0, le=100, description="Minimum success rate percentage" + ), current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """Get clients with success rate above specified threshold""" - return ClientService.get_clients_by_success_rate(db, min_rate) + return ClientQueryService.get_clients_by_success_rate(db, min_rate) + @router.get("/case-worker/{case_worker_id}", response_model=List[ClientResponse]) async def get_clients_by_case_worker( case_worker_id: int, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ): - return ClientService.get_clients_by_case_worker(db, case_worker_id) + return ClientQueryService.get_clients_by_case_worker(db, case_worker_id) + @router.put("/{client_id}", response_model=ClientResponse) async def update_client( client_id: int, client_data: ClientUpdate, current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """Update a client's information""" - return ClientService.update_client(db, client_id, client_data) + return ClientMutationService.update_client(db, client_id, client_data) + @router.put("/{client_id}/services/{user_id}", response_model=ServiceResponse) async def update_client_services( client_id: int, user_id: int, service_update: ServiceUpdate, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ): - return ClientService.update_client_services(db, client_id, user_id, service_update) + return ClientMutationService.update_client_services( + db, client_id, user_id, service_update + ) + @router.post("/{client_id}/case-assignment", response_model=ServiceResponse) async def create_case_assignment( client_id: int, case_worker_id: int = Query(..., description="Case worker ID to assign"), current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """Create a new case assignment for a client with a case worker""" - return ClientService.create_case_assignment(db, client_id, case_worker_id) + return ClientMutationService.create_case_assignment(db, client_id, case_worker_id) + @router.delete("/{client_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_client( client_id: int, current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """Delete a client""" - ClientService.delete_client(db, client_id) + ClientMutationService.delete_client(db, client_id) return None diff --git a/app/clients/schema.py b/app/clients/schema.py index cff28897..a2e1ba4b 100644 --- a/app/clients/schema.py +++ b/app/clients/schema.py @@ -4,21 +4,23 @@ """ # Standard library imports -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel, Field from typing import Optional, List from enum import IntEnum -from app.models import UserRole + # Enums for validation class Gender(IntEnum): MALE = 1 FEMALE = 2 + class PredictionInput(BaseModel): """ Schema for prediction input data containing all client assessment fields. Used for making predictions about client outcomes. """ + age: int gender: str work_experience: int @@ -44,6 +46,7 @@ class PredictionInput(BaseModel): time_unemployed: int need_mental_health_support_bool: str + class ClientBase(BaseModel): age: int = Field(ge=18, description="Age of client, must be 18 or older") gender: Gender = Field(description="Gender: 1 for male, 2 for female") @@ -54,9 +57,15 @@ class ClientBase(BaseModel): citizen_status: bool = Field(description="Client's citizenship status") level_of_schooling: int = Field(ge=1, le=14, description="Education level (1-14)") fluent_english: bool = Field(description="English fluency status") - reading_english_scale: int = Field(ge=0, le=10, description="English reading level (0-10)") - speaking_english_scale: int = Field(ge=0, le=10, description="English speaking level (0-10)") - writing_english_scale: int = Field(ge=0, le=10, description="English writing level (0-10)") + reading_english_scale: int = Field( + ge=0, le=10, description="English reading level (0-10)" + ) + speaking_english_scale: int = Field( + ge=0, le=10, description="English speaking level (0-10)" + ) + writing_english_scale: int = Field( + ge=0, le=10, description="English writing level (0-10)" + ) numeracy_scale: int = Field(ge=0, le=10, description="Numeracy skill level (0-10)") computer_scale: int = Field(ge=0, le=10, description="Computer skill level (0-10)") transportation_bool: bool = Field(description="Has transportation") @@ -68,7 +77,9 @@ class ClientBase(BaseModel): currently_employed: bool = Field(description="Current employment status") substance_use: bool = Field(description="Substance use status") time_unemployed: int = Field(ge=0, description="Time unemployed in months") - need_mental_health_support_bool: bool = Field(description="Needs mental health support") + need_mental_health_support_bool: bool = Field( + description="Needs mental health support" + ) class Config: json_schema_extra = { @@ -96,16 +107,18 @@ class Config: "currently_employed": False, "substance_use": False, "time_unemployed": 6, - "need_mental_health_support_bool": False + "need_mental_health_support_bool": False, } } + class ClientResponse(ClientBase): id: int class Config: from_attributes = True + class ClientUpdate(BaseModel): age: Optional[int] = Field(None, ge=18) gender: Optional[Gender] = None @@ -132,6 +145,7 @@ class ClientUpdate(BaseModel): time_unemployed: Optional[int] = Field(None, ge=0) need_mental_health_support_bool: Optional[bool] = None + class ServiceResponse(BaseModel): client_id: int user_id: int @@ -147,6 +161,7 @@ class ServiceResponse(BaseModel): class Config: from_attributes = True + class ServiceUpdate(BaseModel): employment_assistance: Optional[bool] = None life_stabilization: Optional[bool] = None @@ -157,6 +172,7 @@ class ServiceUpdate(BaseModel): enhanced_referrals: Optional[bool] = None success_rate: Optional[int] = Field(None, ge=0, le=100) + class ClientListResponse(BaseModel): clients: List[ClientResponse] total: int diff --git a/app/clients/service/client_service.py b/app/clients/service/client_service.py index 86c3ef4a..93922f82 100644 --- a/app/clients/service/client_service.py +++ b/app/clients/service/client_service.py @@ -1,48 +1,30 @@ -""" -Client service module handling all database operations for clients. -Provides CRUD operations and business logic for client management. -""" - +from fastapi import HTTPException from sqlalchemy.orm import Session -from sqlalchemy import and_ -from fastapi import HTTPException, status -from typing import List, Optional, Dict, Any +from typing import Optional from app.models import Client, ClientCase, User -from app.clients.schema import ClientUpdate, ServiceUpdate, ServiceResponse +from app.clients.schema import ClientUpdate, ServiceUpdate + -class ClientService: +class ClientQueryService: + # Retrieve a single client by ID @staticmethod def get_client(db: Session, client_id: int): - """Get a specific client by ID""" client = db.query(Client).filter(Client.id == client_id).first() if not client: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Client with id {client_id} not found" - ) + raise HTTPException(status_code=404, detail=f"Client {client_id} not found") return client + # Retrieve a paginated list of clients @staticmethod def get_clients(db: Session, skip: int = 0, limit: int = 50): - """ - Get clients with optional pagination. - Default shows first 50 clients, which means you'd need 3 pages for 150 records. - """ - if skip < 0: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Skip value cannot be negative" - ) - if limit < 1: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Limit must be greater than 0" - ) - - clients = db.query(Client).offset(skip).limit(limit).all() - total = db.query(Client).count() - return {"clients": clients, "total": total} + if skip < 0 or limit < 1: + raise HTTPException(status_code=400, detail="Invalid pagination parameters") + return { + "clients": db.query(Client).offset(skip).limit(limit).all(), + "total": db.query(Client).count(), + } + # Retrieve clients that match various optional criteria filters @staticmethod def get_clients_by_criteria( db: Session, @@ -69,162 +51,125 @@ def get_clients_by_criteria( attending_school: Optional[bool] = None, substance_use: Optional[bool] = None, time_unemployed: Optional[int] = None, - need_mental_health_support_bool: Optional[bool] = None + need_mental_health_support_bool: Optional[bool] = None, ): - """Get clients filtered by any combination of criteria""" + """Search clients by any combination of criteria""" query = db.query(Client) - - if education_level is not None and not (1 <= education_level <= 14): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Education level must be between 1 and 14" - ) - - if age_min is not None and age_min < 18: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Minimum age must be at least 18" - ) - - if gender is not None and gender not in [1, 2]: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Gender must be 1 or 2" - ) - - # Apply filters for non-None values - if employment_status is not None: - query = query.filter(Client.currently_employed == employment_status) - if age_min is not None: - query = query.filter(Client.age >= age_min) - if gender is not None: - query = query.filter(Client.gender == gender) - if education_level is not None: - query = query.filter(Client.level_of_schooling == education_level) - if work_experience is not None: - query = query.filter(Client.work_experience == work_experience) - if canada_workex is not None: - query = query.filter(Client.canada_workex == canada_workex) - if dep_num is not None: - query = query.filter(Client.dep_num == dep_num) - if canada_born is not None: - query = query.filter(Client.canada_born == canada_born) - if citizen_status is not None: - query = query.filter(Client.citizen_status == citizen_status) - if fluent_english is not None: - query = query.filter(Client.fluent_english == fluent_english) - if reading_english_scale is not None: - query = query.filter(Client.reading_english_scale == reading_english_scale) - if speaking_english_scale is not None: - query = query.filter(Client.speaking_english_scale == speaking_english_scale) - if writing_english_scale is not None: - query = query.filter(Client.writing_english_scale == writing_english_scale) - if numeracy_scale is not None: - query = query.filter(Client.numeracy_scale == numeracy_scale) - if computer_scale is not None: - query = query.filter(Client.computer_scale == computer_scale) - if transportation_bool is not None: - query = query.filter(Client.transportation_bool == transportation_bool) - if caregiver_bool is not None: - query = query.filter(Client.caregiver_bool == caregiver_bool) - if housing is not None: - query = query.filter(Client.housing == housing) - if income_source is not None: - query = query.filter(Client.income_source == income_source) - if felony_bool is not None: - query = query.filter(Client.felony_bool == felony_bool) - if attending_school is not None: - query = query.filter(Client.attending_school == attending_school) - if substance_use is not None: - query = query.filter(Client.substance_use == substance_use) - if time_unemployed is not None: - query = query.filter(Client.time_unemployed == time_unemployed) - if need_mental_health_support_bool is not None: - query = query.filter(Client.need_mental_health_support_bool == need_mental_health_support_bool) - + query = ClientQueryService._apply_criteria_filters( + query, + employment_status=employment_status, + education_level=education_level, + age_min=age_min, + gender=gender, + work_experience=work_experience, + canada_workex=canada_workex, + dep_num=dep_num, + canada_born=canada_born, + citizen_status=citizen_status, + fluent_english=fluent_english, + reading_english_scale=reading_english_scale, + speaking_english_scale=speaking_english_scale, + writing_english_scale=writing_english_scale, + numeracy_scale=numeracy_scale, + computer_scale=computer_scale, + transportation_bool=transportation_bool, + caregiver_bool=caregiver_bool, + housing=housing, + income_source=income_source, + felony_bool=felony_bool, + attending_school=attending_school, + substance_use=substance_use, + time_unemployed=time_unemployed, + need_mental_health_support_bool=need_mental_health_support_bool, + ) try: return query.all() except Exception as e: raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error retrieving clients: {str(e)}" + status_code=500, detail=f"Error retrieving clients: {str(e)}" ) + # Filter clients based on service-related fields @staticmethod - def get_clients_by_services( - db: Session, - **service_filters: Optional[bool] - ): - """ - Get clients filtered by multiple service statuses. - """ + def get_clients_by_services(db: Session, **service_filters: Optional[bool]): query = db.query(Client).join(ClientCase) - - for service_name, status in service_filters.items(): - if status is not None: - filter_criteria = getattr(ClientCase, service_name) == status - query = query.filter(filter_criteria) - + for service, status_val in service_filters.items(): + if status_val is not None: + query = query.filter(getattr(ClientCase, service) == status_val) try: return query.all() except Exception as e: raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error retrieving clients: {str(e)}" + status_code=500, detail=f"Error retrieving clients: {str(e)}" ) + # Get all services associated with a given client @staticmethod def get_client_services(db: Session, client_id: int): - """Get all services for a specific client with case worker info""" - client_cases = db.query(ClientCase).filter(ClientCase.client_id == client_id).all() - if not client_cases: + services = db.query(ClientCase).filter(ClientCase.client_id == client_id).all() + if not services: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"No services found for client with id {client_id}" + status_code=404, detail=f"No services found for client {client_id}" ) - return client_cases + return services + # Get clients with a minimum success rate @staticmethod def get_clients_by_success_rate(db: Session, min_rate: int = 70): - """Get clients with success rate at or above the specified percentage""" if not (0 <= min_rate <= 100): raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Success rate must be between 0 and 100" + status_code=400, detail="Success rate must be between 0 and 100" ) - - return db.query(Client).join(ClientCase).filter( - ClientCase.success_rate >= min_rate - ).all() + return ( + db.query(Client) + .join(ClientCase) + .filter(ClientCase.success_rate >= min_rate) + .all() + ) + # Get all clients assigned to a specific case worker @staticmethod def get_clients_by_case_worker(db: Session, case_worker_id: int): - """Get all clients assigned to a specific case worker""" - case_worker = db.query(User).filter(User.id == case_worker_id).first() - if not case_worker: + if not db.query(User).filter(User.id == case_worker_id).first(): raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Case worker with id {case_worker_id} not found" + status_code=404, detail=f"Case worker {case_worker_id} not found" ) - - return db.query(Client).join(ClientCase).filter( - ClientCase.user_id == case_worker_id - ).all() + return ( + db.query(Client) + .join(ClientCase) + .filter(ClientCase.user_id == case_worker_id) + .all() + ) + # Internal helper method to apply dynamic filtering logic @staticmethod - def update_client(db: Session, client_id: int, client_update: ClientUpdate): - """Update a client's information""" - client = db.query(Client).filter(Client.id == client_id).first() - if not client: + def _apply_criteria_filters(query, **filters): + if filters.get("education_level") and not ( + 1 <= filters["education_level"] <= 14 + ): + raise HTTPException(status_code=400, detail="Invalid education level") + if filters.get("age_min") and filters["age_min"] < 18: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Client with id {client_id} not found" + status_code=400, detail="Minimum age must be at least 18" ) + if filters.get("gender") and filters["gender"] not in [1, 2]: + raise HTTPException(status_code=400, detail="Gender must be 1 or 2") - update_data = client_update.dict(exclude_unset=True) - for field, value in update_data.items(): - setattr(client, field, value) + for key, val in filters.items(): + if val is not None and hasattr(Client, key): + query = query.filter(getattr(Client, key) == val) + return query + +class ClientMutationService: + # Update client information based on provided fields + @staticmethod + def update_client(db: Session, client_id: int, update_data: ClientUpdate): + client = db.query(Client).filter(Client.id == client_id).first() + if not client: + raise HTTPException(status_code=404, detail=f"Client {client_id} not found") + for field, value in update_data.dict(exclude_unset=True).items(): + setattr(client, field, value) try: db.commit() db.refresh(client) @@ -232,130 +177,87 @@ def update_client(db: Session, client_id: int, client_update: ClientUpdate): except Exception as e: db.rollback() raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to update client: {str(e)}" + status_code=500, detail=f"Failed to update client: {str(e)}" ) - + + # Update service info for a specific client-case worker relationship @staticmethod def update_client_services( - db: Session, - client_id: int, - user_id: int, - service_update: ServiceUpdate + db: Session, client_id: int, user_id: int, update_data: ServiceUpdate ): - """Update a client's services and outcomes for a specific case worker""" - client_case = db.query(ClientCase).filter( - ClientCase.client_id == client_id, - ClientCase.user_id == user_id - ).first() - - if not client_case: + case = ( + db.query(ClientCase) + .filter(ClientCase.client_id == client_id, ClientCase.user_id == user_id) + .first() + ) + if not case: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"No case found for client {client_id} with case worker {user_id}. " - f"Cannot update services for a non-existent case assignment." + status_code=404, + detail=f"No case found for client {client_id} and worker {user_id}", ) - - update_data = service_update.dict(exclude_unset=True) - for field, value in update_data.items(): - setattr(client_case, field, value) - + for field, value in update_data.dict(exclude_unset=True).items(): + setattr(case, field, value) try: db.commit() - db.refresh(client_case) - return client_case + db.refresh(case) + return case except Exception as e: db.rollback() raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to update client services: {str(e)}" - ) - - @staticmethod - def create_case_assignment( - db: Session, - client_id: int, - case_worker_id: int - ): - """Create a new case assignment""" - # Check if client exists - client = db.query(Client).filter(Client.id == client_id).first() - if not client: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Client with id {client_id} not found" + status_code=500, detail=f"Failed to update client services: {str(e)}" ) - # Check if case worker exists - case_worker = db.query(User).filter(User.id == case_worker_id).first() - if not case_worker: + # Assign a new case worker to a client, with default service values + @staticmethod + def create_case_assignment(db: Session, client_id: int, worker_id: int): + if not db.query(Client).filter(Client.id == client_id).first(): + raise HTTPException(status_code=404, detail=f"Client {client_id} not found") + if not db.query(User).filter(User.id == worker_id).first(): raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Case worker with id {case_worker_id} not found" + status_code=404, detail=f"Case worker {worker_id} not found" ) + if ( + db.query(ClientCase) + .filter(ClientCase.client_id == client_id, ClientCase.user_id == worker_id) + .first() + ): + raise HTTPException(status_code=400, detail=f"Assignment already exists") - # Check if assignment already exists - existing_case = db.query(ClientCase).filter( - ClientCase.client_id == client_id, - ClientCase.user_id == case_worker_id - ).first() - - if existing_case: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Client {client_id} already has a case assigned to case worker {case_worker_id}" + case = ClientCase( + client_id=client_id, + user_id=worker_id, + employment_assistance=False, + life_stabilization=False, + retention_services=False, + specialized_services=False, + employment_related_financial_supports=False, + employer_financial_supports=False, + enhanced_referrals=False, + success_rate=0, ) - try: - # Create new case assignment with default service values - new_case = ClientCase( - client_id=client_id, - user_id=case_worker_id, - employment_assistance=False, - life_stabilization=False, - retention_services=False, - specialized_services=False, - employment_related_financial_supports=False, - employer_financial_supports=False, - enhanced_referrals=False, - success_rate=0 - ) - db.add(new_case) + db.add(case) db.commit() - db.refresh(new_case) - return new_case - + db.refresh(case) + return case except Exception as e: db.rollback() raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to create case assignment: {str(e)}" + status_code=500, detail=f"Failed to create assignment: {str(e)}" ) - + + # Delete a client and all related case records @staticmethod def delete_client(db: Session, client_id: int): - """Delete a client and their associated records""" - # First check if client exists client = db.query(Client).filter(Client.id == client_id).first() if not client: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Client with id {client_id} not found" - ) - + raise HTTPException(status_code=404, detail=f"Client {client_id} not found") try: - # Delete associated client_cases - db.query(ClientCase).filter( - ClientCase.client_id == client_id - ).delete() - - # Delete the client + db.query(ClientCase).filter(ClientCase.client_id == client_id).delete() db.delete(client) db.commit() - except Exception as e: db.rollback() raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to delete client: {str(e)}" + status_code=500, detail=f"Failed to delete client: {str(e)}" ) diff --git a/app/clients/service/logic.py b/app/clients/service/logic.py index c25b4217..f57ce6cb 100644 --- a/app/clients/service/logic.py +++ b/app/clients/service/logic.py @@ -5,7 +5,8 @@ # Standard library imports import os -#import json + +# import json from itertools import product # Third-party imports @@ -14,21 +15,22 @@ # Constants COLUMN_INTERVENTIONS = [ - 'Life Stabilization', - 'General Employment Assistance Services', - 'Retention Services', - 'Specialized Services', - 'Employment-Related Financial Supports for Job Seekers and Employers', - 'Employer Financial Supports', - 'Enhanced Referrals for Skills Development' + "Life Stabilization", + "General Employment Assistance Services", + "Retention Services", + "Specialized Services", + "Employment-Related Financial Supports for Job Seekers and Employers", + "Employer Financial Supports", + "Enhanced Referrals for Skills Development", ] # Load model CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) -MODEL_PATH = os.path.join(CURRENT_DIR, 'model.pkl') +MODEL_PATH = os.path.join(CURRENT_DIR, "model.pkl") with open(MODEL_PATH, "rb") as model_file: MODEL = pickle.load(model_file) + def clean_input_data(input_data): """ Clean and transform input data into model-compatible format. @@ -40,13 +42,30 @@ def clean_input_data(input_data): list: Cleaned and formatted data ready for model input """ columns = [ - "age", "gender", "work_experience", "canada_workex", "dep_num", - "canada_born", "citizen_status", "level_of_schooling", "fluent_english", - "reading_english_scale", "speaking_english_scale", "writing_english_scale", - "numeracy_scale", "computer_scale", "transportation_bool", "caregiver_bool", - "housing", "income_source", "felony_bool", "attending_school", - "currently_employed", "substance_use", "time_unemployed", - "need_mental_health_support_bool" + "age", + "gender", + "work_experience", + "canada_workex", + "dep_num", + "canada_born", + "citizen_status", + "level_of_schooling", + "fluent_english", + "reading_english_scale", + "speaking_english_scale", + "writing_english_scale", + "numeracy_scale", + "computer_scale", + "transportation_bool", + "caregiver_bool", + "housing", + "income_source", + "felony_bool", + "attending_school", + "currently_employed", + "substance_use", + "time_unemployed", + "need_mental_health_support_bool", ] demographics = {key: input_data[key] for key in columns} output = [] @@ -57,6 +76,7 @@ def clean_input_data(input_data): output.append(value) return output + def convert_text(text_data: str): """ Convert text answers from front end into numerical values. @@ -68,33 +88,47 @@ def convert_text(text_data: str): int: Converted numerical value """ categorical_mappings = [ + {"": 0, "true": 1, "false": 0, "no": 0, "yes": 1, "No": 0, "Yes": 1}, { - "": 0, "true": 1, "false": 0, "no": 0, "yes": 1, - "No": 0, "Yes": 1 + "Grade 0-8": 1, + "Grade 9": 2, + "Grade 10": 3, + "Grade 11": 4, + "Grade 12 or equivalent": 5, + "OAC or Grade 13": 6, + "Some college": 7, + "Some university": 8, + "Some apprenticeship": 9, + "Certificate of Apprenticeship": 10, + "Journeyperson": 11, + "Certificate/Diploma": 12, + "Bachelor's degree": 13, + "Post graduate": 14, }, { - "Grade 0-8": 1, "Grade 9": 2, "Grade 10": 3, "Grade 11": 4, - "Grade 12 or equivalent": 5, "OAC or Grade 13": 6, - "Some college": 7, "Some university": 8, "Some apprenticeship": 9, - "Certificate of Apprenticeship": 10, "Journeyperson": 11, - "Certificate/Diploma": 12, "Bachelor's degree": 13, - "Post graduate": 14 + "Renting-private": 1, + "Renting-subsidized": 2, + "Boarding or lodging": 3, + "Homeowner": 4, + "Living with family/friend": 5, + "Institution": 6, + "Temporary second residence": 7, + "Band-owned home": 8, + "Homeless or transient": 9, + "Emergency hostel": 10, }, { - "Renting-private": 1, "Renting-subsidized": 2, - "Boarding or lodging": 3, "Homeowner": 4, - "Living with family/friend": 5, "Institution": 6, - "Temporary second residence": 7, "Band-owned home": 8, - "Homeless or transient": 9, "Emergency hostel": 10 - }, - { - "No Source of Income": 1, "Employment Insurance": 2, + "No Source of Income": 1, + "Employment Insurance": 2, "Workplace Safety and Insurance Board": 3, "Ontario Works applied or receiving": 4, "Ontario Disability Support Program applied or receiving": 5, - "Dependent of someone receiving OW or ODSP": 6, "Crown Ward": 7, - "Employment": 8, "Self-Employment": 9, "Other (specify)": 10 - } + "Dependent of someone receiving OW or ODSP": 6, + "Crown Ward": 7, + "Employment": 8, + "Self-Employment": 9, + "Other (specify)": 10, + }, ] for category in categorical_mappings: if text_data in category: @@ -102,6 +136,7 @@ def convert_text(text_data: str): return int(text_data) if text_data.isnumeric() else text_data + def create_matrix(row_data): """ Create matrix of all possible intervention combinations. @@ -116,6 +151,7 @@ def create_matrix(row_data): perms = intervention_permutations(7) return np.concatenate((np.array(data), np.array(perms)), axis=1) + def intervention_permutations(num): """ Generate all possible intervention combinations. @@ -128,6 +164,7 @@ def intervention_permutations(num): """ return np.array(list(product([0, 1], repeat=num))) + def get_baseline_row(row_data): """ Create baseline row with no interventions. @@ -141,6 +178,7 @@ def get_baseline_row(row_data): base_interventions = np.zeros(7) return np.concatenate((np.array(row_data), base_interventions)) + def intervention_row_to_names(row_data): """ Convert intervention row to list of intervention names. @@ -153,6 +191,7 @@ def intervention_row_to_names(row_data): """ return [COLUMN_INTERVENTIONS[i] for i, value in enumerate(row_data) if value == 1] + def process_results(baseline_pred, results_matrix): """ Process model results into structured output. @@ -165,13 +204,10 @@ def process_results(baseline_pred, results_matrix): dict: Processed results with baseline and interventions """ result_list = [ - (row[-1], intervention_row_to_names(row[:-1])) - for row in results_matrix + (row[-1], intervention_row_to_names(row[:-1])) for row in results_matrix ] - return { - "baseline": baseline_pred[-1], - "interventions": result_list - } + return {"baseline": baseline_pred[-1], "interventions": result_list} + def interpret_and_calculate(input_data): """ @@ -188,25 +224,41 @@ def interpret_and_calculate(input_data): intervention_rows = create_matrix(raw_data) baseline_prediction = MODEL.predict(baseline_row) intervention_predictions = MODEL.predict(intervention_rows).reshape(-1, 1) - result_matrix = np.concatenate((intervention_rows, intervention_predictions), axis=1) + result_matrix = np.concatenate( + (intervention_rows, intervention_predictions), axis=1 + ) result_order = result_matrix[:, -1].argsort() result_matrix = result_matrix[result_order] top_results = result_matrix[-3:, -8:] return process_results(baseline_prediction, top_results) + if __name__ == "__main__": test_data = { - "age": "23", "gender": "1", "work_experience": "1", - "canada_workex": "1", "dep_num": "0", "canada_born": "1", - "citizen_status": "2", "level_of_schooling": "2", - "fluent_english": "3", "reading_english_scale": "2", - "speaking_english_scale": "2", "writing_english_scale": "3", - "numeracy_scale": "2", "computer_scale": "3", - "transportation_bool": "2", "caregiver_bool": "1", - "housing": "1", "income_source": "5", "felony_bool": "1", - "attending_school": "0", "currently_employed": "1", - "substance_use": "1", "time_unemployed": "1", - "need_mental_health_support_bool": "1" + "age": "23", + "gender": "1", + "work_experience": "1", + "canada_workex": "1", + "dep_num": "0", + "canada_born": "1", + "citizen_status": "2", + "level_of_schooling": "2", + "fluent_english": "3", + "reading_english_scale": "2", + "speaking_english_scale": "2", + "writing_english_scale": "3", + "numeracy_scale": "2", + "computer_scale": "3", + "transportation_bool": "2", + "caregiver_bool": "1", + "housing": "1", + "income_source": "5", + "felony_bool": "1", + "attending_school": "0", + "currently_employed": "1", + "substance_use": "1", + "time_unemployed": "1", + "need_mental_health_support_bool": "1", } results = interpret_and_calculate(test_data) print(results) diff --git a/app/clients/service/model.py b/app/clients/service/model.py index b2406370..9f906788 100644 --- a/app/clients/service/model.py +++ b/app/clients/service/model.py @@ -12,73 +12,72 @@ from sklearn.model_selection import train_test_split from sklearn.ensemble import RandomForestRegressor + def prepare_models(): """ Prepare and train the Random Forest model using the dataset. - + Returns: RandomForestRegressor: Trained model for predicting success rates """ # Load dataset - data = pd.read_csv('data_commontool.csv') + data = pd.read_csv("data_commontool.csv") # Define feature columns feature_columns = [ - 'age', # Client's age - 'gender', # Client's gender (bool) - 'work_experience', # Years of work experience - 'canada_workex', # Years of work experience in Canada - 'dep_num', # Number of dependents - 'canada_born', # Born in Canada - 'citizen_status', # Citizenship status - 'level_of_schooling', # Highest level achieved (1-14) - 'fluent_english', # English fluency scale (1-10) - 'reading_english_scale', # Reading ability scale (1-10) - 'speaking_english_scale',# Speaking ability scale (1-10) - 'writing_english_scale', # Writing ability scale (1-10) - 'numeracy_scale', # Numeracy ability scale (1-10) - 'computer_scale', # Computer proficiency scale (1-10) - 'transportation_bool', # Needs transportation support (bool) - 'caregiver_bool', # Is primary caregiver (bool) - 'housing', # Housing situation (1-10) - 'income_source', # Source of income (1-10) - 'felony_bool', # Has a felony (bool) - 'attending_school', # Currently a student (bool) - 'currently_employed', # Currently employed (bool) - 'substance_use', # Substance use disorder (bool) - 'time_unemployed', # Years unemployed - 'need_mental_health_support_bool' # Needs mental health support (bool) + "age", # Client's age + "gender", # Client's gender (bool) + "work_experience", # Years of work experience + "canada_workex", # Years of work experience in Canada + "dep_num", # Number of dependents + "canada_born", # Born in Canada + "citizen_status", # Citizenship status + "level_of_schooling", # Highest level achieved (1-14) + "fluent_english", # English fluency scale (1-10) + "reading_english_scale", # Reading ability scale (1-10) + "speaking_english_scale", # Speaking ability scale (1-10) + "writing_english_scale", # Writing ability scale (1-10) + "numeracy_scale", # Numeracy ability scale (1-10) + "computer_scale", # Computer proficiency scale (1-10) + "transportation_bool", # Needs transportation support (bool) + "caregiver_bool", # Is primary caregiver (bool) + "housing", # Housing situation (1-10) + "income_source", # Source of income (1-10) + "felony_bool", # Has a felony (bool) + "attending_school", # Currently a student (bool) + "currently_employed", # Currently employed (bool) + "substance_use", # Substance use disorder (bool) + "time_unemployed", # Years unemployed + "need_mental_health_support_bool", # Needs mental health support (bool) ] # Define intervention columns intervention_columns = [ - 'employment_assistance', - 'life_stabilization', - 'retention_services', - 'specialized_services', - 'employment_related_financial_supports', - 'employer_financial_supports', - 'enhanced_referrals' + "employment_assistance", + "life_stabilization", + "retention_services", + "specialized_services", + "employment_related_financial_supports", + "employer_financial_supports", + "enhanced_referrals", ] # Combine all feature columns all_features = feature_columns + intervention_columns # Prepare training data features = np.array(data[all_features]) # Changed from X to features - targets = np.array(data['success_rate']) # Changed from y to targets + targets = np.array(data["success_rate"]) # Changed from y to targets # Split the dataset features_train, _, targets_train, _ = train_test_split( # Removed unused variables - features, - targets, - test_size=0.2, - random_state=42 + features, targets, test_size=0.2, random_state=42 ) # Initialize and train the model model = RandomForestRegressor(n_estimators=100, random_state=42) model.fit(features_train, targets_train) return model + def save_model(model, filename="model.pkl"): """ Save the trained model to a file. - + Args: model: Trained model to save filename (str): Name of the file to save the model to @@ -86,19 +85,21 @@ def save_model(model, filename="model.pkl"): with open(filename, "wb") as model_file: pickle.dump(model, model_file) + def load_model(filename="model.pkl"): """ Load a trained model from a file. - + Args: filename (str): Name of the file to load the model from - + Returns: The loaded model """ with open(filename, "rb") as model_file: return pickle.load(model_file) + def main(): """Main function to train and save the model.""" print("Starting model training...") @@ -106,5 +107,6 @@ def main(): save_model(model) print("Model training completed and saved successfully.") + if __name__ == "__main__": main() diff --git a/app/database.py b/app/database.py index 3a489f54..6a800078 100644 --- a/app/database.py +++ b/app/database.py @@ -7,22 +7,25 @@ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker -#Here is where the database is located -SQLALCHEMY_DATABASE_URL = "sqlite:///./sql_app.db" +# Here is where the database is located +SQLALCHEMY_DATABASE_URL = "sqlite:///./sql_app.db" -#Open up a connection so that we are able to use the database -engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}) +# Open up a connection so that we are able to use the database +engine = create_engine( + SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False} +) -#Bind the engine just created +# Bind the engine just created SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) -#Create an object of our database so as to control the database +# Create an object of our database so as to control the database Base = declarative_base() + def get_db(): """ Creates a database session and ensures it's closed after use. - + Yields: Session: SQLAlchemy database session """ diff --git a/app/enums.py b/app/enums.py new file mode 100644 index 00000000..93c00f1a --- /dev/null +++ b/app/enums.py @@ -0,0 +1,15 @@ +""" +Enum definitions for the Common Assessment Tool. +Contains enumerations used across the application. +""" + +import enum + + +class UserRole(str, enum.Enum): + """ + User Role class defining possible user roles in the system. + """ + + admin = "admin" + case_worker = "case_worker" diff --git a/app/main.py b/app/main.py index a8e8fa7f..1fc21e18 100644 --- a/app/main.py +++ b/app/main.py @@ -4,28 +4,37 @@ Handles database initialization and CORS middleware configuration. """ +# Related third-party imports from fastapi import FastAPI -from app import models -from app.database import engine +from fastapi.middleware.cors import CORSMiddleware + +# Local application/library specific imports +from app.database import engine, Base # Add Base here +from app.ml.router import router as models_router # newly added from app.clients.router import router as clients_router from app.auth.router import router as auth_router -from fastapi.middleware.cors import CORSMiddleware + # Initialize database tables -models.Base.metadata.create_all(bind=engine) +Base.metadata.create_all(bind=engine) # Create FastAPI application -app = FastAPI(title="Case Management API", description="API for managing client cases", version="1.0.0") +app = FastAPI( + title="Case Management API", + description="API for managing client cases", + version="1.0.0", +) # Include routers +app.include_router(models_router) app.include_router(auth_router) app.include_router(clients_router) # Configure CORS middleware app.add_middleware( CORSMiddleware, - allow_origins=["*"], # Allows all origins - allow_methods=["*"], # Allows all methods - allow_headers=["*"], # Allows all headers - allow_credentials=True, + allow_origins=["*"], # type: ignore + allow_methods=["*"], # type: ignore + allow_headers=["*"], # type: ignore + allow_credentials=True, # type: ignore ) diff --git a/app/ml/__init__.py b/app/ml/__init__.py new file mode 100644 index 00000000..b5ffa4ee --- /dev/null +++ b/app/ml/__init__.py @@ -0,0 +1,28 @@ +# app/ml/__init__.py +""" +Machine learning package initialization. +""" +from app.ml.model_registry import ModelRegistry +from app.ml.models.random_forest import RandomForestModel +from app.ml.models.gradient_boost import GradientBoostingModel +from app.ml.models.linear_regression import LinearRegressionModel + + +# Initialize the model registry with available models +def initialize_models(): + """Register all available models with the registry.""" + registry = ModelRegistry() + + # Register the models + registry.register_model("RandomForest", RandomForestModel()) + registry.register_model("GradientBoosting", GradientBoostingModel()) + registry.register_model("LinearRegression", LinearRegressionModel()) + + # Set the default model + registry.set_current_model("RandomForest") + + return registry + + +# Initialize models when importing this package +registry = initialize_models() diff --git a/app/ml/base_model.py b/app/ml/base_model.py new file mode 100644 index 00000000..2994136b --- /dev/null +++ b/app/ml/base_model.py @@ -0,0 +1,37 @@ +# app/ml/base_model.py +""" +Base model interface for machine learning models. +All model implementations should inherit from this class. +""" +from abc import ABC, abstractmethod +import pickle + + +class BaseModel(ABC): + """Abstract base class for all ML models.""" + + @abstractmethod + def train(self, features, targets): + """Train the model with the given features and targets.""" + pass + + @abstractmethod + def predict(self, features): + """Make predictions using the trained model.""" + pass + + @abstractmethod + def get_name(self): + """Return the name of the model.""" + pass + + def save(self, filename): + """Save the model to a file.""" + with open(filename, "wb") as model_file: + pickle.dump(self, model_file) + + @classmethod + def load(cls, filename): + """Load the model from a file.""" + with open(filename, "rb") as model_file: + return pickle.load(model_file) diff --git a/app/ml/data_processor.py b/app/ml/data_processor.py new file mode 100644 index 00000000..30a72072 --- /dev/null +++ b/app/ml/data_processor.py @@ -0,0 +1,84 @@ +# app/ml/data_processor.py +""" +Data processing module for machine learning models. +Handles loading and preprocessing data for model training and prediction. +""" +import pandas as pd +import numpy as np +from sklearn.model_selection import train_test_split + + +class DataProcessor: + """Handles data loading and preprocessing for ML models.""" + + def __init__(self, data_file="data_commontool.csv"): + """Initialize with the data file path.""" + self.data_file = data_file + self.feature_columns = [ + "age", + "gender", + "work_experience", + "canada_workex", + "dep_num", + "canada_born", + "citizen_status", + "level_of_schooling", + "fluent_english", + "reading_english_scale", + "speaking_english_scale", + "writing_english_scale", + "numeracy_scale", + "computer_scale", + "transportation_bool", + "caregiver_bool", + "housing", + "income_source", + "felony_bool", + "attending_school", + "currently_employed", + "substance_use", + "time_unemployed", + "need_mental_health_support_bool", + ] + self.intervention_columns = [ + "employment_assistance", + "life_stabilization", + "retention_services", + "specialized_services", + "employment_related_financial_supports", + "employer_financial_supports", + "enhanced_referrals", + ] + + def load_data(self): + """Load the dataset from file.""" + return pd.read_csv(self.data_file) + + def prepare_training_data(self): + """Prepare features and targets for model training.""" + data = self.load_data() + all_features = self.feature_columns + self.intervention_columns + + features = np.array(data[all_features]) + targets = np.array(data["success_rate"]) + + features_train, features_test, targets_train, targets_test = train_test_split( + features, targets, test_size=0.2, random_state=42 + ) + + return features_train, features_test, targets_train, targets_test + + def prepare_prediction_data(self, client_data, interventions): + """Prepare a single client's data for prediction.""" + # Extract client features from the client data + client_features = [client_data.get(col, 0) for col in self.feature_columns] + + # Add intervention features + intervention_features = [ + interventions.get(col, 0) for col in self.intervention_columns + ] + + # Combine all features + all_features = client_features + intervention_features + + return np.array([all_features]) diff --git a/app/ml/model_registry.py b/app/ml/model_registry.py new file mode 100644 index 00000000..7e91b164 --- /dev/null +++ b/app/ml/model_registry.py @@ -0,0 +1,48 @@ +# app/ml/model_registry.py +""" +Model registry for managing different ML models. +""" + + +class ModelRegistry: + """Registry for managing multiple machine learning models.""" + + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super(ModelRegistry, cls).__new__(cls) + cls._instance._models = {} + cls._instance._current_model_name = None + return cls._instance + + def register_model(self, model_name, model_class): + """Register a new model with the registry.""" + self._models[model_name] = model_class + if self._current_model_name is None: + self._current_model_name = model_name + + def get_model(self, model_name=None): + """Get a model by name, or the current model if no name is provided.""" + if model_name is None: + model_name = self._current_model_name + + if model_name not in self._models: + raise ValueError(f"Model {model_name} is not registered") + + return self._models[model_name] + + def set_current_model(self, model_name): + """Set the current active model.""" + if model_name not in self._models: + raise ValueError(f"Model {model_name} is not registered") + + self._current_model_name = model_name + + def get_current_model_name(self): + """Get the name of the currently active model.""" + return self._current_model_name + + def list_available_models(self): + """List all available models in the registry.""" + return list(self._models.keys()) diff --git a/app/ml/models/__init__.py b/app/ml/models/__init__.py new file mode 100644 index 00000000..4b1fa474 --- /dev/null +++ b/app/ml/models/__init__.py @@ -0,0 +1,9 @@ +# app/ml/models/__init__.py +""" +Machine learning models package. +""" +from app.ml.models.random_forest import RandomForestModel +from app.ml.models.gradient_boost import GradientBoostingModel +from app.ml.models.linear_regression import LinearRegressionModel + +__all__ = ["RandomForestModel", "GradientBoostingModel", "LinearRegressionModel"] diff --git a/app/ml/models/gradient_boost.py b/app/ml/models/gradient_boost.py new file mode 100644 index 00000000..191dcc6e --- /dev/null +++ b/app/ml/models/gradient_boost.py @@ -0,0 +1,35 @@ +# app/ml/models/gradient_boost.py +""" +Gradient Boosting implementation for success rate prediction. +""" +from sklearn.ensemble import GradientBoostingRegressor +from app.ml.base_model import BaseModel + + +class GradientBoostingModel(BaseModel): + """Gradient Boosting model for predicting success rates.""" + + def __init__(self, n_estimators=100, learning_rate=0.1, random_state=42): + """Initialize the model with parameters.""" + self.model = GradientBoostingRegressor( + n_estimators=n_estimators, + learning_rate=learning_rate, + random_state=random_state, + ) + self.is_trained = False + + def train(self, features, targets): + """Train the model with the given features and targets.""" + self.model.fit(features, targets) + self.is_trained = True + return self + + def predict(self, features): + """Make predictions using the trained model.""" + if not self.is_trained: + raise ValueError("Model must be trained before making predictions") + return self.model.predict(features) + + def get_name(self): + """Return the name of the model.""" + return "GradientBoosting" diff --git a/app/ml/models/linear_regression.py b/app/ml/models/linear_regression.py new file mode 100644 index 00000000..67799637 --- /dev/null +++ b/app/ml/models/linear_regression.py @@ -0,0 +1,31 @@ +# app/ml/models/linear_regression.py +""" +Linear Regression implementation for success rate prediction. +""" +from sklearn.linear_model import LinearRegression +from app.ml.base_model import BaseModel + + +class LinearRegressionModel(BaseModel): + """Linear Regression model for predicting success rates.""" + + def __init__(self): + """Initialize the model.""" + self.model = LinearRegression() + self.is_trained = False + + def train(self, features, targets): + """Train the model with the given features and targets.""" + self.model.fit(features, targets) + self.is_trained = True + return self + + def predict(self, features): + """Make predictions using the trained model.""" + if not self.is_trained: + raise ValueError("Model must be trained before making predictions") + return self.model.predict(features) + + def get_name(self): + """Return the name of the model.""" + return "LinearRegression" diff --git a/app/ml/models/random_forest.py b/app/ml/models/random_forest.py new file mode 100644 index 00000000..d2c7fe74 --- /dev/null +++ b/app/ml/models/random_forest.py @@ -0,0 +1,33 @@ +# app/ml/models/random_forest.py +""" +Random Forest implementation for success rate prediction. +""" +from sklearn.ensemble import RandomForestRegressor +from app.ml.base_model import BaseModel + + +class RandomForestModel(BaseModel): + """Random Forest model for predicting success rates.""" + + def __init__(self, n_estimators=100, random_state=42): + """Initialize the model with parameters.""" + self.model = RandomForestRegressor( + n_estimators=n_estimators, random_state=random_state + ) + self.is_trained = False + + def train(self, features, targets): + """Train the model with the given features and targets.""" + self.model.fit(features, targets) + self.is_trained = True + return self + + def predict(self, features): + """Make predictions using the trained model.""" + if not self.is_trained: + raise ValueError("Model must be trained before making predictions") + return self.model.predict(features) + + def get_name(self): + """Return the name of the model.""" + return "RandomForest" diff --git a/app/ml/router.py b/app/ml/router.py new file mode 100644 index 00000000..dbf6c56e --- /dev/null +++ b/app/ml/router.py @@ -0,0 +1,96 @@ +# app/ml/router.py +""" +Router for ML model management endpoints. +""" +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel + +from app.ml.model_registry import ModelRegistry +import numpy as np + +router = APIRouter( + prefix="/models", + tags=["models"], + responses={404: {"description": "Not found"}}, +) + + +class ModelResponse(BaseModel): + """Response model for model endpoints.""" + + name: str + + +class ModelsListResponse(BaseModel): + """Response model for listing available models.""" + + models: list[str] + current_model: str + + +@router.get("/current", response_model=ModelResponse) +async def get_current_model(): + """Get the currently active model.""" + registry = ModelRegistry() + current_model = registry.get_current_model_name() + return {"name": current_model} + + +@router.get("/available", response_model=ModelsListResponse) +async def list_available_models(): + """List all available models.""" + registry = ModelRegistry() + available_models = registry.list_available_models() + current_model = registry.get_current_model_name() + return {"models": available_models, "current_model": current_model} + + +@router.post("/switch/{model_name}", response_model=ModelResponse) +async def switch_model(model_name: str): + """Switch to a different model.""" + registry = ModelRegistry() + + try: + registry.set_current_model(model_name) + except ValueError: + raise HTTPException( + status_code=404, + detail=f"Model '{model_name}' not found. Available models: {registry.list_available_models()}", + ) + + return {"name": model_name} + + +@router.get("/test-prediction", response_model=dict) +async def test_prediction(): + """ + Test endpoint to demonstrate how predictions change when switching models. + Returns a prediction using the currently selected model. + """ + try: + # Create consistent test data + test_features = np.ones((1, 31)) * 0.5 # Use the same test data each time + + # Get the current model from registry + registry = ModelRegistry() + current_model = registry.get_model() + model_name = registry.get_current_model_name() + + # Train the model with some dummy data if it's not trained + if not hasattr(current_model, "is_trained") or not current_model.is_trained: + # Create dummy training data + X_train = np.random.rand(10, 31) + y_train = np.random.rand(10) + current_model.train(X_train, y_train) + + # Make a prediction + prediction = float(current_model.predict(test_features)[0]) + + return {"model": model_name, "prediction": prediction, "status": "success"} + except Exception as e: + # Return error information instead of letting it crash + return { + "model": registry.get_current_model_name(), + "error": str(e), + "status": "error", + } diff --git a/app/ml/test_models.py b/app/ml/test_models.py new file mode 100644 index 00000000..9c7f1396 --- /dev/null +++ b/app/ml/test_models.py @@ -0,0 +1,57 @@ +# app/ml/test_models.py +""" +Test script for ML models and model registry. +""" +import numpy as np +from app.ml.model_registry import ModelRegistry +from app.ml.models.random_forest import RandomForestModel +from app.ml.models.gradient_boost import GradientBoostingModel +from app.ml.models.linear_regression import LinearRegressionModel + + +def test_model_registry(): + """Test the model registry and model switching.""" + print("Testing model registry and model switching...") + + # Create the registry + registry = ModelRegistry() + + # Register models + registry.register_model("RandomForest", RandomForestModel()) + registry.register_model("GradientBoosting", GradientBoostingModel()) + registry.register_model("LinearRegression", LinearRegressionModel()) + + # List available models + models = registry.list_available_models() + print(f"Available models: {models}") + + # Get current model + current_model_name = registry.get_current_model_name() + print(f"Current model: {current_model_name}") + + # Switch models + registry.set_current_model("GradientBoosting") + print(f"Switched to model: {registry.get_current_model_name()}") + + # Create some dummy data for testing + features = np.random.rand(10, 31) # 24 client features + 7 intervention features + targets = np.random.rand(10) + + # Train all models + for model_name in models: + print(f"Training {model_name}...") + model = registry.get_model(model_name) + model.train(features, targets) + + # Make predictions with each model + for model_name in models: + print(f"Testing predictions with {model_name}...") + model = registry.get_model(model_name) + predictions = model.predict(features[:1]) + print(f" Prediction: {predictions[0]}") + + print("Model testing complete!") + + +if __name__ == "__main__": + test_model_registry() diff --git a/app/models.py b/app/models.py deleted file mode 100644 index df778348..00000000 --- a/app/models.py +++ /dev/null @@ -1,77 +0,0 @@ -""" -Database models module defining SQLAlchemy ORM models for the Common Assessment Tool. -Contains the Client model for storing client information in the database. -""" - -from app.database import Base -from sqlalchemy import Column, Integer, String, Boolean, ForeignKey, CheckConstraint, Enum -from sqlalchemy.orm import relationship -import enum - -class UserRole(str, enum.Enum): - admin = "admin" - case_worker = "case_worker" - - -class User(Base): - __tablename__ = "users" - - id = Column(Integer, primary_key=True, autoincrement=True) - username = Column(String(50), unique=True, nullable=False) - email = Column(String(100), unique=True, nullable=False) - hashed_password = Column(String(200), nullable=False) - role = Column(Enum(UserRole), nullable=False) - - cases = relationship("ClientCase", back_populates="user") - -class Client(Base): - """ - Client model representing client data in the database. - """ - __tablename__ = "clients" - - id = Column(Integer, primary_key=True, autoincrement=True) - age = Column(Integer, CheckConstraint('age >= 18')) - gender = Column(Integer, CheckConstraint("gender = 1 OR gender = 2")) - work_experience = Column(Integer, CheckConstraint('work_experience >= 0')) - canada_workex = Column(Integer, CheckConstraint('canada_workex >= 0')) - dep_num = Column(Integer, CheckConstraint('dep_num >= 0')) - canada_born = Column(Boolean) - citizen_status = Column(Boolean) - level_of_schooling = Column(Integer, CheckConstraint('level_of_schooling >= 1 AND level_of_schooling <= 14')) - fluent_english = Column(Boolean) - reading_english_scale = Column(Integer, CheckConstraint('reading_english_scale >= 0 AND reading_english_scale <= 10')) - speaking_english_scale = Column(Integer, CheckConstraint('speaking_english_scale >= 0 AND speaking_english_scale <= 10')) - writing_english_scale = Column(Integer, CheckConstraint('writing_english_scale >= 0 AND writing_english_scale <= 10')) - numeracy_scale = Column(Integer, CheckConstraint('numeracy_scale >= 0 AND numeracy_scale <= 10')) - computer_scale = Column(Integer, CheckConstraint('computer_scale >= 0 AND computer_scale <= 10')) - transportation_bool = Column(Boolean) - caregiver_bool = Column(Boolean) - housing = Column(Integer, CheckConstraint('housing >= 1 AND housing <= 10')) - income_source = Column(Integer, CheckConstraint('income_source >= 1 AND income_source <= 11')) - felony_bool = Column(Boolean) - attending_school = Column(Boolean) - currently_employed = Column(Boolean) - substance_use = Column(Boolean) - time_unemployed = Column(Integer, CheckConstraint('time_unemployed >= 0')) - need_mental_health_support_bool = Column(Boolean) - - cases = relationship("ClientCase", back_populates="client") - -class ClientCase(Base): - __tablename__ = "client_cases" - - client_id = Column(Integer, ForeignKey("clients.id"), primary_key=True) - user_id = Column(Integer, ForeignKey("users.id"), primary_key=True) - - employment_assistance = Column(Boolean) - life_stabilization = Column(Boolean) - retention_services = Column(Boolean) - specialized_services = Column(Boolean) - employment_related_financial_supports = Column(Boolean) - employer_financial_supports = Column(Boolean) - enhanced_referrals = Column(Boolean) - success_rate = Column(Integer, CheckConstraint('success_rate >= 0 AND success_rate <= 100')) - - client = relationship("Client", back_populates="cases") - user = relationship("User", back_populates="cases") diff --git a/app/models/__init__.py b/app/models/__init__.py new file mode 100644 index 00000000..e73a157c --- /dev/null +++ b/app/models/__init__.py @@ -0,0 +1,10 @@ +# app/models/__init__.py +""" +Database models package for the Common Assessment Tool. +""" + +from app.models.user import User +from app.models.client import Client +from app.models.relationships import ClientCase + +__all__ = ["User", "Client", "ClientCase"] diff --git a/app/models/client.py b/app/models/client.py new file mode 100644 index 00000000..53384858 --- /dev/null +++ b/app/models/client.py @@ -0,0 +1,74 @@ +# app/models/client.py +""" +Client-related database models for the Common Assessment Tool. +Contains models for storing client information. +""" + +from sqlalchemy import Column, Integer, Boolean +from sqlalchemy.orm import relationship + +from app.database import Base +from app.validators import ( + age_constraint, + gender_constraint, + experience_constraint, + school_level_constraint, + scale_constraint, + housing_constraint, + income_source_constraint, +) + + +class Client(Base): + """ + Represents a Client in the database. + Stores personal and case-related information for each client. + """ + + __tablename__ = "clients" + + id = Column(Integer, primary_key=True, autoincrement=True) + age = Column(Integer) + gender = Column(Integer) + work_experience = Column(Integer) + canada_workex = Column(Integer) + dep_num = Column(Integer) + canada_born = Column(Boolean) + citizen_status = Column(Boolean) + level_of_schooling = Column(Integer) + fluent_english = Column(Boolean) + reading_english_scale = Column(Integer) + speaking_english_scale = Column(Integer) + writing_english_scale = Column(Integer) + numeracy_scale = Column(Integer) + computer_scale = Column(Integer) + transportation_bool = Column(Boolean) + caregiver_bool = Column(Boolean) + housing = Column(Integer) + income_source = Column(Integer) + felony_bool = Column(Boolean) + attending_school = Column(Boolean) + currently_employed = Column(Boolean) + substance_use = Column(Boolean) + time_unemployed = Column(Integer) + need_mental_health_support_bool = Column(Boolean) + + cases = relationship("ClientCase", back_populates="client") + + # Apply constraints + __table_args__ = ( + age_constraint(), + gender_constraint(), + experience_constraint("work_experience"), + experience_constraint("canada_workex"), + experience_constraint("dep_num"), + school_level_constraint(), + scale_constraint("reading_english_scale"), + scale_constraint("speaking_english_scale"), + scale_constraint("writing_english_scale"), + scale_constraint("numeracy_scale"), + scale_constraint("computer_scale"), + housing_constraint(), + income_source_constraint(), + experience_constraint("time_unemployed"), + ) diff --git a/app/models/relationships.py b/app/models/relationships.py new file mode 100644 index 00000000..477b3300 --- /dev/null +++ b/app/models/relationships.py @@ -0,0 +1,38 @@ +# app/models/relationships.py +""" +Relationship models for the Common Assessment Tool. +Contains models that connect other entities in the system. +""" + +from sqlalchemy import Column, Integer, Boolean, ForeignKey +from sqlalchemy.orm import relationship + +from app.database import Base +from app.validators import success_rate_constraint + + +class ClientCase(Base): + """ + Represents the relationship between a client and a case worker. + Stores service information and outcomes for each client-user relationship. + """ + + __tablename__ = "client_cases" + + client_id = Column(Integer, ForeignKey("clients.id"), primary_key=True) + user_id = Column(Integer, ForeignKey("users.id"), primary_key=True) + + employment_assistance = Column(Boolean) + life_stabilization = Column(Boolean) + retention_services = Column(Boolean) + specialized_services = Column(Boolean) + employment_related_financial_supports = Column(Boolean) + employer_financial_supports = Column(Boolean) + enhanced_referrals = Column(Boolean) + success_rate = Column(Integer) + + client = relationship("Client", back_populates="cases") + user = relationship("User", back_populates="cases") + + # Apply constraints + __table_args__ = (success_rate_constraint(),) diff --git a/app/models/user.py b/app/models/user.py new file mode 100644 index 00000000..10c48cdd --- /dev/null +++ b/app/models/user.py @@ -0,0 +1,35 @@ +# app/models/user.py +""" +User-related database models for the Common Assessment Tool. +Contains models related to system users and authentication. +""" + +from sqlalchemy import Column, Integer, String, Enum +from sqlalchemy.orm import relationship + +from app.database import Base +from app.enums import UserRole +from app.validators import username_length_constraint, email_format_constraint + + +class User(Base): + """ + Represents a User in the database. + Stores user details including authentication and roles. + """ + + __tablename__ = "users" + + id = Column(Integer, primary_key=True, autoincrement=True) + username = Column(String(50), unique=True, nullable=False) + email = Column(String(100), unique=True, nullable=False) + hashed_password = Column(String(200), nullable=False) + role = Column(Enum(UserRole), nullable=False) # type: ignore + + cases = relationship("ClientCase", back_populates="user") + + # Apply User-specific constraints + __table_args__ = ( + username_length_constraint(), + email_format_constraint(), + ) diff --git a/app/validators.py b/app/validators.py new file mode 100644 index 00000000..bce8c4b7 --- /dev/null +++ b/app/validators.py @@ -0,0 +1,84 @@ +""" +Validation utilities for the Common Assessment Tool. +Contains functions and constants for validating model data. +""" + +from sqlalchemy import CheckConstraint + +# Constants for validation +MIN_AGE = 18 +GENDER_VALUES = [1, 2] +MIN_SCALE_VALUE = 0 +MAX_SCALE_VALUE = 10 +MIN_SCHOOLING_LEVEL = 1 +MAX_SCHOOLING_LEVEL = 14 +MIN_HOUSING_LEVEL = 1 +MAX_HOUSING_LEVEL = 10 +MIN_INCOME_SOURCE = 1 +MAX_INCOME_SOURCE = 11 +MIN_SUCCESS_RATE = 0 +MAX_SUCCESS_RATE = 100 + + +# Constraint creation functions +def age_constraint(): + """Create a constraint to ensure age is at least 18.""" + return CheckConstraint(f"age >= {MIN_AGE}") + + +def gender_constraint(): + """Create a constraint to ensure gender is a valid value.""" + return CheckConstraint("gender = 1 OR gender = 2") + + +def experience_constraint(field_name): + """Create a constraint to ensure experience is not negative.""" + return CheckConstraint(f"{field_name} >= 0") + + +def school_level_constraint(): + """Create a constraint to ensure schooling level is within valid range.""" + return CheckConstraint( + f"level_of_schooling >= {MIN_SCHOOLING_LEVEL} AND " + f"level_of_schooling <= {MAX_SCHOOLING_LEVEL}" + ) + + +def scale_constraint(field_name): + """Create a constraint to ensure scale values are within 0-10 range.""" + return CheckConstraint( + f"{field_name} >= {MIN_SCALE_VALUE} AND " f"{field_name} <= {MAX_SCALE_VALUE}" + ) + + +def housing_constraint(): + """Create a constraint to ensure housing value is valid.""" + return CheckConstraint( + f"housing >= {MIN_HOUSING_LEVEL} AND " f"housing <= {MAX_HOUSING_LEVEL}" + ) + + +def income_source_constraint(): + """Create a constraint to ensure income source value is valid.""" + return CheckConstraint( + f"income_source >= {MIN_INCOME_SOURCE} AND " + f"income_source <= {MAX_INCOME_SOURCE}" + ) + + +def success_rate_constraint(): + """Create a constraint to ensure success rate is within valid range.""" + return CheckConstraint( + f"success_rate >= {MIN_SUCCESS_RATE} AND " f"success_rate <= {MAX_SUCCESS_RATE}" + ) + + +def username_length_constraint(): + """Create a constraint to ensure username has a minimum length.""" + MIN_USERNAME_LENGTH = 3 + return CheckConstraint(f"LENGTH(username) >= {MIN_USERNAME_LENGTH}") + + +def email_format_constraint(): + """Create a constraint to ensure email contains @ symbol.""" + return CheckConstraint("email LIKE '%@%'") diff --git a/code_quality_README.md b/code_quality_README.md new file mode 100644 index 00000000..41cc5f5b --- /dev/null +++ b/code_quality_README.md @@ -0,0 +1,177 @@ +# Code Quality Analysis + +## Setup + +1. Create a virtual environment: + `python -m venv venv` + +2. Activate the virtual environment: + - Windows: `venv\Scripts\activate` + - Mac/Linux: `source venv/bin/activate` + +## Installation + +Install development tools with: +`pip install -r requirements-dev.txt` + +## Running Analysis + +Run these commands to analyze code quality: 2. `flake8 app/ > flake8_report.txt` 3. `black --check --diff app/ > black_report.txt` 4. `mypy app/ > mypy_report.txt` + +## Installation of Dependency on Main Code: + +`pip install -r ./app/requirements.txt` + +## To check the code using pre-commit: + +``` +pre-commit autoupdate +pre-commit run --all-files +``` + +## Refactoring for SOLID Principles + +### Completed Refactorings: + +- Extracted UserRole enum to a separate file (Single Responsibility Principle) + + - Created app/enums.py for all enumeration types + - Improved code organization and reusability + +- Extracted validation logic from models (Single Responsibility Principle) + + - Created app/validators.py with dedicated validation functions + - Defined constants for validation boundaries + - Made validation rules more maintainable and consistent + - Simplified model classes to focus on structure, not validation + +- Split database models into domain-specific files (Single Responsibility Principle) + + - Created separate files for User, Client, and relationship models + - Organized models by domain area in app/models/ directory + - Fixed imports to maintain compatibility with existing code + - Improved code organization and maintainability + - Made domain model boundaries clearer + +- Refactored ML model implementation (Story 2 preparation) + + - Created a base model interface following the Interface Segregation Principle + - Implemented multiple concrete model classes (Random Forest, Gradient Boosting, Linear Regression) + - Added a model registry implementing the Open/Closed Principle for easy model extension + - Separated data processing from model logic following Single Responsibility Principle + - Created a test script to verify model switching functionality + +- Refactored ClientService into two focused service classes: ClientQueryService and ClientMutationService + + - Applied the Single Responsibility Principle (SRP) to separate read and write operations + - Followed Separation of Concerns to isolate query logic from mutation logic + - Removed duplicated filtering code with a shared criteria filter method + +- Refactored the authentication system using modular service classes + - Applied the Single Responsibility Principle by isolating password logic (SecurityService), token logic (TokenService), and user operations (UserService) + - Followed Separation of Concerns to decouple route handlers from business logic + - Prepared the architecture for future enhancements such as role-based access control, refresh tokens, and password policies + +## Testing the ML Models + +To test the machine learning model implementation and model switching capability: + +1. Ensure you have installed the required dependencies: + pip install numpy pandas scikit-learn + +2. Run the test script: + python -m app.ml.test_models + +3. The test will: + +- Initialize the model registry +- Register three different model types +- Display available models +- Switch between models +- Train each model with sample data +- Make predictions with each model + +## Story 2 Implementation: Multiple ML Model Support + +### Completed Features: + +- Implemented a flexible ML model architecture with model switching capability +- Created a base model interface that all model types implement +- Developed three different model implementations: + - Random Forest (default model) + - Gradient Boosting + - Linear Regression +- Created a model registry to manage different model types and switching +- Implemented API endpoints for model management + +### API Endpoints for Model Management + +The following endpoints are now available for interacting with ML models: + +1. **Get Current Model** + + - Endpoint: `GET /models/current` + - Description: Returns the name of the currently active model + - Response example: `{"name": "RandomForest"}` + - Test command: `curl -X GET "http://127.0.0.1:8000/models/current"` + +2. **List Available Models** + + - Endpoint: `GET /models/available` + - Description: Returns all available models and indicates the current model + - Response example: `{"models": ["RandomForest", "GradientBoosting", "LinearRegression"], "current_model": "RandomForest"}` + - Test command: `curl -X GET "http://127.0.0.1:8000/models/available"` + +3. **Switch Model** + - Endpoint: `POST /models/switch/{model_name}` + - Description: Changes the active model to the specified model + - Response example: `{"name": "GradientBoosting"}` + - Test command: `curl -X POST "http://127.0.0.1:8000/models/switch/GradientBoosting"` + +### Testing the Model Switching Functionality + +You can test the model switching functionality in two ways: + +#### 1. Using the Swagger UI: + +- Start the application: `uvicorn app.main:app --reload` +- Navigate to: `http://127.0.0.1:8000/docs` +- Scroll to the "models" section +- Try each endpoint by clicking "Try it out" and "Execute" + +#### 2. Using the ML Test Script: + +- Run: `python -m app.ml.test_models` +- This will test model registration, switching, and prediction capability + +### Implementation Details + +- **Model Registry Pattern**: Used a singleton registry to manage models +- **Strategy Pattern**: Each model implementation follows the same interface +- **Dependency Injection**: High-level components use abstractions rather than concrete implementations +- **Open for Extension**: New models can be added without modifying existing code + +This implementation satisfies all requirements for Story 2 while following SOLID principles. + + +### predictions + To demonstrate the effect of model switching on predictions, a test endpoint has been added: + + - **Test Prediction Endpoint**: + - Path: `GET /models/test-prediction` + - Purpose: Shows how predictions change when switching between models + - Response: Returns the current model name and its prediction for test data + + **How to demonstrate model switching:** + + 1. Call `GET /models/test-prediction` with the default model + - Note the prediction value (e.g., `{"model": "RandomForest", "prediction": 0.45, "status": "success"}`) + + 2. Switch to a different model using `POST /models/switch/GradientBoosting` + - Response confirms the switch: `{"name": "GradientBoosting"}` + + 3. Call `GET /models/test-prediction` again + - Note the different prediction: `{"model": "GradientBoosting", "prediction": 0.72, "status": "success"}` + + 4. Switch to the third model using `POST /models/switch/LinearRegression` + - Call the test endpoint again to see a third different prediction \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 00000000..6a1e9704 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,23 @@ +version: '3.8' +services: + app: + build: . + ports: + - "8000:8000" + depends_on: + - db + environment: + - DATABASE_HOST=db + - DATABASE_USER=user + - DATABASE_PASSWORD=password + - DATABASE_NAME=dbname + db: + image: postgres:latest + environment: + - POSTGRES_DB=dbname + - POSTGRES_USER=user + - POSTGRES_PASSWORD=password + volumes: + - db_data:/var/lib/postgresql/data +volumes: + db_data: diff --git a/entrypoint.sh b/entrypoint.sh new file mode 100755 index 00000000..ef6bf213 --- /dev/null +++ b/entrypoint.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +# Wait for the database to be ready (adjust as needed for your DB setup) +echo "Waiting for the database to be ready..." +sleep 10 # This is a simple sleep; in a production scenario, use a loop with checks. + +# Run data initialization script +echo "Initializing data..." +python initialize_data.py + +# Start the main application +echo "Starting the application..." +exec uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload diff --git a/initialize_data.py b/initialize_data.py index 1444bf41..ed476cb1 100644 --- a/initialize_data.py +++ b/initialize_data.py @@ -1,8 +1,20 @@ import pandas as pd -from sqlalchemy.orm import Session from app.database import SessionLocal -from app.models import Client, User, ClientCase, UserRole -from app.auth.router import get_password_hash +from app.models import Client, User, ClientCase +from app.enums import UserRole + +# from app.auth.router import get_password_hash + +from passlib.context import CryptContext + +# Create password context for hashing +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + + +# Function to hash passwords +def get_password_hash(password: str) -> str: + return pwd_context.hash(password) + def initialize_database(): print("Starting database initialization...") @@ -15,7 +27,7 @@ def initialize_database(): username="admin", email="admin@example.com", hashed_password=get_password_hash("admin123"), - role=UserRole.admin + role=UserRole.admin, ) db.add(admin_user) db.commit() @@ -30,7 +42,7 @@ def initialize_database(): username="case_worker1", email="caseworker1@example.com", hashed_password=get_password_hash("worker123"), - role=UserRole.case_worker + role=UserRole.case_worker, ) db.add(case_worker) db.commit() @@ -40,46 +52,59 @@ def initialize_database(): # Load CSV data print("Loading CSV data...") - df = pd.read_csv('app/clients/service/data_commontool.csv') - + df = pd.read_csv("app/clients/service/data_commontool.csv") + # Convert data types integer_columns = [ - 'age', 'gender', 'work_experience', 'canada_workex', 'dep_num', - 'level_of_schooling', 'reading_english_scale', 'speaking_english_scale', - 'writing_english_scale', 'numeracy_scale', 'computer_scale', - 'housing', 'income_source', 'time_unemployed', 'success_rate' + "age", + "gender", + "work_experience", + "canada_workex", + "dep_num", + "level_of_schooling", + "reading_english_scale", + "speaking_english_scale", + "writing_english_scale", + "numeracy_scale", + "computer_scale", + "housing", + "income_source", + "time_unemployed", + "success_rate", ] for col in integer_columns: - df[col] = pd.to_numeric(df[col], errors='raise') + df[col] = pd.to_numeric(df[col], errors="raise") # Process each row in CSV for index, row in df.iterrows(): # Create client client = Client( - age=int(row['age']), - gender=int(row['gender']), - work_experience=int(row['work_experience']), - canada_workex=int(row['canada_workex']), - dep_num=int(row['dep_num']), - canada_born=bool(row['canada_born']), - citizen_status=bool(row['citizen_status']), - level_of_schooling=int(row['level_of_schooling']), - fluent_english=bool(row['fluent_english']), - reading_english_scale=int(row['reading_english_scale']), - speaking_english_scale=int(row['speaking_english_scale']), - writing_english_scale=int(row['writing_english_scale']), - numeracy_scale=int(row['numeracy_scale']), - computer_scale=int(row['computer_scale']), - transportation_bool=bool(row['transportation_bool']), - caregiver_bool=bool(row['caregiver_bool']), - housing=int(row['housing']), - income_source=int(row['income_source']), - felony_bool=bool(row['felony_bool']), - attending_school=bool(row['attending_school']), - currently_employed=bool(row['currently_employed']), - substance_use=bool(row['substance_use']), - time_unemployed=int(row['time_unemployed']), - need_mental_health_support_bool=bool(row['need_mental_health_support_bool']) + age=int(row["age"]), + gender=int(row["gender"]), + work_experience=int(row["work_experience"]), + canada_workex=int(row["canada_workex"]), + dep_num=int(row["dep_num"]), + canada_born=bool(row["canada_born"]), + citizen_status=bool(row["citizen_status"]), + level_of_schooling=int(row["level_of_schooling"]), + fluent_english=bool(row["fluent_english"]), + reading_english_scale=int(row["reading_english_scale"]), + speaking_english_scale=int(row["speaking_english_scale"]), + writing_english_scale=int(row["writing_english_scale"]), + numeracy_scale=int(row["numeracy_scale"]), + computer_scale=int(row["computer_scale"]), + transportation_bool=bool(row["transportation_bool"]), + caregiver_bool=bool(row["caregiver_bool"]), + housing=int(row["housing"]), + income_source=int(row["income_source"]), + felony_bool=bool(row["felony_bool"]), + attending_school=bool(row["attending_school"]), + currently_employed=bool(row["currently_employed"]), + substance_use=bool(row["substance_use"]), + time_unemployed=int(row["time_unemployed"]), + need_mental_health_support_bool=bool( + row["need_mental_health_support_bool"] + ), ) db.add(client) db.commit() @@ -88,14 +113,16 @@ def initialize_database(): client_case = ClientCase( client_id=client.id, user_id=admin_user.id, # Assign to admin - employment_assistance=bool(row['employment_assistance']), - life_stabilization=bool(row['life_stabilization']), - retention_services=bool(row['retention_services']), - specialized_services=bool(row['specialized_services']), - employment_related_financial_supports=bool(row['employment_related_financial_supports']), - employer_financial_supports=bool(row['employer_financial_supports']), - enhanced_referrals=bool(row['enhanced_referrals']), - success_rate=int(row['success_rate']) + employment_assistance=bool(row["employment_assistance"]), + life_stabilization=bool(row["life_stabilization"]), + retention_services=bool(row["retention_services"]), + specialized_services=bool(row["specialized_services"]), + employment_related_financial_supports=bool( + row["employment_related_financial_supports"] + ), + employer_financial_supports=bool(row["employer_financial_supports"]), + enhanced_referrals=bool(row["enhanced_referrals"]), + success_rate=int(row["success_rate"]), ) db.add(client_case) db.commit() @@ -108,5 +135,6 @@ def initialize_database(): finally: db.close() + if __name__ == "__main__": - initialize_database() \ No newline at end of file + initialize_database() diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 00000000..4e84a0cc --- /dev/null +++ b/mypy.ini @@ -0,0 +1,7 @@ +[mypy] +exclude = tests/ +ignore_missing_imports = True +explicit_package_bases = True +namespace_packages = True +[mypy-tests.conftest] +ignore_errors = True \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..973e45cf --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,40 @@ +[tool.black] +line-length = 88 +include = '\.pyi?$' +exclude = ''' +/( + \.git + | \.venv + | \.env + | \.eggs + | \.mypy_cache + | \.pytest_cache +)/ +''' + +[build-system] +requires = ["setuptools>=42", "wheel"] +build-backend = "setuptools.build_meta" + +[tool.poetry] +name = "common-assessment-tool" +version = "0.1.0" +description = "A FastAPI application for managing assessments." +license = "MIT" + +[tool.poetry.dependencies] +python = "^3.8" +fastapi = "^0.68.0" +uvicorn = "^0.15.0" +sqlalchemy = "^1.4.22" +psycopg2-binary = "^2.9.1" +python-dotenv = "^0.19.0" + +[tool.poetry.dev-dependencies] +pytest = "^6.2" +black = "^21.7b0" +flake8 = "^3.9" +mypy = "^0.910" + + + diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 00000000..821951d7 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,5 @@ +pylint==2.17.5 +flake8==6.1.0 +black==23.7.0 +mypy==1.5.1 +python-jose>=3.3.0 diff --git a/sprint2_README.md b/sprint2_README.md new file mode 100644 index 00000000..e23dd096 --- /dev/null +++ b/sprint2_README.md @@ -0,0 +1,77 @@ +Sprint 2: + +## Installing Docker + +Docker is required to containerize and run the backend application. Follow the instructions below based on your operating system to install Docker. + +### Linux Installation + +To install Docker on Linux, execute the following commands in your terminal: + +```bash +# Download and run the Docker installation script +curl -fsSL https://get.docker.com -o get-docker.sh +sudo sh get-docker.sh + +# Add your user to the Docker group to manage Docker as a non-root user +sudo usermod -aG docker ${USER} +newgrp docker +``` +### Macos Installation + +# Install Docker on macOS using Homebrew +`brew install --cask docker` + +### Window: +# Install Docker on Windows using Chocolatey +`choco install docker-desktop` + +### Verify Version: +`docker --version` + +### Or Using Official Website: +### Installing Docker + +1. **Visit the Docker website**: Go to [Docker's official website](https://docs.docker.com/get-docker/) to download Docker Desktop. +2. **Download Docker**: Choose the version of Docker Desktop that corresponds to your operating system (Windows/MacOS). +3. **Install Docker**: Open the downloaded file and follow the installation instructions. +4. **Start Docker**: Once installed, run Docker from your applications folder. Docker needs to be running for the application to work. + +### Type the following command to start the application using Docker-Compose: + + +`docker-compose up --build` +This command tells Docker to prepare and run the application. It may take a few minutes to complete, especially the first time. + +### Run the programming using Docker Run: +`docker build -t my-app-name . ` +> +`docker run -p 8000:8000 my-app-name` + + +### 5. Accessing the Application +Once the application is running, open your web browser and visit: + +```http://localhost:8000/docs``` +### 6. Stopping the Application +When you are done using the application, you can stop it by going back to your terminal or command prompt, pressing `Ctrl+C`, and then typing: +`docker-compose down` + +## Continuous Integration (CI) Pipeline + +Our project utilizes GitHub Actions to automate the testing and deployment processes, ensuring that every change pushed to the repository maintains code quality and stability before deployment. + +### CI Workflow Overview + +- **Trigger Events**: The CI pipeline is triggered by any push or pull request to the `main` branch, ensuring that all changes are thoroughly tested. + +- **Build and Test**: Every change to the repository initiates the following actions: + - **Code Checkout**: The latest version of the code is checked out. + - **Environment Setup**: The Python environment is set up with the necessary dependencies installed. + - **Linting**: The codebase is linted using tools like Flake8 and Black to ensure adherence to coding standards. + - **Automated Tests**: Our comprehensive test suite runs via pytest to catch any potential bugs introduced. + +### Docker Integration + +- **Docker Build**: A Docker image of the application is built to verify that the application can be containerized without issues. +- **Docker Run**: The Docker container is run to ensure it starts correctly and serves content as expected, particularly from the `/docs` endpoint using a curl command to ensure the Swagger UI is loaded correctly. diff --git a/tests/conftest.py b/tests/conftest.py index aa30d094..2551339b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,28 +3,34 @@ from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from app.database import Base, get_db +from app.enums import UserRole from app.main import app -from app.auth.router import get_password_hash -from app.models import User, UserRole, Client, ClientCase +from app.models import User, Client, ClientCase +from app.auth.router import SecurityService + # Create test database SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db" -engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}) +engine = create_engine( + SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False} +) TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +# type: ignore @pytest.fixture def test_db(): # Create tables Base.metadata.create_all(bind=engine) - + security_service = SecurityService() db = TestingSessionLocal() try: # Create test admin user admin_user = User( username="testadmin", email="testadmin@example.com", - hashed_password=get_password_hash("testpass123"), - role=UserRole.admin + hashed_password=security_service.get_password_hash("testpass123"), + role=UserRole.admin, ) db.add(admin_user) @@ -32,11 +38,11 @@ def test_db(): case_worker = User( username="testworker", email="worker@example.com", - hashed_password=get_password_hash("workerpass123"), - role=UserRole.case_worker + hashed_password=security_service.get_password_hash("workerpass123"), + role=UserRole.case_worker, ) db.add(case_worker) - + # Create test clients client1 = Client( age=25, @@ -62,9 +68,9 @@ def test_db(): currently_employed=False, substance_use=False, time_unemployed=6, - need_mental_health_support_bool=False + need_mental_health_support_bool=False, ) - + client2 = Client( age=30, gender=2, @@ -89,13 +95,13 @@ def test_db(): currently_employed=True, substance_use=False, time_unemployed=0, - need_mental_health_support_bool=False + need_mental_health_support_bool=False, ) - + db.add(client1) db.add(client2) db.commit() - + # Create test client cases client_case1 = ClientCase( client_id=1, @@ -107,9 +113,9 @@ def test_db(): employment_related_financial_supports=True, employer_financial_supports=False, enhanced_referrals=True, - success_rate=75 + success_rate=75, ) - + client_case2 = ClientCase( client_id=2, user_id=2, # Assigned to case worker @@ -120,18 +126,19 @@ def test_db(): employment_related_financial_supports=False, employer_financial_supports=True, enhanced_referrals=False, - success_rate=85 + success_rate=85, ) - + db.add(client_case1) db.add(client_case2) db.commit() - + yield db finally: db.close() Base.metadata.drop_all(bind=engine) + @pytest.fixture def client(test_db): def override_get_db(): @@ -139,32 +146,33 @@ def override_get_db(): yield test_db finally: test_db.close() - + app.dependency_overrides[get_db] = override_get_db yield TestClient(app) app.dependency_overrides.clear() + @pytest.fixture def admin_token(client): response = client.post( - "/auth/token", - data={"username": "testadmin", "password": "testpass123"} + "/auth/token", data={"username": "testadmin", "password": "testpass123"} ) return response.json()["access_token"] + @pytest.fixture def case_worker_token(client): response = client.post( - "/auth/token", - data={"username": "testworker", "password": "workerpass123"} + "/auth/token", data={"username": "testworker", "password": "workerpass123"} ) return response.json()["access_token"] + @pytest.fixture def admin_headers(admin_token): return {"Authorization": f"Bearer {admin_token}"} + @pytest.fixture def case_worker_headers(case_worker_token): return {"Authorization": f"Bearer {case_worker_token}"} - \ No newline at end of file diff --git a/tests/test_auth.py b/tests/test_auth.py index 1d4692e4..304febdf 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -1,123 +1,113 @@ import pytest from fastapi import status + def test_create_user_success(client, admin_headers): """Test successful user creation by admin""" user_data = { "username": "newuser", "email": "new@test.com", "password": "testpass123", - "role": "case_worker" + "role": "case_worker", } - response = client.post( - "/auth/users", - headers=admin_headers, - json=user_data - ) + response = client.post("/auth/users", headers=admin_headers, json=user_data) assert response.status_code == status.HTTP_200_OK data = response.json() assert data["username"] == "newuser" assert data["role"] == "case_worker" assert "password" not in data # Password should not be in response + def test_create_user_duplicate_username(client, admin_headers): """Test creating user with existing username""" user_data = { "username": "testadmin", # This username exists in test database "email": "another@test.com", "password": "testpass123", - "role": "case_worker" + "role": "case_worker", } - response = client.post( - "/auth/users", - headers=admin_headers, - json=user_data - ) + response = client.post("/auth/users", headers=admin_headers, json=user_data) assert response.status_code == status.HTTP_400_BAD_REQUEST assert "Username already registered" in response.json()["detail"] + def test_create_user_duplicate_email(client, admin_headers): """Test creating user with existing email""" user_data = { "username": "uniqueuser", "email": "testadmin@example.com", # This email exists in test database "password": "testpass123", - "role": "case_worker" + "role": "case_worker", } - response = client.post( - "/auth/users", - headers=admin_headers, - json=user_data - ) + response = client.post("/auth/users", headers=admin_headers, json=user_data) assert response.status_code == status.HTTP_400_BAD_REQUEST assert "Email already registered" in response.json()["detail"] + def test_create_user_invalid_role(client, admin_headers): """Test creating user with invalid role""" user_data = { "username": "newuser", "email": "new@test.com", "password": "testpass123", - "role": "invalid_role" # Invalid role + "role": "invalid_role", # Invalid role } - response = client.post( - "/auth/users", - headers=admin_headers, - json=user_data - ) + response = client.post("/auth/users", headers=admin_headers, json=user_data) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + def test_create_user_unauthorized(client): """Test user creation without authentication""" user_data = { "username": "newuser", "email": "new@test.com", "password": "testpass123", - "role": "case_worker" + "role": "case_worker", } response = client.post("/auth/users", json=user_data) assert response.status_code == status.HTTP_401_UNAUTHORIZED + def test_login_success_admin(client): """Test successful login for admin""" response = client.post( - "/auth/token", - data={"username": "testadmin", "password": "testpass123"} + "/auth/token", data={"username": "testadmin", "password": "testpass123"} ) assert response.status_code == status.HTTP_200_OK data = response.json() assert "access_token" in data assert data["token_type"] == "bearer" + def test_login_success_case_worker(client): """Test successful login for case worker""" response = client.post( - "/auth/token", - data={"username": "testworker", "password": "workerpass123"} + "/auth/token", data={"username": "testworker", "password": "workerpass123"} ) assert response.status_code == status.HTTP_200_OK data = response.json() assert "access_token" in data assert data["token_type"] == "bearer" + def test_login_wrong_password(client): """Test login with incorrect password""" response = client.post( - "/auth/token", - data={"username": "testadmin", "password": "wrongpassword"} + "/auth/token", data={"username": "testadmin", "password": "wrongpassword"} ) assert response.status_code == status.HTTP_401_UNAUTHORIZED assert "Incorrect username or password" in response.json()["detail"] + def test_login_nonexistent_user(client): """Test login with non-existent username""" response = client.post( - "/auth/token", - data={"username": "nonexistent", "password": "testpass123"} + "/auth/token", data={"username": "nonexistent", "password": "testpass123"} ) assert response.status_code == status.HTTP_401_UNAUTHORIZED assert "Incorrect username or password" in response.json()["detail"] + def test_invalid_token(client): """Test using invalid token""" headers = {"Authorization": "Bearer invalid_token_here"} @@ -125,12 +115,14 @@ def test_invalid_token(client): assert response.status_code == status.HTTP_401_UNAUTHORIZED assert "Could not validate credentials" in response.json()["detail"] + def test_missing_token(client): """Test accessing protected endpoint without token""" response = client.get("/clients/") assert response.status_code == status.HTTP_401_UNAUTHORIZED assert "Not authenticated" in response.json()["detail"] + def test_token_user_deleted(client, admin_headers): """Test using token of deleted user""" # First create a new user as admin @@ -138,22 +130,17 @@ def test_token_user_deleted(client, admin_headers): "username": "temporary", "email": "temp@test.com", "password": "temppass123", - "role": "admin" # Changed to admin so they can access /clients/ + "role": "admin", # Changed to admin so they can access /clients/ } - response = client.post( - "/auth/users", - headers=admin_headers, - json=user_data - ) + response = client.post("/auth/users", headers=admin_headers, json=user_data) assert response.status_code == status.HTTP_200_OK # Get token for new user response = client.post( - "/auth/token", - data={"username": "temporary", "password": "temppass123"} + "/auth/token", data={"username": "temporary", "password": "temppass123"} ) token = response.json()["access_token"] - + # Try using the token headers = {"Authorization": f"Bearer {token}"} response = client.get("/clients/", headers=headers) diff --git a/tests/test_clients.py b/tests/test_clients.py index 611a5b34..40b233c2 100644 --- a/tests/test_clients.py +++ b/tests/test_clients.py @@ -1,12 +1,14 @@ import pytest from fastapi import status + # Test GET Operations def test_get_clients_unauthorized(client): """Test that unauthorized access is prevented""" response = client.get("/clients/") assert response.status_code == status.HTTP_401_UNAUTHORIZED + def test_get_clients_as_admin(client, admin_headers): """Test getting all clients as admin""" response = client.get("/clients/", headers=admin_headers) @@ -16,24 +18,24 @@ def test_get_clients_as_admin(client, admin_headers): assert "total" in data assert len(data["clients"]) > 0 + def test_get_client_by_id(client, admin_headers): """Test getting specific client""" # Test existing client response = client.get("/clients/1", headers=admin_headers) assert response.status_code == status.HTTP_200_OK assert response.json()["id"] == 1 - + # Test non-existent client response = client.get("/clients/999", headers=admin_headers) assert response.status_code == status.HTTP_404_NOT_FOUND + def test_get_clients_by_criteria(client, admin_headers): """Test searching clients by various criteria""" # Test single criterion response = client.get( - "/clients/search/by-criteria", - params={"age_min": 25}, - headers=admin_headers + "/clients/search/by-criteria", params={"age_min": 25}, headers=admin_headers ) assert response.status_code == status.HTTP_200_OK assert len(response.json()) > 0 @@ -41,12 +43,8 @@ def test_get_clients_by_criteria(client, admin_headers): # Test multiple criteria response = client.get( "/clients/search/by-criteria", - params={ - "age_min": 25, - "currently_employed": True, - "gender": 2 - }, - headers=admin_headers + params={"age_min": 25, "currently_employed": True, "gender": 2}, + headers=admin_headers, ) assert response.status_code == status.HTTP_200_OK @@ -54,23 +52,24 @@ def test_get_clients_by_criteria(client, admin_headers): response = client.get( "/clients/search/by-criteria", params={"age_min": 15}, # Below minimum age - headers=admin_headers + headers=admin_headers, ) - assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY # Changed from 400 + assert ( + response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + ) # Changed from 400 + def test_get_clients_by_services(client, admin_headers): """Test getting clients by service status""" response = client.get( "/clients/search/by-services", - params={ - "employment_assistance": True, - "life_stabilization": True - }, - headers=admin_headers + params={"employment_assistance": True, "life_stabilization": True}, + headers=admin_headers, ) assert response.status_code == status.HTTP_200_OK assert len(response.json()) > 0 + def test_get_client_services(client, admin_headers): """Test getting services for a specific client""" response = client.get("/clients/1/services", headers=admin_headers) @@ -81,52 +80,46 @@ def test_get_client_services(client, admin_headers): assert "employment_assistance" in services[0] assert "success_rate" in services[0] + def test_get_clients_by_success_rate(client, admin_headers): """Test getting clients by success rate threshold""" response = client.get( - "/clients/search/success-rate", - params={"min_rate": 70}, - headers=admin_headers + "/clients/search/success-rate", params={"min_rate": 70}, headers=admin_headers ) assert response.status_code == status.HTTP_200_OK assert len(response.json()) > 0 + def test_get_clients_by_case_worker(client, admin_headers, case_worker_headers): """Test getting clients assigned to a case worker""" # Test as admin response = client.get("/clients/case-worker/2", headers=admin_headers) assert response.status_code == status.HTTP_200_OK - + # Test as case worker response = client.get("/clients/case-worker/2", headers=case_worker_headers) assert response.status_code == status.HTTP_200_OK + # Test UPDATE Operations def test_update_client(client, admin_headers): """Test updating client information""" - update_data = { - "age": 26, - "currently_employed": True, - "time_unemployed": 0 - } - response = client.put( - "/clients/1", - json=update_data, - headers=admin_headers - ) + update_data = {"age": 26, "currently_employed": True, "time_unemployed": 0} + response = client.put("/clients/1", json=update_data, headers=admin_headers) assert response.status_code == status.HTTP_200_OK updated_client = response.json() assert updated_client["age"] == 26 assert updated_client["currently_employed"] == True assert updated_client["time_unemployed"] == 0 + # Test Create Case Assignment def test_create_case_assignment(client, admin_headers): """Test creating new case assignment""" response = client.post( "/clients/1/case-assignment", params={"case_worker_id": 2}, - headers=admin_headers + headers=admin_headers, ) assert response.status_code == status.HTTP_200_OK @@ -134,10 +127,11 @@ def test_create_case_assignment(client, admin_headers): response = client.post( "/clients/1/case-assignment", params={"case_worker_id": 2}, - headers=admin_headers + headers=admin_headers, ) assert response.status_code == status.HTTP_400_BAD_REQUEST + # Test DELETE Operation def test_delete_client(client, admin_headers): """Test deleting a client"""