diff --git a/.env b/.env new file mode 100644 index 00000000..f292cb8d --- /dev/null +++ b/.env @@ -0,0 +1,3 @@ +SECRET_KEY = "your-secret-key-here" +ALGORITHM = "HS256" +ACCESS_TOKEN_EXPIRE_MINUTES = 30 \ No newline at end of file diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 00000000..dee81e31 --- /dev/null +++ b/.pylintrc @@ -0,0 +1,637 @@ +[MAIN] + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + +# Clear in-memory caches upon conclusion of linting. Useful if running pylint +# in a server-like mode. +clear-cache-post-run=no + +# Load and enable all available extensions. Use --list-extensions to see a list +# all available extensions. +#enable-all-extensions= + +# In error mode, messages with a category besides ERROR or FATAL are +# suppressed, and no reports are done by default. Error mode is compatible with +# disabling specific errors. +#errors-only= + +# Always return a 0 (non-error) status code, even if lint errors are found. +# This is primarily useful in continuous integration scripts. +#exit-zero= + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. +extension-pkg-allow-list= + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. (This is an alternative name to extension-pkg-allow-list +# for backward compatibility.) +extension-pkg-whitelist= + +# Return non-zero exit code if any of these messages/categories are detected, +# even if score is above --fail-under value. Syntax same as enable. Messages +# specified are enabled, while categories only check already-enabled messages. +fail-on= + +# Specify a score threshold under which the program will exit with error. +fail-under=10 + +# Interpret the stdin as a python script, whose filename needs to be passed as +# the module_or_package argument. +#from-stdin= + +# Files or directories to be skipped. They should be base names, not paths. +ignore=CVS + +# Add files or directories matching the regular expressions patterns to the +# ignore-list. The regex matches against paths and can be in Posix or Windows +# format. Because '\\' represents the directory delimiter on Windows systems, +# it can't be used as an escape character. +ignore-paths= + +# Files or directories matching the regular expression patterns are skipped. +# The regex matches against base names, not paths. The default value ignores +# Emacs file locks +ignore-patterns=^\.# + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis). It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules= + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the +# number of processors available to use, and will cap the count on Windows to +# avoid hangs. +jobs=1 + +# Control the amount of potential inferred values when inferring a single +# object. This can help the performance when dealing with large functions or +# complex, nested conditions. +limit-inference-results=100 + +# List of plugins (as comma separated values of python module names) to load, +# usually to register additional checkers. +load-plugins= + +# Pickle collected data for later comparisons. +persistent=yes + +# Minimum Python version to use for version dependent checks. Will default to +# the version used to run pylint. +py-version=3.10 + +# Discover python modules and packages in the file system subtree. +recursive=no + +# Add paths to the list of the source roots. Supports globbing patterns. The +# source root is an absolute path or a path relative to the current working +# directory used to determine a package namespace for modules located under the +# source root. +source-roots= + +# When enabled, pylint would attempt to guess common misconfiguration and emit +# user-friendly hints instead of false-positive error messages. +suggestion-mode=yes + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + +# In verbose mode, extra non-checker-related info will be displayed. +#verbose= + + +[BASIC] + +# Naming style matching correct argument names. +argument-naming-style=snake_case + +# Regular expression matching correct argument names. Overrides argument- +# naming-style. If left empty, argument names will be checked with the set +# naming style. +#argument-rgx= + +# Naming style matching correct attribute names. +attr-naming-style=snake_case + +# Regular expression matching correct attribute names. Overrides attr-naming- +# style. If left empty, attribute names will be checked with the set naming +# style. +#attr-rgx= + +# Bad variable names which should always be refused, separated by a comma. +bad-names=foo, + bar, + baz, + toto, + tutu, + tata + +# Bad variable names regexes, separated by a comma. If names match any regex, +# they will always be refused +bad-names-rgxs= + +# Naming style matching correct class attribute names. +class-attribute-naming-style=any + +# Regular expression matching correct class attribute names. Overrides class- +# attribute-naming-style. If left empty, class attribute names will be checked +# with the set naming style. +#class-attribute-rgx= + +# Naming style matching correct class constant names. +class-const-naming-style=UPPER_CASE + +# Regular expression matching correct class constant names. Overrides class- +# const-naming-style. If left empty, class constant names will be checked with +# the set naming style. +#class-const-rgx= + +# Naming style matching correct class names. +class-naming-style=PascalCase + +# Regular expression matching correct class names. Overrides class-naming- +# style. If left empty, class names will be checked with the set naming style. +#class-rgx= + +# Naming style matching correct constant names. +const-naming-style=UPPER_CASE + +# Regular expression matching correct constant names. Overrides const-naming- +# style. If left empty, constant names will be checked with the set naming +# style. +#const-rgx= + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=-1 + +# Naming style matching correct function names. +function-naming-style=snake_case + +# Regular expression matching correct function names. Overrides function- +# naming-style. If left empty, function names will be checked with the set +# naming style. +#function-rgx= + +# Good variable names which should always be accepted, separated by a comma. +good-names=i, + j, + k, + ex, + Run, + _ + +# Good variable names regexes, separated by a comma. If names match any regex, +# they will always be accepted +good-names-rgxs= + +# Include a hint for the correct naming format with invalid-name. +include-naming-hint=no + +# Naming style matching correct inline iteration names. +inlinevar-naming-style=any + +# Regular expression matching correct inline iteration names. Overrides +# inlinevar-naming-style. If left empty, inline iteration names will be checked +# with the set naming style. +#inlinevar-rgx= + +# Naming style matching correct method names. +method-naming-style=snake_case + +# Regular expression matching correct method names. Overrides method-naming- +# style. If left empty, method names will be checked with the set naming style. +#method-rgx= + +# Naming style matching correct module names. +module-naming-style=snake_case + +# Regular expression matching correct module names. Overrides module-naming- +# style. If left empty, module names will be checked with the set naming style. +#module-rgx= + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=^_ + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +# These decorators are taken in consideration only for invalid-name. +property-classes=abc.abstractproperty + +# Regular expression matching correct type alias names. If left empty, type +# alias names will be checked with the set naming style. +#typealias-rgx= + +# Regular expression matching correct type variable names. If left empty, type +# variable names will be checked with the set naming style. +#typevar-rgx= + +# Naming style matching correct variable names. +variable-naming-style=snake_case + +# Regular expression matching correct variable names. Overrides variable- +# naming-style. If left empty, variable names will be checked with the set +# naming style. +#variable-rgx= + + +[CLASSES] + +# Warn about protected attribute access inside special methods +check-protected-access-in-special-methods=no + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp, + asyncSetUp, + __post_init__ + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict,_fields,_replace,_source,_make,os._exit + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs + + +[DESIGN] + +# List of regular expressions of class ancestor names to ignore when counting +# public methods (see R0903) +exclude-too-few-public-methods= + +# List of qualified class names to ignore when counting class parents (see +# R0901) +ignored-parents= + +# Maximum number of arguments for function / method. +max-args=5 + +# Maximum number of attributes for a class (see R0902). +max-attributes=7 + +# Maximum number of boolean expressions in an if statement (see R0916). +max-bool-expr=5 + +# Maximum number of branch for function / method body. +max-branches=12 + +# Maximum number of locals for function / method body. +max-locals=15 + +# Maximum number of parents for a class (see R0901). +max-parents=7 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + +# Maximum number of return / yield for function / method body. +max-returns=6 + +# Maximum number of statements in function / method body. +max-statements=50 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=2 + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when caught. +overgeneral-exceptions=builtins.BaseException,builtins.Exception + + +[FORMAT] + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^\s*(# )??$ + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + +# Maximum number of characters on a single line. +max-line-length=100 + +# Maximum number of lines in a module. +max-module-lines=1000 + +# Allow the body of a class to be on the same line as the declaration if body +# contains single statement. +single-line-class-stmt=no + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=no + + +[IMPORTS] + +# List of modules that can be imported at any level, not just the top level +# one. +allow-any-import-level= + +# Allow explicit reexports by alias from a package __init__. +allow-reexport-from-package=no + +# Allow wildcard imports from modules that define __all__. +allow-wildcard-with-all=no + +# Deprecated modules which should not be used, separated by a comma. +deprecated-modules= + +# Output a graph (.gv or any supported image format) of external dependencies +# to the given file (report RP0402 must not be disabled). +ext-import-graph= + +# Output a graph (.gv or any supported image format) of all (i.e. internal and +# external) dependencies to the given file (report RP0402 must not be +# disabled). +import-graph= + +# Output a graph (.gv or any supported image format) of internal dependencies +# to the given file (report RP0402 must not be disabled). +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant + +# Couples of modules and preferred modules, separated by a comma. +preferred-modules= + + +[LOGGING] + +# The type of string formatting that logging methods do. `old` means using % +# formatting, `new` is for `{}` formatting. +logging-format-style=old + +# Logging modules to check that the string format arguments are in logging +# function parameter format. +logging-modules=logging + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE, +# UNDEFINED. +confidence=HIGH, + CONTROL_FLOW, + INFERENCE, + INFERENCE_FAILURE, + UNDEFINED + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once). You can also use "--disable=all" to +# disable everything first and then re-enable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use "--disable=all --enable=classes +# --disable=W". +disable=raw-checker-failed, + bad-inline-option, + locally-disabled, + file-ignored, + suppressed-message, + useless-suppression, + deprecated-pragma, + use-symbolic-message-instead, + use-implicit-booleaness-not-comparison-to-string, + use-implicit-booleaness-not-comparison-to-zero, + missing-module-docstring, + missing-class-docstring, + missing-function-docstring + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +enable= + + +[METHOD_ARGS] + +# List of qualified names (i.e., library.method) which require a timeout +# parameter e.g. 'requests.api.get,requests.api.post' +timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME, + XXX, + TODO + +# Regular expression of note tags to take in consideration. +notes-rgx= + + +[REFACTORING] + +# Maximum number of nested blocks for function / method body +max-nested-blocks=5 + +# Complete name of functions that never returns. When checking for +# inconsistent-return-statements if a never returning function is called then +# it will be considered as an explicit return statement and no message will be +# printed. +never-returning-functions=sys.exit,argparse.parse_error + + +[REPORTS] + +# Python expression which should return a score less than or equal to 10. You +# have access to the variables 'fatal', 'error', 'warning', 'refactor', +# 'convention', and 'info' which contain the number of messages in each +# category, as well as 'statement' which is the total number of statements +# analyzed. This score is used by the global evaluation report (RP0004). +evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details. +msg-template= + +# Set the output format. Available formats are: text, parseable, colorized, +# json2 (improved json format), json (old json format) and msvs (visual +# studio). You can also give a reporter class, e.g. +# mypackage.mymodule.MyReporterClass. +#output-format=colorized + +# Tells whether to display a full report or only the messages. +reports=yes + +# Activate the evaluation score. +score=yes + + +[SIMILARITIES] + +# Comments are removed from the similarity computation +ignore-comments=yes + +# Docstrings are removed from the similarity computation +ignore-docstrings=yes + +# Imports are removed from the similarity computation +ignore-imports=yes + +# Signatures are removed from the similarity computation +ignore-signatures=yes + +# Minimum lines number of a similarity. +min-similarity-lines=4 + + +[SPELLING] + +# Limits count of emitted suggestions for spelling mistakes. +max-spelling-suggestions=4 + +# Spelling dictionary name. No available dictionaries : You need to install +# both the python package and the system dependency for enchant to work. +spelling-dict= + +# List of comma separated words that should be considered directives if they +# appear at the beginning of a comment and should not be checked. +spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy: + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains the private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to the private dictionary (see the +# --spelling-private-dict-file option) instead of raising a message. +spelling-store-unknown-words=no + + +[STRING] + +# This flag controls whether inconsistent-quotes generates a warning when the +# character used as a quote delimiter is used inconsistently within a module. +check-quote-consistency=no + +# This flag controls whether the implicit-str-concat should generate a warning +# on implicit string concatenation in sequences defined over several lines. +check-str-concat-over-line-jumps=no + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + +# Tells whether to warn about missing members when the owner of the attribute +# is inferred to be None. +ignore-none=yes + +# This flag controls whether pylint should warn about no-member and similar +# checks whenever an opaque object is returned when inferring. The inference +# can return multiple potential results while evaluating a Python object, but +# some branches might not be evaluated, which results in partial inference. In +# that case, it might be useful to still emit no-member and other checks for +# the rest of the inferred objects. +ignore-on-opaque-inference=yes + +# List of symbolic message names to ignore for Mixin members. +ignored-checks-for-mixins=no-member, + not-async-context-manager, + not-context-manager, + attribute-defined-outside-init + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace + +# Show a hint with possible names when a member name was not found. The aspect +# of finding the hint is based on edit distance. +missing-member-hint=yes + +# The minimum edit distance a name should have in order to be considered a +# similar match for a missing member name. +missing-member-hint-distance=1 + +# The total number of similar names that should be taken in consideration when +# showing a hint for a missing member. +missing-member-max-choices=1 + +# Regex pattern to define which classes are considered mixins. +mixin-class-rgx=.*[Mm]ixin + +# List of decorators that change the signature of a decorated function. +signature-mutators= + + +[VARIABLES] + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid defining new builtins when possible. +additional-builtins= + +# Tells whether unused global variables should be treated as a violation. +allow-global-unused-variables=yes + +# List of names allowed to shadow builtins +allowed-redefined-builtins= + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_, + _cb + +# A regular expression matching the name of dummy variables (i.e. expected to +# not be used). +dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ + +# Argument names that match this expression will be ignored. +ignored-argument-names=_.*|^ignored_|^unused_ + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io diff --git a/README.md b/README.md index b34d6d6b..70b0c190 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -Team TicTech +Team SuperSonics via TicTech Project -- Feature Development Backend: Create CRUD API's for Client @@ -22,15 +22,22 @@ This also has an API file to interact with the front end, and logic in order to -------------------------How to Use------------------------- 1. In the virtual environment you've created for this project, install all dependencies in requirements.txt (pip install -r requirements.txt) -2. Run the app (uvicorn app.main:app --reload) +2. Create a .env file with the following fields: +```markdown +SECRET_KEY = "your-secret-key-here" +ALGORITHM = "HS256" +ACCESS_TOKEN_EXPIRE_MINUTES = 30 +``` -3. Load data into database (python initialize_data.py) +3. Run the app (uvicorn app.main:app --reload) -4. Go to SwaggerUI (http://127.0.0.1:8000/docs) +4. Go to SwaggerUI (http://127.0.0.1:8000/docs) -4. Log in as admin (username: admin password: admin123) +5. Load data into database (python initialize_data.py) (if receiving an error, make sure the app is running and open, then try again) -5. Click on each endpoint to use +6. Log in as admin (username: admin password: admin123) + +7. Click on each endpoint to use -Create User (Only users in admin role can create new users. The role field needs to be either "admin" or "case_worker") -Get clients (Display all the clients that are in the database) @@ -55,3 +62,20 @@ This also has an API file to interact with the front end, and logic in order to -Create case assignment (Allow authorized users to create a new case assignment.) +## Docker Instructions +1. Follow installation guide from Docker: https://www.docker.com/blog/how-to-dockerize-your-python-applications/ +2. WINDOWS-SPECIFIC: Ensure virtualization is enabled in your system BIOS, or Docker cannot run +3. Open the Docker Desktop application +4. In a command prompt, navigate to the CommonAssessmentTool repo's directory on your machine (assumes you already cloned from GitHub) and run the command below (make sure the period at the end is included!): +``` +docker build -t common_assessment_tool . +``` +5. Now run with the following Docker command: +``` +docker run --rm -p 8000:8000 common_assessment_tool +``` +6. Follow the steps to run the Swagger UI as described above (clicking link in step 5 should take you to the UI) +7. To run using Docker-Compose, run the command below in the CommonAssessmentTool repo's directory +``` +docker compose up +``` \ No newline at end of file diff --git a/app/auth/router.py b/app/auth/router.py index 229ee71d..50e07890 100644 --- a/app/auth/router.py +++ b/app/auth/router.py @@ -1,28 +1,35 @@ from datetime import datetime, timedelta from typing import Optional + from fastapi import APIRouter, Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from jose import JWTError, jwt +from passlib.context import CryptContext +from pydantic import BaseModel, Field, validator from sqlalchemy.orm import Session + from app.database import get_db from app.models import User, UserRole -from passlib.context import CryptContext -from pydantic import BaseModel, Field, validator + +from dotenv import load_dotenv +import os router = APIRouter(prefix="/auth", tags=["authentication"]) + class UserCreate(BaseModel): username: str = Field(..., min_length=3, max_length=50) email: str password: str role: UserRole - @validator('role') + @validator("role") def validate_role(cls, v): if v not in [UserRole.admin, UserRole.case_worker]: - raise ValueError('Role must be either admin or case_worker') + raise ValueError("Role must be either admin or case_worker") return v + class UserResponse(BaseModel): username: str email: str @@ -31,26 +38,32 @@ class UserResponse(BaseModel): class Config: from_attributes = True -# Configuration -SECRET_KEY = "your-secret-key-here" -ALGORITHM = "HS256" -ACCESS_TOKEN_EXPIRE_MINUTES = 30 + +# Load configuration from .env +load_dotenv() +SECRET_KEY = os.getenv("SECRET_KEY") +ALGORITHM = os.getenv("ALGORITHM") +ACCESS_TOKEN_EXPIRE_MINUTES = os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES") pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/token") + def verify_password(plain_password: str, hashed_password: str) -> bool: return pwd_context.verify(plain_password, hashed_password) + def get_password_hash(password: str) -> str: return pwd_context.hash(password) + def authenticate_user(db: Session, username: str, password: str) -> Optional[User]: user = db.query(User).filter(User.username == username).first() if not user or not verify_password(password, user.hashed_password): return None return user + def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): to_encode = data.copy() if expires_delta: @@ -61,9 +74,9 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) return encoded_jwt + async def get_current_user( - token: str = Depends(oauth2_scheme), - db: Session = Depends(get_db) + token: str = Depends(oauth2_scheme), db: Session = Depends(get_db) ) -> User: credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -77,24 +90,25 @@ async def get_current_user( raise credentials_exception except JWTError: raise credentials_exception - + user = db.query(User).filter(User.username == username).first() if user is None: raise credentials_exception return user + def get_admin_user(current_user: User = Depends(get_current_user)): if current_user.role != UserRole.admin: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail="Only admin users can perform this operation" + detail="Only admin users can perform this operation", ) return current_user + @router.post("/token") async def login_for_access_token( - form_data: OAuth2PasswordRequestForm = Depends(), - db: Session = Depends(get_db) + form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db) ): user = authenticate_user(db, form_data.username, form_data.password) if not user: @@ -103,31 +117,30 @@ async def login_for_access_token( detail="Incorrect username or password", headers={"WWW-Authenticate": "Bearer"}, ) - access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + access_token_expires = timedelta(minutes=int(ACCESS_TOKEN_EXPIRE_MINUTES)) access_token = create_access_token( data={"sub": user.username}, expires_delta=access_token_expires ) return {"access_token": access_token, "token_type": "bearer"} + @router.post("/users", response_model=UserResponse) async def create_user( user_data: UserCreate, current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """Create a new user (admin only)""" # Check if username exists if db.query(User).filter(User.username == user_data.username).first(): raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Username already registered" + status_code=status.HTTP_400_BAD_REQUEST, detail="Username already registered" ) - + # Check if email exists if db.query(User).filter(User.email == user_data.email).first(): raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Email already registered" + status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered" ) # Create new user @@ -135,9 +148,9 @@ async def create_user( username=user_data.username, email=user_data.email, hashed_password=get_password_hash(user_data.password), - role=user_data.role + role=user_data.role, ) - + try: db.add(db_user) db.commit() @@ -145,7 +158,4 @@ async def create_user( return db_user except Exception as e: db.rollback() - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=str(e) - ) + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) diff --git a/app/clients/router.py b/app/clients/router.py index 4ecc83e4..45c385e5 100644 --- a/app/clients/router.py +++ b/app/clients/router.py @@ -3,42 +3,44 @@ Handles all HTTP requests for client operations including create, read, update, and delete. """ -from fastapi import APIRouter, Depends, HTTPException, status, Query -from sqlalchemy.orm import Session from typing import List, Optional -from app.auth.router import get_current_user, get_admin_user -from app.models import User, UserRole -from app.database import get_db -from app.clients.service.client_service import ClientService +from fastapi import APIRouter, Depends, HTTPException, Query, status +from sqlalchemy.orm import Session + +from app.auth.router import get_admin_user, get_current_user from app.clients.schema import ( - ClientResponse, - ClientUpdate, ClientListResponse, + ClientResponse, + ClientUpdate, ServiceResponse, - ServiceUpdate + ServiceUpdate, ) +from app.clients.service.client_service import ClientService +from app.database import get_db +from app.models import User, UserRole router = APIRouter(prefix="/clients", tags=["clients"]) + @router.get("/", response_model=ClientListResponse) async def get_clients( - current_user: User = Depends(get_admin_user), + current_user: User = Depends(get_admin_user), skip: int = Query(default=0, ge=0, description="Number of records to skip"), limit: int = Query(default=50, ge=1, le=150, description="Maximum number of records to return"), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): return ClientService.get_clients(db, skip, limit) + @router.get("/{client_id}", response_model=ClientResponse) async def get_client( - client_id: int, - current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) + client_id: int, current_user: User = Depends(get_admin_user), db: Session = Depends(get_db) ): """Get a specific client by ID""" return ClientService.get_client(db, client_id) + @router.get("/search/by-criteria", response_model=List[ClientResponse]) async def get_clients_by_criteria( employment_status: Optional[bool] = None, @@ -66,7 +68,7 @@ async def get_clients_by_criteria( time_unemployed: Optional[int] = Query(None, ge=0), need_mental_health_support_bool: Optional[bool] = None, current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """Search clients by any combination of criteria""" return ClientService.get_clients_by_criteria( @@ -94,9 +96,10 @@ async def get_clients_by_criteria( attending_school=attending_school, substance_use=substance_use, time_unemployed=time_unemployed, - need_mental_health_support_bool=need_mental_health_support_bool + need_mental_health_support_bool=need_mental_health_support_bool, ) + @router.get("/search/by-services", response_model=List[ClientResponse]) async def get_clients_by_services( employment_assistance: Optional[bool] = None, @@ -107,7 +110,7 @@ async def get_clients_by_services( employer_financial_supports: Optional[bool] = None, enhanced_referrals: Optional[bool] = None, current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """Get clients filtered by multiple service statuses""" return ClientService.get_clients_by_services( @@ -118,70 +121,73 @@ async def get_clients_by_services( specialized_services=specialized_services, employment_related_financial_supports=employment_related_financial_supports, employer_financial_supports=employer_financial_supports, - enhanced_referrals=enhanced_referrals + enhanced_referrals=enhanced_referrals, ) + @router.get("/{client_id}/services", response_model=List[ServiceResponse]) async def get_client_services( - client_id: int, - current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) + client_id: int, current_user: User = Depends(get_admin_user), db: Session = Depends(get_db) ): """Get all services and their status for a specific client, including case worker info""" return ClientService.get_client_services(db, client_id) + @router.get("/search/success-rate", response_model=List[ClientResponse]) async def get_clients_by_success_rate( min_rate: int = Query(70, ge=0, le=100, description="Minimum success rate percentage"), current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """Get clients with success rate above specified threshold""" return ClientService.get_clients_by_success_rate(db, min_rate) + @router.get("/case-worker/{case_worker_id}", response_model=List[ClientResponse]) async def get_clients_by_case_worker( case_worker_id: int, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ): return ClientService.get_clients_by_case_worker(db, case_worker_id) + @router.put("/{client_id}", response_model=ClientResponse) async def update_client( client_id: int, client_data: ClientUpdate, current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """Update a client's information""" return ClientService.update_client(db, client_id, client_data) + @router.put("/{client_id}/services/{user_id}", response_model=ServiceResponse) async def update_client_services( client_id: int, user_id: int, service_update: ServiceUpdate, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ): return ClientService.update_client_services(db, client_id, user_id, service_update) + @router.post("/{client_id}/case-assignment", response_model=ServiceResponse) async def create_case_assignment( client_id: int, case_worker_id: int = Query(..., description="Case worker ID to assign"), current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """Create a new case assignment for a client with a case worker""" return ClientService.create_case_assignment(db, client_id, case_worker_id) + @router.delete("/{client_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_client( - client_id: int, - current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) + client_id: int, current_user: User = Depends(get_admin_user), db: Session = Depends(get_db) ): """Delete a client""" ClientService.delete_client(db, client_id) diff --git a/app/clients/schema.py b/app/clients/schema.py index cff28897..0120d1cf 100644 --- a/app/clients/schema.py +++ b/app/clients/schema.py @@ -3,22 +3,27 @@ Defines schemas for client data, predictions, and API responses. """ +from enum import IntEnum +from typing import List, Optional + # Standard library imports from pydantic import BaseModel, Field, validator -from typing import Optional, List -from enum import IntEnum + from app.models import UserRole + # Enums for validation class Gender(IntEnum): MALE = 1 FEMALE = 2 + class PredictionInput(BaseModel): """ Schema for prediction input data containing all client assessment fields. Used for making predictions about client outcomes. """ + age: int gender: str work_experience: int @@ -44,6 +49,7 @@ class PredictionInput(BaseModel): time_unemployed: int need_mental_health_support_bool: str + class ClientBase(BaseModel): age: int = Field(ge=18, description="Age of client, must be 18 or older") gender: Gender = Field(description="Gender: 1 for male, 2 for female") @@ -96,16 +102,18 @@ class Config: "currently_employed": False, "substance_use": False, "time_unemployed": 6, - "need_mental_health_support_bool": False + "need_mental_health_support_bool": False, } } + class ClientResponse(ClientBase): id: int class Config: from_attributes = True + class ClientUpdate(BaseModel): age: Optional[int] = Field(None, ge=18) gender: Optional[Gender] = None @@ -132,6 +140,7 @@ class ClientUpdate(BaseModel): time_unemployed: Optional[int] = Field(None, ge=0) need_mental_health_support_bool: Optional[bool] = None + class ServiceResponse(BaseModel): client_id: int user_id: int @@ -147,6 +156,7 @@ class ServiceResponse(BaseModel): class Config: from_attributes = True + class ServiceUpdate(BaseModel): employment_assistance: Optional[bool] = None life_stabilization: Optional[bool] = None @@ -157,6 +167,7 @@ class ServiceUpdate(BaseModel): enhanced_referrals: Optional[bool] = None success_rate: Optional[int] = Field(None, ge=0, le=100) + class ClientListResponse(BaseModel): clients: List[ClientResponse] total: int diff --git a/app/clients/service/client_service.py b/app/clients/service/client_service.py index 86c3ef4a..16115ec1 100644 --- a/app/clients/service/client_service.py +++ b/app/clients/service/client_service.py @@ -3,14 +3,87 @@ Provides CRUD operations and business logic for client management. """ -from sqlalchemy.orm import Session -from sqlalchemy import and_ +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional + from fastapi import HTTPException, status -from typing import List, Optional, Dict, Any +from sqlalchemy import and_ +from sqlalchemy.orm import Session + +from app.clients.schema import ClientUpdate, ServiceResponse, ServiceUpdate from app.models import Client, ClientCase, User -from app.clients.schema import ClientUpdate, ServiceUpdate, ServiceResponse -class ClientService: + +class InterfaceClientQueryService(ABC): + """Interface for client query operations""" + + @abstractmethod + def get_client(self, db: Session, client_id: int) -> Client: + """Get a specific client by ID""" + pass + + @abstractmethod + def get_clients(self, db: Session, skip: int, limit: int) -> Dict[str, Any]: + """Get clients with optional pagination.""" + pass + + @abstractmethod + def get_clients_by_criteria(self, db: Session, **criteria) -> List[Client]: + """Get clients filtered by any combination of criteria""" + pass + + @abstractmethod + def get_clients_by_services(self, db: Session, **service_filters) -> List[Client]: + """Get clients filtered by multiple service statuses.""" + + pass + + @abstractmethod + def get_client_services(self, db: Session, client_id: int) -> List[ClientCase]: + pass + + @abstractmethod + def get_clients_by_success_rate(self, db: Session, min_rate: int) -> List[Client]: + pass + + @abstractmethod + def get_clients_by_case_worker(self, db: Session, case_worker_id: int) -> List[Client]: + pass + + +class InterfaceClientManagementService(ABC): + """Interface for client management operations""" + + @abstractmethod + def update_client( + self, db: Session, client_id: int, client_update: ClientUpdate + ) -> ClientUpdate: + """Update a client's information""" + pass + + @abstractmethod + def update_client_services( + self, db: Session, client_id: int, user_id: int, service_update: ServiceUpdate + ) -> ClientCase: + """Update a client's services and outcomes for a specific caseworker""" + pass + + @abstractmethod + def create_case_assignment( + self, db: Session, client_id: int, case_worker_id: int + ) -> ClientCase: + """Create a new case assignment""" + pass + + @abstractmethod + def delete_client(self, db: Session, client_id: int) -> None: + """Delete a client and their associated records""" + pass + + +class ClientQueryService(InterfaceClientQueryService): + """Implementation of client query service""" + @staticmethod def get_client(db: Session, client_id: int): """Get a specific client by ID""" @@ -18,7 +91,7 @@ def get_client(db: Session, client_id: int): if not client: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Client with id {client_id} not found" + detail=f"Client with id {client_id} not found", ) return client @@ -30,15 +103,13 @@ def get_clients(db: Session, skip: int = 0, limit: int = 50): """ if skip < 0: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Skip value cannot be negative" + status_code=status.HTTP_400_BAD_REQUEST, detail="Skip value cannot be negative" ) if limit < 1: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Limit must be greater than 0" + status_code=status.HTTP_400_BAD_REQUEST, detail="Limit must be greater than 0" ) - + clients = db.query(Client).offset(skip).limit(limit).all() total = db.query(Client).count() return {"clients": clients, "total": total} @@ -69,27 +140,25 @@ def get_clients_by_criteria( attending_school: Optional[bool] = None, substance_use: Optional[bool] = None, time_unemployed: Optional[int] = None, - need_mental_health_support_bool: Optional[bool] = None + need_mental_health_support_bool: Optional[bool] = None, ): """Get clients filtered by any combination of criteria""" query = db.query(Client) - + if education_level is not None and not (1 <= education_level <= 14): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Education level must be between 1 and 14" + detail="Education level must be between 1 and 14", ) - + if age_min is not None and age_min < 18: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Minimum age must be at least 18" + status_code=status.HTTP_400_BAD_REQUEST, detail="Minimum age must be at least 18" ) if gender is not None and gender not in [1, 2]: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Gender must be 1 or 2" + status_code=status.HTTP_400_BAD_REQUEST, detail="Gender must be 1 or 2" ) # Apply filters for non-None values @@ -140,47 +209,46 @@ def get_clients_by_criteria( if time_unemployed is not None: query = query.filter(Client.time_unemployed == time_unemployed) if need_mental_health_support_bool is not None: - query = query.filter(Client.need_mental_health_support_bool == need_mental_health_support_bool) + query = query.filter( + Client.need_mental_health_support_bool == need_mental_health_support_bool + ) try: return query.all() except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error retrieving clients: {str(e)}" + detail=f"Error retrieving clients: {str(e)}", ) @staticmethod - def get_clients_by_services( - db: Session, - **service_filters: Optional[bool] - ): + def get_clients_by_services(db: Session, **service_filters: Optional[bool]): """ Get clients filtered by multiple service statuses. """ query = db.query(Client).join(ClientCase) - + for service_name, status in service_filters.items(): if status is not None: filter_criteria = getattr(ClientCase, service_name) == status query = query.filter(filter_criteria) - + try: return query.all() except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error retrieving clients: {str(e)}" + detail=f"Error retrieving clients: {str(e)}", ) @staticmethod def get_client_services(db: Session, client_id: int): - """Get all services for a specific client with case worker info""" + """Get all services for a specific client with caseworker info""" client_cases = db.query(ClientCase).filter(ClientCase.client_id == client_id).all() if not client_cases: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"No services found for client with id {client_id}" + detail=f"No services found for client with id {client_id}", ) return client_cases @@ -190,12 +258,10 @@ def get_clients_by_success_rate(db: Session, min_rate: int = 70): if not (0 <= min_rate <= 100): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Success rate must be between 0 and 100" + detail="Success rate must be between 0 and 100", ) - - return db.query(Client).join(ClientCase).filter( - ClientCase.success_rate >= min_rate - ).all() + + return db.query(Client).join(ClientCase).filter(ClientCase.success_rate >= min_rate).all() @staticmethod def get_clients_by_case_worker(db: Session, case_worker_id: int): @@ -204,12 +270,14 @@ def get_clients_by_case_worker(db: Session, case_worker_id: int): if not case_worker: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Case worker with id {case_worker_id} not found" + detail=f"Case worker with id {case_worker_id} not found", ) - - return db.query(Client).join(ClientCase).filter( - ClientCase.user_id == case_worker_id - ).all() + + return db.query(Client).join(ClientCase).filter(ClientCase.user_id == case_worker_id).all() + + +class ClientManagementService(InterfaceClientManagementService): + """Implementation of client management service""" @staticmethod def update_client(db: Session, client_id: int, client_update: ClientUpdate): @@ -218,7 +286,7 @@ def update_client(db: Session, client_id: int, client_update: ClientUpdate): if not client: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Client with id {client_id} not found" + detail=f"Client with id {client_id} not found", ) update_data = client_update.dict(exclude_unset=True) @@ -233,27 +301,25 @@ def update_client(db: Session, client_id: int, client_update: ClientUpdate): db.rollback() raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to update client: {str(e)}" + detail=f"Failed to update client: {str(e)}", ) - + @staticmethod def update_client_services( - db: Session, - client_id: int, - user_id: int, - service_update: ServiceUpdate + db: Session, client_id: int, user_id: int, service_update: ServiceUpdate ): """Update a client's services and outcomes for a specific case worker""" - client_case = db.query(ClientCase).filter( - ClientCase.client_id == client_id, - ClientCase.user_id == user_id - ).first() - + client_case = ( + db.query(ClientCase) + .filter(ClientCase.client_id == client_id, ClientCase.user_id == user_id) + .first() + ) + if not client_case: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"No case found for client {client_id} with case worker {user_id}. " - f"Cannot update services for a non-existent case assignment." + f"Cannot update services for a non-existent case assignment.", ) update_data = service_update.dict(exclude_unset=True) @@ -268,43 +334,40 @@ def update_client_services( db.rollback() raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to update client services: {str(e)}" + detail=f"Failed to update client services: {str(e)}", ) - + @staticmethod - def create_case_assignment( - db: Session, - client_id: int, - case_worker_id: int - ): + def create_case_assignment(db: Session, client_id: int, case_worker_id: int): """Create a new case assignment""" # Check if client exists client = db.query(Client).filter(Client.id == client_id).first() if not client: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Client with id {client_id} not found" + detail=f"Client with id {client_id} not found", ) - # Check if case worker exists + # Check if caseworker exists case_worker = db.query(User).filter(User.id == case_worker_id).first() if not case_worker: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Case worker with id {case_worker_id} not found" + detail=f"Case worker with id {case_worker_id} not found", ) # Check if assignment already exists - existing_case = db.query(ClientCase).filter( - ClientCase.client_id == client_id, - ClientCase.user_id == case_worker_id - ).first() - + existing_case = ( + db.query(ClientCase) + .filter(ClientCase.client_id == client_id, ClientCase.user_id == case_worker_id) + .first() + ) + if existing_case: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Client {client_id} already has a case assigned to case worker {case_worker_id}" - ) + detail=f"Client {client_id} already has a case assigned to case worker {case_worker_id}", + ) try: # Create new case assignment with default service values @@ -318,7 +381,7 @@ def create_case_assignment( employment_related_financial_supports=False, employer_financial_supports=False, enhanced_referrals=False, - success_rate=0 + success_rate=0, ) db.add(new_case) db.commit() @@ -329,9 +392,9 @@ def create_case_assignment( db.rollback() raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to create case assignment: {str(e)}" + detail=f"Failed to create case assignment: {str(e)}", ) - + @staticmethod def delete_client(db: Session, client_id: int): """Delete a client and their associated records""" @@ -340,22 +403,77 @@ def delete_client(db: Session, client_id: int): if not client: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Client with id {client_id} not found" + detail=f"Client with id {client_id} not found", ) try: # Delete associated client_cases - db.query(ClientCase).filter( - ClientCase.client_id == client_id - ).delete() - + db.query(ClientCase).filter(ClientCase.client_id == client_id).delete() + # Delete the client db.delete(client) db.commit() - + except Exception as e: db.rollback() raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to delete client: {str(e)}" + detail=f"Failed to delete client: {str(e)}", ) + + +class ClientService: + """ + Facade that maintains backward compatibility with the existing router. + Delegates to specialized service classes. + """ + + # Query methods + @staticmethod + def get_client(db: Session, client_id: int): + return ClientQueryService.get_client(db, client_id) + + @staticmethod + def get_clients(db: Session, skip: int = 0, limit: int = 50): + return ClientQueryService.get_clients(db, skip, limit) + + @staticmethod + def get_clients_by_criteria(db: Session, **criteria): + return ClientQueryService.get_clients_by_criteria(db, **criteria) + + @staticmethod + def get_clients_by_services(db: Session, **service_filters): + return ClientQueryService.get_clients_by_services(db, **service_filters) + + @staticmethod + def get_client_services(db: Session, client_id: int): + return ClientQueryService.get_client_services(db, client_id) + + @staticmethod + def get_clients_by_success_rate(db: Session, min_rate: int = 70): + return ClientQueryService.get_clients_by_success_rate(db, min_rate) + + @staticmethod + def get_clients_by_case_worker(db: Session, case_worker_id: int): + return ClientQueryService.get_clients_by_case_worker(db, case_worker_id) + + # Modification methods + @staticmethod + def update_client(db: Session, client_id: int, client_update: ClientUpdate): + return ClientManagementService.update_client(db, client_id, client_update) + + @staticmethod + def update_client_services( + db: Session, client_id: int, user_id: int, service_update: ServiceUpdate + ): + return ClientManagementService.update_client_services( + db, client_id, user_id, service_update + ) + + @staticmethod + def create_case_assignment(db: Session, client_id: int, case_worker_id: int): + return ClientManagementService.create_case_assignment(db, client_id, case_worker_id) + + @staticmethod + def delete_client(db: Session, client_id: int): + return ClientManagementService.delete_client(db, client_id) diff --git a/app/clients/service/logic.py b/app/clients/service/logic.py index c25b4217..290e1dd9 100644 --- a/app/clients/service/logic.py +++ b/app/clients/service/logic.py @@ -5,30 +5,33 @@ # Standard library imports import os -#import json -from itertools import product # Third-party imports import pickle + +# import json +from itertools import product + import numpy as np # Constants COLUMN_INTERVENTIONS = [ - 'Life Stabilization', - 'General Employment Assistance Services', - 'Retention Services', - 'Specialized Services', - 'Employment-Related Financial Supports for Job Seekers and Employers', - 'Employer Financial Supports', - 'Enhanced Referrals for Skills Development' + "Life Stabilization", + "General Employment Assistance Services", + "Retention Services", + "Specialized Services", + "Employment-Related Financial Supports for Job Seekers and Employers", + "Employer Financial Supports", + "Enhanced Referrals for Skills Development", ] # Load model CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) -MODEL_PATH = os.path.join(CURRENT_DIR, 'model.pkl') +MODEL_PATH = os.path.join(CURRENT_DIR, "model.pkl") with open(MODEL_PATH, "rb") as model_file: MODEL = pickle.load(model_file) + def clean_input_data(input_data): """ Clean and transform input data into model-compatible format. @@ -40,13 +43,30 @@ def clean_input_data(input_data): list: Cleaned and formatted data ready for model input """ columns = [ - "age", "gender", "work_experience", "canada_workex", "dep_num", - "canada_born", "citizen_status", "level_of_schooling", "fluent_english", - "reading_english_scale", "speaking_english_scale", "writing_english_scale", - "numeracy_scale", "computer_scale", "transportation_bool", "caregiver_bool", - "housing", "income_source", "felony_bool", "attending_school", - "currently_employed", "substance_use", "time_unemployed", - "need_mental_health_support_bool" + "age", + "gender", + "work_experience", + "canada_workex", + "dep_num", + "canada_born", + "citizen_status", + "level_of_schooling", + "fluent_english", + "reading_english_scale", + "speaking_english_scale", + "writing_english_scale", + "numeracy_scale", + "computer_scale", + "transportation_bool", + "caregiver_bool", + "housing", + "income_source", + "felony_bool", + "attending_school", + "currently_employed", + "substance_use", + "time_unemployed", + "need_mental_health_support_bool", ] demographics = {key: input_data[key] for key in columns} output = [] @@ -57,6 +77,7 @@ def clean_input_data(input_data): output.append(value) return output + def convert_text(text_data: str): """ Convert text answers from front end into numerical values. @@ -68,33 +89,47 @@ def convert_text(text_data: str): int: Converted numerical value """ categorical_mappings = [ + {"": 0, "true": 1, "false": 0, "no": 0, "yes": 1, "No": 0, "Yes": 1}, { - "": 0, "true": 1, "false": 0, "no": 0, "yes": 1, - "No": 0, "Yes": 1 - }, - { - "Grade 0-8": 1, "Grade 9": 2, "Grade 10": 3, "Grade 11": 4, - "Grade 12 or equivalent": 5, "OAC or Grade 13": 6, - "Some college": 7, "Some university": 8, "Some apprenticeship": 9, - "Certificate of Apprenticeship": 10, "Journeyperson": 11, - "Certificate/Diploma": 12, "Bachelor's degree": 13, - "Post graduate": 14 + "Grade 0-8": 1, + "Grade 9": 2, + "Grade 10": 3, + "Grade 11": 4, + "Grade 12 or equivalent": 5, + "OAC or Grade 13": 6, + "Some college": 7, + "Some university": 8, + "Some apprenticeship": 9, + "Certificate of Apprenticeship": 10, + "Journeyperson": 11, + "Certificate/Diploma": 12, + "Bachelor's degree": 13, + "Post graduate": 14, }, { - "Renting-private": 1, "Renting-subsidized": 2, - "Boarding or lodging": 3, "Homeowner": 4, - "Living with family/friend": 5, "Institution": 6, - "Temporary second residence": 7, "Band-owned home": 8, - "Homeless or transient": 9, "Emergency hostel": 10 + "Renting-private": 1, + "Renting-subsidized": 2, + "Boarding or lodging": 3, + "Homeowner": 4, + "Living with family/friend": 5, + "Institution": 6, + "Temporary second residence": 7, + "Band-owned home": 8, + "Homeless or transient": 9, + "Emergency hostel": 10, }, { - "No Source of Income": 1, "Employment Insurance": 2, + "No Source of Income": 1, + "Employment Insurance": 2, "Workplace Safety and Insurance Board": 3, "Ontario Works applied or receiving": 4, "Ontario Disability Support Program applied or receiving": 5, - "Dependent of someone receiving OW or ODSP": 6, "Crown Ward": 7, - "Employment": 8, "Self-Employment": 9, "Other (specify)": 10 - } + "Dependent of someone receiving OW or ODSP": 6, + "Crown Ward": 7, + "Employment": 8, + "Self-Employment": 9, + "Other (specify)": 10, + }, ] for category in categorical_mappings: if text_data in category: @@ -102,6 +137,7 @@ def convert_text(text_data: str): return int(text_data) if text_data.isnumeric() else text_data + def create_matrix(row_data): """ Create matrix of all possible intervention combinations. @@ -116,6 +152,7 @@ def create_matrix(row_data): perms = intervention_permutations(7) return np.concatenate((np.array(data), np.array(perms)), axis=1) + def intervention_permutations(num): """ Generate all possible intervention combinations. @@ -128,6 +165,7 @@ def intervention_permutations(num): """ return np.array(list(product([0, 1], repeat=num))) + def get_baseline_row(row_data): """ Create baseline row with no interventions. @@ -141,6 +179,7 @@ def get_baseline_row(row_data): base_interventions = np.zeros(7) return np.concatenate((np.array(row_data), base_interventions)) + def intervention_row_to_names(row_data): """ Convert intervention row to list of intervention names. @@ -153,6 +192,7 @@ def intervention_row_to_names(row_data): """ return [COLUMN_INTERVENTIONS[i] for i, value in enumerate(row_data) if value == 1] + def process_results(baseline_pred, results_matrix): """ Process model results into structured output. @@ -164,14 +204,9 @@ def process_results(baseline_pred, results_matrix): Returns: dict: Processed results with baseline and interventions """ - result_list = [ - (row[-1], intervention_row_to_names(row[:-1])) - for row in results_matrix - ] - return { - "baseline": baseline_pred[-1], - "interventions": result_list - } + result_list = [(row[-1], intervention_row_to_names(row[:-1])) for row in results_matrix] + return {"baseline": baseline_pred[-1], "interventions": result_list} + def interpret_and_calculate(input_data): """ @@ -194,19 +229,33 @@ def interpret_and_calculate(input_data): top_results = result_matrix[-3:, -8:] return process_results(baseline_prediction, top_results) + if __name__ == "__main__": test_data = { - "age": "23", "gender": "1", "work_experience": "1", - "canada_workex": "1", "dep_num": "0", "canada_born": "1", - "citizen_status": "2", "level_of_schooling": "2", - "fluent_english": "3", "reading_english_scale": "2", - "speaking_english_scale": "2", "writing_english_scale": "3", - "numeracy_scale": "2", "computer_scale": "3", - "transportation_bool": "2", "caregiver_bool": "1", - "housing": "1", "income_source": "5", "felony_bool": "1", - "attending_school": "0", "currently_employed": "1", - "substance_use": "1", "time_unemployed": "1", - "need_mental_health_support_bool": "1" + "age": "23", + "gender": "1", + "work_experience": "1", + "canada_workex": "1", + "dep_num": "0", + "canada_born": "1", + "citizen_status": "2", + "level_of_schooling": "2", + "fluent_english": "3", + "reading_english_scale": "2", + "speaking_english_scale": "2", + "writing_english_scale": "3", + "numeracy_scale": "2", + "computer_scale": "3", + "transportation_bool": "2", + "caregiver_bool": "1", + "housing": "1", + "income_source": "5", + "felony_bool": "1", + "attending_school": "0", + "currently_employed": "1", + "substance_use": "1", + "time_unemployed": "1", + "need_mental_health_support_bool": "1", } results = interpret_and_calculate(test_data) print(results) diff --git a/app/clients/service/ml_models.py b/app/clients/service/ml_models.py new file mode 100644 index 00000000..50957926 --- /dev/null +++ b/app/clients/service/ml_models.py @@ -0,0 +1,147 @@ +from abc import ABC, abstractmethod +from typing import List + +import numpy as np +from sklearn.ensemble import RandomForestRegressor +from sklearn.linear_model import LinearRegression +from sklearn.svm import SVR + + +class InterfaceBaseMLModel(ABC): + """Interface of a base ML Model""" + + @abstractmethod + def fit(self, X: np.ndarray, y: np.ndarray): + pass + + @abstractmethod + def predict(self, X: np.ndarray) -> np.ndarray: + pass + + def save(self, path: str): + import pickle + + with open(path, "wb") as f: + pickle.dump(self, f) + + @staticmethod + def load(path: str): + import pickle + + with open(path, "rb") as f: + return pickle.load(f) + + @abstractmethod + def __str__(self) -> str: + """Return the name of the model""" + pass + + +class LinearRegressionModel(InterfaceBaseMLModel): + def __init__(self): + self.model = LinearRegression() + + def fit(self, X, y): + self.model.fit(X, y) + + def predict(self, X): + return self.model.predict(X) + + def __str__(self): + return "Linear Regression" + + +class RandomForestModel(InterfaceBaseMLModel): + def __init__(self, n_estimators=100, random_state=42): + self.model = RandomForestRegressor(n_estimators=n_estimators, random_state=random_state) + + def fit(self, X, y): + self.model.fit(X, y) + + def predict(self, X): + return self.model.predict(X) + + def __str__(self): + return "Random Forest Regressor" + + +class SVMModel(InterfaceBaseMLModel): + def __init__(self): + self.model = SVR() + + def fit(self, X, y): + self.model.fit(X, y) + + def predict(self, X): + return self.model.predict(X) + + def __str__(self): + return "Support Vector Machine" + + +class InterfaceMLModelRepository(ABC): + """Interface for ML Models storage""" + + @abstractmethod + def list_models(self) -> List[InterfaceBaseMLModel]: + """Get list of all available models instances""" + pass + + @abstractmethod + def is_model_available(self, model_name: str) -> bool: + """Check if a model is valid""" + pass + + @abstractmethod + def get_model_instance(self, model_name: str) -> InterfaceBaseMLModel: + """Return an instance of the requested model""" + pass + + +class InterfaceMLModelManager(ABC): + """Interface for ML model management""" + + @abstractmethod + def get_current_model(self) -> InterfaceBaseMLModel: + """Get the current active ml model""" + pass + + @abstractmethod + def switch_model(self, model_name: str) -> bool: + """Switch between models""" + pass + + +class MLModelRepository(InterfaceMLModelRepository): + def __init__(self): + self._model_map = { + "Linear Regression": LinearRegressionModel, + "Random Forest Regressor": RandomForestModel, + "Support Vector Machine": SVMModel, + } + + def list_models(self) -> List[InterfaceBaseMLModel]: + return [model_class() for model_class in self._model_map.values()] + + def is_model_available(self, model_name: str) -> bool: + return model_name in self._model_map + + def get_model_instance(self, model_name: str) -> InterfaceBaseMLModel: + if not self.is_model_available(model_name): + raise ValueError(f"Model '{model_name}' is not available.") + return self._model_map[model_name]() + + +class MLModelManager(InterfaceMLModelManager): + def __init__(self, repository: InterfaceMLModelRepository): + self._repository = repository + self._current_model = repository.get_model_instance("Random Forest Regressor") + + def get_current_model(self) -> str: + return self._current_model + + def switch_model(self, model_name: str) -> bool: + if self._repository.is_model_available(model_name): + self._current_model = self._repository.get_model_instance(model_name) + return True + return False diff --git a/app/clients/service/ml_models_router.py b/app/clients/service/ml_models_router.py new file mode 100644 index 00000000..d1070153 --- /dev/null +++ b/app/clients/service/ml_models_router.py @@ -0,0 +1,30 @@ +from fastapi import APIRouter, HTTPException + +from app.clients.service.ml_models import MLModelManager, MLModelRepository + +router = APIRouter(prefix="/ml_models") +model_repository = MLModelRepository() +model_manager = MLModelManager(model_repository) + + +@router.get("/list") +def list_models(): + """List all available ML models""" + # return {"models": model_repository.list_models()} + return {"models": [str(model) for model in model_repository.list_models()]} + + +@router.post("/switch/{model_name}") +def switch_models(model_name: str): + """Switch between ML models""" + success = model_manager.switch_model(model_name) + if not success: + raise HTTPException(status_code=400, detail="Model switch failed") + return {"message": f"Model switched to {model_name}"} + + +@router.get("/current") +def current_model(): + """Get the current ML model""" + # return {"current_model": model_manager.get_current_model()} + return {"current_model": str(model_manager.get_current_model())} diff --git a/app/clients/service/model.py b/app/clients/service/model.py index b2406370..cba85686 100644 --- a/app/clients/service/model.py +++ b/app/clients/service/model.py @@ -1,110 +1,193 @@ """ Model training module for the Common Assessment Tool. Handles the preparation, training, and saving of the prediction model. +Pass in model name via command line """ +import os + # Standard library imports import pickle +import sys # Third-party imports import numpy as np import pandas as pd -from sklearn.model_selection import train_test_split + +# Local imports +from ml_models import ( + InterfaceBaseMLModel, + LinearRegressionModel, + MLModelRepository, + RandomForestModel, + SVMModel, +) +from sklearn import svm from sklearn.ensemble import RandomForestRegressor +from sklearn.linear_model import LinearRegression +from sklearn.model_selection import train_test_split + +repo = MLModelRepository() + +default_unformatted_model_path = "pretrained_models" + os.sep + "model_{}.pkl" + + +def get_model_by_name(model_type: str, n_estimators=100, random_state=42) -> InterfaceBaseMLModel: + model_map = { + "Linear Regression": LinearRegressionModel, + "Random Forest Regressor": lambda: RandomForestModel(n_estimators, random_state), + "Support Vector Machine": SVMModel, + } + + if model_type not in model_map: + print(f"ERROR! Invalid model type '{model_type}' passed in.") + print(f"Available models: {repo.list_models()}") + sys.exit(-1) -def prepare_models(): + constructor = model_map[model_type] + return constructor() if callable(constructor) else constructor() + + +def prepare_model_data(test_size=0.2, random_state=42): """ Prepare and train the Random Forest model using the dataset. - + Args: + test_size: The percent of the dataset to use as test data (rest will be used as train data) + random_state: The random state to generate train/test split with + Returns: RandomForestRegressor: Trained model for predicting success rates """ # Load dataset - data = pd.read_csv('data_commontool.csv') + data = pd.read_csv("data_commontool.csv") # Define feature columns feature_columns = [ - 'age', # Client's age - 'gender', # Client's gender (bool) - 'work_experience', # Years of work experience - 'canada_workex', # Years of work experience in Canada - 'dep_num', # Number of dependents - 'canada_born', # Born in Canada - 'citizen_status', # Citizenship status - 'level_of_schooling', # Highest level achieved (1-14) - 'fluent_english', # English fluency scale (1-10) - 'reading_english_scale', # Reading ability scale (1-10) - 'speaking_english_scale',# Speaking ability scale (1-10) - 'writing_english_scale', # Writing ability scale (1-10) - 'numeracy_scale', # Numeracy ability scale (1-10) - 'computer_scale', # Computer proficiency scale (1-10) - 'transportation_bool', # Needs transportation support (bool) - 'caregiver_bool', # Is primary caregiver (bool) - 'housing', # Housing situation (1-10) - 'income_source', # Source of income (1-10) - 'felony_bool', # Has a felony (bool) - 'attending_school', # Currently a student (bool) - 'currently_employed', # Currently employed (bool) - 'substance_use', # Substance use disorder (bool) - 'time_unemployed', # Years unemployed - 'need_mental_health_support_bool' # Needs mental health support (bool) + "age", # Client's age + "gender", # Client's gender (bool) + "work_experience", # Years of work experience + "canada_workex", # Years of work experience in Canada + "dep_num", # Number of dependents + "canada_born", # Born in Canada + "citizen_status", # Citizenship status + "level_of_schooling", # Highest level achieved (1-14) + "fluent_english", # English fluency scale (1-10) + "reading_english_scale", # Reading ability scale (1-10) + "speaking_english_scale", # Speaking ability scale (1-10) + "writing_english_scale", # Writing ability scale (1-10) + "numeracy_scale", # Numeracy ability scale (1-10) + "computer_scale", # Computer proficiency scale (1-10) + "transportation_bool", # Needs transportation support (bool) + "caregiver_bool", # Is primary caregiver (bool) + "housing", # Housing situation (1-10) + "income_source", # Source of income (1-10) + "felony_bool", # Has a felony (bool) + "attending_school", # Currently a student (bool) + "currently_employed", # Currently employed (bool) + "substance_use", # Substance use disorder (bool) + "time_unemployed", # Years unemployed + "need_mental_health_support_bool", # Needs mental health support (bool) ] # Define intervention columns intervention_columns = [ - 'employment_assistance', - 'life_stabilization', - 'retention_services', - 'specialized_services', - 'employment_related_financial_supports', - 'employer_financial_supports', - 'enhanced_referrals' + "employment_assistance", + "life_stabilization", + "retention_services", + "specialized_services", + "employment_related_financial_supports", + "employer_financial_supports", + "enhanced_referrals", ] # Combine all feature columns all_features = feature_columns + intervention_columns # Prepare training data features = np.array(data[all_features]) # Changed from X to features - targets = np.array(data['success_rate']) # Changed from y to targets + targets = np.array(data["success_rate"]) # Changed from y to targets # Split the dataset - features_train, _, targets_train, _ = train_test_split( # Removed unused variables + X_train, x_test, Y_train, y_test = train_test_split( + # Removed unused variables features, targets, - test_size=0.2, - random_state=42 + test_size=test_size, + random_state=random_state, ) - # Initialize and train the model - model = RandomForestRegressor(n_estimators=100, random_state=42) - model.fit(features_train, targets_train) + + return X_train, x_test, Y_train, y_test + + +def train_model( + X_train, Y_train, model_type, n_estimators=100, random_state=42 +) -> InterfaceBaseMLModel: + """ + Trains the model + Args: + X_train: Training features + targets_train: Target features + Y_train: Which model to create + n_estimators: Number estimators (for random forest) + random_state: Random state to train with (for random forest) + + Returns: A trained model of the type specified + + """ + model = get_model_by_name(model_type, n_estimators, random_state) + model.fit(X_train, Y_train) return model -def save_model(model, filename="model.pkl"): + +def get_true_file_name(model_type, filename): + """ + Takes a model type and file name, formats model type, and replaces spaces with underscores + Args: + model_type: The model type as a String + filename: The file name (should follow 'model_{}.pkl' format) + + Returns: The clean file name + """ + return filename.format(model_type).replace(" ", "_") + + +def save_model(model, model_type, filename=default_unformatted_model_path): """ Save the trained model to a file. - + Args: model: Trained model to save + model_type: The type of model being saved filename (str): Name of the file to save the model to """ - with open(filename, "wb") as model_file: + true_file_name = get_true_file_name(model_type, filename) + with open(true_file_name, "wb") as model_file: pickle.dump(model, model_file) -def load_model(filename="model.pkl"): + +def load_model(model_type, filename=default_unformatted_model_path): """ Load a trained model from a file. - + Args: + model_type: The type of model being loaded filename (str): Name of the file to load the model from - + Returns: The loaded model """ - with open(filename, "rb") as model_file: + true_file_name = get_true_file_name(model_type, filename) + with open(true_file_name, "rb") as model_file: return pickle.load(model_file) -def main(): + +def main(argv): """Main function to train and save the model.""" - print("Starting model training...") - model = prepare_models() - save_model(model) + # Get the model type from the command line arguments + model_type = argv[1] + + # Train and save the model + print("Starting model training for {} model...".format(model_type)) + X_train, x_test, Y_train, y_test = prepare_model_data() + model = train_model(X_train, Y_train, model_type) + save_model(model, model_type) print("Model training completed and saved successfully.") + if __name__ == "__main__": - main() + main(sys.argv) diff --git a/app/clients/service/pretrained_models/model_Linear_Regression.pkl b/app/clients/service/pretrained_models/model_Linear_Regression.pkl new file mode 100644 index 00000000..d264f1b9 Binary files /dev/null and b/app/clients/service/pretrained_models/model_Linear_Regression.pkl differ diff --git a/app/clients/service/pretrained_models/model_Random_Forest_Regressor.pkl b/app/clients/service/pretrained_models/model_Random_Forest_Regressor.pkl new file mode 100644 index 00000000..cde4340f Binary files /dev/null and b/app/clients/service/pretrained_models/model_Random_Forest_Regressor.pkl differ diff --git a/app/clients/service/pretrained_models/model_Support_Vector_Machine.pkl b/app/clients/service/pretrained_models/model_Support_Vector_Machine.pkl new file mode 100644 index 00000000..999b824c Binary files /dev/null and b/app/clients/service/pretrained_models/model_Support_Vector_Machine.pkl differ diff --git a/app/database.py b/app/database.py index 3a489f54..b5c8b948 100644 --- a/app/database.py +++ b/app/database.py @@ -7,22 +7,23 @@ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker -#Here is where the database is located -SQLALCHEMY_DATABASE_URL = "sqlite:///./sql_app.db" +# Here is where the database is located +SQLALCHEMY_DATABASE_URL = "sqlite:///./sql_app.db" -#Open up a connection so that we are able to use the database +# Open up a connection so that we are able to use the database engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}) -#Bind the engine just created +# Bind the engine just created SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) -#Create an object of our database so as to control the database +# Create an object of our database so as to control the database Base = declarative_base() + def get_db(): """ Creates a database session and ensures it's closed after use. - + Yields: Session: SQLAlchemy database session """ diff --git a/app/main.py b/app/main.py index a8e8fa7f..66746785 100644 --- a/app/main.py +++ b/app/main.py @@ -5,27 +5,32 @@ """ from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware + from app import models -from app.database import engine -from app.clients.router import router as clients_router from app.auth.router import router as auth_router -from fastapi.middleware.cors import CORSMiddleware +from app.clients.router import router as clients_router +from app.clients.service.ml_models_router import router as ml_models_router +from app.database import engine # Initialize database tables models.Base.metadata.create_all(bind=engine) # Create FastAPI application -app = FastAPI(title="Case Management API", description="API for managing client cases", version="1.0.0") +app = FastAPI( + title="Case Management API", description="API for managing client cases", version="1.0.0" +) # Include routers app.include_router(auth_router) app.include_router(clients_router) +app.include_router(ml_models_router) # Configure CORS middleware app.add_middleware( CORSMiddleware, - allow_origins=["*"], # Allows all origins - allow_methods=["*"], # Allows all methods - allow_headers=["*"], # Allows all headers + allow_origins=["*"], # Allows all origins + allow_methods=["*"], # Allows all methods + allow_headers=["*"], # Allows all headers allow_credentials=True, ) diff --git a/app/models.py b/app/models.py index df778348..870210e5 100644 --- a/app/models.py +++ b/app/models.py @@ -3,11 +3,14 @@ Contains the Client model for storing client information in the database. """ -from app.database import Base -from sqlalchemy import Column, Integer, String, Boolean, ForeignKey, CheckConstraint, Enum -from sqlalchemy.orm import relationship import enum +from sqlalchemy import Boolean, CheckConstraint, Column, Enum, ForeignKey, Integer, String +from sqlalchemy.orm import relationship + +from app.database import Base + + class UserRole(str, enum.Enum): admin = "admin" case_worker = "case_worker" @@ -24,46 +27,61 @@ class User(Base): cases = relationship("ClientCase", back_populates="user") + class Client(Base): """ Client model representing client data in the database. """ + __tablename__ = "clients" id = Column(Integer, primary_key=True, autoincrement=True) - age = Column(Integer, CheckConstraint('age >= 18')) + age = Column(Integer, CheckConstraint("age >= 18")) gender = Column(Integer, CheckConstraint("gender = 1 OR gender = 2")) - work_experience = Column(Integer, CheckConstraint('work_experience >= 0')) - canada_workex = Column(Integer, CheckConstraint('canada_workex >= 0')) - dep_num = Column(Integer, CheckConstraint('dep_num >= 0')) + work_experience = Column(Integer, CheckConstraint("work_experience >= 0")) + canada_workex = Column(Integer, CheckConstraint("canada_workex >= 0")) + dep_num = Column(Integer, CheckConstraint("dep_num >= 0")) canada_born = Column(Boolean) citizen_status = Column(Boolean) - level_of_schooling = Column(Integer, CheckConstraint('level_of_schooling >= 1 AND level_of_schooling <= 14')) + level_of_schooling = Column( + Integer, CheckConstraint("level_of_schooling >= 1 AND level_of_schooling <= 14") + ) fluent_english = Column(Boolean) - reading_english_scale = Column(Integer, CheckConstraint('reading_english_scale >= 0 AND reading_english_scale <= 10')) - speaking_english_scale = Column(Integer, CheckConstraint('speaking_english_scale >= 0 AND speaking_english_scale <= 10')) - writing_english_scale = Column(Integer, CheckConstraint('writing_english_scale >= 0 AND writing_english_scale <= 10')) - numeracy_scale = Column(Integer, CheckConstraint('numeracy_scale >= 0 AND numeracy_scale <= 10')) - computer_scale = Column(Integer, CheckConstraint('computer_scale >= 0 AND computer_scale <= 10')) + reading_english_scale = Column( + Integer, CheckConstraint("reading_english_scale >= 0 AND reading_english_scale <= 10") + ) + speaking_english_scale = Column( + Integer, CheckConstraint("speaking_english_scale >= 0 AND speaking_english_scale <= 10") + ) + writing_english_scale = Column( + Integer, CheckConstraint("writing_english_scale >= 0 AND writing_english_scale <= 10") + ) + numeracy_scale = Column( + Integer, CheckConstraint("numeracy_scale >= 0 AND numeracy_scale <= 10") + ) + computer_scale = Column( + Integer, CheckConstraint("computer_scale >= 0 AND computer_scale <= 10") + ) transportation_bool = Column(Boolean) caregiver_bool = Column(Boolean) - housing = Column(Integer, CheckConstraint('housing >= 1 AND housing <= 10')) - income_source = Column(Integer, CheckConstraint('income_source >= 1 AND income_source <= 11')) + housing = Column(Integer, CheckConstraint("housing >= 1 AND housing <= 10")) + income_source = Column(Integer, CheckConstraint("income_source >= 1 AND income_source <= 11")) felony_bool = Column(Boolean) attending_school = Column(Boolean) currently_employed = Column(Boolean) substance_use = Column(Boolean) - time_unemployed = Column(Integer, CheckConstraint('time_unemployed >= 0')) + time_unemployed = Column(Integer, CheckConstraint("time_unemployed >= 0")) need_mental_health_support_bool = Column(Boolean) cases = relationship("ClientCase", back_populates="client") + class ClientCase(Base): __tablename__ = "client_cases" client_id = Column(Integer, ForeignKey("clients.id"), primary_key=True) user_id = Column(Integer, ForeignKey("users.id"), primary_key=True) - + employment_assistance = Column(Boolean) life_stabilization = Column(Boolean) retention_services = Column(Boolean) @@ -71,7 +89,7 @@ class ClientCase(Base): employment_related_financial_supports = Column(Boolean) employer_financial_supports = Column(Boolean) enhanced_referrals = Column(Boolean) - success_rate = Column(Integer, CheckConstraint('success_rate >= 0 AND success_rate <= 100')) + success_rate = Column(Integer, CheckConstraint("success_rate >= 0 AND success_rate <= 100")) client = relationship("Client", back_populates="cases") user = relationship("User", back_populates="cases") diff --git a/app/requirements.txt b/app/requirements.txt index 57fc8e6e..ddcf118a 100644 --- a/app/requirements.txt +++ b/app/requirements.txt @@ -38,6 +38,5 @@ typer==0.12.5 typing_extensions==4.12.2 tzdata==2024.1 uvicorn==0.30.6 -uvloop==0.20.0 watchfiles==0.23.0 websockets==13.0 diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 00000000..d3b893b6 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,11 @@ +services: + web: + build: . + command: ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"] + ports: + - "8000:8000" + develop: + watch: + - action: sync + path: . + target: /code \ No newline at end of file diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 00000000..6dd3a8a7 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,40 @@ +# Global options: +[mypy] +# Enable strict mode for better type checking +strict = false + +# Encourage but don't require type annotations +disallow_untyped_defs = false +# At least annotate return types +disallow_incomplete_defs = true +# Type check the body of functions without annotations +check_untyped_defs = true + +# Allow calling functions without type hints +disallow_untyped_calls = false + +# Be more permissive with 'Any' usage +disallow_any_unimported = false +disallow_any_explicit = false +disallow_any_generics = false +# Still warn about returning Any +warn_return_any = true + +# Don't enforce strict subclassing +disallow_subclassing_any = false + +# Still catch obvious issues +warn_redundant_casts = true +warn_unused_ignores = true +warn_no_return = false +warn_unreachable = true + +# Allow redefinitions in some contexts +allow_redefinition = true + +# Module import settings +ignore_missing_imports = true +follow_imports = silent + +# Performance improvements +cache_dir = .mypy_cache \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..820b4f6f --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,50 @@ +# Overall Project Configuration +[project] +name = "Common_Assessment_Tool" +version = "1.0.0" +authors = [{name = "David Treadwell", email = "treadwell.d@northeastern.edu"}, {name = "Fran Li", email = "li.fengr@northeastern.edu"}, {name = "Steve Chen", email = "chen.steve2@northeastern.edu"}] +readme = "README.md" +license = "MIT" +dynamic = ["dependencies"] +requires-python = ">=3.10" + +# Set up dependencies from requirements.txt file +[tool.setuptools.dynamic] +dependencies = {file = ["requirements.txt"]} + +# Project urls +[project.urls] +Repository = "https://github.com/dtread4/CommonAssessmentTool" + +# Build system +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +# Packages finder +[tool.setuptools.packages.find] +where = ["."] + +# Optional dependency configuration +[project.optional-dependencies] +dev = ["black", "isort"] +extra = ['uvloop==0.20.0'] + +# Black Configuration +[tool.black] +line-length = 100 +include = '\.pyi?$' +skip-magic-trailing-comma = true +target-version = ['py310'] + +# isort Configuration +[tool.isort] +profile = "black" +line_length = 100 +known_first_party = ["app"] +multi_line_output = 3 +force_grid_wrap = 0 +combine_as_imports = true +include_trailing_comma = true +force_single_line = false +skip = ["venv", ".venv", "migrations"] \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 93d35fbf..d098b41c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -133,7 +133,6 @@ tzdata==2023.3 uri-template==1.2.0 urllib3==2.0.7 uvicorn==0.23.2 -uvloop==0.17.0 watchfiles==0.20.0 wcwidth==0.2.8 webcolors==1.13 @@ -141,4 +140,3 @@ webencodings==0.5.1 websocket-client==1.5.1 websockets==11.0.3 widgetsnbextension==4.0.7 -