diff --git a/.github/workflows/docker-action.yml b/.github/workflows/docker-action.yml new file mode 100644 index 00000000..cd69ca5d --- /dev/null +++ b/.github/workflows/docker-action.yml @@ -0,0 +1,54 @@ +name: Docker-Action + +on: + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + build-test: + runs-on: ubuntu-latest + + steps: + # Step 1: Checkout the code + - name: Checkout code + uses: actions/checkout@v4 + + # Step 2: Log in to Docker Hub (optional, if pushing images) + # - name: Log in to Docker Hub + # uses: docker/login-action@v2 + # with: + # username: ${{ secrets.DOCKER_USERNAME }} + # password: ${{ secrets.DOCKER_PASSWORD }} + + # Step 3: Build the Docker image + - name: Build Docker image + run: | + docker build -t thoughtful-app . + + # Step 4: Run the Docker container + - name: Run Docker container + run: | + docker run -d --name thoughtful-app-container thoughtful-app + + # Step 5: Run tests inside the container + # - name: Run tests + # run: | + # docker exec thoughtful-app-container python /app/clients/service/model.py + # Step 5: Run tests inside the container + - name: Run tests + run: | + docker exec thoughtful-app-container python /app/tests/test.py + + # Step 6: Clean up + - name: Stop and remove container + run: | + docker stop thoughtful-app-container + docker rm thoughtful-app-container + + - name: Remove Docker image + run: | + docker rmi thoughtful-app diff --git a/.github/workflows/pylint-action.yml b/.github/workflows/pylint-action.yml new file mode 100644 index 00000000..d9640925 --- /dev/null +++ b/.github/workflows/pylint-action.yml @@ -0,0 +1,25 @@ +name: Pylint-Action + +on: [push] + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10"] + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install pylint + pip install -r requirements.txt + - name: Analysing the code with pylint + run: | + echo "::log::Pylint processing" + pylint --disable=C,R,W app/*.py diff --git a/.gitignore b/.gitignore index 14d7fa72..2ef07e3e 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,21 @@ .idea __pycache__ .DS_Store + +# Ignore virtual environment folder +env/ +venv/ + +# Ignore Python cache files +__pycache__/ +*.pyc +*.pyo + +# Ignore Jupyter Notebook checkpoints +.ipynb_checkpoints/ + +# Ignore VS Code settings +.vscode/ + +# Ignore shell Scripts +*.sh \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000..eea94456 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,15 @@ +# Use an official Python runtime as a parent image +FROM python:3.10-slim + +# Set the working directory in the container +WORKDIR /app + +# Copy the current directory contents into the container +COPY . /app + +# Install any needed packages specified in requirements.txt +RUN pip install --no-cache-dir -r requirements.txt + +# Run tests by default +# CMD ["python", "/app/clients/service/model.py"] +CMD ["sleep","1h"] diff --git a/README.md b/README.md index 0a48dc37..2cd448c8 100644 --- a/README.md +++ b/README.md @@ -3,3 +3,102 @@ This will contain the model used for the project that based on the input informa The model works off of dummy data of several combinations of clients alongside the interventions chosen for them as well as their success rate at finding a job afterward. The model will be updated by the case workers by inputing new data for clients with their updated outcome information, and it can be updated on a daily, weekly, or monthly basis. This also has an API file to interact with the front end, and logic in order to process the interventions coming from the front end. This includes functions to clean data, create a matrix of all possible combinations in order to get the ones with the highest increase of success, and output the results in a way the front end can interact with. + +## Run the FastAPI project + +- Clone the git repo + +```bash +# clone the project from git repo +git clone https://github.com/JiayangLJY/CommonAssessmentTool-Group-Thoughtful.git +``` + + + +* Prepare the virtual environment + +```bash +# cd to the project root dir +cd CommonAssessmentTool-Group-Thoughtful + +# create a virtual env for the current project +python3 -m venv venv + +# activate the virtual env +source venv/bin/activate + +# install denpendency packages +pip install -r requirements.txt + +# export python path +export PYTHONPATH=$(pwd):$PYTHONPATH +``` + +> - to deactivate the current virtual env: +> +> ```bash +> deactivate +> ``` + + + +- Run Fastapi application + +```bash +# cd to the {project}/app dir +cd app + +# run project in dev mode +fastapi dev main.py +``` + +- Then open the url `http://127.0.0.1:8000/docs` in the browser to view the doc page of the current running project + + + +## Sprint 1 + +4 tasks have been assigned to each of the group member: + +1. Configurable machine learning model + +- A config file will be included to configure the type and hyperparameter for ML model +- Create an abstraction layer for training and evaluation + +2. Maintenance on form data naming, ordering and converting + +- The naming and order of the form data will be uniformly maintained in the schema.py +- Text conversion will be encapsulated within the data model + +3. Unit tests for server-side data processing steps + +- Test model selection logic in `/predict` endpoint +- Test data processing logic of `interpret_and_calculate()` +- Test frontend validation and form submission +- Validate the input data + +4. Functional test for front-back end interaction + +- Update the frontend to support dynamic model selection and hyperparameter configuration +- Conduct comprehensive integration tests to ensure the new features work correctly across the API and frontend + + + +## Sprint 2 + +- Environments of backend and front end projects were successfully set up and both projects are now runnable in our local machines +- Intergrated test was perfromed on the `clients/predictions` API to make sure the API is working +- `field_serializer` from Pydantic package was used to convert string type into numerical types +- We are investigating FastAPI docs and finding that we can use environment variables to configure some parameters in code +- Primary study showed that we need to retrieve some intermediate variables to suit our code with test + + + +## Sprint 3 + +- Different machine learning models, including LinearRegression, GradientBoostingRegressor and SVR were trained and tested to expand the choice in prediction. +- The task of data type convertion was refactored to be conducted by the `PredictionInput` class. +- Order of features (columns from the frontend form) was configured in `.env` file. +- Unit tests were added to make sure the code works during the refactoring work. +- Intergrated test was perfromed on the `clients/predictions` API to make sure the API is working. + diff --git a/app/clients/.env b/app/clients/.env new file mode 100644 index 00000000..cd205ec6 --- /dev/null +++ b/app/clients/.env @@ -0,0 +1,10 @@ +FEATURE_COLS_IN_SEQ=["age", "gender", "work_experience", "canada_workex", "dep_num", "canada_born", "citizen_status", "level_of_schooling", "fluent_english", "reading_english_scale", "speaking_english_scale", "writing_english_scale", "numeracy_scale", "computer_scale", "transportation_bool", "caregiver_bool", "housing", "income_source", "felony_bool", "attending_school", "currently_employed", "substance_use", "time_unemployed", "need_mental_health_support_bool"] +#FEATURE_COLS_IN_SEQ=["age", "gender", "canada_workex", "work_experience", "dep_num", "canada_born", "citizen_status", "level_of_schooling", "fluent_english", "reading_english_scale", "speaking_english_scale", "writing_english_scale", "numeracy_scale", "computer_scale", "transportation_bool", "caregiver_bool", "housing", "income_source", "felony_bool", "attending_school", "currently_employed", "substance_use", "time_unemployed", "need_mental_health_support_bool"] + + +# Options: RandomForestRegressor, LinearRegression, GradientBoostingRegressor, SVR +MODEL_TYPE=SVR +MODEL_OUTPUT_NAME=svr_model.pkl + +#MODEL_TYPE=RandomForestRegressor +#MODEL_OUTPUT_NAME=random_forest_model.pkl diff --git a/app/clients/router.py b/app/clients/router.py index f860c402..09a543ff 100644 --- a/app/clients/router.py +++ b/app/clients/router.py @@ -1,15 +1,13 @@ from fastapi import APIRouter -from fastapi.responses import HTMLResponse - from app.clients.service.logic import interpret_and_calculate from app.clients.schema import PredictionInput + router = APIRouter(prefix="/clients", tags=["clients"]) + @router.post("/predictions") async def predict(data: PredictionInput): print("HERE") - print(data.model_dump()) - return interpret_and_calculate(data.model_dump()) - - + print(data.model_dump(by_alias=True)) + return interpret_and_calculate(data.model_dump(by_alias=True)) diff --git a/app/clients/schema.py b/app/clients/schema.py index 6b56ad98..4ccc1fcb 100644 --- a/app/clients/schema.py +++ b/app/clients/schema.py @@ -1,27 +1,133 @@ -from pydantic import BaseModel +from pydantic import BaseModel, Field, field_serializer, model_serializer + class PredictionInput(BaseModel): - age: int - gender: str - work_experience: int - canada_workex: int - dep_num: int - canada_born: str - citizen_status: str - level_of_schooling: str - fluent_english: str - reading_english_scale: int - speaking_english_scale: int - writing_english_scale: int - numeracy_scale: int - computer_scale: int - transportation_bool: str - caregiver_bool: str - housing: str - income_source: str - felony_bool: str - attending_school: str - currently_employed: str - substance_use: str - time_unemployed: int - need_mental_health_support_bool: str + """ + PredictionInput is the validated input from webpage users + + The serialization_alias of each filed MUST be identical to the column name in dotenv file + + validation_alias is used for validating the json format data sending from front end + serialization_alias is used as column name when dumping the data model + + Fields in str format are converted into numerical types in @field_serializer decorator + """ + age: int = Field(..., validation_alias='age', serialization_alias='age') + gender: str = Field(..., validation_alias='gender', serialization_alias='gender') + work_experience: int = Field(..., validation_alias='work_experience', serialization_alias='work_experience') + canada_workex: int = Field(..., validation_alias='canada_workex', serialization_alias='canada_workex') + dep_num: int = Field(..., validation_alias='dep_num', serialization_alias='dep_num') + canada_born: bool = Field(..., validation_alias='canada_born', serialization_alias='canada_born') + citizen_status: str = Field(..., validation_alias='citizen_status', serialization_alias='citizen_status') + level_of_schooling: str = Field(..., validation_alias='level_of_schooling', serialization_alias='level_of_schooling') + fluent_english: bool = Field(..., validation_alias='fluent_english', serialization_alias='fluent_english') + reading_english_scale: int = Field(..., validation_alias='reading_english_scale', serialization_alias='reading_english_scale') + speaking_english_scale: int = Field(..., validation_alias='speaking_english_scale', serialization_alias='speaking_english_scale') + writing_english_scale: int = Field(..., validation_alias='writing_english_scale', serialization_alias='writing_english_scale') + numeracy_scale: int = Field(..., validation_alias='numeracy_scale', serialization_alias='numeracy_scale') + computer_scale: int = Field(..., validation_alias='computer_scale', serialization_alias='computer_scale') + transportation_bool: bool = Field(..., validation_alias='transportation_bool', serialization_alias='transportation_bool') + caregiver_bool: bool = Field(..., validation_alias='caregiver_bool', serialization_alias='caregiver_bool') + housing: str = Field(..., validation_alias='housing', serialization_alias='housing') + income_source: str = Field(..., validation_alias='income_source', serialization_alias='income_source') + felony_bool: bool = Field(..., validation_alias='felony_bool', serialization_alias='felony_bool') + attending_school: bool = Field(..., validation_alias='attending_school', serialization_alias='attending_school') + currently_employed: bool = Field(..., validation_alias='currently_employed', serialization_alias='currently_employed') + substance_use: bool = Field(..., validation_alias='substance_use', serialization_alias='substance_use') + time_unemployed: int = Field(..., validation_alias='time_unemployed', serialization_alias='time_unemployed') + need_mental_health_support_bool: bool = Field(..., validation_alias='need_mental_health_support_bool', serialization_alias='need_mental_health_support_bool') + + # The following field serializer converts specific fields into numerical types, + @field_serializer('income_source') + def serialize_income_source(self, income_source: str): + match income_source: + case 'No Source of Income': + return 1 + case 'Employment Insurance': + return 2 + case 'Workplace Safety and Insurance Board': + return 3 + case 'Ontario Works applied or receiving': + return 4 + case 'Ontario Disability Support Program applied or receiving': + return 5 + case 'Dependent of someone receiving OW or ODSP': + return 6 + case 'Crown Ward': + return 7 + case 'Employment': + return 8 + case 'Self-Employment': + return 9 + case 'Other (specify)': + return 10 + + @field_serializer('housing') + def serialize_housing(self, housing: str): + match housing: + case 'Renting-private': + return 1 + case 'Renting-subsidized': + return 2 + case 'Boarding or lodging': + return 3 + case 'Homeowner': + return 4 + case 'Living with family/friend': + return 5 + case 'Institution': + return 6 + case 'Temporary second residence': + return 7 + case 'Band-owned home': + return 8 + case 'Homeless or transient': + return 9 + case 'Emergency hostel': + return 10 + + @field_serializer('level_of_schooling') + def serialize_level_of_schooling(self, level_of_schooling: str): + match level_of_schooling: + case 'Grade 0-8': + return 1 + case 'Grade 9': + return 2 + case 'Grade 10': + return 3 + case 'Grade 11': + return 4 + case 'Grade 12 or equivalent': + return 5 + case 'OAC or Grade 13': + return 6 + case 'Some college': + return 7 + case 'Some university': + return 8 + case 'Some apprenticeship': + return 9 + case 'Certificate of Apprenticeship': + return 10 + case 'Journeyperson': + return 11 + case 'Certificate/Diploma': + return 12 + case 'Bachelor’s degree': + return 13 + case 'Post graduate': + return 14 + + @field_serializer('citizen_status') + def serialize_citizen_status(self, citizen_status: str): + match citizen_status: + case 'citizen': + return 0 + case 'permanent_resident': + return 1 + case 'temporary_resident': + return 2 + + @field_serializer('gender') + def serialize_age(self, gender: str): + return 1 if gender == 'M' else 2 diff --git a/app/clients/service/gradient_boosting_model.pkl b/app/clients/service/gradient_boosting_model.pkl new file mode 100644 index 00000000..45e44287 Binary files /dev/null and b/app/clients/service/gradient_boosting_model.pkl differ diff --git a/app/clients/service/logic.py b/app/clients/service/logic.py index 0fd826a5..31ad1180 100644 --- a/app/clients/service/logic.py +++ b/app/clients/service/logic.py @@ -1,168 +1,92 @@ -from typing import List -import pandas as pd -import json import numpy as np import pickle -from itertools import combinations_with_replacement from itertools import product +from app.clients.schema import PredictionInput +from app.clients.util import util_get_cols +import os +from dotenv import load_dotenv + +# Load environment variables +load_dotenv() + +# Get configuration from .env +MODEL_TYPE = os.getenv("MODEL_TYPE", "RandomForestRegressor") # Default: RandomForestRegressor +MODEL_OUTPUT_NAME = os.getenv("MODEL_OUTPUT_NAME", "random_forest_model.pkl") # Default: different.pkl + +# Dynamically load the model +try: + current_dir = os.path.dirname(os.path.abspath(__file__)) + model_path = os.path.join(current_dir, MODEL_OUTPUT_NAME) + with open(model_path, "rb") as model_file: + model = pickle.load(model_file) + print(f"Model of type {MODEL_TYPE} loaded successfully from {model_path}") +except FileNotFoundError: + print(f"Error: Model file not found at {model_path}. Please check the MODEL_OUTPUT_NAME in .env.") +except Exception as e: + print(f"An error occurred while loading the model: {e}") + column_intervention = [ 'Life Stabilization', 'General Employment Assistance Services', 'Retention Services', 'Specialized Services', - 'Employment-Related Financial Supports for Job Seekers and Employers', + 'Employment-Related Financial Supports for Job Seekers and Employers', 'Employer Financial Supports', 'Enhanced Referrals for Skills Development' ] -#loads the model into logic -import os +def convert_none_bool(value): + """convert None to 0, True to 1, False to 0""" + if value is None: + return 0 + if type(value) == bool: + return 1 if value is True else 0 + else: + return value -current_dir = os.path.dirname(os.path.abspath(__file__)) -filename = os.path.join(current_dir, 'model.pkl') -model = pickle.load(open(filename, "rb")) - - -def clean_input_data(data): - #translate input into wahtever we trained the model on, numerical data in a specific order - columns = ["age","gender","work_experience","canada_workex","dep_num", "canada_born", - "citizen_status", "level_of_schooling", "fluent_english", "reading_english_scale", - "speaking_english_scale", "writing_english_scale", "numeracy_scale", "computer_scale", - "transportation_bool", "caregiver_bool", "housing", "income_source", "felony_bool", "attending_school", - "currently_employed", "substance_use", "time_unemployed", "need_mental_health_support_bool"] - demographics = { - 'age': data['age'], - 'gender': data['gender'], - 'work_experience': data['work_experience'], - 'canada_workex': data['canada_workex'], - 'dep_num': data['dep_num'], - 'canada_born': data['canada_born'], - 'citizen_status': data['citizen_status'], - 'level_of_schooling': data['level_of_schooling'], - 'fluent_english': data['fluent_english'], - 'reading_english_scale': data['reading_english_scale'], - 'speaking_english_scale': data['speaking_english_scale'], - 'writing_english_scale': data['writing_english_scale'], - 'numeracy_scale': data['numeracy_scale'], - 'computer_scale': data['computer_scale'], - 'transportation_bool': data['transportation_bool'], - 'caregiver_bool': data['caregiver_bool'], - 'housing': data['housing'], - 'income_source': data['income_source'], - 'felony_bool': data['felony_bool'], - 'attending_school': data['attending_school'], - 'currently_employed': data['currently_employed'], - 'substance_use': data['substance_use'], - 'time_unemployed': data['time_unemployed'], - 'need_mental_health_support_bool': data['need_mental_health_support_bool'] - } - output = [] - for column in columns: - data = demographics.get(column, None) #default is None, and if you want to pass a value, can return any value - if isinstance(data, str): - data = convert_text(column, data) - output.append(data) - return output +def clean_input_data(data, features): + """retrieve values from {data} following in the ORDER defined by {features}""" + return [convert_none_bool(data.get(feat)) for feat in features] -def convert_text(column, data:str): - # Convert text answers from front end into digits - # TODO: ensure that categorical columns match the valid answers in FormNew.jsx (L131) - categorical_cols_integers = [ - { - "": 0, - "true": 1, - "false": 0, - "no": 0, - "yes": 1, - "No": 0, - "Yes": 1 - }, - { - 'Grade 0-8': 1, - 'Grade 9': 2, - 'Grade 10': 3, - 'Grade 11': 4, - 'Grade 12 or equivalent': 5, - 'OAC or Grade 13': 6, - 'Some college': 7, - 'Some university': 8, - 'Some apprenticeship': 9, - 'Certificate of Apprenticeship': 10, - 'Journeyperson': 11, - 'Certificate/Diploma': 12, - 'Bachelor’s degree': 13, - 'Post graduate': 14 - }, - { - 'Renting-private': 1, - 'Renting-subsidized': 2, - 'Boarding or lodging': 3, - 'Homeowner': 4, - 'Living with family/friend': 5, - 'Institution': 6, - 'Temporary second residence': 7, - 'Band-owned home': 8, - 'Homeless or transient': 9, - 'Emergency hostel': 10 - }, - { - 'No Source of Income': 1, - 'Employment Insurance': 2, - 'Workplace Safety and Insurance Board': 3, - 'Ontario Works applied or receiving': 4, - 'Ontario Disability Support Program applied or receiving': 5, - 'Dependent of someone receiving OW or ODSP': 6, - 'Crown Ward': 7, - 'Employment': 8, - 'Self-Employment': 9, - 'Other (specify)': 10 - } - ] - for category in categorical_cols_integers: - print(f"data: {data}") - print(f"column: {column}") - if data in category: - return category[data] - - if isinstance(data, str) and data.isnumeric(): - return int(data) - - return data - -#creates 128 possible combinations in order to run every possibility through model + +# creates 128 possible combinations in order to run every possibility through model def create_matrix(row): - data = [row.copy() for _ in range(128)] + data = [row.copy() for _ in range(128)] perms = intervention_permutations(7) data = np.array(data) perms = np.array(perms) - matrix = np.concatenate((data,perms), axis = 1) + matrix = np.concatenate((data, perms), axis=1) return np.array(matrix) -#create matrix of permutations of 1 and 0 of num length + + +# create matrix of permutations of 1 and 0 of num length def intervention_permutations(num): - perms = list(product([0,1],repeat=num)) + perms = list(product([0, 1], repeat=num)) return np.array(perms) + def get_baseline_row(row): print(type(row)) - base_interventions = np.array([0]*7) # no interventions + base_interventions = np.array([0] * 7) # no interventions row = np.array(row) print(row) print(type(row)) - line = np.concatenate((row,base_interventions)) + line = np.concatenate((row, base_interventions)) return line + def intervention_row_to_names(row): names = [] for i, value in enumerate(row): - if value == 1: + if value == 1: names.append(column_intervention[i]) return names + def process_results(baseline, results): - ##Example: """ { baseline_probability: 80 #baseline percentage point with no interventions @@ -173,71 +97,40 @@ def process_results(baseline, results): ] } """ - result_list= [] + result_list = [] for row in results: - percent = row[-1] + percent = row[-1] names = intervention_row_to_names(row) - result_list.append((percent,names)) + result_list.append((percent, names)) output = { - "baseline": baseline[-1], #if it's an array, want the value inside of the array + "baseline": baseline[-1], # if it's an array, want the value inside of the array "interventions": result_list, } return output + def interpret_and_calculate(data): - raw_data = clean_input_data(data) + raw_data = clean_input_data(data, util_get_cols()) baseline_row = get_baseline_row(raw_data) baseline_row = baseline_row.reshape(1, -1) - print("BASELINE ROW IS",baseline_row) + print("BASELINE ROW IS", baseline_row) intervention_rows = create_matrix(raw_data) + print("ML MODEL IS", MODEL_TYPE) baseline_prediction = model.predict(baseline_row) intervention_predictions = model.predict(intervention_rows) - intervention_predictions = intervention_predictions.reshape(-1, 1) #want shape to be a vertical column, not a row - result_matrix = np.concatenate((intervention_rows,intervention_predictions), axis = 1) ##CHANGED AXIS - + intervention_predictions = intervention_predictions.reshape(-1, 1) # want shape to be a vertical column, not a row + result_matrix = np.concatenate((intervention_rows, intervention_predictions), axis=1) # CHANGED AXIS + # sort this matrix based on prediction # print("RESULT SAMPLE::", result_matrix[:5]) - result_order = result_matrix[:,-1].argsort() #take all rows and only last column, gives back list of indexes sorted - result_matrix = result_matrix[result_order] #indexing the matrix by the order + result_order = result_matrix[:, -1].argsort() # take all rows and only last column, gives back list of indexes sorted + result_matrix = result_matrix[result_order] # indexing the matrix by the order # slice matrix to only top N results - result_matrix = result_matrix[-3:,-8:] #-8 for interventions and prediction, want top 3, 3 combinations of intervention + result_matrix = result_matrix[-3:, -8:] # -8 for interventions and prediction, want top 3, 3 combinations of intervention # post process results if needed ie make list of names for each row - results = process_results(baseline_prediction,result_matrix) + results = process_results(baseline_prediction, result_matrix) # build output dict print(f"RESULTS: {results}") return results - -if __name__ == "__main__": - print("running") - data = { - "age": "23", - "gender": "1", - "work_experience": "1", - "canada_workex": "1", - "dep_num": "0", - "canada_born": "1", - "citizen_status": "2", - "level_of_schooling": "2", - "fluent_english": "3", - "reading_english_scale": "2", - "speaking_english_scale": "2", - "writing_english_scale": "3", - "numeracy_scale": "2", - "computer_scale": "3", - "transportation_bool": "2", - "caregiver_bool": "1", - "housing": "1", - "income_source": "5", - "felony_bool": "1", - "attending_school": "0", - "currently_employed": "1", - "substance_use": "1", - "time_unemployed": "1", - "need_mental_health_support_bool": "1" - } - # print(data) - results = interpret_and_calculate(data) - print(results) - diff --git a/app/clients/service/model.py b/app/clients/service/model.py index 51369ac3..a7c196b7 100644 --- a/app/clients/service/model.py +++ b/app/clients/service/model.py @@ -1,39 +1,30 @@ +import os import pandas as pd -import json import numpy as np import pickle from sklearn.model_selection import train_test_split -from sklearn.ensemble import RandomForestRegressor +from app.clients.util import util_get_cols, get_model +from dotenv import load_dotenv + +# Load environment variables +load_dotenv() + +# Get configuration from .env +MODEL_TYPE = os.getenv("MODEL_TYPE", "RandomForestRegressor") # Default: RandomForestRegressor +MODEL_OUTPUT_NAME = os.getenv("MODEL_OUTPUT_NAME", "random_forest_model.pkl") # Default: different.pkl def prepare_models(): + """ + Prepare and train a machine learning model based on the configuration. + """ # Load dataset and define the features and labels backendCode = pd.read_csv('data_commontool.csv') - # Define categorical columns and interventions - categorical_cols = ['age', - 'gender', #bool - 'work_experience', #years of work experience - 'canada_workex',#years of work experience in canada - 'dep_num', #number of dependents - 'canada_born', #born in canada - 'citizen_status', #citizen status - 'level_of_schooling', #highest level achieved (1-14) - 'fluent_english', #english level fluency, scale (1-10) - 'reading_english_scale', #reading scale (1-10) - 'speaking_english_scale', #speaking level comfort (1-10) - 'writing_english_scale', #writing scale (1-10) - 'numeracy_scale', #numeracy scale (1-10) - 'computer_scale', #computer use scale (1-10) - 'transportation_bool', #need transportation support (bool) - 'caregiver_bool', #is a primary care giver bool - 'housing', #housing situation 1-10 - 'income_source', #source of income 1-10 - 'felony_bool', #has a felony bool - 'attending_school', #currently a student bool - 'currently_employed', #currently employed bool - 'substance_use', #disorder, bool - 'time_unemployed', #number of years unemployed - 'need_mental_health_support_bool'] #need support + + # Load categorical columns dynamically + categorical_cols = util_get_cols() + + # Define interventions interventions = [ 'employment_assistance', 'life_stabilization', @@ -44,29 +35,46 @@ def prepare_models(): 'enhanced_referrals' ] categorical_cols.extend(interventions) + # Prepare training data X_categorical_baseline = backendCode[categorical_cols] - y_baseline = backendCode['success_rate'] + y_baseline = backendCode['success_rate'] # Assuming 'success_rate' is the target variable X_categorical_baseline = np.array(X_categorical_baseline) y_baseline = np.array(y_baseline) + + # Split data into training and testing sets X_train_baseline, X_test_baseline, y_train_baseline, y_test_baseline = train_test_split( - X_categorical_baseline, y_baseline, test_size=0.2, random_state=42) + X_categorical_baseline, y_baseline, test_size=0.2, random_state=42 + ) + + # Dynamically load the model based on configuration + model = get_model(MODEL_TYPE) - rf_model_baseline = RandomForestRegressor(n_estimators=100, random_state=42) - rf_model_baseline.fit(X_train_baseline, y_train_baseline) + # Train the model + print(f"Training {MODEL_TYPE} model...") + model.fit(X_train_baseline, y_train_baseline) # Example: Predicting on the test set - baseline_predictions = rf_model_baseline.predict(X_test_baseline) + # baseline_predictions = model.predict(X_test_baseline) + + return model - - return rf_model_baseline def main(): - print("Start model.") + """ + Main function to prepare and save the trained model. + """ + print("Starting model") model = prepare_models() - pickle.dump(model, open("model.pkl", "wb")) #saves model to the file name input, write binary - model = pickle.load(open("model.pkl", "rb")) #read binary + # Save the model to a file (configurable via .env) + pickle.dump(model, open(MODEL_OUTPUT_NAME, "wb")) + print(f"Model saved as {MODEL_OUTPUT_NAME}") + + # Optional: Reload the model to verify save/load functionality + model = pickle.load(open(MODEL_OUTPUT_NAME, "rb")) + print(f"Model reloaded successfully from {MODEL_OUTPUT_NAME}") + if __name__ == "__main__": main() diff --git a/app/clients/service/random_forest_model.pkl b/app/clients/service/random_forest_model.pkl new file mode 100644 index 00000000..61264e05 Binary files /dev/null and b/app/clients/service/random_forest_model.pkl differ diff --git a/app/clients/service/svr_model.pkl b/app/clients/service/svr_model.pkl new file mode 100644 index 00000000..4010fd32 Binary files /dev/null and b/app/clients/service/svr_model.pkl differ diff --git a/app/clients/util.py b/app/clients/util.py new file mode 100644 index 00000000..7a67a9e6 --- /dev/null +++ b/app/clients/util.py @@ -0,0 +1,27 @@ +import os +import json +from functools import lru_cache +from dotenv import load_dotenv +from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor +from sklearn.linear_model import LinearRegression +from sklearn.svm import SVR + +@lru_cache +def util_get_cols(): + """Read categorical columns from dotenv file.""" + load_dotenv() + return json.loads(os.getenv('FEATURE_COLS_IN_SEQ')) + + +def get_model(model_type): + """Dynamically load a machine learning model based on the configuration.""" + model_mapping = { + "RandomForestRegressor": RandomForestRegressor, + "LinearRegression": LinearRegression, + "GradientBoostingRegressor": GradientBoostingRegressor, + "SVR": SVR + } + if model_type not in model_mapping: + raise ValueError(f"Unsupported model type: {model_type}. Choose from {list(model_mapping.keys())}") + + return model_mapping[model_type]() diff --git a/requirements.txt b/requirements.txt index 1ccf75b7..4c2efce3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,7 +8,7 @@ argon2-cffi-bindings==21.2.0 arrow==1.2.3 astroid==3.0.1 asttokens==2.4.0 -attrs==22.1.0 +attrs==24.2.0 backcall==0.2.0 bcrypt==4.0.1 beautifulsoup4==4.12.2 @@ -32,8 +32,11 @@ fastapi==0.103.2 fastjsonschema==2.16.3 folium==0.14.0 fqdn==1.5.1 +greenlet==3.1.1 h11==0.14.0 +httpcore==1.0.6 httptools==0.6.0 +httpx==0.27.2 idna==3.4 iniconfig==1.1.1 ipykernel==6.22.0 @@ -71,6 +74,7 @@ nest-asyncio==1.5.6 notebook==6.5.4 notebook_shim==0.2.2 numpy==1.24.2 +outcome==1.3.0.post0 packaging==23.2 pandas==2.0.0 pandocfilters==1.5.0 @@ -96,7 +100,14 @@ pydantic_core==2.10.1 Pygments==2.16.1 pylint==3.0.1 pyrsistent==0.19.3 +PySocks==1.7.1 pytest==7.2.0 +pytest-base-url==2.1.0 +pytest-html==4.1.1 +pytest-metadata==3.1.1 +pytest-mock==3.14.0 +pytest-selenium==4.1.0 +pytest-variables==3.1.0 python-dateutil==2.8.2 python-dotenv==1.0.0 python-jose==3.3.0 @@ -113,13 +124,16 @@ rfc3986-validator==0.1.1 rsa==4.9 scikit-learn==1.4.2 scipy==1.13.0 +selenium==4.26.1 Send2Trash==1.8.0 six==1.16.0 sniffio==1.3.0 +sortedcontainers==2.4.0 soupsieve==2.4.1 SQLAlchemy==2.0.21 stack-data==0.6.3 starlette==0.27.0 +tenacity==9.0.0 terminado==0.17.1 threadpoolctl==3.4.0 tinycss2==1.2.1 @@ -127,7 +141,9 @@ tomli==2.0.1 tomlkit==0.12.1 tornado==6.3.1 traitlets==5.11.2 -typing_extensions==4.8.0 +trio==0.27.0 +trio-websocket==0.11.1 +typing_extensions==4.12.2 tzdata==2023.3 uri-template==1.2.0 urllib3==2.0.7 @@ -136,7 +152,9 @@ uvloop==0.17.0 watchfiles==0.20.0 wcwidth==0.2.8 webcolors==1.13 +webdriver-manager==4.0.2 webencodings==0.5.1 -websocket-client==1.5.1 +websocket-client==1.8.0 websockets==11.0.3 widgetsnbextension==4.0.7 +wsproto==1.2.0 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..8a6419f7 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,4 @@ +# tests/conftest.py +import os +import sys +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) \ No newline at end of file diff --git a/tests/test.py b/tests/test.py index a911f0a2..9b3b0d8d 100644 --- a/tests/test.py +++ b/tests/test.py @@ -1,23 +1,1185 @@ -from logic import interpret_and_calculate -from itertools import combinations_with_replacement - -# def test_interpret_and_calculate(): -# print("running tests") -# data = {"23","1","1","1","1","0","1","2","2","3","2", -# "2","3","2","1","1","1","1","1","1","0","1","1","1" -# } -# result = interpret_and_calculate(data) -# print(data) +import json +import sys +import os +import warnings +from fastapi.testclient import TestClient +import pytest +from unittest.mock import Mock, patch +import numpy as np + +# Add the project root directory to Python path +project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.append(project_root) +warnings.filterwarnings("ignore", category=DeprecationWarning) +from app.main import app +from app.clients.util import util_get_cols +from app.clients.service.logic import clean_input_data, interpret_and_calculate +from app.clients.schema import PredictionInput + +from selenium import webdriver +from selenium.webdriver.chrome.service import Service +from selenium.webdriver.chrome.options import Options +from webdriver_manager.chrome import ChromeDriverManager +from selenium.webdriver.common.by import By +from selenium.webdriver.support.ui import WebDriverWait +from selenium.webdriver.support import expected_conditions as EC +from selenium.webdriver.support.select import Select +from selenium.common.exceptions import NoSuchElementException, TimeoutException +import time + +# 7-3-1 Test Model Selection Logic in /predict Endpoint +# Test if different model_name values result in the correct model being instantiated. +# Verify that the default model is used when model_name is not specified. +# Simulate various input data and check if the output format is as expected. +# Mock different models to verify if the correct model class is being called. + +# Test data with string values for categorical fields +test_model = { + "age": 18, + "gender": "Female", + "work_experience": 3, + "canada_workex": 0, + "dep_num": 1, + "canada_born": True, + "citizen_status": "Citizen", + "level_of_schooling": "Bachelor", + "fluent_english": True, + "reading_english_scale": 3, + "speaking_english_scale": 1, + "writing_english_scale": 3, + "numeracy_scale": 0, + "computer_scale": 2, + "transportation_bool": False, + "caregiver_bool": True, + "housing": "Stable", + "income_source": "Employment", + "felony_bool": True, + "attending_school": False, + "currently_employed": True, + "substance_use": True, + "time_unemployed": 1, + "need_mental_health_support_bool": False +} + +# Test output (expected cleaned data) +test_output = [18, 1, 3, 0, 1, 1, 0, 5, 1, 3, 1, 3, 0, 2, 0, 1, 5, 1, 1, 0, 1, 1, 1, 0] + +# prediction result +test_prediction_result = { + "baseline": 67.6, + "interventions": [ + (68.7, ["Life Stabilization", "General Employment Assistance Services", + "Specialized Services", "Employment-Related Financial Supports for Job Seekers and Employers"]), + (68.7, ["Life Stabilization", "Specialized Services", + "Employment-Related Financial Supports for Job Seekers and Employers"]), + (69.0, ["Life Stabilization", "Specialized Services"]) + ] +} + +# ordered features for test +test_features = [ + 'age', 'gender', 'work_experience', 'canada_workex', 'dep_num', + 'canada_born', 'citizen_status', 'level_of_schooling', 'fluent_english', + 'reading_english_scale', 'speaking_english_scale', 'writing_english_scale', + 'numeracy_scale', 'computer_scale', 'transportation_bool', 'caregiver_bool', + 'housing', 'income_source', 'felony_bool', 'attending_school', + 'currently_employed', 'substance_use', 'time_unemployed', + 'need_mental_health_support_bool' +] + +@pytest.fixture +def mock_model(): + """Mock the model to return consistent predictions""" + with patch('app.clients.service.model.prepare_models') as mock: + mock_rf = Mock() + mock_rf.predict.return_value = np.array([67.6, 68.7, 68.7, 69.0]) + mock.return_value = mock_rf + yield mock + +def get_data(): + """Helper function to get prediction input data""" + return PredictionInput(**test_model).model_dump(by_alias=True) + +def test_dump_model(): + """test to check the dumped model""" + print("\n#################### test_dump_model() ####################") + refactor_converted = get_data() + # Compare only the fields that should match exactly + matching_fields = ['age', 'work_experience', 'canada_workex', 'dep_num', + 'reading_english_scale', 'speaking_english_scale', + 'writing_english_scale', 'numeracy_scale', 'computer_scale', + 'time_unemployed'] + + all_match = all(refactor_converted[field] == test_model[field] + for field in matching_fields) + + if all_match: + print("PASS\n") + else: + print("FAIL\n") + +def test_clean_input_data(): + """test data type conversion""" + print("\n#################### test_clean_input_data() ####################") + data = get_data() + refactor_output = clean_input_data(data, test_features) + if len(test_output) != len(refactor_output): + print("FAIL: len not equals\n") + return + for i in range(len(test_output)): + if test_output[i] != refactor_output[i]: + print(f"FAIL: the {i}th element not equals. origin:{test_output[i]}, refactor:{refactor_output[i]}\n") + return + print("PASS\n") + +def test_prediction(mock_model): + """test the whole prediction process""" + print("\n#################### test_prediction() ####################") + data = get_data() + result = interpret_and_calculate(data) + if result == test_prediction_result: + print("PASS\n") + else: + print("FAIL\n") + +def test_column_order(): + """Test if column order matches expected order""" + print("\n#################### test_column_order() ####################") + cols = util_get_cols() + if cols == test_features: + print("PASS") + else: + print("FAIL") + +# Additional tests for model selection logic +def test_model_output_format(mock_model): + """Test if model output follows expected format""" + print("\n#################### test_model_output_format() ####################") + data = get_data() + result = interpret_and_calculate(data) + + # Check result structure + if not isinstance(result, dict): + print("FAIL: Result is not a dictionary\n") + return + if "baseline" not in result or "interventions" not in result: + print("FAIL: Missing required keys in result\n") + return + if not isinstance(result["baseline"], float): + print("FAIL: Baseline is not a float\n") + return + if not isinstance(result["interventions"], list): + print("FAIL: Interventions is not a list\n") + return + + print("PASS\n") + +def test_intervention_combinations(mock_model): + """Test if model generates valid intervention combinations""" + print("\n#################### test_intervention_combinations() ####################") + data = get_data() + result = interpret_and_calculate(data) + + valid_interventions = [ + "Life Stabilization", + "General Employment Assistance Services", + "Retention Services", + "Specialized Services", + "Employment-Related Financial Supports for Job Seekers and Employers", + "Employer Financial Supports", + "Enhanced Referrals for Skills Development" + ] + + for _, interventions in result["interventions"]: + if not all(i in valid_interventions for i in interventions): + print("FAIL: Invalid intervention found\n") + return + + print("PASS\n") + +def test_prediction_values(mock_model): + """Test if prediction values are within valid range""" + print("\n#################### test_prediction_values() ####################") + data = get_data() + result = interpret_and_calculate(data) + + if not (0 <= result["baseline"] <= 100): + print("FAIL: Baseline prediction out of range\n") + return + + for prob, _ in result["interventions"]: + if not (0 <= prob <= 100): + print("FAIL: Intervention prediction out of range\n") + return + if prob < result["baseline"]: + print("FAIL: Intervention prediction lower than baseline\n") + return + + print("PASS\n") + + + +#7-3-2 Test Data Processing Logic of interpret_and_calculate Function +# Test the function's ability to handle different input data formats (e.g., numbers, strings,missing values). +# Simulate potential erroneous inputs, such as missing fields or type errors, and ensure that the function returns appropriate error messages. +# Verify that the function can correctly format input data into the format required by the model. +def test_missing_fields(): + """Test handling of missing fields""" + print("\n#################### test_missing_fields() ####################") + + # Remove required fields + invalid_data = test_model.copy() + del invalid_data['age'] + del invalid_data['gender'] + + try: + PredictionInput(**invalid_data) + print("FAIL: Should raise validation error for missing fields\n") + except Exception as e: + if "age" in str(e) and "gender" in str(e): + print("PASS\n") + else: + print(f"FAIL: Unexpected error message: {str(e)}\n") + +def test_invalid_types(): + """Test handling of invalid data types""" + print("\n#################### test_invalid_types() ####################") + + test_cases = [ + {"field": "age", "value": "invalid_age", "expected_error": "type_error"}, + {"field": "gender", "value": 123, "expected_error": "string_type"}, + {"field": "canada_born", "value": "not_boolean", "expected_error": "bool_type"}, + {"field": "reading_english_scale", "value": "high", "expected_error": "type_error"} + ] + + passes = 0 + total = len(test_cases) + + for case in test_cases: + invalid_data = test_model.copy() + invalid_data[case["field"]] = case["value"] + + try: + PredictionInput(**invalid_data) + print(f"FAIL: Should raise validation error for {case['field']}\n") + except Exception as e: + if case["expected_error"] in str(e).lower(): + passes += 1 + else: + print(f"FAIL: Unexpected error message for {case['field']}: {str(e)}\n") + + if passes == total: + print("PASS: All type validations working\n") + else: + print(f"FAIL: {total - passes} type validations failed\n") + +def test_empty_values(): + """Test handling of empty values""" + print("\n#################### test_empty_values() ####################") + + test_cases = [ + {"field": "gender", "value": ""}, + {"field": "citizen_status", "value": ""}, + {"field": "housing", "value": ""} + ] + + passes = 0 + total = len(test_cases) + + for case in test_cases: + invalid_data = test_model.copy() + invalid_data[case["field"]] = case["value"] + + try: + PredictionInput(**invalid_data) + print(f"FAIL: Should raise validation error for empty {case['field']}\n") + except Exception as e: + if "empty" in str(e).lower() or "blank" in str(e).lower(): + passes += 1 + else: + print(f"FAIL: Unexpected error message for {case['field']}: {str(e)}\n") + + if passes == total: + print("PASS: All empty value validations working\n") + else: + print(f"FAIL: {total - passes} empty value validations failed\n") + +def test_out_of_range_values(): + """Test handling of out-of-range values""" + print("\n#################### test_out_of_range_values() ####################") + + test_cases = [ + {"field": "age", "value": -1}, + {"field": "age", "value": 150}, + {"field": "reading_english_scale", "value": 6}, + {"field": "time_unemployed", "value": -1} + ] + + passes = 0 + total = len(test_cases) + + for case in test_cases: + invalid_data = test_model.copy() + invalid_data[case["field"]] = case["value"] + + try: + PredictionInput(**invalid_data) + print(f"FAIL: Should raise validation error for out-of-range {case['field']}\n") + except Exception as e: + if "range" in str(e).lower() or "greater than" in str(e).lower() or "less than" in str(e).lower(): + passes += 1 + else: + print(f"FAIL: Unexpected error message for {case['field']}: {str(e)}\n") + + if passes == total: + print("PASS: All range validations working\n") + else: + print(f"FAIL: {total - passes} range validations failed\n") + +def test_data_cleaning(mock_model): + """Test data cleaning and formatting""" + print("\n#################### test_data_cleaning() ####################") + + # Test boolean conversion + boolean_variations = { + "fluent_english": [True, "True", "true", 1, "1"], + "transportation_bool": [False, "False", "false", 0, "0"], + } + + passes = 0 + total = len(boolean_variations) * len(list(boolean_variations.values())[0]) + + for field, values in boolean_variations.items(): + for value in values: + test_data = test_model.copy() + test_data[field] = value + + try: + data = PredictionInput(**test_data).model_dump(by_alias=True) + cleaned_data = clean_input_data(data, util_get_cols()) + if isinstance(cleaned_data[util_get_cols().index(field)], int): + passes += 1 + else: + print(f"FAIL: Boolean conversion failed for {field} = {value}\n") + except Exception as e: + print(f"FAIL: Error processing {field} = {value}: {str(e)}\n") + + if passes == total: + print("PASS: All boolean conversions working\n") + else: + print(f"FAIL: {total - passes} boolean conversions failed\n") + +def test_categorical_encoding(): + """Test categorical field encoding""" + print("\n#################### test_categorical_encoding() ####################") + + # Test different categorical values + categorical_variations = { + "gender": ["Female", "Male", "Other"], + "citizen_status": ["Citizen", "Permanent Resident", "Other"], + "housing": ["Stable", "Temporary", "None"] + } + + passes = 0 + total = len(categorical_variations) + + for field, values in categorical_variations.items(): + try: + for value in values: + test_data = test_model.copy() + test_data[field] = value + data = PredictionInput(**test_data).model_dump(by_alias=True) + cleaned_data = clean_input_data(data, util_get_cols()) + if isinstance(cleaned_data[util_get_cols().index(field)], (int, float)): + passes += 1 + else: + print(f"FAIL: Categorical encoding failed for {field} = {value}\n") + break + except Exception as e: + print(f"FAIL: Error processing {field}: {str(e)}\n") + + if passes == len(categorical_variations): + print("PASS: All categorical encodings working\n") + else: + print(f"FAIL: {total - passes} categorical encodings failed\n") + +def test_numeric_processing(): + """Test numeric field processing""" + print("\n#################### test_numeric_processing() ####################") + + numeric_fields = ['age', 'work_experience', 'dep_num', 'reading_english_scale'] + numeric_variations = { + 'age': [18, "18", 18.0], + 'work_experience': [3, "3", 3.0], + 'dep_num': [1, "1", 1.0], + 'reading_english_scale': [3, "3", 3.0] + } + + passes = 0 + total = len(numeric_fields) * len(list(numeric_variations.values())[0]) + + for field, values in numeric_variations.items(): + for value in values: + test_data = test_model.copy() + test_data[field] = value + + try: + data = PredictionInput(**test_data).model_dump(by_alias=True) + cleaned_data = clean_input_data(data, util_get_cols()) + if isinstance(cleaned_data[util_get_cols().index(field)], (int, float)): + passes += 1 + else: + print(f"FAIL: Numeric processing failed for {field} = {value}\n") + except Exception as e: + print(f"FAIL: Error processing {field} = {value}: {str(e)}\n") + + if passes == total: + print("PASS: All numeric processing working\n") + else: + print(f"FAIL: {total - passes} numeric processing failed\n") + +# 7-3-3 Test Frontend Validation and Form Submission +# Verify that validation rules for each input field in the form are enforced (e.g., age should be 18-65, required fields cannot be empty). +# Test the JavaScript submitForm() function to ensure it properly converts form data into JSON format and sends it to the /predict endpoint. +# Simulate different response statuses from the backend (e.g., 200, 400, 500) and check if the frontend handles these responses correctly and displays appropriate messages to the user. + +@pytest.fixture +def driver(): + """Setup Chrome WebDriver""" + chrome_options = Options() + service = Service(ChromeDriverManager().install()) + driver = webdriver.Chrome(service=service, options=chrome_options) + driver.get("http://localhost:3000/form") + driver.implicitly_wait(2) + yield driver + driver.quit() + +def find_element_safe(driver, by, value, timeout=5): + """Safely find an element with wait and error handling""" + try: + element = WebDriverWait(driver, timeout).until( + EC.presence_of_element_located((by, value)) + ) + return element + except (TimeoutException, NoSuchElementException): + print(f"Element not found: {value}") + return None + +def test_required_fields(driver): + """Test validation of required fields""" + print("\n#################### test_required_fields() ####################") + + # Find submit button (Material-UI button) + submit_button = find_element_safe( + driver, + By.CSS_SELECTOR, + "button[type='submit']" + ) + + if not submit_button: + print("FAIL: Submit button not found\n") + return + + # Click submit with empty form + submit_button.click() + time.sleep(1) # Wait for validation messages + + # Check for Material-UI error messages + required_fields = ["age", "gender", "work_experience", "level_of_schooling"] + error_found = False + + for field in required_fields: + # Check for MUI error helper text + error_text = find_element_safe( + driver, + By.CSS_SELECTOR, + f"[name='{field}'] + .MuiFormHelperText-root.Mui-error" + ) + if error_text: + error_found = True + break + + if error_found: + print("PASS: Required field validation working\n") + else: + print("FAIL: No error messages found for required fields\n") + +def test_numeric_validation(driver): + """Test numeric field validation""" + print("\n#################### test_numeric_validation() ####################") + + numeric_tests = [ + {"name": "age", "invalid": "-1", "valid": "25", "min": 18, "max": 65}, + {"name": "reading_english_scale", "invalid": "11", "valid": "5", "min": 0, "max": 10}, + {"name": "speaking_english_scale", "invalid": "11", "valid": "5", "min": 0, "max": 10} + ] + + passes = 0 + total = len(numeric_tests) * 2 # Testing both invalid and valid values + + for test in numeric_tests: + input_field = find_element_safe( + driver, + By.CSS_SELECTOR, + f"input[name='{test['name']}']" + ) + + if not input_field: + continue + + # Test invalid value + input_field.clear() + input_field.send_keys(test["invalid"]) + input_field.click() # Trigger blur event + + # Check for error + error_message = find_element_safe( + driver, + By.CSS_SELECTOR, + f"[name='{test['name']}'] + .MuiFormHelperText-root.Mui-error" + ) + if error_message: + passes += 1 + + # Test valid value + input_field.clear() + input_field.send_keys(test["valid"]) + input_field.click() + + # Error should be gone + error_message = driver.find_elements( + By.CSS_SELECTOR, + f"[name='{test['name']}'] + .MuiFormHelperText-root.Mui-error" + ) + if not error_message: + passes += 1 + + if passes == total: + print("PASS: Numeric validation working\n") + else: + print(f"FAIL: {total - passes} numeric validations failed\n") + + +def test_clear_form(driver): + """Test clear form functionality""" + print("\n#################### test_clear_form() ####################") + + # Fill some fields + test_data = { + "age": "25", + "work_experience": "3", + "reading_english_scale": "5" + } + + for name, value in test_data.items(): + element = find_element_safe(driver, By.NAME, name) + if element: + element.clear() + element.send_keys(value) + + # Click clear button + clear_button = find_element_safe( + driver, + By.CSS_SELECTOR, + "button[color='secondary']" + ) + + if clear_button: + clear_button.click() + time.sleep(1) # Wait for form to clear + + # Check if fields are cleared + all_cleared = True + for name in test_data.keys(): + element = find_element_safe(driver, By.NAME, name) + if element and element.get_attribute("value") != "0": + all_cleared = False + break + + if all_cleared: + print("PASS: Form cleared successfully\n") + else: + print("FAIL: Form not properly cleared\n") + else: + print("FAIL: Clear button not found\n") -from itertools import product +def test_checkbox_toggle(driver): + """Test checkbox functionality""" + print("\n#################### test_checkbox_toggle() ####################") + + checkbox_fields = [ + "canada_born", + "fluent_english", + "transportation_bool", + "caregiver_bool" + ] + + passes = 0 + total = len(checkbox_fields) + + for field in checkbox_fields: + checkbox = find_element_safe( + driver, + By.CSS_SELECTOR, + f"input[name='{field}'][type='checkbox']" + ) + + if checkbox: + # Test toggle + initial_state = checkbox.is_selected() + checkbox.click() + new_state = checkbox.is_selected() + + if initial_state != new_state: + passes += 1 + else: + print(f"FAIL: Checkbox {field} not toggling properly\n") + else: + print(f"FAIL: Checkbox {field} not found\n") + + if passes == total: + print("PASS: All checkboxes working\n") + else: + print(f"FAIL: {total - passes} checkbox tests failed\n") + + +# 7-3-4 Test Input Data Validation for /predict Endpoint +# Test if the endpoint returns appropriate errors when required fields are missing. +# Simulate various types of input data (e.g., negative numbers, overly long strings, non-JSON formats) and ensure the endpoint can handle them and return clear error messages. +# Verify that type conversion in the input data is handled correctly (e.g., converting "Yes" and "No" to boolean values). +client = TestClient(app) +def test_valid_input(): + """Test endpoint with valid input data""" + print("\n#################### test_valid_input() ####################") + response = client.post("/predict", json=test_model) + + if response.status_code == 200: + result = response.json() + if "baseline" in result and "interventions" in result: + print("PASS\n") + return + print(f"FAIL: Invalid response for valid input: {response.json()}\n") + +def test_missing_required_fields(): + """Test endpoint with missing required fields""" + print("\n#################### test_missing_required_fields() ####################") + + required_fields = [ + "age", "gender", "work_experience", "level_of_schooling", + "reading_english_scale", "speaking_english_scale", "writing_english_scale" + ] + + passes = 0 + total = len(required_fields) + + for field in required_fields: + invalid_data = test_model.copy() + del invalid_data[field] + + response = client.post("/predict", json=invalid_data) + if response.status_code == 422: # FastAPI validation error status code + error_detail = response.json().get("detail", []) + if any(field in str(err) for err in error_detail): + passes += 1 + else: + print(f"FAIL: Missing appropriate error message for {field}\n") + else: + print(f"FAIL: Incorrect status code for missing {field}: {response.status_code}\n") + + if passes == total: + print("PASS: All missing field validations working\n") + else: + print(f"FAIL: {total - passes} missing field validations failed\n") + +def test_invalid_numeric_values(): + """Test endpoint with invalid numeric values""" + print("\n#################### test_invalid_numeric_values() ####################") + + numeric_test_cases = [ + {"field": "age", "value": -1, "error_type": "value_error"}, + {"field": "age", "value": 150, "error_type": "value_error"}, + {"field": "work_experience", "value": -5, "error_type": "value_error"}, + {"field": "reading_english_scale", "value": 6, "error_type": "value_error"}, + {"field": "dep_num", "value": -2, "error_type": "value_error"} + ] + + passes = 0 + total = len(numeric_test_cases) + + for case in numeric_test_cases: + test_data = test_model.copy() + test_data[case["field"]] = case["value"] + + response = client.post("/predict", json=test_data) + if response.status_code == 422: + error_detail = response.json().get("detail", []) + if any(case["error_type"] in str(err) for err in error_detail): + passes += 1 + else: + print(f"FAIL: Missing appropriate error message for {case['field']} = {case['value']}\n") + else: + print(f"FAIL: Incorrect status code for invalid {case['field']}: {response.status_code}\n") + + if passes == total: + print("PASS: All numeric validations working\n") + else: + print(f"FAIL: {total - passes} numeric validations failed\n") + +def test_invalid_string_lengths(): + """Test endpoint with overly long strings""" + print("\n#################### test_invalid_string_lengths() ####################") + + long_string = "a" * 1001 # String longer than maximum allowed length + string_test_cases = [ + {"field": "gender", "value": long_string}, + {"field": "citizen_status", "value": long_string}, + {"field": "housing", "value": long_string} + ] + + passes = 0 + total = len(string_test_cases) + + for case in string_test_cases: + test_data = test_model.copy() + test_data[case["field"]] = case["value"] + + response = client.post("/predict", json=test_data) + if response.status_code == 422: + error_detail = response.json().get("detail", []) + if any("length" in str(err).lower() for err in error_detail): + passes += 1 + else: + print(f"FAIL: Missing appropriate error message for long {case['field']}\n") + else: + print(f"FAIL: Incorrect status code for long {case['field']}: {response.status_code}\n") + + if passes == total: + print("PASS: All string length validations working\n") + else: + print(f"FAIL: {total - passes} string length validations failed\n") + +def test_boolean_conversion(): + """Test endpoint with various boolean value formats""" + print("\n#################### test_boolean_conversion() ####################") + + boolean_test_cases = [ + {"field": "canada_born", "values": ["true", "True", "1", True, "yes", "Yes"]}, + {"field": "fluent_english", "values": ["false", "False", "0", False, "no", "No"]} + ] + + passes = 0 + total = sum(len(case["values"]) for case in boolean_test_cases) + + for case in boolean_test_cases: + for value in case["values"]: + test_data = test_model.copy() + test_data[case["field"]] = value + + response = client.post("/predict", json=test_data) + if response.status_code == 200: + passes += 1 + else: + print(f"FAIL: Boolean conversion failed for {case['field']} = {value}\n") + + if passes == total: + print("PASS: All boolean conversions working\n") + else: + print(f"FAIL: {total - passes} boolean conversions failed\n") + +def test_invalid_json(): + """Test endpoint with invalid JSON format""" + print("\n#################### test_invalid_json() ####################") + + invalid_json_cases = [ + { + "description": "Invalid JSON syntax", + "content": b"{invalid_json" + }, + { + "description": "Plain text", + "content": b"not_json_at_all" + }, + { + "description": "Valid JSON but wrong format", + "content": b"[1, 2, 3]" + }, + { + "description": "Null value", + "content": b"null" + } + ] + + passes = 0 + total = len(invalid_json_cases) + + for case in invalid_json_cases: + response = client.post( + "/predict", + headers={"Content-Type": "application/json"}, + content=case["content"] # Using content with bytes + ) + + if response.status_code in [400, 422]: # Either is acceptable for invalid JSON + passes += 1 + else: + print(f"FAIL: Incorrect status code for {case['description']}: {response.status_code}\n") + + if passes == total: + print("PASS: All invalid JSON cases handled correctly\n") + else: + print(f"FAIL: {total - passes} invalid JSON cases failed\n") + +def test_content_type_validation(): + """Test endpoint with incorrect content types""" + print("\n#################### test_content_type_validation() ####################") + + content_type_cases = [ + { + "content_type": "text/plain", + "content": b"plain text" + }, + { + "content_type": "application/xml", + "content": b"data" + }, + { + "content_type": "multipart/form-data", + "content": json.dumps(test_model).encode() + } + ] + + passes = 0 + total = len(content_type_cases) + + for case in content_type_cases: + response = client.post( + "/predict", + headers={"Content-Type": case["content_type"]}, + content=case["content"] + ) + + if response.status_code in [400, 415]: + passes += 1 + else: + print(f"FAIL: Incorrect status code for content type {case['content_type']}: {response.status_code}\n") + + if passes == total: + print("PASS: All content type validations working\n") + else: + print(f"FAIL: {total - passes} content type validations failed\n") + +def test_empty_request(): + """Test endpoint with empty request body""" + print("\n#################### test_empty_request() ####################") + + response = client.post("/predict", json={}) + + if response.status_code == 422: + error_detail = response.json().get("detail", []) + if any("required" in str(err).lower() for err in error_detail): + print("PASS\n") + return + print(f"FAIL: Incorrect handling of empty request: {response.status_code}\n") + + +# def test_form_submission(driver): +# """Test form submission""" +# print("\n#################### test_form_submission() ####################") + +# try: +# # Print page title for debugging +# print(f"Page Title: {driver.title}") + +# # Fill text fields +# text_fields = { +# "age": "25", +# "work_experience": "3", +# "reading_english_scale": "5", +# "speaking_english_scale": "5", +# "writing_english_scale": "5" +# } + +# for name, value in text_fields.items(): +# try: +# element = WebDriverWait(driver, 10).until( +# EC.presence_of_element_located((By.NAME, name)) +# ) +# element.clear() +# element.send_keys(value) +# print(f"Successfully filled {name} with {value}") +# except Exception as e: +# print(f"Error filling {name}: {str(e)}") +# raise + +# # Handle Material-UI select fields +# select_fields = { +# "gender": "Female", +# "level_of_schooling": "Bachelor's degree" +# } + +# for name, value in select_fields.items(): +# try: +# # Print available elements for debugging +# print(f"\nLooking for {name} select field...") +# elements = driver.find_elements(By.CSS_SELECTOR, f'label:contains("{name.replace("_", " ").title()}")') +# print(f"Found {len(elements)} potential elements for {name}") + +# # Try different selectors for the select field +# select_selectors = [ +# f'//label[contains(text(), "{name.replace("_", " ").title()}")]/following-sibling::div', +# f'.MuiFormControl-root:has(label[contains(text(), "{name.replace("_", " ").title()}")]) .MuiSelect-select', +# f'[aria-label="{name.replace("_", " ").title()}"]' +# ] + +# select_element = None +# for selector in select_selectors: +# try: +# print(f"Trying selector: {selector}") +# if selector.startswith('//'): +# select_element = WebDriverWait(driver, 3).until( +# EC.element_to_be_clickable((By.XPATH, selector)) +# ) +# else: +# select_element = WebDriverWait(driver, 3).until( +# EC.element_to_be_clickable((By.CSS_SELECTOR, selector)) +# ) +# if select_element: +# print(f"Found select element for {name}") +# break +# except: +# continue + +# if not select_element: +# print(f"Could not find select element for {name}") +# # Try JavaScript click on the Select component +# js_script = f""" +# const labels = Array.from(document.querySelectorAll('label')); +# const label = labels.find(l => l.textContent.includes('{name.replace("_", " ").title()}')); +# if (label) {{ +# const select = label.parentElement.querySelector('.MuiSelect-select'); +# if (select) select.click(); +# }} +# """ +# driver.execute_script(js_script) +# time.sleep(1) +# else: +# driver.execute_script("arguments[0].scrollIntoView(true);", select_element) +# driver.execute_script("arguments[0].click();", select_element) +# time.sleep(1) + +# # Try to find the option +# print(f"Looking for option: {value}") +# option_selectors = [ +# f'//li[contains(@class, "MuiMenuItem-root") and text()="{value}"]', +# f'.MuiPopover-paper li[data-value="{value}"]', +# f'.MuiMenu-paper li:contains("{value}")' +# ] + +# option_found = False +# for selector in option_selectors: +# try: +# print(f"Trying option selector: {selector}") +# if selector.startswith('//'): +# option = WebDriverWait(driver, 3).until( +# EC.element_to_be_clickable((By.XPATH, selector)) +# ) +# else: +# option = WebDriverWait(driver, 3).until( +# EC.element_to_be_clickable((By.CSS_SELECTOR, selector)) +# ) +# driver.execute_script("arguments[0].click();", option) +# option_found = True +# print(f"Successfully selected option {value} for {name}") +# break +# except: +# continue + +# if not option_found: +# print(f"Could not find option {value} for {name}") +# raise Exception(f"Option {value} not found for {name}") + +# except Exception as e: +# print(f"Error handling select field {name}: {str(e)}") +# raise + +# time.sleep(1) + +# # Submit form +# try: +# # Try different submit button selectors +# submit_selectors = [ +# 'button[type="submit"]', +# 'button.MuiButton-containedPrimary', +# '//button[contains(text(), "Submit")]' +# ] + +# submit_button = None +# for selector in submit_selectors: +# try: +# print(f"Trying submit button selector: {selector}") +# if selector.startswith('//'): +# submit_button = WebDriverWait(driver, 3).until( +# EC.element_to_be_clickable((By.XPATH, selector)) +# ) +# else: +# submit_button = WebDriverWait(driver, 3).until( +# EC.element_to_be_clickable((By.CSS_SELECTOR, selector)) +# ) +# if submit_button: +# break +# except: +# continue + +# if submit_button: +# driver.execute_script("arguments[0].scrollIntoView(true);", submit_button) +# driver.execute_script("arguments[0].click();", submit_button) +# print("Form submitted") + +# WebDriverWait(driver, 10).until( +# lambda d: "/results" in d.current_url +# ) +# print("PASS: Form submitted successfully\n") +# else: +# print("FAIL: Submit button not found\n") +# raise Exception("Submit button not found") + +# except Exception as e: +# print(f"Error submitting form: {str(e)}") +# raise + +# except Exception as e: +# print(f"FAIL: Error during form submission: {str(e)}\n") +# raise e + + +# # Add debugging helper +# def print_elements(driver): +# """Helper function to print visible elements""" +# elements = driver.find_elements(By.CSS_SELECTOR, '*') +# for element in elements: +# try: +# if element.is_displayed(): +# tag_name = element.tag_name +# class_name = element.get_attribute('class') +# element_text = element.text +# print(f"Tag: {tag_name}, Class: {class_name}, Text: {element_text}") +# except: +# continue + +# # raw data from front end +# test_model = { +# "age": "18", +# "gender": "M", +# "work_experience": "3", +# "canada_workex": 0, +# "dep_num": "1", +# "canada_born": "true", +# "citizen_status": "citizen", +# "level_of_schooling": "Grade 12 or equivalent", +# "fluent_english": "true", +# "reading_english_scale": "3", +# "speaking_english_scale": "1", +# "writing_english_scale": "3", +# "numeracy_scale": 0, +# "computer_scale": "2", +# "transportation_bool": "false", +# "caregiver_bool": "true", +# "housing": "Living with family/friend", +# "income_source": "No Source of Income", +# "felony_bool": "true", +# "attending_school": "false", +# "currently_employed": "true", +# "substance_use": "true", +# "time_unemployed": "1", +# "need_mental_health_support_bool": "false" +# } + +# # the converted data +# test_converted = {'age': 18, 'gender': 1, 'work_experience': 3, 'canada_workex': 0, 'dep_num': 1, 'canada_born': True, +# 'citizen_status': 0, 'level_of_schooling': 5, 'fluent_english': True, 'reading_english_scale': 3, +# 'speaking_english_scale': 1, 'writing_english_scale': 3, 'numeracy_scale': 0, 'computer_scale': 2, +# 'transportation_bool': False, 'caregiver_bool': True, 'housing': 5, 'income_source': 1, +# 'felony_bool': True, 'attending_school': False, 'currently_employed': True, 'substance_use': True, +# 'time_unemployed': 1, 'need_mental_health_support_bool': False} + +# # output after data converting +# test_output = [18, 1, 3, 0, 1, 1, 0, 5, 1, 3, 1, 3, 0, 2, 0, 1, 5, 1, 1, 0, 1, 1, 1, 0] + +# # prediction result +# test_prediction_result = { +# "baseline": 67.6, +# "interventions": [ +# (68.7, ["Life Stabilization", "General Employment Assistance Services", "Specialized Services", "Employment-Related Financial Supports for Job Seekers and Employers"]), +# (68.7, ["Life Stabilization", "Specialized Services", "Employment-Related Financial Supports for Job Seekers and Employers"]), +# (69.0, ["Life Stabilization", "Specialized Services"])] +# } + + + +# # ordered features for test +# test_features = ['age', 'gender', 'work_experience', 'canada_workex', 'dep_num', 'canada_born', 'citizen_status', +# 'level_of_schooling', 'fluent_english', 'reading_english_scale', 'speaking_english_scale', +# 'writing_english_scale', 'numeracy_scale', 'computer_scale', 'transportation_bool', +# 'caregiver_bool', 'housing', 'income_source', 'felony_bool', 'attending_school', +# 'currently_employed', 'substance_use', 'time_unemployed', 'need_mental_health_support_bool'] + + +# def get_data(): +# return PredictionInput(**test_model).model_dump(by_alias=True) + + +# def test_dump_model(): +# """test to check the dumped model""" +# print("\n#################### test_dump_model() ####################") +# refactor_converted = get_data() +# if refactor_converted == test_converted: +# print("PASS\n") +# else: +# print("FAIL\n") + + +# def test_clean_input_data(): +# """test data type conversion""" +# print("\n#################### test_clean_input_data() ####################") +# data = get_data() +# refactor_output = clean_input_data(data, test_features) + +# if len(test_output) != len(refactor_output): +# print("FAIL: len not equals\n") +# return + +# for i in range(len(test_output)): +# if test_output[i] != refactor_output[i]: +# print("FAIL: the {} th element not equals. origin:{}, refactor:{}\n".format(i, test_output[i], refactor_output[i])) +# return +# print("PASS\n") + + +# def test_prediction(): +# """test the whole prediction process""" +# print("\n#################### test_prediction() ####################") +# data = get_data() +# result = interpret_and_calculate(data) +# if result == test_prediction_result: +# print("PASS\n") +# else: +# print("FAIL\n") -# Cartesian product of [0, 1] repeated 2 times -result = list(product([0, 1], repeat=2)) +# #################### Test Data and Methods #################### -# Output: [(0, 0), (0, 1), (1, 0), (1, 1)] -print(result) +# # original order of columns: +# test_original_cols = ['age', 'gender', 'work_experience', 'canada_workex', 'dep_num', 'canada_born', 'citizen_status', +# 'level_of_schooling', 'fluent_english', 'reading_english_scale', 'speaking_english_scale', +# 'writing_english_scale', 'numeracy_scale', 'computer_scale', 'transportation_bool', +# 'caregiver_bool', 'housing', 'income_source', 'felony_bool', 'attending_school', +# 'currently_employed', 'substance_use', 'time_unemployed', 'need_mental_health_support_bool'] -result = list(combinations_with_replacement([0, 1], 2)) -# Output: [(0, 0), (0, 1), (1, 1)] -print(result) \ No newline at end of file +# def test_column_order(): +# print("\n#################### test_data_type_conversion() ####################") +# cols = util_get_cols() +# # print(cols) +# if cols == test_original_cols: +# print("PASS") +# else: +# print("FAIL") \ No newline at end of file