From af2654782175518c58a005b770d7cb598d63a75f Mon Sep 17 00:00:00 2001 From: Helen Wang Date: Tue, 18 Mar 2025 22:03:10 -0700 Subject: [PATCH 01/23] Add code quality tools setup and instructions --- .gitignore | 6 +++++- code_quality_README.md | 20 ++++++++++++++++++++ requirements-dev.txt | 4 ++++ 3 files changed, 29 insertions(+), 1 deletion(-) create mode 100644 code_quality_README.md create mode 100644 requirements-dev.txt diff --git a/.gitignore b/.gitignore index 371e45c1..ae97a7e6 100644 --- a/.gitignore +++ b/.gitignore @@ -31,4 +31,8 @@ build/ htmlcov/ .tox/ .coverage -.coverage.* \ No newline at end of file +.coverage.* + +# Reports +*_report.txt +reports/ diff --git a/code_quality_README.md b/code_quality_README.md new file mode 100644 index 00000000..7d1ceab8 --- /dev/null +++ b/code_quality_README.md @@ -0,0 +1,20 @@ +# Code Quality Analysis + +## Setup +1. Create a virtual environment: + `python -m venv venv` + +2. Activate the virtual environment: + - Windows: `venv\Scripts\activate` + - Mac/Linux: `source venv/bin/activate` + +## Installation +Install development tools with: +`pip install -r requirements-dev.txt` + +## Running Analysis +Run these commands to analyze code quality: +1. `pylint app/ > pylint_report.txt` +2. `flake8 app/ > flake8_report.txt` +3. `black --check --diff app/ > black_report.txt` +4. `mypy app/ > mypy_report.txt` diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 00000000..51b8e8ab --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,4 @@ +pylint==2.17.5 +flake8==6.1.0 +black==23.7.0 +mypy==1.5.1 From a9abac98783bc04e0e4636e8159a3fa2984fe4f8 Mon Sep 17 00:00:00 2001 From: johnsonli010801 <97709195+johnsonli010801@users.noreply.github.com> Date: Wed, 19 Mar 2025 19:20:00 -0700 Subject: [PATCH 02/23] Update code and automated tools, add configuration files and half way finish refactoring --- .flake8 | 3 + .github/workflows/ci.yml | 9 ++ .pre-commit-config.yaml | 17 +++ app/auth/router.py | 41 ++++--- app/clients/router.py | 77 +++++++++---- app/clients/schema.py | 30 +++-- app/clients/service/client_service.py | 158 +++++++++++++------------- app/clients/service/logic.py | 158 +++++++++++++++++--------- app/clients/service/model.py | 84 +++++++------- app/database.py | 17 +-- app/main.py | 17 ++- app/models.py | 85 ++++++++++---- code_quality_README.md | 10 +- initialize_data.py | 102 ++++++++++------- mypy.ini | 2 + pyproject.toml | 40 +++++++ requirements-dev.txt | 1 + tests/conftest.py | 47 ++++---- tests/test_auth.py | 71 +++++------- tests/test_clients.py | 60 +++++----- 20 files changed, 638 insertions(+), 391 deletions(-) create mode 100644 .flake8 create mode 100644 .pre-commit-config.yaml create mode 100644 mypy.ini create mode 100644 pyproject.toml diff --git a/.flake8 b/.flake8 new file mode 100644 index 00000000..041c5f07 --- /dev/null +++ b/.flake8 @@ -0,0 +1,3 @@ +[flake8] +max-line-length = 120 +exclude = .git,__pycache__,venv,tests/* diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 30c81bdb..6aee7c50 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -24,8 +24,17 @@ jobs: python -m pip install --upgrade pip # Upgrade pip to the latest version pip install setuptools wheel pip install -r requirements.txt # Install dependencies from requirements.txt + pip install -r requirements-dev.txt pip install pylint pytest + + - name: Run Flake8 + run: flake8 . + - name: Run Black + run: black . --check + + - name: Run MyPy + run: mypy . - name: Run Tests run: | python -m pytest tests/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..49b66366 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,17 @@ +repos: + - repo: https://github.com/psf/black + rev: 25.1.0 # Use the latest stable version; check the GitHub releases for the latest version + hooks: + - id: black + language_version: python3.12 # Ensure this matches your project's Python version + + - repo: https://github.com/PyCQA/flake8 + rev: '7.1.2' # Use the latest stable release; check the GitHub releases for the latest version + hooks: + - id: flake8 + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: 'v1.15.0' # Use the latest stable release; check the GitHub releases for the latest version + hooks: + - id: mypy + diff --git a/app/auth/router.py b/app/auth/router.py index 229ee71d..35433b54 100644 --- a/app/auth/router.py +++ b/app/auth/router.py @@ -11,18 +11,20 @@ router = APIRouter(prefix="/auth", tags=["authentication"]) + class UserCreate(BaseModel): username: str = Field(..., min_length=3, max_length=50) email: str password: str role: UserRole - @validator('role') + @validator("role") def validate_role(cls, v): if v not in [UserRole.admin, UserRole.case_worker]: - raise ValueError('Role must be either admin or case_worker') + raise ValueError("Role must be either admin or case_worker") return v + class UserResponse(BaseModel): username: str email: str @@ -31,6 +33,7 @@ class UserResponse(BaseModel): class Config: from_attributes = True + # Configuration SECRET_KEY = "your-secret-key-here" ALGORITHM = "HS256" @@ -39,18 +42,22 @@ class Config: pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/token") + def verify_password(plain_password: str, hashed_password: str) -> bool: return pwd_context.verify(plain_password, hashed_password) + def get_password_hash(password: str) -> str: return pwd_context.hash(password) + def authenticate_user(db: Session, username: str, password: str) -> Optional[User]: user = db.query(User).filter(User.username == username).first() if not user or not verify_password(password, user.hashed_password): return None return user + def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): to_encode = data.copy() if expires_delta: @@ -61,9 +68,9 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) return encoded_jwt + async def get_current_user( - token: str = Depends(oauth2_scheme), - db: Session = Depends(get_db) + token: str = Depends(oauth2_scheme), db: Session = Depends(get_db) ) -> User: credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -77,24 +84,25 @@ async def get_current_user( raise credentials_exception except JWTError: raise credentials_exception - + user = db.query(User).filter(User.username == username).first() if user is None: raise credentials_exception return user + def get_admin_user(current_user: User = Depends(get_current_user)): if current_user.role != UserRole.admin: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail="Only admin users can perform this operation" + detail="Only admin users can perform this operation", ) return current_user + @router.post("/token") async def login_for_access_token( - form_data: OAuth2PasswordRequestForm = Depends(), - db: Session = Depends(get_db) + form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db) ): user = authenticate_user(db, form_data.username, form_data.password) if not user: @@ -109,25 +117,25 @@ async def login_for_access_token( ) return {"access_token": access_token, "token_type": "bearer"} + @router.post("/users", response_model=UserResponse) async def create_user( user_data: UserCreate, current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """Create a new user (admin only)""" # Check if username exists if db.query(User).filter(User.username == user_data.username).first(): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Username already registered" + detail="Username already registered", ) - + # Check if email exists if db.query(User).filter(User.email == user_data.email).first(): raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Email already registered" + status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered" ) # Create new user @@ -135,9 +143,9 @@ async def create_user( username=user_data.username, email=user_data.email, hashed_password=get_password_hash(user_data.password), - role=user_data.role + role=user_data.role, ) - + try: db.add(db_user) db.commit() @@ -146,6 +154,5 @@ async def create_user( except Exception as e: db.rollback() raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=str(e) + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) ) diff --git a/app/clients/router.py b/app/clients/router.py index 4ecc83e4..02708963 100644 --- a/app/clients/router.py +++ b/app/clients/router.py @@ -3,42 +3,59 @@ Handles all HTTP requests for client operations including create, read, update, and delete. """ -from fastapi import APIRouter, Depends, HTTPException, status, Query -from sqlalchemy.orm import Session from typing import List, Optional -from app.auth.router import get_current_user, get_admin_user -from app.models import User, UserRole +from fastapi import APIRouter, Depends, status, Query +from sqlalchemy.orm import Session + +from app.auth.router import get_current_user, get_admin_user from app.database import get_db +from app.models import User from app.clients.service.client_service import ClientService from app.clients.schema import ( - ClientResponse, - ClientUpdate, + ClientResponse, + ClientUpdate, ClientListResponse, ServiceResponse, - ServiceUpdate + ServiceUpdate, ) + router = APIRouter(prefix="/clients", tags=["clients"]) + @router.get("/", response_model=ClientListResponse) async def get_clients( - current_user: User = Depends(get_admin_user), + current_user: User = Depends(get_admin_user), skip: int = Query(default=0, ge=0, description="Number of records to skip"), - limit: int = Query(default=50, ge=1, le=150, description="Maximum number of records to return"), - db: Session = Depends(get_db) + limit: int = Query( + default=50, ge=1, le=150, description="Maximum number of records to return" + ), + db: Session = Depends(get_db), ): + """ + Retrieve a list of clients with optional pagination. + Parameters: + current_user: User object of the currently authenticated user. + skip: The number of items to skip (for pagination). + limit: The maximum number of items to return. + db: Database session dependency. + Returns: + A list of clients according to the specified pagination rules. + """ return ClientService.get_clients(db, skip, limit) + @router.get("/{client_id}", response_model=ClientResponse) async def get_client( client_id: int, current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """Get a specific client by ID""" return ClientService.get_client(db, client_id) + @router.get("/search/by-criteria", response_model=List[ClientResponse]) async def get_clients_by_criteria( employment_status: Optional[bool] = None, @@ -66,7 +83,7 @@ async def get_clients_by_criteria( time_unemployed: Optional[int] = Query(None, ge=0), need_mental_health_support_bool: Optional[bool] = None, current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """Search clients by any combination of criteria""" return ClientService.get_clients_by_criteria( @@ -94,9 +111,10 @@ async def get_clients_by_criteria( attending_school=attending_school, substance_use=substance_use, time_unemployed=time_unemployed, - need_mental_health_support_bool=need_mental_health_support_bool + need_mental_health_support_bool=need_mental_health_support_bool, ) + @router.get("/search/by-services", response_model=List[ClientResponse]) async def get_clients_by_services( employment_assistance: Optional[bool] = None, @@ -107,7 +125,7 @@ async def get_clients_by_services( employer_financial_supports: Optional[bool] = None, enhanced_referrals: Optional[bool] = None, current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """Get clients filtered by multiple service statuses""" return ClientService.get_clients_by_services( @@ -118,70 +136,79 @@ async def get_clients_by_services( specialized_services=specialized_services, employment_related_financial_supports=employment_related_financial_supports, employer_financial_supports=employer_financial_supports, - enhanced_referrals=enhanced_referrals + enhanced_referrals=enhanced_referrals, ) + @router.get("/{client_id}/services", response_model=List[ServiceResponse]) async def get_client_services( client_id: int, current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """Get all services and their status for a specific client, including case worker info""" return ClientService.get_client_services(db, client_id) + @router.get("/search/success-rate", response_model=List[ClientResponse]) async def get_clients_by_success_rate( - min_rate: int = Query(70, ge=0, le=100, description="Minimum success rate percentage"), + min_rate: int = Query( + 70, ge=0, le=100, description="Minimum success rate percentage" + ), current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """Get clients with success rate above specified threshold""" return ClientService.get_clients_by_success_rate(db, min_rate) + @router.get("/case-worker/{case_worker_id}", response_model=List[ClientResponse]) async def get_clients_by_case_worker( case_worker_id: int, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ): return ClientService.get_clients_by_case_worker(db, case_worker_id) + @router.put("/{client_id}", response_model=ClientResponse) async def update_client( client_id: int, client_data: ClientUpdate, current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """Update a client's information""" return ClientService.update_client(db, client_id, client_data) + @router.put("/{client_id}/services/{user_id}", response_model=ServiceResponse) async def update_client_services( client_id: int, user_id: int, service_update: ServiceUpdate, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ): return ClientService.update_client_services(db, client_id, user_id, service_update) + @router.post("/{client_id}/case-assignment", response_model=ServiceResponse) async def create_case_assignment( client_id: int, case_worker_id: int = Query(..., description="Case worker ID to assign"), current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """Create a new case assignment for a client with a case worker""" return ClientService.create_case_assignment(db, client_id, case_worker_id) + @router.delete("/{client_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_client( client_id: int, current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """Delete a client""" ClientService.delete_client(db, client_id) diff --git a/app/clients/schema.py b/app/clients/schema.py index cff28897..a2e1ba4b 100644 --- a/app/clients/schema.py +++ b/app/clients/schema.py @@ -4,21 +4,23 @@ """ # Standard library imports -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel, Field from typing import Optional, List from enum import IntEnum -from app.models import UserRole + # Enums for validation class Gender(IntEnum): MALE = 1 FEMALE = 2 + class PredictionInput(BaseModel): """ Schema for prediction input data containing all client assessment fields. Used for making predictions about client outcomes. """ + age: int gender: str work_experience: int @@ -44,6 +46,7 @@ class PredictionInput(BaseModel): time_unemployed: int need_mental_health_support_bool: str + class ClientBase(BaseModel): age: int = Field(ge=18, description="Age of client, must be 18 or older") gender: Gender = Field(description="Gender: 1 for male, 2 for female") @@ -54,9 +57,15 @@ class ClientBase(BaseModel): citizen_status: bool = Field(description="Client's citizenship status") level_of_schooling: int = Field(ge=1, le=14, description="Education level (1-14)") fluent_english: bool = Field(description="English fluency status") - reading_english_scale: int = Field(ge=0, le=10, description="English reading level (0-10)") - speaking_english_scale: int = Field(ge=0, le=10, description="English speaking level (0-10)") - writing_english_scale: int = Field(ge=0, le=10, description="English writing level (0-10)") + reading_english_scale: int = Field( + ge=0, le=10, description="English reading level (0-10)" + ) + speaking_english_scale: int = Field( + ge=0, le=10, description="English speaking level (0-10)" + ) + writing_english_scale: int = Field( + ge=0, le=10, description="English writing level (0-10)" + ) numeracy_scale: int = Field(ge=0, le=10, description="Numeracy skill level (0-10)") computer_scale: int = Field(ge=0, le=10, description="Computer skill level (0-10)") transportation_bool: bool = Field(description="Has transportation") @@ -68,7 +77,9 @@ class ClientBase(BaseModel): currently_employed: bool = Field(description="Current employment status") substance_use: bool = Field(description="Substance use status") time_unemployed: int = Field(ge=0, description="Time unemployed in months") - need_mental_health_support_bool: bool = Field(description="Needs mental health support") + need_mental_health_support_bool: bool = Field( + description="Needs mental health support" + ) class Config: json_schema_extra = { @@ -96,16 +107,18 @@ class Config: "currently_employed": False, "substance_use": False, "time_unemployed": 6, - "need_mental_health_support_bool": False + "need_mental_health_support_bool": False, } } + class ClientResponse(ClientBase): id: int class Config: from_attributes = True + class ClientUpdate(BaseModel): age: Optional[int] = Field(None, ge=18) gender: Optional[Gender] = None @@ -132,6 +145,7 @@ class ClientUpdate(BaseModel): time_unemployed: Optional[int] = Field(None, ge=0) need_mental_health_support_bool: Optional[bool] = None + class ServiceResponse(BaseModel): client_id: int user_id: int @@ -147,6 +161,7 @@ class ServiceResponse(BaseModel): class Config: from_attributes = True + class ServiceUpdate(BaseModel): employment_assistance: Optional[bool] = None life_stabilization: Optional[bool] = None @@ -157,6 +172,7 @@ class ServiceUpdate(BaseModel): enhanced_referrals: Optional[bool] = None success_rate: Optional[int] = Field(None, ge=0, le=100) + class ClientListResponse(BaseModel): clients: List[ClientResponse] total: int diff --git a/app/clients/service/client_service.py b/app/clients/service/client_service.py index 86c3ef4a..6af668df 100644 --- a/app/clients/service/client_service.py +++ b/app/clients/service/client_service.py @@ -4,11 +4,11 @@ """ from sqlalchemy.orm import Session -from sqlalchemy import and_ from fastapi import HTTPException, status -from typing import List, Optional, Dict, Any +from typing import Optional from app.models import Client, ClientCase, User -from app.clients.schema import ClientUpdate, ServiceUpdate, ServiceResponse +from app.clients.schema import ClientUpdate, ServiceUpdate + class ClientService: @staticmethod @@ -18,7 +18,7 @@ def get_client(db: Session, client_id: int): if not client: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Client with id {client_id} not found" + detail=f"Client with id {client_id} not found", ) return client @@ -31,14 +31,14 @@ def get_clients(db: Session, skip: int = 0, limit: int = 50): if skip < 0: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Skip value cannot be negative" + detail="Skip value cannot be negative", ) if limit < 1: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Limit must be greater than 0" + detail="Limit must be greater than 0", ) - + clients = db.query(Client).offset(skip).limit(limit).all() total = db.query(Client).count() return {"clients": clients, "total": total} @@ -69,27 +69,26 @@ def get_clients_by_criteria( attending_school: Optional[bool] = None, substance_use: Optional[bool] = None, time_unemployed: Optional[int] = None, - need_mental_health_support_bool: Optional[bool] = None + need_mental_health_support_bool: Optional[bool] = None, ): """Get clients filtered by any combination of criteria""" query = db.query(Client) - + if education_level is not None and not (1 <= education_level <= 14): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Education level must be between 1 and 14" + detail="Education level must be between 1 and 14", ) - + if age_min is not None and age_min < 18: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Minimum age must be at least 18" + detail="Minimum age must be at least 18", ) if gender is not None and gender not in [1, 2]: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Gender must be 1 or 2" + status_code=status.HTTP_400_BAD_REQUEST, detail="Gender must be 1 or 2" ) # Apply filters for non-None values @@ -116,7 +115,9 @@ def get_clients_by_criteria( if reading_english_scale is not None: query = query.filter(Client.reading_english_scale == reading_english_scale) if speaking_english_scale is not None: - query = query.filter(Client.speaking_english_scale == speaking_english_scale) + query = query.filter( + Client.speaking_english_scale == speaking_english_scale + ) if writing_english_scale is not None: query = query.filter(Client.writing_english_scale == writing_english_scale) if numeracy_scale is not None: @@ -140,47 +141,49 @@ def get_clients_by_criteria( if time_unemployed is not None: query = query.filter(Client.time_unemployed == time_unemployed) if need_mental_health_support_bool is not None: - query = query.filter(Client.need_mental_health_support_bool == need_mental_health_support_bool) + query = query.filter( + Client.need_mental_health_support_bool + == need_mental_health_support_bool + ) try: return query.all() except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error retrieving clients: {str(e)}" + detail=f"Error retrieving clients: {str(e)}", ) @staticmethod - def get_clients_by_services( - db: Session, - **service_filters: Optional[bool] - ): + def get_clients_by_services(db: Session, **service_filters: Optional[bool]): """ Get clients filtered by multiple service statuses. """ query = db.query(Client).join(ClientCase) - - for service_name, status in service_filters.items(): - if status is not None: + + for service_name, statu in service_filters.items(): + if statu is not None: filter_criteria = getattr(ClientCase, service_name) == status query = query.filter(filter_criteria) - + try: return query.all() except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error retrieving clients: {str(e)}" + detail=f"Error retrieving clients: {str(e)}", ) @staticmethod def get_client_services(db: Session, client_id: int): """Get all services for a specific client with case worker info""" - client_cases = db.query(ClientCase).filter(ClientCase.client_id == client_id).all() + client_cases = ( + db.query(ClientCase).filter(ClientCase.client_id == client_id).all() + ) if not client_cases: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"No services found for client with id {client_id}" + detail=f"No services found for client with id {client_id}", ) return client_cases @@ -190,12 +193,15 @@ def get_clients_by_success_rate(db: Session, min_rate: int = 70): if not (0 <= min_rate <= 100): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Success rate must be between 0 and 100" + detail="Success rate must be between 0 and 100", ) - - return db.query(Client).join(ClientCase).filter( - ClientCase.success_rate >= min_rate - ).all() + + return ( + db.query(Client) + .join(ClientCase) + .filter(ClientCase.success_rate >= min_rate) + .all() + ) @staticmethod def get_clients_by_case_worker(db: Session, case_worker_id: int): @@ -204,12 +210,15 @@ def get_clients_by_case_worker(db: Session, case_worker_id: int): if not case_worker: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Case worker with id {case_worker_id} not found" + detail=f"Case worker with id {case_worker_id} not found", ) - - return db.query(Client).join(ClientCase).filter( - ClientCase.user_id == case_worker_id - ).all() + + return ( + db.query(Client) + .join(ClientCase) + .filter(ClientCase.user_id == case_worker_id) + .all() + ) @staticmethod def update_client(db: Session, client_id: int, client_update: ClientUpdate): @@ -218,7 +227,7 @@ def update_client(db: Session, client_id: int, client_update: ClientUpdate): if not client: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Client with id {client_id} not found" + detail=f"Client with id {client_id} not found", ) update_data = client_update.dict(exclude_unset=True) @@ -233,27 +242,25 @@ def update_client(db: Session, client_id: int, client_update: ClientUpdate): db.rollback() raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to update client: {str(e)}" + detail=f"Failed to update client: {str(e)}", ) - + @staticmethod def update_client_services( - db: Session, - client_id: int, - user_id: int, - service_update: ServiceUpdate + db: Session, client_id: int, user_id: int, service_update: ServiceUpdate ): """Update a client's services and outcomes for a specific case worker""" - client_case = db.query(ClientCase).filter( - ClientCase.client_id == client_id, - ClientCase.user_id == user_id - ).first() - + client_case = ( + db.query(ClientCase) + .filter(ClientCase.client_id == client_id, ClientCase.user_id == user_id) + .first() + ) + if not client_case: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"No case found for client {client_id} with case worker {user_id}. " - f"Cannot update services for a non-existent case assignment." + f"Cannot update services for a non-existent case assignment.", ) update_data = service_update.dict(exclude_unset=True) @@ -268,22 +275,18 @@ def update_client_services( db.rollback() raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to update client services: {str(e)}" + detail=f"Failed to update client services: {str(e)}", ) - + @staticmethod - def create_case_assignment( - db: Session, - client_id: int, - case_worker_id: int - ): + def create_case_assignment(db: Session, client_id: int, case_worker_id: int): """Create a new case assignment""" # Check if client exists client = db.query(Client).filter(Client.id == client_id).first() if not client: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Client with id {client_id} not found" + detail=f"Client with id {client_id} not found", ) # Check if case worker exists @@ -291,20 +294,23 @@ def create_case_assignment( if not case_worker: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Case worker with id {case_worker_id} not found" + detail=f"Case worker with id {case_worker_id} not found", ) # Check if assignment already exists - existing_case = db.query(ClientCase).filter( - ClientCase.client_id == client_id, - ClientCase.user_id == case_worker_id - ).first() - + existing_case = ( + db.query(ClientCase) + .filter( + ClientCase.client_id == client_id, ClientCase.user_id == case_worker_id + ) + .first() + ) + if existing_case: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Client {client_id} already has a case assigned to case worker {case_worker_id}" - ) + detail=f"Client {client_id} already has a case assigned to case worker {case_worker_id}", + ) try: # Create new case assignment with default service values @@ -318,7 +324,7 @@ def create_case_assignment( employment_related_financial_supports=False, employer_financial_supports=False, enhanced_referrals=False, - success_rate=0 + success_rate=0, ) db.add(new_case) db.commit() @@ -329,9 +335,9 @@ def create_case_assignment( db.rollback() raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to create case assignment: {str(e)}" + detail=f"Failed to create case assignment: {str(e)}", ) - + @staticmethod def delete_client(db: Session, client_id: int): """Delete a client and their associated records""" @@ -340,22 +346,20 @@ def delete_client(db: Session, client_id: int): if not client: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Client with id {client_id} not found" + detail=f"Client with id {client_id} not found", ) try: # Delete associated client_cases - db.query(ClientCase).filter( - ClientCase.client_id == client_id - ).delete() - + db.query(ClientCase).filter(ClientCase.client_id == client_id).delete() + # Delete the client db.delete(client) db.commit() - + except Exception as e: db.rollback() raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to delete client: {str(e)}" + detail=f"Failed to delete client: {str(e)}", ) diff --git a/app/clients/service/logic.py b/app/clients/service/logic.py index c25b4217..f57ce6cb 100644 --- a/app/clients/service/logic.py +++ b/app/clients/service/logic.py @@ -5,7 +5,8 @@ # Standard library imports import os -#import json + +# import json from itertools import product # Third-party imports @@ -14,21 +15,22 @@ # Constants COLUMN_INTERVENTIONS = [ - 'Life Stabilization', - 'General Employment Assistance Services', - 'Retention Services', - 'Specialized Services', - 'Employment-Related Financial Supports for Job Seekers and Employers', - 'Employer Financial Supports', - 'Enhanced Referrals for Skills Development' + "Life Stabilization", + "General Employment Assistance Services", + "Retention Services", + "Specialized Services", + "Employment-Related Financial Supports for Job Seekers and Employers", + "Employer Financial Supports", + "Enhanced Referrals for Skills Development", ] # Load model CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) -MODEL_PATH = os.path.join(CURRENT_DIR, 'model.pkl') +MODEL_PATH = os.path.join(CURRENT_DIR, "model.pkl") with open(MODEL_PATH, "rb") as model_file: MODEL = pickle.load(model_file) + def clean_input_data(input_data): """ Clean and transform input data into model-compatible format. @@ -40,13 +42,30 @@ def clean_input_data(input_data): list: Cleaned and formatted data ready for model input """ columns = [ - "age", "gender", "work_experience", "canada_workex", "dep_num", - "canada_born", "citizen_status", "level_of_schooling", "fluent_english", - "reading_english_scale", "speaking_english_scale", "writing_english_scale", - "numeracy_scale", "computer_scale", "transportation_bool", "caregiver_bool", - "housing", "income_source", "felony_bool", "attending_school", - "currently_employed", "substance_use", "time_unemployed", - "need_mental_health_support_bool" + "age", + "gender", + "work_experience", + "canada_workex", + "dep_num", + "canada_born", + "citizen_status", + "level_of_schooling", + "fluent_english", + "reading_english_scale", + "speaking_english_scale", + "writing_english_scale", + "numeracy_scale", + "computer_scale", + "transportation_bool", + "caregiver_bool", + "housing", + "income_source", + "felony_bool", + "attending_school", + "currently_employed", + "substance_use", + "time_unemployed", + "need_mental_health_support_bool", ] demographics = {key: input_data[key] for key in columns} output = [] @@ -57,6 +76,7 @@ def clean_input_data(input_data): output.append(value) return output + def convert_text(text_data: str): """ Convert text answers from front end into numerical values. @@ -68,33 +88,47 @@ def convert_text(text_data: str): int: Converted numerical value """ categorical_mappings = [ + {"": 0, "true": 1, "false": 0, "no": 0, "yes": 1, "No": 0, "Yes": 1}, { - "": 0, "true": 1, "false": 0, "no": 0, "yes": 1, - "No": 0, "Yes": 1 + "Grade 0-8": 1, + "Grade 9": 2, + "Grade 10": 3, + "Grade 11": 4, + "Grade 12 or equivalent": 5, + "OAC or Grade 13": 6, + "Some college": 7, + "Some university": 8, + "Some apprenticeship": 9, + "Certificate of Apprenticeship": 10, + "Journeyperson": 11, + "Certificate/Diploma": 12, + "Bachelor's degree": 13, + "Post graduate": 14, }, { - "Grade 0-8": 1, "Grade 9": 2, "Grade 10": 3, "Grade 11": 4, - "Grade 12 or equivalent": 5, "OAC or Grade 13": 6, - "Some college": 7, "Some university": 8, "Some apprenticeship": 9, - "Certificate of Apprenticeship": 10, "Journeyperson": 11, - "Certificate/Diploma": 12, "Bachelor's degree": 13, - "Post graduate": 14 + "Renting-private": 1, + "Renting-subsidized": 2, + "Boarding or lodging": 3, + "Homeowner": 4, + "Living with family/friend": 5, + "Institution": 6, + "Temporary second residence": 7, + "Band-owned home": 8, + "Homeless or transient": 9, + "Emergency hostel": 10, }, { - "Renting-private": 1, "Renting-subsidized": 2, - "Boarding or lodging": 3, "Homeowner": 4, - "Living with family/friend": 5, "Institution": 6, - "Temporary second residence": 7, "Band-owned home": 8, - "Homeless or transient": 9, "Emergency hostel": 10 - }, - { - "No Source of Income": 1, "Employment Insurance": 2, + "No Source of Income": 1, + "Employment Insurance": 2, "Workplace Safety and Insurance Board": 3, "Ontario Works applied or receiving": 4, "Ontario Disability Support Program applied or receiving": 5, - "Dependent of someone receiving OW or ODSP": 6, "Crown Ward": 7, - "Employment": 8, "Self-Employment": 9, "Other (specify)": 10 - } + "Dependent of someone receiving OW or ODSP": 6, + "Crown Ward": 7, + "Employment": 8, + "Self-Employment": 9, + "Other (specify)": 10, + }, ] for category in categorical_mappings: if text_data in category: @@ -102,6 +136,7 @@ def convert_text(text_data: str): return int(text_data) if text_data.isnumeric() else text_data + def create_matrix(row_data): """ Create matrix of all possible intervention combinations. @@ -116,6 +151,7 @@ def create_matrix(row_data): perms = intervention_permutations(7) return np.concatenate((np.array(data), np.array(perms)), axis=1) + def intervention_permutations(num): """ Generate all possible intervention combinations. @@ -128,6 +164,7 @@ def intervention_permutations(num): """ return np.array(list(product([0, 1], repeat=num))) + def get_baseline_row(row_data): """ Create baseline row with no interventions. @@ -141,6 +178,7 @@ def get_baseline_row(row_data): base_interventions = np.zeros(7) return np.concatenate((np.array(row_data), base_interventions)) + def intervention_row_to_names(row_data): """ Convert intervention row to list of intervention names. @@ -153,6 +191,7 @@ def intervention_row_to_names(row_data): """ return [COLUMN_INTERVENTIONS[i] for i, value in enumerate(row_data) if value == 1] + def process_results(baseline_pred, results_matrix): """ Process model results into structured output. @@ -165,13 +204,10 @@ def process_results(baseline_pred, results_matrix): dict: Processed results with baseline and interventions """ result_list = [ - (row[-1], intervention_row_to_names(row[:-1])) - for row in results_matrix + (row[-1], intervention_row_to_names(row[:-1])) for row in results_matrix ] - return { - "baseline": baseline_pred[-1], - "interventions": result_list - } + return {"baseline": baseline_pred[-1], "interventions": result_list} + def interpret_and_calculate(input_data): """ @@ -188,25 +224,41 @@ def interpret_and_calculate(input_data): intervention_rows = create_matrix(raw_data) baseline_prediction = MODEL.predict(baseline_row) intervention_predictions = MODEL.predict(intervention_rows).reshape(-1, 1) - result_matrix = np.concatenate((intervention_rows, intervention_predictions), axis=1) + result_matrix = np.concatenate( + (intervention_rows, intervention_predictions), axis=1 + ) result_order = result_matrix[:, -1].argsort() result_matrix = result_matrix[result_order] top_results = result_matrix[-3:, -8:] return process_results(baseline_prediction, top_results) + if __name__ == "__main__": test_data = { - "age": "23", "gender": "1", "work_experience": "1", - "canada_workex": "1", "dep_num": "0", "canada_born": "1", - "citizen_status": "2", "level_of_schooling": "2", - "fluent_english": "3", "reading_english_scale": "2", - "speaking_english_scale": "2", "writing_english_scale": "3", - "numeracy_scale": "2", "computer_scale": "3", - "transportation_bool": "2", "caregiver_bool": "1", - "housing": "1", "income_source": "5", "felony_bool": "1", - "attending_school": "0", "currently_employed": "1", - "substance_use": "1", "time_unemployed": "1", - "need_mental_health_support_bool": "1" + "age": "23", + "gender": "1", + "work_experience": "1", + "canada_workex": "1", + "dep_num": "0", + "canada_born": "1", + "citizen_status": "2", + "level_of_schooling": "2", + "fluent_english": "3", + "reading_english_scale": "2", + "speaking_english_scale": "2", + "writing_english_scale": "3", + "numeracy_scale": "2", + "computer_scale": "3", + "transportation_bool": "2", + "caregiver_bool": "1", + "housing": "1", + "income_source": "5", + "felony_bool": "1", + "attending_school": "0", + "currently_employed": "1", + "substance_use": "1", + "time_unemployed": "1", + "need_mental_health_support_bool": "1", } results = interpret_and_calculate(test_data) print(results) diff --git a/app/clients/service/model.py b/app/clients/service/model.py index b2406370..9f906788 100644 --- a/app/clients/service/model.py +++ b/app/clients/service/model.py @@ -12,73 +12,72 @@ from sklearn.model_selection import train_test_split from sklearn.ensemble import RandomForestRegressor + def prepare_models(): """ Prepare and train the Random Forest model using the dataset. - + Returns: RandomForestRegressor: Trained model for predicting success rates """ # Load dataset - data = pd.read_csv('data_commontool.csv') + data = pd.read_csv("data_commontool.csv") # Define feature columns feature_columns = [ - 'age', # Client's age - 'gender', # Client's gender (bool) - 'work_experience', # Years of work experience - 'canada_workex', # Years of work experience in Canada - 'dep_num', # Number of dependents - 'canada_born', # Born in Canada - 'citizen_status', # Citizenship status - 'level_of_schooling', # Highest level achieved (1-14) - 'fluent_english', # English fluency scale (1-10) - 'reading_english_scale', # Reading ability scale (1-10) - 'speaking_english_scale',# Speaking ability scale (1-10) - 'writing_english_scale', # Writing ability scale (1-10) - 'numeracy_scale', # Numeracy ability scale (1-10) - 'computer_scale', # Computer proficiency scale (1-10) - 'transportation_bool', # Needs transportation support (bool) - 'caregiver_bool', # Is primary caregiver (bool) - 'housing', # Housing situation (1-10) - 'income_source', # Source of income (1-10) - 'felony_bool', # Has a felony (bool) - 'attending_school', # Currently a student (bool) - 'currently_employed', # Currently employed (bool) - 'substance_use', # Substance use disorder (bool) - 'time_unemployed', # Years unemployed - 'need_mental_health_support_bool' # Needs mental health support (bool) + "age", # Client's age + "gender", # Client's gender (bool) + "work_experience", # Years of work experience + "canada_workex", # Years of work experience in Canada + "dep_num", # Number of dependents + "canada_born", # Born in Canada + "citizen_status", # Citizenship status + "level_of_schooling", # Highest level achieved (1-14) + "fluent_english", # English fluency scale (1-10) + "reading_english_scale", # Reading ability scale (1-10) + "speaking_english_scale", # Speaking ability scale (1-10) + "writing_english_scale", # Writing ability scale (1-10) + "numeracy_scale", # Numeracy ability scale (1-10) + "computer_scale", # Computer proficiency scale (1-10) + "transportation_bool", # Needs transportation support (bool) + "caregiver_bool", # Is primary caregiver (bool) + "housing", # Housing situation (1-10) + "income_source", # Source of income (1-10) + "felony_bool", # Has a felony (bool) + "attending_school", # Currently a student (bool) + "currently_employed", # Currently employed (bool) + "substance_use", # Substance use disorder (bool) + "time_unemployed", # Years unemployed + "need_mental_health_support_bool", # Needs mental health support (bool) ] # Define intervention columns intervention_columns = [ - 'employment_assistance', - 'life_stabilization', - 'retention_services', - 'specialized_services', - 'employment_related_financial_supports', - 'employer_financial_supports', - 'enhanced_referrals' + "employment_assistance", + "life_stabilization", + "retention_services", + "specialized_services", + "employment_related_financial_supports", + "employer_financial_supports", + "enhanced_referrals", ] # Combine all feature columns all_features = feature_columns + intervention_columns # Prepare training data features = np.array(data[all_features]) # Changed from X to features - targets = np.array(data['success_rate']) # Changed from y to targets + targets = np.array(data["success_rate"]) # Changed from y to targets # Split the dataset features_train, _, targets_train, _ = train_test_split( # Removed unused variables - features, - targets, - test_size=0.2, - random_state=42 + features, targets, test_size=0.2, random_state=42 ) # Initialize and train the model model = RandomForestRegressor(n_estimators=100, random_state=42) model.fit(features_train, targets_train) return model + def save_model(model, filename="model.pkl"): """ Save the trained model to a file. - + Args: model: Trained model to save filename (str): Name of the file to save the model to @@ -86,19 +85,21 @@ def save_model(model, filename="model.pkl"): with open(filename, "wb") as model_file: pickle.dump(model, model_file) + def load_model(filename="model.pkl"): """ Load a trained model from a file. - + Args: filename (str): Name of the file to load the model from - + Returns: The loaded model """ with open(filename, "rb") as model_file: return pickle.load(model_file) + def main(): """Main function to train and save the model.""" print("Starting model training...") @@ -106,5 +107,6 @@ def main(): save_model(model) print("Model training completed and saved successfully.") + if __name__ == "__main__": main() diff --git a/app/database.py b/app/database.py index 3a489f54..6a800078 100644 --- a/app/database.py +++ b/app/database.py @@ -7,22 +7,25 @@ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker -#Here is where the database is located -SQLALCHEMY_DATABASE_URL = "sqlite:///./sql_app.db" +# Here is where the database is located +SQLALCHEMY_DATABASE_URL = "sqlite:///./sql_app.db" -#Open up a connection so that we are able to use the database -engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}) +# Open up a connection so that we are able to use the database +engine = create_engine( + SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False} +) -#Bind the engine just created +# Bind the engine just created SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) -#Create an object of our database so as to control the database +# Create an object of our database so as to control the database Base = declarative_base() + def get_db(): """ Creates a database session and ensures it's closed after use. - + Yields: Session: SQLAlchemy database session """ diff --git a/app/main.py b/app/main.py index a8e8fa7f..5038db96 100644 --- a/app/main.py +++ b/app/main.py @@ -4,18 +4,25 @@ Handles database initialization and CORS middleware configuration. """ +# Related third-party imports from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware + +# Local application/library specific imports from app import models from app.database import engine from app.clients.router import router as clients_router from app.auth.router import router as auth_router -from fastapi.middleware.cors import CORSMiddleware # Initialize database tables models.Base.metadata.create_all(bind=engine) # Create FastAPI application -app = FastAPI(title="Case Management API", description="API for managing client cases", version="1.0.0") +app = FastAPI( + title="Case Management API", + description="API for managing client cases", + version="1.0.0", +) # Include routers app.include_router(auth_router) @@ -24,8 +31,8 @@ # Configure CORS middleware app.add_middleware( CORSMiddleware, - allow_origins=["*"], # Allows all origins - allow_methods=["*"], # Allows all methods - allow_headers=["*"], # Allows all headers + allow_origins=["*"], # Allows all origins + allow_methods=["*"], # Allows all methods + allow_headers=["*"], # Allows all headers allow_credentials=True, ) diff --git a/app/models.py b/app/models.py index df778348..dbfb891b 100644 --- a/app/models.py +++ b/app/models.py @@ -3,17 +3,37 @@ Contains the Client model for storing client information in the database. """ -from app.database import Base -from sqlalchemy import Column, Integer, String, Boolean, ForeignKey, CheckConstraint, Enum -from sqlalchemy.orm import relationship import enum +from sqlalchemy import ( + Column, + Integer, + String, + Boolean, + ForeignKey, + Enum, + CheckConstraint, +) +from sqlalchemy.orm import relationship + +from app.database import Base + + class UserRole(str, enum.Enum): + """ + User Role class store two roles + """ + admin = "admin" case_worker = "case_worker" class User(Base): + """ + Represents a User in the database. + Stores user details including authentication and roles. + """ + __tablename__ = "users" id = Column(Integer, primary_key=True, autoincrement=True) @@ -24,46 +44,71 @@ class User(Base): cases = relationship("ClientCase", back_populates="user") + class Client(Base): """ - Client model representing client data in the database. + Represents a Client in the database. + Stores personal and case-related information for each client. """ + __tablename__ = "clients" id = Column(Integer, primary_key=True, autoincrement=True) - age = Column(Integer, CheckConstraint('age >= 18')) + age = Column(Integer, CheckConstraint("age >= 18")) gender = Column(Integer, CheckConstraint("gender = 1 OR gender = 2")) - work_experience = Column(Integer, CheckConstraint('work_experience >= 0')) - canada_workex = Column(Integer, CheckConstraint('canada_workex >= 0')) - dep_num = Column(Integer, CheckConstraint('dep_num >= 0')) + work_experience = Column(Integer, CheckConstraint("work_experience >= 0")) + canada_workex = Column(Integer, CheckConstraint("canada_workex >= 0")) + dep_num = Column(Integer, CheckConstraint("dep_num >= 0")) canada_born = Column(Boolean) citizen_status = Column(Boolean) - level_of_schooling = Column(Integer, CheckConstraint('level_of_schooling >= 1 AND level_of_schooling <= 14')) + level_of_schooling = Column( + Integer, CheckConstraint("level_of_schooling >= 1 AND level_of_schooling <= 14") + ) fluent_english = Column(Boolean) - reading_english_scale = Column(Integer, CheckConstraint('reading_english_scale >= 0 AND reading_english_scale <= 10')) - speaking_english_scale = Column(Integer, CheckConstraint('speaking_english_scale >= 0 AND speaking_english_scale <= 10')) - writing_english_scale = Column(Integer, CheckConstraint('writing_english_scale >= 0 AND writing_english_scale <= 10')) - numeracy_scale = Column(Integer, CheckConstraint('numeracy_scale >= 0 AND numeracy_scale <= 10')) - computer_scale = Column(Integer, CheckConstraint('computer_scale >= 0 AND computer_scale <= 10')) + reading_english_scale = Column( + Integer, + CheckConstraint("reading_english_scale >= 0 AND reading_english_scale <= 10"), + ) + speaking_english_scale = Column( + Integer, + CheckConstraint("speaking_english_scale >= 0 AND speaking_english_scale <= 10"), + ) + writing_english_scale = Column( + Integer, + CheckConstraint("writing_english_scale >= 0 AND writing_english_scale <= 10"), + ) + numeracy_scale = Column( + Integer, CheckConstraint("numeracy_scale >= 0 AND numeracy_scale <= 10") + ) + computer_scale = Column( + Integer, CheckConstraint("computer_scale >= 0 AND computer_scale <= 10") + ) transportation_bool = Column(Boolean) caregiver_bool = Column(Boolean) - housing = Column(Integer, CheckConstraint('housing >= 1 AND housing <= 10')) - income_source = Column(Integer, CheckConstraint('income_source >= 1 AND income_source <= 11')) + housing = Column(Integer, CheckConstraint("housing >= 1 AND housing <= 10")) + income_source = Column( + Integer, CheckConstraint("income_source >= 1 AND income_source <= 11") + ) felony_bool = Column(Boolean) attending_school = Column(Boolean) currently_employed = Column(Boolean) substance_use = Column(Boolean) - time_unemployed = Column(Integer, CheckConstraint('time_unemployed >= 0')) + time_unemployed = Column(Integer, CheckConstraint("time_unemployed >= 0")) need_mental_health_support_bool = Column(Boolean) cases = relationship("ClientCase", back_populates="client") + class ClientCase(Base): + """ + ClientCase class + """ + __tablename__ = "client_cases" client_id = Column(Integer, ForeignKey("clients.id"), primary_key=True) user_id = Column(Integer, ForeignKey("users.id"), primary_key=True) - + employment_assistance = Column(Boolean) life_stabilization = Column(Boolean) retention_services = Column(Boolean) @@ -71,7 +116,9 @@ class ClientCase(Base): employment_related_financial_supports = Column(Boolean) employer_financial_supports = Column(Boolean) enhanced_referrals = Column(Boolean) - success_rate = Column(Integer, CheckConstraint('success_rate >= 0 AND success_rate <= 100')) + success_rate = Column( + Integer, CheckConstraint("success_rate >= 0 AND success_rate <= 100") + ) client = relationship("Client", back_populates="cases") user = relationship("User", back_populates="cases") diff --git a/code_quality_README.md b/code_quality_README.md index 7d1ceab8..339f7699 100644 --- a/code_quality_README.md +++ b/code_quality_README.md @@ -14,7 +14,15 @@ Install development tools with: ## Running Analysis Run these commands to analyze code quality: -1. `pylint app/ > pylint_report.txt` 2. `flake8 app/ > flake8_report.txt` 3. `black --check --diff app/ > black_report.txt` 4. `mypy app/ > mypy_report.txt` + +## Installation of Dependency on Main Code: +`pip install -r ./app/requirements.txt` + +## To check the code using pre-commit: +``` +pre-commit autoupdate +pre-commit run --all-files +``` \ No newline at end of file diff --git a/initialize_data.py b/initialize_data.py index 1444bf41..e9e5d29a 100644 --- a/initialize_data.py +++ b/initialize_data.py @@ -1,9 +1,9 @@ import pandas as pd -from sqlalchemy.orm import Session from app.database import SessionLocal from app.models import Client, User, ClientCase, UserRole from app.auth.router import get_password_hash + def initialize_database(): print("Starting database initialization...") db = SessionLocal() @@ -15,7 +15,7 @@ def initialize_database(): username="admin", email="admin@example.com", hashed_password=get_password_hash("admin123"), - role=UserRole.admin + role=UserRole.admin, ) db.add(admin_user) db.commit() @@ -30,7 +30,7 @@ def initialize_database(): username="case_worker1", email="caseworker1@example.com", hashed_password=get_password_hash("worker123"), - role=UserRole.case_worker + role=UserRole.case_worker, ) db.add(case_worker) db.commit() @@ -40,46 +40,59 @@ def initialize_database(): # Load CSV data print("Loading CSV data...") - df = pd.read_csv('app/clients/service/data_commontool.csv') - + df = pd.read_csv("app/clients/service/data_commontool.csv") + # Convert data types integer_columns = [ - 'age', 'gender', 'work_experience', 'canada_workex', 'dep_num', - 'level_of_schooling', 'reading_english_scale', 'speaking_english_scale', - 'writing_english_scale', 'numeracy_scale', 'computer_scale', - 'housing', 'income_source', 'time_unemployed', 'success_rate' + "age", + "gender", + "work_experience", + "canada_workex", + "dep_num", + "level_of_schooling", + "reading_english_scale", + "speaking_english_scale", + "writing_english_scale", + "numeracy_scale", + "computer_scale", + "housing", + "income_source", + "time_unemployed", + "success_rate", ] for col in integer_columns: - df[col] = pd.to_numeric(df[col], errors='raise') + df[col] = pd.to_numeric(df[col], errors="raise") # Process each row in CSV for index, row in df.iterrows(): # Create client client = Client( - age=int(row['age']), - gender=int(row['gender']), - work_experience=int(row['work_experience']), - canada_workex=int(row['canada_workex']), - dep_num=int(row['dep_num']), - canada_born=bool(row['canada_born']), - citizen_status=bool(row['citizen_status']), - level_of_schooling=int(row['level_of_schooling']), - fluent_english=bool(row['fluent_english']), - reading_english_scale=int(row['reading_english_scale']), - speaking_english_scale=int(row['speaking_english_scale']), - writing_english_scale=int(row['writing_english_scale']), - numeracy_scale=int(row['numeracy_scale']), - computer_scale=int(row['computer_scale']), - transportation_bool=bool(row['transportation_bool']), - caregiver_bool=bool(row['caregiver_bool']), - housing=int(row['housing']), - income_source=int(row['income_source']), - felony_bool=bool(row['felony_bool']), - attending_school=bool(row['attending_school']), - currently_employed=bool(row['currently_employed']), - substance_use=bool(row['substance_use']), - time_unemployed=int(row['time_unemployed']), - need_mental_health_support_bool=bool(row['need_mental_health_support_bool']) + age=int(row["age"]), + gender=int(row["gender"]), + work_experience=int(row["work_experience"]), + canada_workex=int(row["canada_workex"]), + dep_num=int(row["dep_num"]), + canada_born=bool(row["canada_born"]), + citizen_status=bool(row["citizen_status"]), + level_of_schooling=int(row["level_of_schooling"]), + fluent_english=bool(row["fluent_english"]), + reading_english_scale=int(row["reading_english_scale"]), + speaking_english_scale=int(row["speaking_english_scale"]), + writing_english_scale=int(row["writing_english_scale"]), + numeracy_scale=int(row["numeracy_scale"]), + computer_scale=int(row["computer_scale"]), + transportation_bool=bool(row["transportation_bool"]), + caregiver_bool=bool(row["caregiver_bool"]), + housing=int(row["housing"]), + income_source=int(row["income_source"]), + felony_bool=bool(row["felony_bool"]), + attending_school=bool(row["attending_school"]), + currently_employed=bool(row["currently_employed"]), + substance_use=bool(row["substance_use"]), + time_unemployed=int(row["time_unemployed"]), + need_mental_health_support_bool=bool( + row["need_mental_health_support_bool"] + ), ) db.add(client) db.commit() @@ -88,14 +101,16 @@ def initialize_database(): client_case = ClientCase( client_id=client.id, user_id=admin_user.id, # Assign to admin - employment_assistance=bool(row['employment_assistance']), - life_stabilization=bool(row['life_stabilization']), - retention_services=bool(row['retention_services']), - specialized_services=bool(row['specialized_services']), - employment_related_financial_supports=bool(row['employment_related_financial_supports']), - employer_financial_supports=bool(row['employer_financial_supports']), - enhanced_referrals=bool(row['enhanced_referrals']), - success_rate=int(row['success_rate']) + employment_assistance=bool(row["employment_assistance"]), + life_stabilization=bool(row["life_stabilization"]), + retention_services=bool(row["retention_services"]), + specialized_services=bool(row["specialized_services"]), + employment_related_financial_supports=bool( + row["employment_related_financial_supports"] + ), + employer_financial_supports=bool(row["employer_financial_supports"]), + enhanced_referrals=bool(row["enhanced_referrals"]), + success_rate=int(row["success_rate"]), ) db.add(client_case) db.commit() @@ -108,5 +123,6 @@ def initialize_database(): finally: db.close() + if __name__ == "__main__": - initialize_database() \ No newline at end of file + initialize_database() diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 00000000..976ba029 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,2 @@ +[mypy] +ignore_missing_imports = True diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..973e45cf --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,40 @@ +[tool.black] +line-length = 88 +include = '\.pyi?$' +exclude = ''' +/( + \.git + | \.venv + | \.env + | \.eggs + | \.mypy_cache + | \.pytest_cache +)/ +''' + +[build-system] +requires = ["setuptools>=42", "wheel"] +build-backend = "setuptools.build_meta" + +[tool.poetry] +name = "common-assessment-tool" +version = "0.1.0" +description = "A FastAPI application for managing assessments." +license = "MIT" + +[tool.poetry.dependencies] +python = "^3.8" +fastapi = "^0.68.0" +uvicorn = "^0.15.0" +sqlalchemy = "^1.4.22" +psycopg2-binary = "^2.9.1" +python-dotenv = "^0.19.0" + +[tool.poetry.dev-dependencies] +pytest = "^6.2" +black = "^21.7b0" +flake8 = "^3.9" +mypy = "^0.910" + + + diff --git a/requirements-dev.txt b/requirements-dev.txt index 51b8e8ab..2bbc6f11 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,3 +2,4 @@ pylint==2.17.5 flake8==6.1.0 black==23.7.0 mypy==1.5.1 +python-jose>=3.3.0 \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index aa30d094..b6d68586 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,14 +9,17 @@ # Create test database SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db" -engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}) +engine = create_engine( + SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False} +) TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + @pytest.fixture def test_db(): # Create tables Base.metadata.create_all(bind=engine) - + db = TestingSessionLocal() try: # Create test admin user @@ -24,7 +27,7 @@ def test_db(): username="testadmin", email="testadmin@example.com", hashed_password=get_password_hash("testpass123"), - role=UserRole.admin + role=UserRole.admin, ) db.add(admin_user) @@ -33,10 +36,10 @@ def test_db(): username="testworker", email="worker@example.com", hashed_password=get_password_hash("workerpass123"), - role=UserRole.case_worker + role=UserRole.case_worker, ) db.add(case_worker) - + # Create test clients client1 = Client( age=25, @@ -62,9 +65,9 @@ def test_db(): currently_employed=False, substance_use=False, time_unemployed=6, - need_mental_health_support_bool=False + need_mental_health_support_bool=False, ) - + client2 = Client( age=30, gender=2, @@ -89,13 +92,13 @@ def test_db(): currently_employed=True, substance_use=False, time_unemployed=0, - need_mental_health_support_bool=False + need_mental_health_support_bool=False, ) - + db.add(client1) db.add(client2) db.commit() - + # Create test client cases client_case1 = ClientCase( client_id=1, @@ -107,9 +110,9 @@ def test_db(): employment_related_financial_supports=True, employer_financial_supports=False, enhanced_referrals=True, - success_rate=75 + success_rate=75, ) - + client_case2 = ClientCase( client_id=2, user_id=2, # Assigned to case worker @@ -120,18 +123,19 @@ def test_db(): employment_related_financial_supports=False, employer_financial_supports=True, enhanced_referrals=False, - success_rate=85 + success_rate=85, ) - + db.add(client_case1) db.add(client_case2) db.commit() - + yield db finally: db.close() Base.metadata.drop_all(bind=engine) + @pytest.fixture def client(test_db): def override_get_db(): @@ -139,32 +143,33 @@ def override_get_db(): yield test_db finally: test_db.close() - + app.dependency_overrides[get_db] = override_get_db yield TestClient(app) app.dependency_overrides.clear() + @pytest.fixture def admin_token(client): response = client.post( - "/auth/token", - data={"username": "testadmin", "password": "testpass123"} + "/auth/token", data={"username": "testadmin", "password": "testpass123"} ) return response.json()["access_token"] + @pytest.fixture def case_worker_token(client): response = client.post( - "/auth/token", - data={"username": "testworker", "password": "workerpass123"} + "/auth/token", data={"username": "testworker", "password": "workerpass123"} ) return response.json()["access_token"] + @pytest.fixture def admin_headers(admin_token): return {"Authorization": f"Bearer {admin_token}"} + @pytest.fixture def case_worker_headers(case_worker_token): return {"Authorization": f"Bearer {case_worker_token}"} - \ No newline at end of file diff --git a/tests/test_auth.py b/tests/test_auth.py index 1d4692e4..304febdf 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -1,123 +1,113 @@ import pytest from fastapi import status + def test_create_user_success(client, admin_headers): """Test successful user creation by admin""" user_data = { "username": "newuser", "email": "new@test.com", "password": "testpass123", - "role": "case_worker" + "role": "case_worker", } - response = client.post( - "/auth/users", - headers=admin_headers, - json=user_data - ) + response = client.post("/auth/users", headers=admin_headers, json=user_data) assert response.status_code == status.HTTP_200_OK data = response.json() assert data["username"] == "newuser" assert data["role"] == "case_worker" assert "password" not in data # Password should not be in response + def test_create_user_duplicate_username(client, admin_headers): """Test creating user with existing username""" user_data = { "username": "testadmin", # This username exists in test database "email": "another@test.com", "password": "testpass123", - "role": "case_worker" + "role": "case_worker", } - response = client.post( - "/auth/users", - headers=admin_headers, - json=user_data - ) + response = client.post("/auth/users", headers=admin_headers, json=user_data) assert response.status_code == status.HTTP_400_BAD_REQUEST assert "Username already registered" in response.json()["detail"] + def test_create_user_duplicate_email(client, admin_headers): """Test creating user with existing email""" user_data = { "username": "uniqueuser", "email": "testadmin@example.com", # This email exists in test database "password": "testpass123", - "role": "case_worker" + "role": "case_worker", } - response = client.post( - "/auth/users", - headers=admin_headers, - json=user_data - ) + response = client.post("/auth/users", headers=admin_headers, json=user_data) assert response.status_code == status.HTTP_400_BAD_REQUEST assert "Email already registered" in response.json()["detail"] + def test_create_user_invalid_role(client, admin_headers): """Test creating user with invalid role""" user_data = { "username": "newuser", "email": "new@test.com", "password": "testpass123", - "role": "invalid_role" # Invalid role + "role": "invalid_role", # Invalid role } - response = client.post( - "/auth/users", - headers=admin_headers, - json=user_data - ) + response = client.post("/auth/users", headers=admin_headers, json=user_data) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + def test_create_user_unauthorized(client): """Test user creation without authentication""" user_data = { "username": "newuser", "email": "new@test.com", "password": "testpass123", - "role": "case_worker" + "role": "case_worker", } response = client.post("/auth/users", json=user_data) assert response.status_code == status.HTTP_401_UNAUTHORIZED + def test_login_success_admin(client): """Test successful login for admin""" response = client.post( - "/auth/token", - data={"username": "testadmin", "password": "testpass123"} + "/auth/token", data={"username": "testadmin", "password": "testpass123"} ) assert response.status_code == status.HTTP_200_OK data = response.json() assert "access_token" in data assert data["token_type"] == "bearer" + def test_login_success_case_worker(client): """Test successful login for case worker""" response = client.post( - "/auth/token", - data={"username": "testworker", "password": "workerpass123"} + "/auth/token", data={"username": "testworker", "password": "workerpass123"} ) assert response.status_code == status.HTTP_200_OK data = response.json() assert "access_token" in data assert data["token_type"] == "bearer" + def test_login_wrong_password(client): """Test login with incorrect password""" response = client.post( - "/auth/token", - data={"username": "testadmin", "password": "wrongpassword"} + "/auth/token", data={"username": "testadmin", "password": "wrongpassword"} ) assert response.status_code == status.HTTP_401_UNAUTHORIZED assert "Incorrect username or password" in response.json()["detail"] + def test_login_nonexistent_user(client): """Test login with non-existent username""" response = client.post( - "/auth/token", - data={"username": "nonexistent", "password": "testpass123"} + "/auth/token", data={"username": "nonexistent", "password": "testpass123"} ) assert response.status_code == status.HTTP_401_UNAUTHORIZED assert "Incorrect username or password" in response.json()["detail"] + def test_invalid_token(client): """Test using invalid token""" headers = {"Authorization": "Bearer invalid_token_here"} @@ -125,12 +115,14 @@ def test_invalid_token(client): assert response.status_code == status.HTTP_401_UNAUTHORIZED assert "Could not validate credentials" in response.json()["detail"] + def test_missing_token(client): """Test accessing protected endpoint without token""" response = client.get("/clients/") assert response.status_code == status.HTTP_401_UNAUTHORIZED assert "Not authenticated" in response.json()["detail"] + def test_token_user_deleted(client, admin_headers): """Test using token of deleted user""" # First create a new user as admin @@ -138,22 +130,17 @@ def test_token_user_deleted(client, admin_headers): "username": "temporary", "email": "temp@test.com", "password": "temppass123", - "role": "admin" # Changed to admin so they can access /clients/ + "role": "admin", # Changed to admin so they can access /clients/ } - response = client.post( - "/auth/users", - headers=admin_headers, - json=user_data - ) + response = client.post("/auth/users", headers=admin_headers, json=user_data) assert response.status_code == status.HTTP_200_OK # Get token for new user response = client.post( - "/auth/token", - data={"username": "temporary", "password": "temppass123"} + "/auth/token", data={"username": "temporary", "password": "temppass123"} ) token = response.json()["access_token"] - + # Try using the token headers = {"Authorization": f"Bearer {token}"} response = client.get("/clients/", headers=headers) diff --git a/tests/test_clients.py b/tests/test_clients.py index 611a5b34..40b233c2 100644 --- a/tests/test_clients.py +++ b/tests/test_clients.py @@ -1,12 +1,14 @@ import pytest from fastapi import status + # Test GET Operations def test_get_clients_unauthorized(client): """Test that unauthorized access is prevented""" response = client.get("/clients/") assert response.status_code == status.HTTP_401_UNAUTHORIZED + def test_get_clients_as_admin(client, admin_headers): """Test getting all clients as admin""" response = client.get("/clients/", headers=admin_headers) @@ -16,24 +18,24 @@ def test_get_clients_as_admin(client, admin_headers): assert "total" in data assert len(data["clients"]) > 0 + def test_get_client_by_id(client, admin_headers): """Test getting specific client""" # Test existing client response = client.get("/clients/1", headers=admin_headers) assert response.status_code == status.HTTP_200_OK assert response.json()["id"] == 1 - + # Test non-existent client response = client.get("/clients/999", headers=admin_headers) assert response.status_code == status.HTTP_404_NOT_FOUND + def test_get_clients_by_criteria(client, admin_headers): """Test searching clients by various criteria""" # Test single criterion response = client.get( - "/clients/search/by-criteria", - params={"age_min": 25}, - headers=admin_headers + "/clients/search/by-criteria", params={"age_min": 25}, headers=admin_headers ) assert response.status_code == status.HTTP_200_OK assert len(response.json()) > 0 @@ -41,12 +43,8 @@ def test_get_clients_by_criteria(client, admin_headers): # Test multiple criteria response = client.get( "/clients/search/by-criteria", - params={ - "age_min": 25, - "currently_employed": True, - "gender": 2 - }, - headers=admin_headers + params={"age_min": 25, "currently_employed": True, "gender": 2}, + headers=admin_headers, ) assert response.status_code == status.HTTP_200_OK @@ -54,23 +52,24 @@ def test_get_clients_by_criteria(client, admin_headers): response = client.get( "/clients/search/by-criteria", params={"age_min": 15}, # Below minimum age - headers=admin_headers + headers=admin_headers, ) - assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY # Changed from 400 + assert ( + response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + ) # Changed from 400 + def test_get_clients_by_services(client, admin_headers): """Test getting clients by service status""" response = client.get( "/clients/search/by-services", - params={ - "employment_assistance": True, - "life_stabilization": True - }, - headers=admin_headers + params={"employment_assistance": True, "life_stabilization": True}, + headers=admin_headers, ) assert response.status_code == status.HTTP_200_OK assert len(response.json()) > 0 + def test_get_client_services(client, admin_headers): """Test getting services for a specific client""" response = client.get("/clients/1/services", headers=admin_headers) @@ -81,52 +80,46 @@ def test_get_client_services(client, admin_headers): assert "employment_assistance" in services[0] assert "success_rate" in services[0] + def test_get_clients_by_success_rate(client, admin_headers): """Test getting clients by success rate threshold""" response = client.get( - "/clients/search/success-rate", - params={"min_rate": 70}, - headers=admin_headers + "/clients/search/success-rate", params={"min_rate": 70}, headers=admin_headers ) assert response.status_code == status.HTTP_200_OK assert len(response.json()) > 0 + def test_get_clients_by_case_worker(client, admin_headers, case_worker_headers): """Test getting clients assigned to a case worker""" # Test as admin response = client.get("/clients/case-worker/2", headers=admin_headers) assert response.status_code == status.HTTP_200_OK - + # Test as case worker response = client.get("/clients/case-worker/2", headers=case_worker_headers) assert response.status_code == status.HTTP_200_OK + # Test UPDATE Operations def test_update_client(client, admin_headers): """Test updating client information""" - update_data = { - "age": 26, - "currently_employed": True, - "time_unemployed": 0 - } - response = client.put( - "/clients/1", - json=update_data, - headers=admin_headers - ) + update_data = {"age": 26, "currently_employed": True, "time_unemployed": 0} + response = client.put("/clients/1", json=update_data, headers=admin_headers) assert response.status_code == status.HTTP_200_OK updated_client = response.json() assert updated_client["age"] == 26 assert updated_client["currently_employed"] == True assert updated_client["time_unemployed"] == 0 + # Test Create Case Assignment def test_create_case_assignment(client, admin_headers): """Test creating new case assignment""" response = client.post( "/clients/1/case-assignment", params={"case_worker_id": 2}, - headers=admin_headers + headers=admin_headers, ) assert response.status_code == status.HTTP_200_OK @@ -134,10 +127,11 @@ def test_create_case_assignment(client, admin_headers): response = client.post( "/clients/1/case-assignment", params={"case_worker_id": 2}, - headers=admin_headers + headers=admin_headers, ) assert response.status_code == status.HTTP_400_BAD_REQUEST + # Test DELETE Operation def test_delete_client(client, admin_headers): """Test deleting a client""" From 7d6dc32696ec78e26689975ff34e7b9bdaba7374 Mon Sep 17 00:00:00 2001 From: ruoyanj01 Date: Fri, 21 Mar 2025 17:44:44 -0700 Subject: [PATCH 03/23] refactor auth --- app/auth/router.py | 156 ++++++++++++++++++++++++--------------------- 1 file changed, 85 insertions(+), 71 deletions(-) diff --git a/app/auth/router.py b/app/auth/router.py index 35433b54..7049cd24 100644 --- a/app/auth/router.py +++ b/app/auth/router.py @@ -1,17 +1,20 @@ from datetime import datetime, timedelta from typing import Optional + from fastapi import APIRouter, Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm -from jose import JWTError, jwt +from pydantic import BaseModel, Field, validator +from jose import jwt, JWTError from sqlalchemy.orm import Session +from passlib.context import CryptContext + from app.database import get_db from app.models import User, UserRole -from passlib.context import CryptContext -from pydantic import BaseModel, Field, validator router = APIRouter(prefix="/auth", tags=["authentication"]) +# Schemas class UserCreate(BaseModel): username: str = Field(..., min_length=3, max_length=50) email: str @@ -24,7 +27,6 @@ def validate_role(cls, v): raise ValueError("Role must be either admin or case_worker") return v - class UserResponse(BaseModel): username: str email: str @@ -34,41 +36,86 @@ class Config: from_attributes = True -# Configuration -SECRET_KEY = "your-secret-key-here" -ALGORITHM = "HS256" -ACCESS_TOKEN_EXPIRE_MINUTES = 30 +# Security Service +class SecurityService: + pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") -pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") -oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/token") + def verify_password(self, plain_password: str, hashed_password: str) -> bool: + return self.pwd_context.verify(plain_password, hashed_password) + def get_password_hash(self, password: str) -> str: + return self.pwd_context.hash(password) -def verify_password(plain_password: str, hashed_password: str) -> bool: - return pwd_context.verify(plain_password, hashed_password) +security = SecurityService() -def get_password_hash(password: str) -> str: - return pwd_context.hash(password) +# Token Service +class TokenService: + SECRET_KEY = "your-secret-key-here" + ALGORITHM = "HS256" + ACCESS_TOKEN_EXPIRE_MINUTES = 30 + def create_access_token(self, data: dict, expires_delta: Optional[timedelta] = None): + to_encode = data.copy() + expire = datetime.utcnow() + (expires_delta or timedelta(minutes=15)) + to_encode.update({"exp": expire}) + return jwt.encode(to_encode, self.SECRET_KEY, algorithm=self.ALGORITHM) -def authenticate_user(db: Session, username: str, password: str) -> Optional[User]: - user = db.query(User).filter(User.username == username).first() - if not user or not verify_password(password, user.hashed_password): - return None - return user + def decode_token(self, token: str): + try: + return jwt.decode(token, self.SECRET_KEY, algorithms=[self.ALGORITHM]) + except JWTError: + return None + +token_service = TokenService() + + +# User Service +class UserService: + def authenticate_user(self, db: Session, username: str, password: str) -> Optional[User]: + user = db.query(User).filter(User.username == username).first() + if not user or not security.verify_password(password, user.hashed_password): + return None + return user + + def create_user(self, db: Session, user_data: UserCreate) -> User: + if db.query(User).filter(User.username == user_data.username).first(): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Username already registered", + ) + + if db.query(User).filter(User.email == user_data.email).first(): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Email already registered", + ) + db_user = User( + username=user_data.username, + email=user_data.email, + hashed_password=security.get_password_hash(user_data.password), + role=user_data.role, + ) + + try: + db.add(db_user) + db.commit() + db.refresh(db_user) + return db_user + except Exception as e: + db.rollback() + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=str(e), + ) -def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): - to_encode = data.copy() - if expires_delta: - expire = datetime.utcnow() + expires_delta - else: - expire = datetime.utcnow() + timedelta(minutes=15) - to_encode.update({"exp": expire}) - encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) - return encoded_jwt +user_service = UserService() +# Auth Dependencies +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/token") + async def get_current_user( token: str = Depends(oauth2_scheme), db: Session = Depends(get_db) ) -> User: @@ -77,12 +124,9 @@ async def get_current_user( detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) - try: - payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) - username: str = payload.get("sub") - if username is None: - raise credentials_exception - except JWTError: + payload = token_service.decode_token(token) + username = payload.get("sub") if payload else None + if not username: raise credentials_exception user = db.query(User).filter(User.username == username).first() @@ -90,7 +134,6 @@ async def get_current_user( raise credentials_exception return user - def get_admin_user(current_user: User = Depends(get_current_user)): if current_user.role != UserRole.admin: raise HTTPException( @@ -100,19 +143,21 @@ def get_admin_user(current_user: User = Depends(get_current_user)): return current_user +# Routes @router.post("/token") async def login_for_access_token( - form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db) + form_data: OAuth2PasswordRequestForm = Depends(), + db: Session = Depends(get_db) ): - user = authenticate_user(db, form_data.username, form_data.password) + user = user_service.authenticate_user(db, form_data.username, form_data.password) if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect username or password", headers={"WWW-Authenticate": "Bearer"}, ) - access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) - access_token = create_access_token( + access_token_expires = timedelta(minutes=token_service.ACCESS_TOKEN_EXPIRE_MINUTES) + access_token = token_service.create_access_token( data={"sub": user.username}, expires_delta=access_token_expires ) return {"access_token": access_token, "token_type": "bearer"} @@ -124,35 +169,4 @@ async def create_user( current_user: User = Depends(get_admin_user), db: Session = Depends(get_db), ): - """Create a new user (admin only)""" - # Check if username exists - if db.query(User).filter(User.username == user_data.username).first(): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Username already registered", - ) - - # Check if email exists - if db.query(User).filter(User.email == user_data.email).first(): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered" - ) - - # Create new user - db_user = User( - username=user_data.username, - email=user_data.email, - hashed_password=get_password_hash(user_data.password), - role=user_data.role, - ) - - try: - db.add(db_user) - db.commit() - db.refresh(db_user) - return db_user - except Exception as e: - db.rollback() - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) - ) + return user_service.create_user(db, user_data) From 8574aa6d9e0ef6df7fbe4bff884499f7de7d7d48 Mon Sep 17 00:00:00 2001 From: Helen Wang Date: Sun, 23 Mar 2025 20:50:33 -0700 Subject: [PATCH 04/23] Refactor: Extract UserRole enum to separate file following Single Responsibility Principle --- app/enums.py | 13 +++++++++++++ app/models.py | 10 +--------- 2 files changed, 14 insertions(+), 9 deletions(-) create mode 100644 app/enums.py diff --git a/app/enums.py b/app/enums.py new file mode 100644 index 00000000..8efc6981 --- /dev/null +++ b/app/enums.py @@ -0,0 +1,13 @@ +""" +Enum definitions for the Common Assessment Tool. +Contains enumerations used across the application. +""" + +import enum + +class UserRole(str, enum.Enum): + """ + User Role class defining possible user roles in the system. + """ + admin = "admin" + case_worker = "case_worker" \ No newline at end of file diff --git a/app/models.py b/app/models.py index dbfb891b..94d6a1b0 100644 --- a/app/models.py +++ b/app/models.py @@ -17,15 +17,7 @@ from sqlalchemy.orm import relationship from app.database import Base - - -class UserRole(str, enum.Enum): - """ - User Role class store two roles - """ - - admin = "admin" - case_worker = "case_worker" +from app.enums import UserRole # New import class User(Base): From 0cc4f6f169d2871fa1b7578b63d01c5ff445ac81 Mon Sep 17 00:00:00 2001 From: Helen Wang Date: Sun, 23 Mar 2025 21:21:47 -0700 Subject: [PATCH 05/23] Refactor: Extract validation logic from models to dedicated validators module --- app/models.py | 87 ++++++++++++++++++++++++++--------------------- app/validators.py | 68 ++++++++++++++++++++++++++++++++++++ 2 files changed, 117 insertions(+), 38 deletions(-) create mode 100644 app/validators.py diff --git a/app/models.py b/app/models.py index 94d6a1b0..038bad2d 100644 --- a/app/models.py +++ b/app/models.py @@ -3,8 +3,6 @@ Contains the Client model for storing client information in the database. """ -import enum - from sqlalchemy import ( Column, Integer, @@ -12,12 +10,21 @@ Boolean, ForeignKey, Enum, - CheckConstraint, ) from sqlalchemy.orm import relationship from app.database import Base -from app.enums import UserRole # New import +from app.enums import UserRole +from app.validators import ( + age_constraint, + gender_constraint, + experience_constraint, + school_level_constraint, + scale_constraint, + housing_constraint, + income_source_constraint, + success_rate_constraint, +) class User(Base): @@ -46,50 +53,51 @@ class Client(Base): __tablename__ = "clients" id = Column(Integer, primary_key=True, autoincrement=True) - age = Column(Integer, CheckConstraint("age >= 18")) - gender = Column(Integer, CheckConstraint("gender = 1 OR gender = 2")) - work_experience = Column(Integer, CheckConstraint("work_experience >= 0")) - canada_workex = Column(Integer, CheckConstraint("canada_workex >= 0")) - dep_num = Column(Integer, CheckConstraint("dep_num >= 0")) + age = Column(Integer) + gender = Column(Integer) + work_experience = Column(Integer) + canada_workex = Column(Integer) + dep_num = Column(Integer) canada_born = Column(Boolean) citizen_status = Column(Boolean) - level_of_schooling = Column( - Integer, CheckConstraint("level_of_schooling >= 1 AND level_of_schooling <= 14") - ) + level_of_schooling = Column(Integer) fluent_english = Column(Boolean) - reading_english_scale = Column( - Integer, - CheckConstraint("reading_english_scale >= 0 AND reading_english_scale <= 10"), - ) - speaking_english_scale = Column( - Integer, - CheckConstraint("speaking_english_scale >= 0 AND speaking_english_scale <= 10"), - ) - writing_english_scale = Column( - Integer, - CheckConstraint("writing_english_scale >= 0 AND writing_english_scale <= 10"), - ) - numeracy_scale = Column( - Integer, CheckConstraint("numeracy_scale >= 0 AND numeracy_scale <= 10") - ) - computer_scale = Column( - Integer, CheckConstraint("computer_scale >= 0 AND computer_scale <= 10") - ) + reading_english_scale = Column(Integer) + speaking_english_scale = Column(Integer) + writing_english_scale = Column(Integer) + numeracy_scale = Column(Integer) + computer_scale = Column(Integer) transportation_bool = Column(Boolean) caregiver_bool = Column(Boolean) - housing = Column(Integer, CheckConstraint("housing >= 1 AND housing <= 10")) - income_source = Column( - Integer, CheckConstraint("income_source >= 1 AND income_source <= 11") - ) + housing = Column(Integer) + income_source = Column(Integer) felony_bool = Column(Boolean) attending_school = Column(Boolean) currently_employed = Column(Boolean) substance_use = Column(Boolean) - time_unemployed = Column(Integer, CheckConstraint("time_unemployed >= 0")) + time_unemployed = Column(Integer) need_mental_health_support_bool = Column(Boolean) cases = relationship("ClientCase", back_populates="client") + # Apply constraints + __table_args__ = ( + age_constraint(), + gender_constraint(), + experience_constraint("work_experience"), + experience_constraint("canada_workex"), + experience_constraint("dep_num"), + school_level_constraint(), + scale_constraint("reading_english_scale"), + scale_constraint("speaking_english_scale"), + scale_constraint("writing_english_scale"), + scale_constraint("numeracy_scale"), + scale_constraint("computer_scale"), + housing_constraint(), + income_source_constraint(), + experience_constraint("time_unemployed"), + ) + class ClientCase(Base): """ @@ -108,9 +116,12 @@ class ClientCase(Base): employment_related_financial_supports = Column(Boolean) employer_financial_supports = Column(Boolean) enhanced_referrals = Column(Boolean) - success_rate = Column( - Integer, CheckConstraint("success_rate >= 0 AND success_rate <= 100") - ) + success_rate = Column(Integer) client = relationship("Client", back_populates="cases") user = relationship("User", back_populates="cases") + + # Apply constraint + __table_args__ = ( + success_rate_constraint(), + ) \ No newline at end of file diff --git a/app/validators.py b/app/validators.py new file mode 100644 index 00000000..64975012 --- /dev/null +++ b/app/validators.py @@ -0,0 +1,68 @@ +""" +Validation utilities for the Common Assessment Tool. +Contains functions and constants for validating model data. +""" + +from sqlalchemy import CheckConstraint + +# Constants for validation +MIN_AGE = 18 +GENDER_VALUES = [1, 2] +MIN_SCALE_VALUE = 0 +MAX_SCALE_VALUE = 10 +MIN_SCHOOLING_LEVEL = 1 +MAX_SCHOOLING_LEVEL = 14 +MIN_HOUSING_LEVEL = 1 +MAX_HOUSING_LEVEL = 10 +MIN_INCOME_SOURCE = 1 +MAX_INCOME_SOURCE = 11 +MIN_SUCCESS_RATE = 0 +MAX_SUCCESS_RATE = 100 + +# Constraint creation functions +def age_constraint(): + """Create a constraint to ensure age is at least 18.""" + return CheckConstraint(f"age >= {MIN_AGE}") + +def gender_constraint(): + """Create a constraint to ensure gender is a valid value.""" + return CheckConstraint("gender = 1 OR gender = 2") + +def experience_constraint(field_name): + """Create a constraint to ensure experience is not negative.""" + return CheckConstraint(f"{field_name} >= 0") + +def school_level_constraint(): + """Create a constraint to ensure schooling level is within valid range.""" + return CheckConstraint( + f"level_of_schooling >= {MIN_SCHOOLING_LEVEL} AND " + f"level_of_schooling <= {MAX_SCHOOLING_LEVEL}" + ) + +def scale_constraint(field_name): + """Create a constraint to ensure scale values are within 0-10 range.""" + return CheckConstraint( + f"{field_name} >= {MIN_SCALE_VALUE} AND " + f"{field_name} <= {MAX_SCALE_VALUE}" + ) + +def housing_constraint(): + """Create a constraint to ensure housing value is valid.""" + return CheckConstraint( + f"housing >= {MIN_HOUSING_LEVEL} AND " + f"housing <= {MAX_HOUSING_LEVEL}" + ) + +def income_source_constraint(): + """Create a constraint to ensure income source value is valid.""" + return CheckConstraint( + f"income_source >= {MIN_INCOME_SOURCE} AND " + f"income_source <= {MAX_INCOME_SOURCE}" + ) + +def success_rate_constraint(): + """Create a constraint to ensure success rate is within valid range.""" + return CheckConstraint( + f"success_rate >= {MIN_SUCCESS_RATE} AND " + f"success_rate <= {MAX_SUCCESS_RATE}" + ) \ No newline at end of file From be69bb6294ced90c3db9fadc5066e36cf9af0704 Mon Sep 17 00:00:00 2001 From: Helen Wang Date: Sun, 23 Mar 2025 21:25:32 -0700 Subject: [PATCH 06/23] Update code quality README with completed refactorings --- code_quality_README.md | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/code_quality_README.md b/code_quality_README.md index 339f7699..63b3a527 100644 --- a/code_quality_README.md +++ b/code_quality_README.md @@ -25,4 +25,14 @@ Run these commands to analyze code quality: ``` pre-commit autoupdate pre-commit run --all-files -``` \ No newline at end of file +``` + +## Refactoring for SOLID Principles + +### Completed Refactorings: +- Extracted UserRole enum to a separate file (Single Responsibility Principle) +- Extracted validation logic from models (Single Responsibility Principle) + - Created app/validators.py with dedicated validation functions + - Defined constants for validation boundaries + - Made validation rules more maintainable and consistent + - Simplified model classes to focus on structure, not validation \ No newline at end of file From 07702a5f3da6fe29aa9c3efbd68aa25bb5eeca18 Mon Sep 17 00:00:00 2001 From: Helen Wang Date: Sun, 23 Mar 2025 21:28:54 -0700 Subject: [PATCH 07/23] Update code quality README with completed refactorings --- code_quality_README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/code_quality_README.md b/code_quality_README.md index 63b3a527..6d74e85a 100644 --- a/code_quality_README.md +++ b/code_quality_README.md @@ -30,7 +30,7 @@ pre-commit run --all-files ## Refactoring for SOLID Principles ### Completed Refactorings: -- Extracted UserRole enum to a separate file (Single Responsibility Principle) +- Extracted UserRole enum to a separate file (app/enums.py) (Single Responsibility Principle) - Extracted validation logic from models (Single Responsibility Principle) - Created app/validators.py with dedicated validation functions - Defined constants for validation boundaries From 41b65ccd070a1dd3388849773bd4a5d76c402d14 Mon Sep 17 00:00:00 2001 From: Helen Wang Date: Sun, 23 Mar 2025 23:35:07 -0700 Subject: [PATCH 08/23] Refactor: Split models into domain-specific files following Single Responsibility Principle --- app/__init__.py | 0 app/auth/router.py | 3 +- app/main.py | 5 ++- app/models/__init__.py | 10 +++++ app/models/client.py | 73 +++++++++++++++++++++++++++++++++++++ app/models/relationships.py | 39 ++++++++++++++++++++ app/models/user.py | 34 +++++++++++++++++ app/validators.py | 11 +++++- code_quality_README.md | 14 ++++++- 9 files changed, 183 insertions(+), 6 deletions(-) delete mode 100644 app/__init__.py create mode 100644 app/models/__init__.py create mode 100644 app/models/client.py create mode 100644 app/models/relationships.py create mode 100644 app/models/user.py diff --git a/app/__init__.py b/app/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/app/auth/router.py b/app/auth/router.py index 7049cd24..d926376f 100644 --- a/app/auth/router.py +++ b/app/auth/router.py @@ -9,7 +9,8 @@ from passlib.context import CryptContext from app.database import get_db -from app.models import User, UserRole +from app.models import User +from app.enums import UserRole router = APIRouter(prefix="/auth", tags=["authentication"]) diff --git a/app/main.py b/app/main.py index 5038db96..6a136eb9 100644 --- a/app/main.py +++ b/app/main.py @@ -10,12 +10,13 @@ # Local application/library specific imports from app import models -from app.database import engine +from app.database import engine,Base #Add Base here from app.clients.router import router as clients_router from app.auth.router import router as auth_router +from app.database import Base # Initialize database tables -models.Base.metadata.create_all(bind=engine) +Base.metadata.create_all(bind=engine) # Create FastAPI application app = FastAPI( diff --git a/app/models/__init__.py b/app/models/__init__.py new file mode 100644 index 00000000..c3261f82 --- /dev/null +++ b/app/models/__init__.py @@ -0,0 +1,10 @@ +# app/models/__init__.py +""" +Database models package for the Common Assessment Tool. +""" + +from app.models.user import User +from app.models.client import Client +from app.models.relationships import ClientCase + +__all__ = ["User", "Client", "ClientCase"] \ No newline at end of file diff --git a/app/models/client.py b/app/models/client.py new file mode 100644 index 00000000..91e4b155 --- /dev/null +++ b/app/models/client.py @@ -0,0 +1,73 @@ +# app/models/client.py +""" +Client-related database models for the Common Assessment Tool. +Contains models for storing client information. +""" + +from sqlalchemy import Column, Integer, Boolean +from sqlalchemy.orm import relationship + +from app.database import Base +from app.validators import ( + age_constraint, + gender_constraint, + experience_constraint, + school_level_constraint, + scale_constraint, + housing_constraint, + income_source_constraint, +) + +class Client(Base): + """ + Represents a Client in the database. + Stores personal and case-related information for each client. + """ + + __tablename__ = "clients" + + id = Column(Integer, primary_key=True, autoincrement=True) + age = Column(Integer) + gender = Column(Integer) + work_experience = Column(Integer) + canada_workex = Column(Integer) + dep_num = Column(Integer) + canada_born = Column(Boolean) + citizen_status = Column(Boolean) + level_of_schooling = Column(Integer) + fluent_english = Column(Boolean) + reading_english_scale = Column(Integer) + speaking_english_scale = Column(Integer) + writing_english_scale = Column(Integer) + numeracy_scale = Column(Integer) + computer_scale = Column(Integer) + transportation_bool = Column(Boolean) + caregiver_bool = Column(Boolean) + housing = Column(Integer) + income_source = Column(Integer) + felony_bool = Column(Boolean) + attending_school = Column(Boolean) + currently_employed = Column(Boolean) + substance_use = Column(Boolean) + time_unemployed = Column(Integer) + need_mental_health_support_bool = Column(Boolean) + + cases = relationship("ClientCase", back_populates="client") + + # Apply constraints + __table_args__ = ( + age_constraint(), + gender_constraint(), + experience_constraint("work_experience"), + experience_constraint("canada_workex"), + experience_constraint("dep_num"), + school_level_constraint(), + scale_constraint("reading_english_scale"), + scale_constraint("speaking_english_scale"), + scale_constraint("writing_english_scale"), + scale_constraint("numeracy_scale"), + scale_constraint("computer_scale"), + housing_constraint(), + income_source_constraint(), + experience_constraint("time_unemployed"), + ) \ No newline at end of file diff --git a/app/models/relationships.py b/app/models/relationships.py new file mode 100644 index 00000000..675aa590 --- /dev/null +++ b/app/models/relationships.py @@ -0,0 +1,39 @@ +# app/models/relationships.py +""" +Relationship models for the Common Assessment Tool. +Contains models that connect other entities in the system. +""" + +from sqlalchemy import Column, Integer, Boolean, ForeignKey +from sqlalchemy.orm import relationship + +from app.database import Base +from app.validators import success_rate_constraint + +class ClientCase(Base): + """ + Represents the relationship between a client and a case worker. + Stores service information and outcomes for each client-user relationship. + """ + + __tablename__ = "client_cases" + + client_id = Column(Integer, ForeignKey("clients.id"), primary_key=True) + user_id = Column(Integer, ForeignKey("users.id"), primary_key=True) + + employment_assistance = Column(Boolean) + life_stabilization = Column(Boolean) + retention_services = Column(Boolean) + specialized_services = Column(Boolean) + employment_related_financial_supports = Column(Boolean) + employer_financial_supports = Column(Boolean) + enhanced_referrals = Column(Boolean) + success_rate = Column(Integer) + + client = relationship("Client", back_populates="cases") + user = relationship("User", back_populates="cases") + + # Apply constraints + __table_args__ = ( + success_rate_constraint(), + ) \ No newline at end of file diff --git a/app/models/user.py b/app/models/user.py new file mode 100644 index 00000000..b92a7dcd --- /dev/null +++ b/app/models/user.py @@ -0,0 +1,34 @@ +# app/models/user.py +""" +User-related database models for the Common Assessment Tool. +Contains models related to system users and authentication. +""" + +from sqlalchemy import Column, Integer, String, Enum +from sqlalchemy.orm import relationship + +from app.database import Base +from app.enums import UserRole +from app.validators import username_length_constraint, email_format_constraint + +class User(Base): + """ + Represents a User in the database. + Stores user details including authentication and roles. + """ + + __tablename__ = "users" + + id = Column(Integer, primary_key=True, autoincrement=True) + username = Column(String(50), unique=True, nullable=False) + email = Column(String(100), unique=True, nullable=False) + hashed_password = Column(String(200), nullable=False) + role = Column(Enum(UserRole), nullable=False) + + cases = relationship("ClientCase", back_populates="user") + + # Apply User-specific constraints + __table_args__ = ( + username_length_constraint(), + email_format_constraint(), + ) \ No newline at end of file diff --git a/app/validators.py b/app/validators.py index 64975012..4d331dd4 100644 --- a/app/validators.py +++ b/app/validators.py @@ -65,4 +65,13 @@ def success_rate_constraint(): return CheckConstraint( f"success_rate >= {MIN_SUCCESS_RATE} AND " f"success_rate <= {MAX_SUCCESS_RATE}" - ) \ No newline at end of file + ) + +def username_length_constraint(): + """Create a constraint to ensure username has a minimum length.""" + MIN_USERNAME_LENGTH = 3 + return CheckConstraint(f"LENGTH(username) >= {MIN_USERNAME_LENGTH}") + +def email_format_constraint(): + """Create a constraint to ensure email contains @ symbol.""" + return CheckConstraint("email LIKE '%@%'") \ No newline at end of file diff --git a/code_quality_README.md b/code_quality_README.md index 6d74e85a..80aeff13 100644 --- a/code_quality_README.md +++ b/code_quality_README.md @@ -30,9 +30,19 @@ pre-commit run --all-files ## Refactoring for SOLID Principles ### Completed Refactorings: -- Extracted UserRole enum to a separate file (app/enums.py) (Single Responsibility Principle) +- Extracted UserRole enum to a separate file (Single Responsibility Principle) + - Created app/enums.py for all enumeration types + - Improved code organization and reusability + - Extracted validation logic from models (Single Responsibility Principle) - Created app/validators.py with dedicated validation functions - Defined constants for validation boundaries - Made validation rules more maintainable and consistent - - Simplified model classes to focus on structure, not validation \ No newline at end of file + - Simplified model classes to focus on structure, not validation + +- Split database models into domain-specific files (Single Responsibility Principle) + - Created separate files for User, Client, and relationship models + - Organized models by domain area in app/models/ directory + - Fixed imports to maintain compatibility with existing code + - Improved code organization and maintainability + - Made domain model boundaries clearer \ No newline at end of file From 27024d781edd3f2335c3b9028ae60322b5f4f6e3 Mon Sep 17 00:00:00 2001 From: Helen Wang Date: Mon, 24 Mar 2025 11:39:28 -0700 Subject: [PATCH 09/23] delete original redundant models file --- app/models.py | 127 -------------------------------------------------- 1 file changed, 127 deletions(-) delete mode 100644 app/models.py diff --git a/app/models.py b/app/models.py deleted file mode 100644 index 038bad2d..00000000 --- a/app/models.py +++ /dev/null @@ -1,127 +0,0 @@ -""" -Database models module defining SQLAlchemy ORM models for the Common Assessment Tool. -Contains the Client model for storing client information in the database. -""" - -from sqlalchemy import ( - Column, - Integer, - String, - Boolean, - ForeignKey, - Enum, -) -from sqlalchemy.orm import relationship - -from app.database import Base -from app.enums import UserRole -from app.validators import ( - age_constraint, - gender_constraint, - experience_constraint, - school_level_constraint, - scale_constraint, - housing_constraint, - income_source_constraint, - success_rate_constraint, -) - - -class User(Base): - """ - Represents a User in the database. - Stores user details including authentication and roles. - """ - - __tablename__ = "users" - - id = Column(Integer, primary_key=True, autoincrement=True) - username = Column(String(50), unique=True, nullable=False) - email = Column(String(100), unique=True, nullable=False) - hashed_password = Column(String(200), nullable=False) - role = Column(Enum(UserRole), nullable=False) - - cases = relationship("ClientCase", back_populates="user") - - -class Client(Base): - """ - Represents a Client in the database. - Stores personal and case-related information for each client. - """ - - __tablename__ = "clients" - - id = Column(Integer, primary_key=True, autoincrement=True) - age = Column(Integer) - gender = Column(Integer) - work_experience = Column(Integer) - canada_workex = Column(Integer) - dep_num = Column(Integer) - canada_born = Column(Boolean) - citizen_status = Column(Boolean) - level_of_schooling = Column(Integer) - fluent_english = Column(Boolean) - reading_english_scale = Column(Integer) - speaking_english_scale = Column(Integer) - writing_english_scale = Column(Integer) - numeracy_scale = Column(Integer) - computer_scale = Column(Integer) - transportation_bool = Column(Boolean) - caregiver_bool = Column(Boolean) - housing = Column(Integer) - income_source = Column(Integer) - felony_bool = Column(Boolean) - attending_school = Column(Boolean) - currently_employed = Column(Boolean) - substance_use = Column(Boolean) - time_unemployed = Column(Integer) - need_mental_health_support_bool = Column(Boolean) - - cases = relationship("ClientCase", back_populates="client") - - # Apply constraints - __table_args__ = ( - age_constraint(), - gender_constraint(), - experience_constraint("work_experience"), - experience_constraint("canada_workex"), - experience_constraint("dep_num"), - school_level_constraint(), - scale_constraint("reading_english_scale"), - scale_constraint("speaking_english_scale"), - scale_constraint("writing_english_scale"), - scale_constraint("numeracy_scale"), - scale_constraint("computer_scale"), - housing_constraint(), - income_source_constraint(), - experience_constraint("time_unemployed"), - ) - - -class ClientCase(Base): - """ - ClientCase class - """ - - __tablename__ = "client_cases" - - client_id = Column(Integer, ForeignKey("clients.id"), primary_key=True) - user_id = Column(Integer, ForeignKey("users.id"), primary_key=True) - - employment_assistance = Column(Boolean) - life_stabilization = Column(Boolean) - retention_services = Column(Boolean) - specialized_services = Column(Boolean) - employment_related_financial_supports = Column(Boolean) - employer_financial_supports = Column(Boolean) - enhanced_referrals = Column(Boolean) - success_rate = Column(Integer) - - client = relationship("Client", back_populates="cases") - user = relationship("User", back_populates="cases") - - # Apply constraint - __table_args__ = ( - success_rate_constraint(), - ) \ No newline at end of file From 517f5b7c1b7eee4dcacb127178a681e8e1740c4b Mon Sep 17 00:00:00 2001 From: Helen Wang Date: Mon, 24 Mar 2025 12:12:08 -0700 Subject: [PATCH 10/23] Refactor ML model implementation for SOLID principles and implement multiple model types --- app/ml/base_model.py | 36 ++++++++++++++++++ app/ml/data_processor.py | 60 ++++++++++++++++++++++++++++++ app/ml/model_registry.py | 46 +++++++++++++++++++++++ app/ml/models/__init__.py | 9 +++++ app/ml/models/gradient_boost.py | 34 +++++++++++++++++ app/ml/models/linear_regression.py | 30 +++++++++++++++ app/ml/models/random_forest.py | 33 ++++++++++++++++ app/ml/test_models.py | 56 ++++++++++++++++++++++++++++ code_quality_README.md | 27 +++++++++++++- requirements-dev.txt | 2 +- 10 files changed, 331 insertions(+), 2 deletions(-) create mode 100644 app/ml/base_model.py create mode 100644 app/ml/data_processor.py create mode 100644 app/ml/model_registry.py create mode 100644 app/ml/models/__init__.py create mode 100644 app/ml/models/gradient_boost.py create mode 100644 app/ml/models/linear_regression.py create mode 100644 app/ml/models/random_forest.py create mode 100644 app/ml/test_models.py diff --git a/app/ml/base_model.py b/app/ml/base_model.py new file mode 100644 index 00000000..d91cda9f --- /dev/null +++ b/app/ml/base_model.py @@ -0,0 +1,36 @@ +# app/ml/base_model.py +""" +Base model interface for machine learning models. +All model implementations should inherit from this class. +""" +from abc import ABC, abstractmethod +import pickle + +class BaseModel(ABC): + """Abstract base class for all ML models.""" + + @abstractmethod + def train(self, features, targets): + """Train the model with the given features and targets.""" + pass + + @abstractmethod + def predict(self, features): + """Make predictions using the trained model.""" + pass + + @abstractmethod + def get_name(self): + """Return the name of the model.""" + pass + + def save(self, filename): + """Save the model to a file.""" + with open(filename, "wb") as model_file: + pickle.dump(self, model_file) + + @classmethod + def load(cls, filename): + """Load the model from a file.""" + with open(filename, "rb") as model_file: + return pickle.load(model_file) \ No newline at end of file diff --git a/app/ml/data_processor.py b/app/ml/data_processor.py new file mode 100644 index 00000000..d4946eed --- /dev/null +++ b/app/ml/data_processor.py @@ -0,0 +1,60 @@ +# app/ml/data_processor.py +""" +Data processing module for machine learning models. +Handles loading and preprocessing data for model training and prediction. +""" +import pandas as pd +import numpy as np +from sklearn.model_selection import train_test_split + +class DataProcessor: + """Handles data loading and preprocessing for ML models.""" + + def __init__(self, data_file="data_commontool.csv"): + """Initialize with the data file path.""" + self.data_file = data_file + self.feature_columns = [ + "age", "gender", "work_experience", "canada_workex", "dep_num", + "canada_born", "citizen_status", "level_of_schooling", + "fluent_english", "reading_english_scale", "speaking_english_scale", + "writing_english_scale", "numeracy_scale", "computer_scale", + "transportation_bool", "caregiver_bool", "housing", "income_source", + "felony_bool", "attending_school", "currently_employed", + "substance_use", "time_unemployed", "need_mental_health_support_bool" + ] + self.intervention_columns = [ + "employment_assistance", "life_stabilization", "retention_services", + "specialized_services", "employment_related_financial_supports", + "employer_financial_supports", "enhanced_referrals" + ] + + def load_data(self): + """Load the dataset from file.""" + return pd.read_csv(self.data_file) + + def prepare_training_data(self): + """Prepare features and targets for model training.""" + data = self.load_data() + all_features = self.feature_columns + self.intervention_columns + + features = np.array(data[all_features]) + targets = np.array(data["success_rate"]) + + features_train, features_test, targets_train, targets_test = train_test_split( + features, targets, test_size=0.2, random_state=42 + ) + + return features_train, features_test, targets_train, targets_test + + def prepare_prediction_data(self, client_data, interventions): + """Prepare a single client's data for prediction.""" + # Extract client features from the client data + client_features = [client_data.get(col, 0) for col in self.feature_columns] + + # Add intervention features + intervention_features = [interventions.get(col, 0) for col in self.intervention_columns] + + # Combine all features + all_features = client_features + intervention_features + + return np.array([all_features]) \ No newline at end of file diff --git a/app/ml/model_registry.py b/app/ml/model_registry.py new file mode 100644 index 00000000..f0605d83 --- /dev/null +++ b/app/ml/model_registry.py @@ -0,0 +1,46 @@ +# app/ml/model_registry.py +""" +Model registry for managing different ML models. +""" +class ModelRegistry: + """Registry for managing multiple machine learning models.""" + + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super(ModelRegistry, cls).__new__(cls) + cls._instance._models = {} + cls._instance._current_model_name = None + return cls._instance + + def register_model(self, model_name, model_class): + """Register a new model with the registry.""" + self._models[model_name] = model_class + if self._current_model_name is None: + self._current_model_name = model_name + + def get_model(self, model_name=None): + """Get a model by name, or the current model if no name is provided.""" + if model_name is None: + model_name = self._current_model_name + + if model_name not in self._models: + raise ValueError(f"Model {model_name} is not registered") + + return self._models[model_name] + + def set_current_model(self, model_name): + """Set the current active model.""" + if model_name not in self._models: + raise ValueError(f"Model {model_name} is not registered") + + self._current_model_name = model_name + + def get_current_model_name(self): + """Get the name of the currently active model.""" + return self._current_model_name + + def list_available_models(self): + """List all available models in the registry.""" + return list(self._models.keys()) \ No newline at end of file diff --git a/app/ml/models/__init__.py b/app/ml/models/__init__.py new file mode 100644 index 00000000..40c26e49 --- /dev/null +++ b/app/ml/models/__init__.py @@ -0,0 +1,9 @@ +# app/ml/models/__init__.py +""" +Machine learning models package. +""" +from app.ml.models.random_forest import RandomForestModel +from app.ml.models.gradient_boost import GradientBoostingModel +from app.ml.models.linear_regression import LinearRegressionModel + +__all__ = ["RandomForestModel", "GradientBoostingModel", "LinearRegressionModel"] \ No newline at end of file diff --git a/app/ml/models/gradient_boost.py b/app/ml/models/gradient_boost.py new file mode 100644 index 00000000..4563646d --- /dev/null +++ b/app/ml/models/gradient_boost.py @@ -0,0 +1,34 @@ +# app/ml/models/gradient_boost.py +""" +Gradient Boosting implementation for success rate prediction. +""" +from sklearn.ensemble import GradientBoostingRegressor +from app.ml.base_model import BaseModel + +class GradientBoostingModel(BaseModel): + """Gradient Boosting model for predicting success rates.""" + + def __init__(self, n_estimators=100, learning_rate=0.1, random_state=42): + """Initialize the model with parameters.""" + self.model = GradientBoostingRegressor( + n_estimators=n_estimators, + learning_rate=learning_rate, + random_state=random_state + ) + self.is_trained = False + + def train(self, features, targets): + """Train the model with the given features and targets.""" + self.model.fit(features, targets) + self.is_trained = True + return self + + def predict(self, features): + """Make predictions using the trained model.""" + if not self.is_trained: + raise ValueError("Model must be trained before making predictions") + return self.model.predict(features) + + def get_name(self): + """Return the name of the model.""" + return "GradientBoosting" \ No newline at end of file diff --git a/app/ml/models/linear_regression.py b/app/ml/models/linear_regression.py new file mode 100644 index 00000000..7ce356ad --- /dev/null +++ b/app/ml/models/linear_regression.py @@ -0,0 +1,30 @@ +# app/ml/models/linear_regression.py +""" +Linear Regression implementation for success rate prediction. +""" +from sklearn.linear_model import LinearRegression +from app.ml.base_model import BaseModel + +class LinearRegressionModel(BaseModel): + """Linear Regression model for predicting success rates.""" + + def __init__(self): + """Initialize the model.""" + self.model = LinearRegression() + self.is_trained = False + + def train(self, features, targets): + """Train the model with the given features and targets.""" + self.model.fit(features, targets) + self.is_trained = True + return self + + def predict(self, features): + """Make predictions using the trained model.""" + if not self.is_trained: + raise ValueError("Model must be trained before making predictions") + return self.model.predict(features) + + def get_name(self): + """Return the name of the model.""" + return "LinearRegression" \ No newline at end of file diff --git a/app/ml/models/random_forest.py b/app/ml/models/random_forest.py new file mode 100644 index 00000000..fa6836d7 --- /dev/null +++ b/app/ml/models/random_forest.py @@ -0,0 +1,33 @@ +# app/ml/models/random_forest.py +""" +Random Forest implementation for success rate prediction. +""" +from sklearn.ensemble import RandomForestRegressor +from app.ml.base_model import BaseModel + +class RandomForestModel(BaseModel): + """Random Forest model for predicting success rates.""" + + def __init__(self, n_estimators=100, random_state=42): + """Initialize the model with parameters.""" + self.model = RandomForestRegressor( + n_estimators=n_estimators, + random_state=random_state + ) + self.is_trained = False + + def train(self, features, targets): + """Train the model with the given features and targets.""" + self.model.fit(features, targets) + self.is_trained = True + return self + + def predict(self, features): + """Make predictions using the trained model.""" + if not self.is_trained: + raise ValueError("Model must be trained before making predictions") + return self.model.predict(features) + + def get_name(self): + """Return the name of the model.""" + return "RandomForest" \ No newline at end of file diff --git a/app/ml/test_models.py b/app/ml/test_models.py new file mode 100644 index 00000000..c381f7e4 --- /dev/null +++ b/app/ml/test_models.py @@ -0,0 +1,56 @@ +# app/ml/test_models.py +""" +Test script for ML models and model registry. +""" +import numpy as np +from app.ml.data_processor import DataProcessor +from app.ml.model_registry import ModelRegistry +from app.ml.models.random_forest import RandomForestModel +from app.ml.models.gradient_boost import GradientBoostingModel +from app.ml.models.linear_regression import LinearRegressionModel + +def test_model_registry(): + """Test the model registry and model switching.""" + print("Testing model registry and model switching...") + + # Create the registry + registry = ModelRegistry() + + # Register models + registry.register_model("RandomForest", RandomForestModel()) + registry.register_model("GradientBoosting", GradientBoostingModel()) + registry.register_model("LinearRegression", LinearRegressionModel()) + + # List available models + models = registry.list_available_models() + print(f"Available models: {models}") + + # Get current model + current_model_name = registry.get_current_model_name() + print(f"Current model: {current_model_name}") + + # Switch models + registry.set_current_model("GradientBoosting") + print(f"Switched to model: {registry.get_current_model_name()}") + + # Create some dummy data for testing + features = np.random.rand(10, 31) # 24 client features + 7 intervention features + targets = np.random.rand(10) + + # Train all models + for model_name in models: + print(f"Training {model_name}...") + model = registry.get_model(model_name) + model.train(features, targets) + + # Make predictions with each model + for model_name in models: + print(f"Testing predictions with {model_name}...") + model = registry.get_model(model_name) + predictions = model.predict(features[:1]) + print(f" Prediction: {predictions[0]}") + + print("Model testing complete!") + +if __name__ == "__main__": + test_model_registry() \ No newline at end of file diff --git a/code_quality_README.md b/code_quality_README.md index 80aeff13..715b67c1 100644 --- a/code_quality_README.md +++ b/code_quality_README.md @@ -45,4 +45,29 @@ pre-commit run --all-files - Organized models by domain area in app/models/ directory - Fixed imports to maintain compatibility with existing code - Improved code organization and maintainability - - Made domain model boundaries clearer \ No newline at end of file + - Made domain model boundaries clearer + +- Refactored ML model implementation (Story 2 preparation) + - Created a base model interface following the Interface Segregation Principle + - Implemented multiple concrete model classes (Random Forest, Gradient Boosting, Linear Regression) + - Added a model registry implementing the Open/Closed Principle for easy model extension + - Separated data processing from model logic following Single Responsibility Principle + - Created a test script to verify model switching functionality + +## Testing the ML Models + +To test the machine learning model implementation and model switching capability: + +1. Ensure you have installed the required dependencies: + pip install numpy pandas scikit-learn + +2. Run the test script: + python -m app.ml.test_models + +3. The test will: +- Initialize the model registry +- Register three different model types +- Display available models +- Switch between models +- Train each model with sample data +- Make predictions with each model \ No newline at end of file diff --git a/requirements-dev.txt b/requirements-dev.txt index 2bbc6f11..821951d7 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,4 +2,4 @@ pylint==2.17.5 flake8==6.1.0 black==23.7.0 mypy==1.5.1 -python-jose>=3.3.0 \ No newline at end of file +python-jose>=3.3.0 From 02079fbac2fb469926cfa982e45e4402f7acb131 Mon Sep 17 00:00:00 2001 From: Helen Wang Date: Mon, 24 Mar 2025 12:38:29 -0700 Subject: [PATCH 11/23] Complete Story 2: Implement model switching with three ML models and API endpoints --- app/main.py | 3 +++ app/ml/__init__.py | 26 +++++++++++++++++++ app/ml/router.py | 53 +++++++++++++++++++++++++++++++++++++ code_quality_README.md | 59 +++++++++++++++++++++++++++++++++++++++++- 4 files changed, 140 insertions(+), 1 deletion(-) create mode 100644 app/ml/__init__.py create mode 100644 app/ml/router.py diff --git a/app/main.py b/app/main.py index 6a136eb9..667c85bd 100644 --- a/app/main.py +++ b/app/main.py @@ -10,7 +10,9 @@ # Local application/library specific imports from app import models +import app.ml # This will execute the initialization code from app.database import engine,Base #Add Base here +from app.ml.router import router as models_router # newly added from app.clients.router import router as clients_router from app.auth.router import router as auth_router @@ -26,6 +28,7 @@ ) # Include routers +app.include_router(models_router) app.include_router(auth_router) app.include_router(clients_router) diff --git a/app/ml/__init__.py b/app/ml/__init__.py new file mode 100644 index 00000000..8957555e --- /dev/null +++ b/app/ml/__init__.py @@ -0,0 +1,26 @@ +# app/ml/__init__.py +""" +Machine learning package initialization. +""" +from app.ml.model_registry import ModelRegistry +from app.ml.models.random_forest import RandomForestModel +from app.ml.models.gradient_boost import GradientBoostingModel +from app.ml.models.linear_regression import LinearRegressionModel + +# Initialize the model registry with available models +def initialize_models(): + """Register all available models with the registry.""" + registry = ModelRegistry() + + # Register the models + registry.register_model("RandomForest", RandomForestModel()) + registry.register_model("GradientBoosting", GradientBoostingModel()) + registry.register_model("LinearRegression", LinearRegressionModel()) + + # Set the default model + registry.set_current_model("RandomForest") + + return registry + +# Initialize models when importing this package +registry = initialize_models() \ No newline at end of file diff --git a/app/ml/router.py b/app/ml/router.py new file mode 100644 index 00000000..0a96d331 --- /dev/null +++ b/app/ml/router.py @@ -0,0 +1,53 @@ +# app/ml/router.py +""" +Router for ML model management endpoints. +""" +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel + +from app.ml.model_registry import ModelRegistry + +router = APIRouter( + prefix="/models", + tags=["models"], + responses={404: {"description": "Not found"}}, +) + +class ModelResponse(BaseModel): + """Response model for model endpoints.""" + name: str + +class ModelsListResponse(BaseModel): + """Response model for listing available models.""" + models: list[str] + current_model: str + +@router.get("/current", response_model=ModelResponse) +async def get_current_model(): + """Get the currently active model.""" + registry = ModelRegistry() + current_model = registry.get_current_model_name() + return {"name": current_model} + +@router.get("/available", response_model=ModelsListResponse) +async def list_available_models(): + """List all available models.""" + registry = ModelRegistry() + available_models = registry.list_available_models() + current_model = registry.get_current_model_name() + return {"models": available_models, "current_model": current_model} + +@router.post("/switch/{model_name}", response_model=ModelResponse) +async def switch_model(model_name: str): + """Switch to a different model.""" + registry = ModelRegistry() + + try: + registry.set_current_model(model_name) + except ValueError: + raise HTTPException( + status_code=404, + detail=f"Model '{model_name}' not found. Available models: {registry.list_available_models()}" + ) + + return {"name": model_name} \ No newline at end of file diff --git a/code_quality_README.md b/code_quality_README.md index 715b67c1..c40c2e44 100644 --- a/code_quality_README.md +++ b/code_quality_README.md @@ -70,4 +70,61 @@ To test the machine learning model implementation and model switching capability - Display available models - Switch between models - Train each model with sample data -- Make predictions with each model \ No newline at end of file +- Make predictions with each model + +## Story 2 Implementation: Multiple ML Model Support + +### Completed Features: +- Implemented a flexible ML model architecture with model switching capability +- Created a base model interface that all model types implement +- Developed three different model implementations: + - Random Forest (default model) + - Gradient Boosting + - Linear Regression +- Created a model registry to manage different model types and switching +- Implemented API endpoints for model management + +### API Endpoints for Model Management + +The following endpoints are now available for interacting with ML models: + +1. **Get Current Model** + - Endpoint: `GET /models/current` + - Description: Returns the name of the currently active model + - Response example: `{"name": "RandomForest"}` + - Test command: `curl -X GET "http://127.0.0.1:8000/models/current"` + +2. **List Available Models** + - Endpoint: `GET /models/available` + - Description: Returns all available models and indicates the current model + - Response example: `{"models": ["RandomForest", "GradientBoosting", "LinearRegression"], "current_model": "RandomForest"}` + - Test command: `curl -X GET "http://127.0.0.1:8000/models/available"` + +3. **Switch Model** + - Endpoint: `POST /models/switch/{model_name}` + - Description: Changes the active model to the specified model + - Response example: `{"name": "GradientBoosting"}` + - Test command: `curl -X POST "http://127.0.0.1:8000/models/switch/GradientBoosting"` + +### Testing the Model Switching Functionality + +You can test the model switching functionality in two ways: + +#### 1. Using the Swagger UI: +- Start the application: `uvicorn app.main:app --reload` +- Navigate to: `http://127.0.0.1:8000/docs` +- Scroll to the "models" section +- Try each endpoint by clicking "Try it out" and "Execute" + +#### 2. Using the ML Test Script: +- Run: `python -m app.ml.test_models` +- This will test model registration, switching, and prediction capability + +### Implementation Details + +- **Model Registry Pattern**: Used a singleton registry to manage models +- **Strategy Pattern**: Each model implementation follows the same interface +- **Dependency Injection**: High-level components use abstractions rather than concrete implementations +- **Open for Extension**: New models can be added without modifying existing code + +This implementation satisfies all requirements for Story 2 while following SOLID principles. \ No newline at end of file From 909f2cc089ddc97ef8378ab532a48d8755516c04 Mon Sep 17 00:00:00 2001 From: ruoyanj01 Date: Tue, 25 Mar 2025 14:20:26 -0700 Subject: [PATCH 12/23] refactor service/client_service.py, router.py, schema.py --- app/clients/router.py | 25 +- app/clients/service/client_service.py | 385 ++++++++------------------ 2 files changed, 131 insertions(+), 279 deletions(-) diff --git a/app/clients/router.py b/app/clients/router.py index 02708963..88c9487c 100644 --- a/app/clients/router.py +++ b/app/clients/router.py @@ -11,7 +11,8 @@ from app.auth.router import get_current_user, get_admin_user from app.database import get_db from app.models import User -from app.clients.service.client_service import ClientService +from app.clients.service.client_service import ClientQueryService, ClientMutationService + from app.clients.schema import ( ClientResponse, ClientUpdate, @@ -43,7 +44,7 @@ async def get_clients( Returns: A list of clients according to the specified pagination rules. """ - return ClientService.get_clients(db, skip, limit) + return ClientQueryService.get_clients(db, skip, limit) @router.get("/{client_id}", response_model=ClientResponse) @@ -53,7 +54,7 @@ async def get_client( db: Session = Depends(get_db), ): """Get a specific client by ID""" - return ClientService.get_client(db, client_id) + return ClientQueryService.get_client(db, client_id) @router.get("/search/by-criteria", response_model=List[ClientResponse]) @@ -86,7 +87,7 @@ async def get_clients_by_criteria( db: Session = Depends(get_db), ): """Search clients by any combination of criteria""" - return ClientService.get_clients_by_criteria( + return ClientQueryService.get_clients_by_criteria( db, employment_status=employment_status, education_level=education_level, @@ -128,7 +129,7 @@ async def get_clients_by_services( db: Session = Depends(get_db), ): """Get clients filtered by multiple service statuses""" - return ClientService.get_clients_by_services( + return ClientQueryService.get_clients_by_services( db, employment_assistance=employment_assistance, life_stabilization=life_stabilization, @@ -147,7 +148,7 @@ async def get_client_services( db: Session = Depends(get_db), ): """Get all services and their status for a specific client, including case worker info""" - return ClientService.get_client_services(db, client_id) + return ClientQueryService.get_client_services(db, client_id) @router.get("/search/success-rate", response_model=List[ClientResponse]) @@ -159,7 +160,7 @@ async def get_clients_by_success_rate( db: Session = Depends(get_db), ): """Get clients with success rate above specified threshold""" - return ClientService.get_clients_by_success_rate(db, min_rate) + return ClientQueryService.get_clients_by_success_rate(db, min_rate) @router.get("/case-worker/{case_worker_id}", response_model=List[ClientResponse]) @@ -168,7 +169,7 @@ async def get_clients_by_case_worker( current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ): - return ClientService.get_clients_by_case_worker(db, case_worker_id) + return ClientQueryService.get_clients_by_case_worker(db, case_worker_id) @router.put("/{client_id}", response_model=ClientResponse) @@ -179,7 +180,7 @@ async def update_client( db: Session = Depends(get_db), ): """Update a client's information""" - return ClientService.update_client(db, client_id, client_data) + return ClientMutationService.update_client(db, client_id, client_data) @router.put("/{client_id}/services/{user_id}", response_model=ServiceResponse) @@ -190,7 +191,7 @@ async def update_client_services( current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ): - return ClientService.update_client_services(db, client_id, user_id, service_update) + return ClientMutationService.update_client_services(db, client_id, user_id, service_update) @router.post("/{client_id}/case-assignment", response_model=ServiceResponse) @@ -201,7 +202,7 @@ async def create_case_assignment( db: Session = Depends(get_db), ): """Create a new case assignment for a client with a case worker""" - return ClientService.create_case_assignment(db, client_id, case_worker_id) + return ClientMutationService.create_case_assignment(db, client_id, case_worker_id) @router.delete("/{client_id}", status_code=status.HTTP_204_NO_CONTENT) @@ -211,5 +212,5 @@ async def delete_client( db: Session = Depends(get_db), ): """Delete a client""" - ClientService.delete_client(db, client_id) + ClientMutationService.delete_client(db, client_id) return None diff --git a/app/clients/service/client_service.py b/app/clients/service/client_service.py index 6af668df..63212faa 100644 --- a/app/clients/service/client_service.py +++ b/app/clients/service/client_service.py @@ -1,48 +1,30 @@ -""" -Client service module handling all database operations for clients. -Provides CRUD operations and business logic for client management. -""" - -from sqlalchemy.orm import Session from fastapi import HTTPException, status +from sqlalchemy.orm import Session from typing import Optional from app.models import Client, ClientCase, User from app.clients.schema import ClientUpdate, ServiceUpdate -class ClientService: +class ClientQueryService: + # Retrieve a single client by ID @staticmethod def get_client(db: Session, client_id: int): - """Get a specific client by ID""" client = db.query(Client).filter(Client.id == client_id).first() if not client: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Client with id {client_id} not found", - ) + raise HTTPException(status_code=404, detail=f"Client {client_id} not found") return client + # Retrieve a paginated list of clients @staticmethod def get_clients(db: Session, skip: int = 0, limit: int = 50): - """ - Get clients with optional pagination. - Default shows first 50 clients, which means you'd need 3 pages for 150 records. - """ - if skip < 0: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Skip value cannot be negative", - ) - if limit < 1: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Limit must be greater than 0", - ) - - clients = db.query(Client).offset(skip).limit(limit).all() - total = db.query(Client).count() - return {"clients": clients, "total": total} - + if skip < 0 or limit < 1: + raise HTTPException(status_code=400, detail="Invalid pagination parameters") + return { + "clients": db.query(Client).offset(skip).limit(limit).all(), + "total": db.query(Client).count() + } + + # Retrieve clients that match various optional criteria filters @staticmethod def get_clients_by_criteria( db: Session, @@ -71,295 +53,164 @@ def get_clients_by_criteria( time_unemployed: Optional[int] = None, need_mental_health_support_bool: Optional[bool] = None, ): - """Get clients filtered by any combination of criteria""" + """Search clients by any combination of criteria""" query = db.query(Client) - - if education_level is not None and not (1 <= education_level <= 14): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Education level must be between 1 and 14", - ) - - if age_min is not None and age_min < 18: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Minimum age must be at least 18", - ) - - if gender is not None and gender not in [1, 2]: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail="Gender must be 1 or 2" - ) - - # Apply filters for non-None values - if employment_status is not None: - query = query.filter(Client.currently_employed == employment_status) - if age_min is not None: - query = query.filter(Client.age >= age_min) - if gender is not None: - query = query.filter(Client.gender == gender) - if education_level is not None: - query = query.filter(Client.level_of_schooling == education_level) - if work_experience is not None: - query = query.filter(Client.work_experience == work_experience) - if canada_workex is not None: - query = query.filter(Client.canada_workex == canada_workex) - if dep_num is not None: - query = query.filter(Client.dep_num == dep_num) - if canada_born is not None: - query = query.filter(Client.canada_born == canada_born) - if citizen_status is not None: - query = query.filter(Client.citizen_status == citizen_status) - if fluent_english is not None: - query = query.filter(Client.fluent_english == fluent_english) - if reading_english_scale is not None: - query = query.filter(Client.reading_english_scale == reading_english_scale) - if speaking_english_scale is not None: - query = query.filter( - Client.speaking_english_scale == speaking_english_scale - ) - if writing_english_scale is not None: - query = query.filter(Client.writing_english_scale == writing_english_scale) - if numeracy_scale is not None: - query = query.filter(Client.numeracy_scale == numeracy_scale) - if computer_scale is not None: - query = query.filter(Client.computer_scale == computer_scale) - if transportation_bool is not None: - query = query.filter(Client.transportation_bool == transportation_bool) - if caregiver_bool is not None: - query = query.filter(Client.caregiver_bool == caregiver_bool) - if housing is not None: - query = query.filter(Client.housing == housing) - if income_source is not None: - query = query.filter(Client.income_source == income_source) - if felony_bool is not None: - query = query.filter(Client.felony_bool == felony_bool) - if attending_school is not None: - query = query.filter(Client.attending_school == attending_school) - if substance_use is not None: - query = query.filter(Client.substance_use == substance_use) - if time_unemployed is not None: - query = query.filter(Client.time_unemployed == time_unemployed) - if need_mental_health_support_bool is not None: - query = query.filter( - Client.need_mental_health_support_bool - == need_mental_health_support_bool - ) - + query = ClientQueryService._apply_criteria_filters( + query, + employment_status=employment_status, + education_level=education_level, + age_min=age_min, + gender=gender, + work_experience=work_experience, + canada_workex=canada_workex, + dep_num=dep_num, + canada_born=canada_born, + citizen_status=citizen_status, + fluent_english=fluent_english, + reading_english_scale=reading_english_scale, + speaking_english_scale=speaking_english_scale, + writing_english_scale=writing_english_scale, + numeracy_scale=numeracy_scale, + computer_scale=computer_scale, + transportation_bool=transportation_bool, + caregiver_bool=caregiver_bool, + housing=housing, + income_source=income_source, + felony_bool=felony_bool, + attending_school=attending_school, + substance_use=substance_use, + time_unemployed=time_unemployed, + need_mental_health_support_bool=need_mental_health_support_bool, + ) try: return query.all() except Exception as e: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error retrieving clients: {str(e)}", - ) + raise HTTPException(status_code=500, detail=f"Error retrieving clients: {str(e)}") + # Filter clients based on service-related fields @staticmethod def get_clients_by_services(db: Session, **service_filters: Optional[bool]): - """ - Get clients filtered by multiple service statuses. - """ query = db.query(Client).join(ClientCase) - - for service_name, statu in service_filters.items(): - if statu is not None: - filter_criteria = getattr(ClientCase, service_name) == status - query = query.filter(filter_criteria) - + for service, status_val in service_filters.items(): + if status_val is not None: + query = query.filter(getattr(ClientCase, service) == status_val) try: return query.all() except Exception as e: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error retrieving clients: {str(e)}", - ) + raise HTTPException(status_code=500, detail=f"Error retrieving clients: {str(e)}") + # Get all services associated with a given client @staticmethod def get_client_services(db: Session, client_id: int): - """Get all services for a specific client with case worker info""" - client_cases = ( - db.query(ClientCase).filter(ClientCase.client_id == client_id).all() - ) - if not client_cases: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"No services found for client with id {client_id}", - ) - return client_cases + services = db.query(ClientCase).filter(ClientCase.client_id == client_id).all() + if not services: + raise HTTPException(status_code=404, detail=f"No services found for client {client_id}") + return services + # Get clients with a minimum success rate @staticmethod def get_clients_by_success_rate(db: Session, min_rate: int = 70): - """Get clients with success rate at or above the specified percentage""" if not (0 <= min_rate <= 100): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Success rate must be between 0 and 100", - ) - - return ( - db.query(Client) - .join(ClientCase) - .filter(ClientCase.success_rate >= min_rate) - .all() - ) + raise HTTPException(status_code=400, detail="Success rate must be between 0 and 100") + return db.query(Client).join(ClientCase).filter(ClientCase.success_rate >= min_rate).all() + # Get all clients assigned to a specific case worker @staticmethod def get_clients_by_case_worker(db: Session, case_worker_id: int): - """Get all clients assigned to a specific case worker""" - case_worker = db.query(User).filter(User.id == case_worker_id).first() - if not case_worker: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Case worker with id {case_worker_id} not found", - ) - - return ( - db.query(Client) - .join(ClientCase) - .filter(ClientCase.user_id == case_worker_id) - .all() - ) + if not db.query(User).filter(User.id == case_worker_id).first(): + raise HTTPException(status_code=404, detail=f"Case worker {case_worker_id} not found") + return db.query(Client).join(ClientCase).filter(ClientCase.user_id == case_worker_id).all() + # Internal helper method to apply dynamic filtering logic @staticmethod - def update_client(db: Session, client_id: int, client_update: ClientUpdate): - """Update a client's information""" + def _apply_criteria_filters(query, **filters): + if filters.get("education_level") and not (1 <= filters["education_level"] <= 14): + raise HTTPException(status_code=400, detail="Invalid education level") + if filters.get("age_min") and filters["age_min"] < 18: + raise HTTPException(status_code=400, detail="Minimum age must be at least 18") + if filters.get("gender") and filters["gender"] not in [1, 2]: + raise HTTPException(status_code=400, detail="Gender must be 1 or 2") + + for key, val in filters.items(): + if val is not None and hasattr(Client, key): + query = query.filter(getattr(Client, key) == val) + return query + + +class ClientMutationService: + # Update client information based on provided fields + @staticmethod + def update_client(db: Session, client_id: int, update_data: ClientUpdate): client = db.query(Client).filter(Client.id == client_id).first() if not client: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Client with id {client_id} not found", - ) - - update_data = client_update.dict(exclude_unset=True) - for field, value in update_data.items(): + raise HTTPException(status_code=404, detail=f"Client {client_id} not found") + for field, value in update_data.dict(exclude_unset=True).items(): setattr(client, field, value) - try: db.commit() db.refresh(client) return client except Exception as e: db.rollback() - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to update client: {str(e)}", - ) + raise HTTPException(status_code=500, detail=f"Failed to update client: {str(e)}") + # Update service info for a specific client-case worker relationship @staticmethod - def update_client_services( - db: Session, client_id: int, user_id: int, service_update: ServiceUpdate - ): - """Update a client's services and outcomes for a specific case worker""" - client_case = ( - db.query(ClientCase) - .filter(ClientCase.client_id == client_id, ClientCase.user_id == user_id) - .first() - ) - - if not client_case: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"No case found for client {client_id} with case worker {user_id}. " - f"Cannot update services for a non-existent case assignment.", - ) - - update_data = service_update.dict(exclude_unset=True) - for field, value in update_data.items(): - setattr(client_case, field, value) - + def update_client_services(db: Session, client_id: int, user_id: int, update_data: ServiceUpdate): + case = db.query(ClientCase).filter(ClientCase.client_id == client_id, ClientCase.user_id == user_id).first() + if not case: + raise HTTPException(status_code=404, detail=f"No case found for client {client_id} and worker {user_id}") + for field, value in update_data.dict(exclude_unset=True).items(): + setattr(case, field, value) try: db.commit() - db.refresh(client_case) - return client_case + db.refresh(case) + return case except Exception as e: db.rollback() - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to update client services: {str(e)}", - ) + raise HTTPException(status_code=500, detail=f"Failed to update client services: {str(e)}") + # Assign a new case worker to a client, with default service values @staticmethod - def create_case_assignment(db: Session, client_id: int, case_worker_id: int): - """Create a new case assignment""" - # Check if client exists - client = db.query(Client).filter(Client.id == client_id).first() - if not client: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Client with id {client_id} not found", - ) - - # Check if case worker exists - case_worker = db.query(User).filter(User.id == case_worker_id).first() - if not case_worker: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Case worker with id {case_worker_id} not found", - ) - - # Check if assignment already exists - existing_case = ( - db.query(ClientCase) - .filter( - ClientCase.client_id == client_id, ClientCase.user_id == case_worker_id - ) - .first() + def create_case_assignment(db: Session, client_id: int, worker_id: int): + if not db.query(Client).filter(Client.id == client_id).first(): + raise HTTPException(status_code=404, detail=f"Client {client_id} not found") + if not db.query(User).filter(User.id == worker_id).first(): + raise HTTPException(status_code=404, detail=f"Case worker {worker_id} not found") + if db.query(ClientCase).filter(ClientCase.client_id == client_id, ClientCase.user_id == worker_id).first(): + raise HTTPException(status_code=400, detail=f"Assignment already exists") + + case = ClientCase( + client_id=client_id, + user_id=worker_id, + employment_assistance=False, + life_stabilization=False, + retention_services=False, + specialized_services=False, + employment_related_financial_supports=False, + employer_financial_supports=False, + enhanced_referrals=False, + success_rate=0 ) - - if existing_case: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Client {client_id} already has a case assigned to case worker {case_worker_id}", - ) - try: - # Create new case assignment with default service values - new_case = ClientCase( - client_id=client_id, - user_id=case_worker_id, - employment_assistance=False, - life_stabilization=False, - retention_services=False, - specialized_services=False, - employment_related_financial_supports=False, - employer_financial_supports=False, - enhanced_referrals=False, - success_rate=0, - ) - db.add(new_case) + db.add(case) db.commit() - db.refresh(new_case) - return new_case - + db.refresh(case) + return case except Exception as e: db.rollback() - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to create case assignment: {str(e)}", - ) + raise HTTPException(status_code=500, detail=f"Failed to create assignment: {str(e)}") + # Delete a client and all related case records @staticmethod def delete_client(db: Session, client_id: int): - """Delete a client and their associated records""" - # First check if client exists client = db.query(Client).filter(Client.id == client_id).first() if not client: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Client with id {client_id} not found", - ) - + raise HTTPException(status_code=404, detail=f"Client {client_id} not found") try: - # Delete associated client_cases db.query(ClientCase).filter(ClientCase.client_id == client_id).delete() - - # Delete the client db.delete(client) db.commit() - except Exception as e: db.rollback() - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to delete client: {str(e)}", - ) + raise HTTPException(status_code=500, detail=f"Failed to delete client: {str(e)}") From cea1b052aa7a6aa0fad7bee77f02d9fec6f2e9bc Mon Sep 17 00:00:00 2001 From: ruoyanj01 Date: Tue, 25 Mar 2025 14:35:10 -0700 Subject: [PATCH 13/23] update code_quality_README.md --- code_quality_README.md | 34 +++++++++++++++++++++++++++++----- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/code_quality_README.md b/code_quality_README.md index c40c2e44..8e8c4a24 100644 --- a/code_quality_README.md +++ b/code_quality_README.md @@ -1,6 +1,7 @@ # Code Quality Analysis ## Setup + 1. Create a virtual environment: `python -m venv venv` @@ -9,19 +10,20 @@ - Mac/Linux: `source venv/bin/activate` ## Installation + Install development tools with: `pip install -r requirements-dev.txt` ## Running Analysis -Run these commands to analyze code quality: -2. `flake8 app/ > flake8_report.txt` -3. `black --check --diff app/ > black_report.txt` -4. `mypy app/ > mypy_report.txt` + +Run these commands to analyze code quality: 2. `flake8 app/ > flake8_report.txt` 3. `black --check --diff app/ > black_report.txt` 4. `mypy app/ > mypy_report.txt` ## Installation of Dependency on Main Code: + `pip install -r ./app/requirements.txt` ## To check the code using pre-commit: + ``` pre-commit autoupdate pre-commit run --all-files @@ -30,17 +32,21 @@ pre-commit run --all-files ## Refactoring for SOLID Principles ### Completed Refactorings: + - Extracted UserRole enum to a separate file (Single Responsibility Principle) + - Created app/enums.py for all enumeration types - Improved code organization and reusability - Extracted validation logic from models (Single Responsibility Principle) + - Created app/validators.py with dedicated validation functions - Defined constants for validation boundaries - Made validation rules more maintainable and consistent - Simplified model classes to focus on structure, not validation - Split database models into domain-specific files (Single Responsibility Principle) + - Created separate files for User, Client, and relationship models - Organized models by domain area in app/models/ directory - Fixed imports to maintain compatibility with existing code @@ -48,12 +54,24 @@ pre-commit run --all-files - Made domain model boundaries clearer - Refactored ML model implementation (Story 2 preparation) + - Created a base model interface following the Interface Segregation Principle - Implemented multiple concrete model classes (Random Forest, Gradient Boosting, Linear Regression) - Added a model registry implementing the Open/Closed Principle for easy model extension - Separated data processing from model logic following Single Responsibility Principle - Created a test script to verify model switching functionality +- Refactored ClientService into two focused service classes: ClientQueryService and ClientMutationService + + - Applied the Single Responsibility Principle (SRP) to separate read and write operations + - Followed Separation of Concerns to isolate query logic from mutation logic + - Removed duplicated filtering code with a shared criteria filter method + +- Refactored the authentication system using modular service classes + - Applied the Single Responsibility Principle by isolating password logic (SecurityService), token logic (TokenService), and user operations (UserService) + - Followed Separation of Concerns to decouple route handlers from business logic + - Prepared the architecture for future enhancements such as role-based access control, refresh tokens, and password policies + ## Testing the ML Models To test the machine learning model implementation and model switching capability: @@ -65,6 +83,7 @@ To test the machine learning model implementation and model switching capability python -m app.ml.test_models 3. The test will: + - Initialize the model registry - Register three different model types - Display available models @@ -75,6 +94,7 @@ To test the machine learning model implementation and model switching capability ## Story 2 Implementation: Multiple ML Model Support ### Completed Features: + - Implemented a flexible ML model architecture with model switching capability - Created a base model interface that all model types implement - Developed three different model implementations: @@ -89,12 +109,14 @@ To test the machine learning model implementation and model switching capability The following endpoints are now available for interacting with ML models: 1. **Get Current Model** + - Endpoint: `GET /models/current` - Description: Returns the name of the currently active model - Response example: `{"name": "RandomForest"}` - Test command: `curl -X GET "http://127.0.0.1:8000/models/current"` 2. **List Available Models** + - Endpoint: `GET /models/available` - Description: Returns all available models and indicates the current model - Response example: `{"models": ["RandomForest", "GradientBoosting", "LinearRegression"], "current_model": "RandomForest"}` @@ -111,12 +133,14 @@ The following endpoints are now available for interacting with ML models: You can test the model switching functionality in two ways: #### 1. Using the Swagger UI: + - Start the application: `uvicorn app.main:app --reload` - Navigate to: `http://127.0.0.1:8000/docs` - Scroll to the "models" section - Try each endpoint by clicking "Try it out" and "Execute" #### 2. Using the ML Test Script: + - Run: `python -m app.ml.test_models` - This will test model registration, switching, and prediction capability @@ -127,4 +151,4 @@ You can test the model switching functionality in two ways: - **Dependency Injection**: High-level components use abstractions rather than concrete implementations - **Open for Extension**: New models can be added without modifying existing code -This implementation satisfies all requirements for Story 2 while following SOLID principles. \ No newline at end of file +This implementation satisfies all requirements for Story 2 while following SOLID principles. From 2ead2293ad20ef0df838addb9e05bdf393c503af Mon Sep 17 00:00:00 2001 From: Helen Wang Date: Tue, 25 Mar 2025 15:24:22 -0700 Subject: [PATCH 14/23] debug login --- initialize_data.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/initialize_data.py b/initialize_data.py index e9e5d29a..6882e1ae 100644 --- a/initialize_data.py +++ b/initialize_data.py @@ -1,7 +1,15 @@ import pandas as pd from app.database import SessionLocal -from app.models import Client, User, ClientCase, UserRole -from app.auth.router import get_password_hash +from app.models import Client, User, ClientCase +from app.enums import UserRole +#from app.auth.router import get_password_hash + +from passlib.context import CryptContext +# Create password context for hashing +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") +# Function to hash passwords +def get_password_hash(password: str) -> str: + return pwd_context.hash(password) def initialize_database(): From 3635e52e5e8cebb1b5acf6bcadd915ec4df8620c Mon Sep 17 00:00:00 2001 From: johnsonli010801 <97709195+johnsonli010801@users.noreply.github.com> Date: Wed, 26 Mar 2025 14:49:24 -0700 Subject: [PATCH 15/23] fix style --- .flake8 | 1 + app/auth/router.py | 17 +++-- app/clients/router.py | 4 +- app/clients/service/client_service.py | 89 ++++++++++++++++++++------- app/enums.py | 4 +- app/main.py | 8 +-- app/ml/__init__.py | 10 +-- app/ml/base_model.py | 13 ++-- app/ml/data_processor.py | 68 +++++++++++++------- app/ml/model_registry.py | 24 ++++---- app/ml/models/__init__.py | 2 +- app/ml/models/gradient_boost.py | 13 ++-- app/ml/models/linear_regression.py | 11 ++-- app/ml/models/random_forest.py | 14 ++--- app/ml/router.py | 15 +++-- app/ml/test_models.py | 23 +++---- app/models/__init__.py | 2 +- app/models/client.py | 3 +- app/models/relationships.py | 5 +- app/models/user.py | 5 +- app/validators.py | 21 ++++--- initialize_data.py | 6 +- mypy.ini | 5 ++ tests/conftest.py | 2 + 24 files changed, 241 insertions(+), 124 deletions(-) diff --git a/.flake8 b/.flake8 index 041c5f07..1fdc63a8 100644 --- a/.flake8 +++ b/.flake8 @@ -1,3 +1,4 @@ [flake8] max-line-length = 120 exclude = .git,__pycache__,venv,tests/* +ignore = F541 diff --git a/app/auth/router.py b/app/auth/router.py index d926376f..b08dae1f 100644 --- a/app/auth/router.py +++ b/app/auth/router.py @@ -28,6 +28,7 @@ def validate_role(cls, v): raise ValueError("Role must be either admin or case_worker") return v + class UserResponse(BaseModel): username: str email: str @@ -47,6 +48,7 @@ def verify_password(self, plain_password: str, hashed_password: str) -> bool: def get_password_hash(self, password: str) -> str: return self.pwd_context.hash(password) + security = SecurityService() @@ -56,7 +58,9 @@ class TokenService: ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES = 30 - def create_access_token(self, data: dict, expires_delta: Optional[timedelta] = None): + def create_access_token( + self, data: dict, expires_delta: Optional[timedelta] = None + ): to_encode = data.copy() expire = datetime.utcnow() + (expires_delta or timedelta(minutes=15)) to_encode.update({"exp": expire}) @@ -68,12 +72,15 @@ def decode_token(self, token: str): except JWTError: return None + token_service = TokenService() # User Service class UserService: - def authenticate_user(self, db: Session, username: str, password: str) -> Optional[User]: + def authenticate_user( + self, db: Session, username: str, password: str + ) -> Optional[User]: user = db.query(User).filter(User.username == username).first() if not user or not security.verify_password(password, user.hashed_password): return None @@ -111,12 +118,14 @@ def create_user(self, db: Session, user_data: UserCreate) -> User: detail=str(e), ) + user_service = UserService() # Auth Dependencies oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/token") + async def get_current_user( token: str = Depends(oauth2_scheme), db: Session = Depends(get_db) ) -> User: @@ -135,6 +144,7 @@ async def get_current_user( raise credentials_exception return user + def get_admin_user(current_user: User = Depends(get_current_user)): if current_user.role != UserRole.admin: raise HTTPException( @@ -147,8 +157,7 @@ def get_admin_user(current_user: User = Depends(get_current_user)): # Routes @router.post("/token") async def login_for_access_token( - form_data: OAuth2PasswordRequestForm = Depends(), - db: Session = Depends(get_db) + form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db) ): user = user_service.authenticate_user(db, form_data.username, form_data.password) if not user: diff --git a/app/clients/router.py b/app/clients/router.py index 88c9487c..4444a9ef 100644 --- a/app/clients/router.py +++ b/app/clients/router.py @@ -191,7 +191,9 @@ async def update_client_services( current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ): - return ClientMutationService.update_client_services(db, client_id, user_id, service_update) + return ClientMutationService.update_client_services( + db, client_id, user_id, service_update + ) @router.post("/{client_id}/case-assignment", response_model=ServiceResponse) diff --git a/app/clients/service/client_service.py b/app/clients/service/client_service.py index 63212faa..93922f82 100644 --- a/app/clients/service/client_service.py +++ b/app/clients/service/client_service.py @@ -1,4 +1,4 @@ -from fastapi import HTTPException, status +from fastapi import HTTPException from sqlalchemy.orm import Session from typing import Optional from app.models import Client, ClientCase, User @@ -21,7 +21,7 @@ def get_clients(db: Session, skip: int = 0, limit: int = 50): raise HTTPException(status_code=400, detail="Invalid pagination parameters") return { "clients": db.query(Client).offset(skip).limit(limit).all(), - "total": db.query(Client).count() + "total": db.query(Client).count(), } # Retrieve clients that match various optional criteria filters @@ -85,7 +85,9 @@ def get_clients_by_criteria( try: return query.all() except Exception as e: - raise HTTPException(status_code=500, detail=f"Error retrieving clients: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Error retrieving clients: {str(e)}" + ) # Filter clients based on service-related fields @staticmethod @@ -97,37 +99,59 @@ def get_clients_by_services(db: Session, **service_filters: Optional[bool]): try: return query.all() except Exception as e: - raise HTTPException(status_code=500, detail=f"Error retrieving clients: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Error retrieving clients: {str(e)}" + ) # Get all services associated with a given client @staticmethod def get_client_services(db: Session, client_id: int): services = db.query(ClientCase).filter(ClientCase.client_id == client_id).all() if not services: - raise HTTPException(status_code=404, detail=f"No services found for client {client_id}") + raise HTTPException( + status_code=404, detail=f"No services found for client {client_id}" + ) return services # Get clients with a minimum success rate @staticmethod def get_clients_by_success_rate(db: Session, min_rate: int = 70): if not (0 <= min_rate <= 100): - raise HTTPException(status_code=400, detail="Success rate must be between 0 and 100") - return db.query(Client).join(ClientCase).filter(ClientCase.success_rate >= min_rate).all() + raise HTTPException( + status_code=400, detail="Success rate must be between 0 and 100" + ) + return ( + db.query(Client) + .join(ClientCase) + .filter(ClientCase.success_rate >= min_rate) + .all() + ) # Get all clients assigned to a specific case worker @staticmethod def get_clients_by_case_worker(db: Session, case_worker_id: int): if not db.query(User).filter(User.id == case_worker_id).first(): - raise HTTPException(status_code=404, detail=f"Case worker {case_worker_id} not found") - return db.query(Client).join(ClientCase).filter(ClientCase.user_id == case_worker_id).all() + raise HTTPException( + status_code=404, detail=f"Case worker {case_worker_id} not found" + ) + return ( + db.query(Client) + .join(ClientCase) + .filter(ClientCase.user_id == case_worker_id) + .all() + ) # Internal helper method to apply dynamic filtering logic @staticmethod def _apply_criteria_filters(query, **filters): - if filters.get("education_level") and not (1 <= filters["education_level"] <= 14): + if filters.get("education_level") and not ( + 1 <= filters["education_level"] <= 14 + ): raise HTTPException(status_code=400, detail="Invalid education level") if filters.get("age_min") and filters["age_min"] < 18: - raise HTTPException(status_code=400, detail="Minimum age must be at least 18") + raise HTTPException( + status_code=400, detail="Minimum age must be at least 18" + ) if filters.get("gender") and filters["gender"] not in [1, 2]: raise HTTPException(status_code=400, detail="Gender must be 1 or 2") @@ -152,14 +176,25 @@ def update_client(db: Session, client_id: int, update_data: ClientUpdate): return client except Exception as e: db.rollback() - raise HTTPException(status_code=500, detail=f"Failed to update client: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Failed to update client: {str(e)}" + ) # Update service info for a specific client-case worker relationship @staticmethod - def update_client_services(db: Session, client_id: int, user_id: int, update_data: ServiceUpdate): - case = db.query(ClientCase).filter(ClientCase.client_id == client_id, ClientCase.user_id == user_id).first() + def update_client_services( + db: Session, client_id: int, user_id: int, update_data: ServiceUpdate + ): + case = ( + db.query(ClientCase) + .filter(ClientCase.client_id == client_id, ClientCase.user_id == user_id) + .first() + ) if not case: - raise HTTPException(status_code=404, detail=f"No case found for client {client_id} and worker {user_id}") + raise HTTPException( + status_code=404, + detail=f"No case found for client {client_id} and worker {user_id}", + ) for field, value in update_data.dict(exclude_unset=True).items(): setattr(case, field, value) try: @@ -168,7 +203,9 @@ def update_client_services(db: Session, client_id: int, user_id: int, update_dat return case except Exception as e: db.rollback() - raise HTTPException(status_code=500, detail=f"Failed to update client services: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Failed to update client services: {str(e)}" + ) # Assign a new case worker to a client, with default service values @staticmethod @@ -176,8 +213,14 @@ def create_case_assignment(db: Session, client_id: int, worker_id: int): if not db.query(Client).filter(Client.id == client_id).first(): raise HTTPException(status_code=404, detail=f"Client {client_id} not found") if not db.query(User).filter(User.id == worker_id).first(): - raise HTTPException(status_code=404, detail=f"Case worker {worker_id} not found") - if db.query(ClientCase).filter(ClientCase.client_id == client_id, ClientCase.user_id == worker_id).first(): + raise HTTPException( + status_code=404, detail=f"Case worker {worker_id} not found" + ) + if ( + db.query(ClientCase) + .filter(ClientCase.client_id == client_id, ClientCase.user_id == worker_id) + .first() + ): raise HTTPException(status_code=400, detail=f"Assignment already exists") case = ClientCase( @@ -190,7 +233,7 @@ def create_case_assignment(db: Session, client_id: int, worker_id: int): employment_related_financial_supports=False, employer_financial_supports=False, enhanced_referrals=False, - success_rate=0 + success_rate=0, ) try: db.add(case) @@ -199,7 +242,9 @@ def create_case_assignment(db: Session, client_id: int, worker_id: int): return case except Exception as e: db.rollback() - raise HTTPException(status_code=500, detail=f"Failed to create assignment: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Failed to create assignment: {str(e)}" + ) # Delete a client and all related case records @staticmethod @@ -213,4 +258,6 @@ def delete_client(db: Session, client_id: int): db.commit() except Exception as e: db.rollback() - raise HTTPException(status_code=500, detail=f"Failed to delete client: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Failed to delete client: {str(e)}" + ) diff --git a/app/enums.py b/app/enums.py index 8efc6981..93c00f1a 100644 --- a/app/enums.py +++ b/app/enums.py @@ -5,9 +5,11 @@ import enum + class UserRole(str, enum.Enum): """ User Role class defining possible user roles in the system. """ + admin = "admin" - case_worker = "case_worker" \ No newline at end of file + case_worker = "case_worker" diff --git a/app/main.py b/app/main.py index 667c85bd..ab0f7290 100644 --- a/app/main.py +++ b/app/main.py @@ -9,14 +9,12 @@ from fastapi.middleware.cors import CORSMiddleware # Local application/library specific imports -from app import models -import app.ml # This will execute the initialization code -from app.database import engine,Base #Add Base here -from app.ml.router import router as models_router # newly added +from app.database import engine, Base # Add Base here +from app.ml.router import router as models_router # newly added from app.clients.router import router as clients_router from app.auth.router import router as auth_router -from app.database import Base + # Initialize database tables Base.metadata.create_all(bind=engine) diff --git a/app/ml/__init__.py b/app/ml/__init__.py index 8957555e..b5ffa4ee 100644 --- a/app/ml/__init__.py +++ b/app/ml/__init__.py @@ -7,20 +7,22 @@ from app.ml.models.gradient_boost import GradientBoostingModel from app.ml.models.linear_regression import LinearRegressionModel + # Initialize the model registry with available models def initialize_models(): """Register all available models with the registry.""" registry = ModelRegistry() - + # Register the models registry.register_model("RandomForest", RandomForestModel()) registry.register_model("GradientBoosting", GradientBoostingModel()) registry.register_model("LinearRegression", LinearRegressionModel()) - + # Set the default model registry.set_current_model("RandomForest") - + return registry + # Initialize models when importing this package -registry = initialize_models() \ No newline at end of file +registry = initialize_models() diff --git a/app/ml/base_model.py b/app/ml/base_model.py index d91cda9f..2994136b 100644 --- a/app/ml/base_model.py +++ b/app/ml/base_model.py @@ -6,31 +6,32 @@ from abc import ABC, abstractmethod import pickle + class BaseModel(ABC): """Abstract base class for all ML models.""" - + @abstractmethod def train(self, features, targets): """Train the model with the given features and targets.""" pass - + @abstractmethod def predict(self, features): """Make predictions using the trained model.""" pass - + @abstractmethod def get_name(self): """Return the name of the model.""" pass - + def save(self, filename): """Save the model to a file.""" with open(filename, "wb") as model_file: pickle.dump(self, model_file) - + @classmethod def load(cls, filename): """Load the model from a file.""" with open(filename, "rb") as model_file: - return pickle.load(model_file) \ No newline at end of file + return pickle.load(model_file) diff --git a/app/ml/data_processor.py b/app/ml/data_processor.py index d4946eed..30a72072 100644 --- a/app/ml/data_processor.py +++ b/app/ml/data_processor.py @@ -7,54 +7,78 @@ import numpy as np from sklearn.model_selection import train_test_split + class DataProcessor: """Handles data loading and preprocessing for ML models.""" - + def __init__(self, data_file="data_commontool.csv"): """Initialize with the data file path.""" self.data_file = data_file self.feature_columns = [ - "age", "gender", "work_experience", "canada_workex", "dep_num", - "canada_born", "citizen_status", "level_of_schooling", - "fluent_english", "reading_english_scale", "speaking_english_scale", - "writing_english_scale", "numeracy_scale", "computer_scale", - "transportation_bool", "caregiver_bool", "housing", "income_source", - "felony_bool", "attending_school", "currently_employed", - "substance_use", "time_unemployed", "need_mental_health_support_bool" + "age", + "gender", + "work_experience", + "canada_workex", + "dep_num", + "canada_born", + "citizen_status", + "level_of_schooling", + "fluent_english", + "reading_english_scale", + "speaking_english_scale", + "writing_english_scale", + "numeracy_scale", + "computer_scale", + "transportation_bool", + "caregiver_bool", + "housing", + "income_source", + "felony_bool", + "attending_school", + "currently_employed", + "substance_use", + "time_unemployed", + "need_mental_health_support_bool", ] self.intervention_columns = [ - "employment_assistance", "life_stabilization", "retention_services", - "specialized_services", "employment_related_financial_supports", - "employer_financial_supports", "enhanced_referrals" + "employment_assistance", + "life_stabilization", + "retention_services", + "specialized_services", + "employment_related_financial_supports", + "employer_financial_supports", + "enhanced_referrals", ] - + def load_data(self): """Load the dataset from file.""" return pd.read_csv(self.data_file) - + def prepare_training_data(self): """Prepare features and targets for model training.""" data = self.load_data() all_features = self.feature_columns + self.intervention_columns - + features = np.array(data[all_features]) targets = np.array(data["success_rate"]) - + features_train, features_test, targets_train, targets_test = train_test_split( features, targets, test_size=0.2, random_state=42 ) - + return features_train, features_test, targets_train, targets_test - + def prepare_prediction_data(self, client_data, interventions): """Prepare a single client's data for prediction.""" # Extract client features from the client data client_features = [client_data.get(col, 0) for col in self.feature_columns] - + # Add intervention features - intervention_features = [interventions.get(col, 0) for col in self.intervention_columns] - + intervention_features = [ + interventions.get(col, 0) for col in self.intervention_columns + ] + # Combine all features all_features = client_features + intervention_features - - return np.array([all_features]) \ No newline at end of file + + return np.array([all_features]) diff --git a/app/ml/model_registry.py b/app/ml/model_registry.py index f0605d83..7e91b164 100644 --- a/app/ml/model_registry.py +++ b/app/ml/model_registry.py @@ -2,45 +2,47 @@ """ Model registry for managing different ML models. """ + + class ModelRegistry: """Registry for managing multiple machine learning models.""" - + _instance = None - + def __new__(cls): if cls._instance is None: cls._instance = super(ModelRegistry, cls).__new__(cls) cls._instance._models = {} cls._instance._current_model_name = None return cls._instance - + def register_model(self, model_name, model_class): """Register a new model with the registry.""" self._models[model_name] = model_class if self._current_model_name is None: self._current_model_name = model_name - + def get_model(self, model_name=None): """Get a model by name, or the current model if no name is provided.""" if model_name is None: model_name = self._current_model_name - + if model_name not in self._models: raise ValueError(f"Model {model_name} is not registered") - + return self._models[model_name] - + def set_current_model(self, model_name): """Set the current active model.""" if model_name not in self._models: raise ValueError(f"Model {model_name} is not registered") - + self._current_model_name = model_name - + def get_current_model_name(self): """Get the name of the currently active model.""" return self._current_model_name - + def list_available_models(self): """List all available models in the registry.""" - return list(self._models.keys()) \ No newline at end of file + return list(self._models.keys()) diff --git a/app/ml/models/__init__.py b/app/ml/models/__init__.py index 40c26e49..4b1fa474 100644 --- a/app/ml/models/__init__.py +++ b/app/ml/models/__init__.py @@ -6,4 +6,4 @@ from app.ml.models.gradient_boost import GradientBoostingModel from app.ml.models.linear_regression import LinearRegressionModel -__all__ = ["RandomForestModel", "GradientBoostingModel", "LinearRegressionModel"] \ No newline at end of file +__all__ = ["RandomForestModel", "GradientBoostingModel", "LinearRegressionModel"] diff --git a/app/ml/models/gradient_boost.py b/app/ml/models/gradient_boost.py index 4563646d..191dcc6e 100644 --- a/app/ml/models/gradient_boost.py +++ b/app/ml/models/gradient_boost.py @@ -5,30 +5,31 @@ from sklearn.ensemble import GradientBoostingRegressor from app.ml.base_model import BaseModel + class GradientBoostingModel(BaseModel): """Gradient Boosting model for predicting success rates.""" - + def __init__(self, n_estimators=100, learning_rate=0.1, random_state=42): """Initialize the model with parameters.""" self.model = GradientBoostingRegressor( n_estimators=n_estimators, learning_rate=learning_rate, - random_state=random_state + random_state=random_state, ) self.is_trained = False - + def train(self, features, targets): """Train the model with the given features and targets.""" self.model.fit(features, targets) self.is_trained = True return self - + def predict(self, features): """Make predictions using the trained model.""" if not self.is_trained: raise ValueError("Model must be trained before making predictions") return self.model.predict(features) - + def get_name(self): """Return the name of the model.""" - return "GradientBoosting" \ No newline at end of file + return "GradientBoosting" diff --git a/app/ml/models/linear_regression.py b/app/ml/models/linear_regression.py index 7ce356ad..67799637 100644 --- a/app/ml/models/linear_regression.py +++ b/app/ml/models/linear_regression.py @@ -5,26 +5,27 @@ from sklearn.linear_model import LinearRegression from app.ml.base_model import BaseModel + class LinearRegressionModel(BaseModel): """Linear Regression model for predicting success rates.""" - + def __init__(self): """Initialize the model.""" self.model = LinearRegression() self.is_trained = False - + def train(self, features, targets): """Train the model with the given features and targets.""" self.model.fit(features, targets) self.is_trained = True return self - + def predict(self, features): """Make predictions using the trained model.""" if not self.is_trained: raise ValueError("Model must be trained before making predictions") return self.model.predict(features) - + def get_name(self): """Return the name of the model.""" - return "LinearRegression" \ No newline at end of file + return "LinearRegression" diff --git a/app/ml/models/random_forest.py b/app/ml/models/random_forest.py index fa6836d7..d2c7fe74 100644 --- a/app/ml/models/random_forest.py +++ b/app/ml/models/random_forest.py @@ -5,29 +5,29 @@ from sklearn.ensemble import RandomForestRegressor from app.ml.base_model import BaseModel + class RandomForestModel(BaseModel): """Random Forest model for predicting success rates.""" - + def __init__(self, n_estimators=100, random_state=42): """Initialize the model with parameters.""" self.model = RandomForestRegressor( - n_estimators=n_estimators, - random_state=random_state + n_estimators=n_estimators, random_state=random_state ) self.is_trained = False - + def train(self, features, targets): """Train the model with the given features and targets.""" self.model.fit(features, targets) self.is_trained = True return self - + def predict(self, features): """Make predictions using the trained model.""" if not self.is_trained: raise ValueError("Model must be trained before making predictions") return self.model.predict(features) - + def get_name(self): """Return the name of the model.""" - return "RandomForest" \ No newline at end of file + return "RandomForest" diff --git a/app/ml/router.py b/app/ml/router.py index 0a96d331..af67b6f2 100644 --- a/app/ml/router.py +++ b/app/ml/router.py @@ -13,15 +13,20 @@ responses={404: {"description": "Not found"}}, ) + class ModelResponse(BaseModel): """Response model for model endpoints.""" + name: str + class ModelsListResponse(BaseModel): """Response model for listing available models.""" + models: list[str] current_model: str + @router.get("/current", response_model=ModelResponse) async def get_current_model(): """Get the currently active model.""" @@ -29,6 +34,7 @@ async def get_current_model(): current_model = registry.get_current_model_name() return {"name": current_model} + @router.get("/available", response_model=ModelsListResponse) async def list_available_models(): """List all available models.""" @@ -37,17 +43,18 @@ async def list_available_models(): current_model = registry.get_current_model_name() return {"models": available_models, "current_model": current_model} + @router.post("/switch/{model_name}", response_model=ModelResponse) async def switch_model(model_name: str): """Switch to a different model.""" registry = ModelRegistry() - + try: registry.set_current_model(model_name) except ValueError: raise HTTPException( status_code=404, - detail=f"Model '{model_name}' not found. Available models: {registry.list_available_models()}" + detail=f"Model '{model_name}' not found. Available models: {registry.list_available_models()}", ) - - return {"name": model_name} \ No newline at end of file + + return {"name": model_name} diff --git a/app/ml/test_models.py b/app/ml/test_models.py index c381f7e4..9c7f1396 100644 --- a/app/ml/test_models.py +++ b/app/ml/test_models.py @@ -3,54 +3,55 @@ Test script for ML models and model registry. """ import numpy as np -from app.ml.data_processor import DataProcessor from app.ml.model_registry import ModelRegistry from app.ml.models.random_forest import RandomForestModel from app.ml.models.gradient_boost import GradientBoostingModel from app.ml.models.linear_regression import LinearRegressionModel + def test_model_registry(): """Test the model registry and model switching.""" print("Testing model registry and model switching...") - + # Create the registry registry = ModelRegistry() - + # Register models registry.register_model("RandomForest", RandomForestModel()) registry.register_model("GradientBoosting", GradientBoostingModel()) registry.register_model("LinearRegression", LinearRegressionModel()) - + # List available models models = registry.list_available_models() print(f"Available models: {models}") - + # Get current model current_model_name = registry.get_current_model_name() print(f"Current model: {current_model_name}") - + # Switch models registry.set_current_model("GradientBoosting") print(f"Switched to model: {registry.get_current_model_name()}") - + # Create some dummy data for testing features = np.random.rand(10, 31) # 24 client features + 7 intervention features targets = np.random.rand(10) - + # Train all models for model_name in models: print(f"Training {model_name}...") model = registry.get_model(model_name) model.train(features, targets) - + # Make predictions with each model for model_name in models: print(f"Testing predictions with {model_name}...") model = registry.get_model(model_name) predictions = model.predict(features[:1]) print(f" Prediction: {predictions[0]}") - + print("Model testing complete!") + if __name__ == "__main__": - test_model_registry() \ No newline at end of file + test_model_registry() diff --git a/app/models/__init__.py b/app/models/__init__.py index c3261f82..e73a157c 100644 --- a/app/models/__init__.py +++ b/app/models/__init__.py @@ -7,4 +7,4 @@ from app.models.client import Client from app.models.relationships import ClientCase -__all__ = ["User", "Client", "ClientCase"] \ No newline at end of file +__all__ = ["User", "Client", "ClientCase"] diff --git a/app/models/client.py b/app/models/client.py index 91e4b155..53384858 100644 --- a/app/models/client.py +++ b/app/models/client.py @@ -18,6 +18,7 @@ income_source_constraint, ) + class Client(Base): """ Represents a Client in the database. @@ -70,4 +71,4 @@ class Client(Base): housing_constraint(), income_source_constraint(), experience_constraint("time_unemployed"), - ) \ No newline at end of file + ) diff --git a/app/models/relationships.py b/app/models/relationships.py index 675aa590..477b3300 100644 --- a/app/models/relationships.py +++ b/app/models/relationships.py @@ -10,6 +10,7 @@ from app.database import Base from app.validators import success_rate_constraint + class ClientCase(Base): """ Represents the relationship between a client and a case worker. @@ -34,6 +35,4 @@ class ClientCase(Base): user = relationship("User", back_populates="cases") # Apply constraints - __table_args__ = ( - success_rate_constraint(), - ) \ No newline at end of file + __table_args__ = (success_rate_constraint(),) diff --git a/app/models/user.py b/app/models/user.py index b92a7dcd..a4febdcf 100644 --- a/app/models/user.py +++ b/app/models/user.py @@ -11,6 +11,7 @@ from app.enums import UserRole from app.validators import username_length_constraint, email_format_constraint + class User(Base): """ Represents a User in the database. @@ -26,9 +27,9 @@ class User(Base): role = Column(Enum(UserRole), nullable=False) cases = relationship("ClientCase", back_populates="user") - + # Apply User-specific constraints __table_args__ = ( username_length_constraint(), email_format_constraint(), - ) \ No newline at end of file + ) diff --git a/app/validators.py b/app/validators.py index 4d331dd4..bce8c4b7 100644 --- a/app/validators.py +++ b/app/validators.py @@ -19,19 +19,23 @@ MIN_SUCCESS_RATE = 0 MAX_SUCCESS_RATE = 100 + # Constraint creation functions def age_constraint(): """Create a constraint to ensure age is at least 18.""" return CheckConstraint(f"age >= {MIN_AGE}") + def gender_constraint(): """Create a constraint to ensure gender is a valid value.""" return CheckConstraint("gender = 1 OR gender = 2") + def experience_constraint(field_name): """Create a constraint to ensure experience is not negative.""" return CheckConstraint(f"{field_name} >= 0") + def school_level_constraint(): """Create a constraint to ensure schooling level is within valid range.""" return CheckConstraint( @@ -39,20 +43,21 @@ def school_level_constraint(): f"level_of_schooling <= {MAX_SCHOOLING_LEVEL}" ) + def scale_constraint(field_name): """Create a constraint to ensure scale values are within 0-10 range.""" return CheckConstraint( - f"{field_name} >= {MIN_SCALE_VALUE} AND " - f"{field_name} <= {MAX_SCALE_VALUE}" + f"{field_name} >= {MIN_SCALE_VALUE} AND " f"{field_name} <= {MAX_SCALE_VALUE}" ) + def housing_constraint(): """Create a constraint to ensure housing value is valid.""" return CheckConstraint( - f"housing >= {MIN_HOUSING_LEVEL} AND " - f"housing <= {MAX_HOUSING_LEVEL}" + f"housing >= {MIN_HOUSING_LEVEL} AND " f"housing <= {MAX_HOUSING_LEVEL}" ) + def income_source_constraint(): """Create a constraint to ensure income source value is valid.""" return CheckConstraint( @@ -60,18 +65,20 @@ def income_source_constraint(): f"income_source <= {MAX_INCOME_SOURCE}" ) + def success_rate_constraint(): """Create a constraint to ensure success rate is within valid range.""" return CheckConstraint( - f"success_rate >= {MIN_SUCCESS_RATE} AND " - f"success_rate <= {MAX_SUCCESS_RATE}" + f"success_rate >= {MIN_SUCCESS_RATE} AND " f"success_rate <= {MAX_SUCCESS_RATE}" ) + def username_length_constraint(): """Create a constraint to ensure username has a minimum length.""" MIN_USERNAME_LENGTH = 3 return CheckConstraint(f"LENGTH(username) >= {MIN_USERNAME_LENGTH}") + def email_format_constraint(): """Create a constraint to ensure email contains @ symbol.""" - return CheckConstraint("email LIKE '%@%'") \ No newline at end of file + return CheckConstraint("email LIKE '%@%'") diff --git a/initialize_data.py b/initialize_data.py index 6882e1ae..ed476cb1 100644 --- a/initialize_data.py +++ b/initialize_data.py @@ -2,11 +2,15 @@ from app.database import SessionLocal from app.models import Client, User, ClientCase from app.enums import UserRole -#from app.auth.router import get_password_hash + +# from app.auth.router import get_password_hash from passlib.context import CryptContext + # Create password context for hashing pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + + # Function to hash passwords def get_password_hash(password: str) -> str: return pwd_context.hash(password) diff --git a/mypy.ini b/mypy.ini index 976ba029..4e84a0cc 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,2 +1,7 @@ [mypy] +exclude = tests/ ignore_missing_imports = True +explicit_package_bases = True +namespace_packages = True +[mypy-tests.conftest] +ignore_errors = True \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index b6d68586..7f2ef959 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,6 +7,7 @@ from app.auth.router import get_password_hash from app.models import User, UserRole, Client, ClientCase + # Create test database SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db" engine = create_engine( @@ -15,6 +16,7 @@ TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) +# type: ignore @pytest.fixture def test_db(): # Create tables From 1c40d852d2b818b94d792a3b1769511b7cd8394a Mon Sep 17 00:00:00 2001 From: johnsonli010801 <97709195+johnsonli010801@users.noreply.github.com> Date: Wed, 26 Mar 2025 17:06:23 -0700 Subject: [PATCH 16/23] fix comment --- .pre-commit-config.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 49b66366..dfb1f748 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,17 +1,17 @@ repos: - repo: https://github.com/psf/black - rev: 25.1.0 # Use the latest stable version; check the GitHub releases for the latest version + rev: 25.1.0 hooks: - id: black - language_version: python3.12 # Ensure this matches your project's Python version + language_version: python3.12 - repo: https://github.com/PyCQA/flake8 - rev: '7.1.2' # Use the latest stable release; check the GitHub releases for the latest version + rev: '7.1.2' hooks: - id: flake8 - repo: https://github.com/pre-commit/mirrors-mypy - rev: 'v1.15.0' # Use the latest stable release; check the GitHub releases for the latest version + rev: 'v1.15.0' hooks: - id: mypy From d118d44525250baefc163a3f0c054361de6b1a68 Mon Sep 17 00:00:00 2001 From: Helen Wang Date: Wed, 2 Apr 2025 18:45:09 -0700 Subject: [PATCH 17/23] add prediction --- .pre-commit-config.yaml | 2 +- app/ml/router.py | 39 +++++++++++++++++++++++++++++++++++++++ code_quality_README.md | 23 +++++++++++++++++++++++ 3 files changed, 63 insertions(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index dfb1f748..8a5defe5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,7 +6,7 @@ repos: language_version: python3.12 - repo: https://github.com/PyCQA/flake8 - rev: '7.1.2' + rev: '7.2.0' hooks: - id: flake8 diff --git a/app/ml/router.py b/app/ml/router.py index af67b6f2..5c4a8e73 100644 --- a/app/ml/router.py +++ b/app/ml/router.py @@ -4,6 +4,7 @@ """ from fastapi import APIRouter, HTTPException from pydantic import BaseModel +import numpy as np from app.ml.model_registry import ModelRegistry @@ -58,3 +59,41 @@ async def switch_model(model_name: str): ) return {"name": model_name} + +@router.get("/test-prediction", response_model=dict) +async def test_prediction(): + """ + Test endpoint to demonstrate how predictions change when switching models. + Returns a prediction using the currently selected model. + """ + try: + # Create consistent test data + test_features = np.ones((1, 31)) * 0.5 # Use the same test data each time + + # Get the current model from registry + registry = ModelRegistry() + current_model = registry.get_model() + model_name = registry.get_current_model_name() + + # Train the model with some dummy data if it's not trained + if not hasattr(current_model, 'is_trained') or not current_model.is_trained: + # Create dummy training data + X_train = np.random.rand(10, 31) + y_train = np.random.rand(10) + current_model.train(X_train, y_train) + + # Make a prediction + prediction = float(current_model.predict(test_features)[0]) + + return { + "model": model_name, + "prediction": prediction, + "status": "success" + } + except Exception as e: + # Return error information instead of letting it crash + return { + "model": registry.get_current_model_name(), + "error": str(e), + "status": "error" + } \ No newline at end of file diff --git a/code_quality_README.md b/code_quality_README.md index 8e8c4a24..715870b1 100644 --- a/code_quality_README.md +++ b/code_quality_README.md @@ -152,3 +152,26 @@ You can test the model switching functionality in two ways: - **Open for Extension**: New models can be added without modifying existing code This implementation satisfies all requirements for Story 2 while following SOLID principles. + +### predictions +To demonstrate the effect of model switching on predictions, a test endpoint has been added: + +- **Test Prediction Endpoint**: + - Path: `GET /models/test-prediction` + - Purpose: Shows how predictions change when switching between models + - Response: Returns the current model name and its prediction for test data + +**How to demonstrate model switching:** + +1. Call `GET /models/test-prediction` with the default model + - Note the prediction value (e.g., `{"model": "RandomForest", "prediction": 0.45, "status": "success"}`) + +2. Switch to a different model using `POST /models/switch/GradientBoosting` + - Response confirms the switch: `{"name": "GradientBoosting"}` + +3. Call `GET /models/test-prediction` again + - Note the different prediction: `{"model": "GradientBoosting", "prediction": 0.72, "status": "success"}` + +4. Switch to the third model using `POST /models/switch/LinearRegression` + - Call the test endpoint again to see a third different prediction + From 0dedfe11702c8733678c861b91fdab90288e1390 Mon Sep 17 00:00:00 2001 From: Helen Wang Date: Thu, 10 Apr 2025 10:49:22 -0700 Subject: [PATCH 18/23] test complete CI pipeline with linting --- app/main.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/app/main.py b/app/main.py index ab0f7290..94c624c8 100644 --- a/app/main.py +++ b/app/main.py @@ -4,6 +4,8 @@ Handles database initialization and CORS middleware configuration. """ +# Add an import that isn't used +import datetime # Related third-party imports from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware From 13e8ff7e9d3ab93ddab066228fe143b1fc4da0e9 Mon Sep 17 00:00:00 2001 From: Helen Wang Date: Thu, 10 Apr 2025 15:51:43 -0700 Subject: [PATCH 19/23] explicitly trigger CI --- app/main.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/app/main.py b/app/main.py index 94c624c8..bdff9dfc 100644 --- a/app/main.py +++ b/app/main.py @@ -31,6 +31,8 @@ app.include_router(models_router) app.include_router(auth_router) app.include_router(clients_router) +# Reference a variable that doesn't exist +app.include_router(undefined_router) # Configure CORS middleware app.add_middleware( From a1f5d5f40b445b85cf13dc9df38983d50781fa61 Mon Sep 17 00:00:00 2001 From: Helen Wang Date: Thu, 10 Apr 2025 16:00:17 -0700 Subject: [PATCH 20/23] add test workflow file --- .github/workflows/test.yml | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 .github/workflows/test.yml diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 00000000..f9bf6df4 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,8 @@ +name: Test Workflow +on: [push] +jobs: + test: + runs-on: ubuntu-latest + steps: + - name: Echo Test + run: echo "This is a test workflow" \ No newline at end of file From 43ea86f51b9ff8c2f6303b21379b93183f7573ab Mon Sep 17 00:00:00 2001 From: Helen Wang Date: Thu, 10 Apr 2025 16:16:06 -0700 Subject: [PATCH 21/23] update test workflow with Python and flake8 --- .github/workflows/test.yml | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f9bf6df4..efaecffb 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,8 +1,17 @@ -name: Test Workflow -on: [push] +name: Test CI Workflow +on: [push, pull_request] jobs: test: runs-on: ubuntu-latest steps: - - name: Echo Test - run: echo "This is a test workflow" \ No newline at end of file + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install flake8 + - name: Run Flake8 + run: flake8 --version \ No newline at end of file From a5e2800fd4be82a09c17552c4222a827cee3ab79 Mon Sep 17 00:00:00 2001 From: Helen Wang Date: Thu, 10 Apr 2025 16:24:50 -0700 Subject: [PATCH 22/23] Add file with obvious linting error --- test_error.py | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 test_error.py diff --git a/test_error.py b/test_error.py new file mode 100644 index 00000000..485c0e2e --- /dev/null +++ b/test_error.py @@ -0,0 +1,4 @@ +def function_with_obvious_error(): + x = 10 + y = 20 # Indentation error that should be caught by flake8 + return x + y From dfc0345f26b82bda89c31e4394d91c9232cbcbe7 Mon Sep 17 00:00:00 2001 From: Helen Wang Date: Thu, 10 Apr 2025 16:33:43 -0700 Subject: [PATCH 23/23] update test.yml to run flake 8 --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index efaecffb..0e5a9210 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -14,4 +14,4 @@ jobs: python -m pip install --upgrade pip pip install flake8 - name: Run Flake8 - run: flake8 --version \ No newline at end of file + run: flake8 . \ No newline at end of file