diff --git a/.env b/.env new file mode 100644 index 00000000..f292cb8d --- /dev/null +++ b/.env @@ -0,0 +1,3 @@ +SECRET_KEY = "your-secret-key-here" +ALGORITHM = "HS256" +ACCESS_TOKEN_EXPIRE_MINUTES = 30 \ No newline at end of file diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 00000000..dee81e31 --- /dev/null +++ b/.pylintrc @@ -0,0 +1,637 @@ +[MAIN] + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + +# Clear in-memory caches upon conclusion of linting. Useful if running pylint +# in a server-like mode. +clear-cache-post-run=no + +# Load and enable all available extensions. Use --list-extensions to see a list +# all available extensions. +#enable-all-extensions= + +# In error mode, messages with a category besides ERROR or FATAL are +# suppressed, and no reports are done by default. Error mode is compatible with +# disabling specific errors. +#errors-only= + +# Always return a 0 (non-error) status code, even if lint errors are found. +# This is primarily useful in continuous integration scripts. +#exit-zero= + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. +extension-pkg-allow-list= + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. (This is an alternative name to extension-pkg-allow-list +# for backward compatibility.) +extension-pkg-whitelist= + +# Return non-zero exit code if any of these messages/categories are detected, +# even if score is above --fail-under value. Syntax same as enable. Messages +# specified are enabled, while categories only check already-enabled messages. +fail-on= + +# Specify a score threshold under which the program will exit with error. +fail-under=10 + +# Interpret the stdin as a python script, whose filename needs to be passed as +# the module_or_package argument. +#from-stdin= + +# Files or directories to be skipped. They should be base names, not paths. +ignore=CVS + +# Add files or directories matching the regular expressions patterns to the +# ignore-list. The regex matches against paths and can be in Posix or Windows +# format. Because '\\' represents the directory delimiter on Windows systems, +# it can't be used as an escape character. +ignore-paths= + +# Files or directories matching the regular expression patterns are skipped. +# The regex matches against base names, not paths. The default value ignores +# Emacs file locks +ignore-patterns=^\.# + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis). It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules= + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the +# number of processors available to use, and will cap the count on Windows to +# avoid hangs. +jobs=1 + +# Control the amount of potential inferred values when inferring a single +# object. This can help the performance when dealing with large functions or +# complex, nested conditions. +limit-inference-results=100 + +# List of plugins (as comma separated values of python module names) to load, +# usually to register additional checkers. +load-plugins= + +# Pickle collected data for later comparisons. +persistent=yes + +# Minimum Python version to use for version dependent checks. Will default to +# the version used to run pylint. +py-version=3.10 + +# Discover python modules and packages in the file system subtree. +recursive=no + +# Add paths to the list of the source roots. Supports globbing patterns. The +# source root is an absolute path or a path relative to the current working +# directory used to determine a package namespace for modules located under the +# source root. +source-roots= + +# When enabled, pylint would attempt to guess common misconfiguration and emit +# user-friendly hints instead of false-positive error messages. +suggestion-mode=yes + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + +# In verbose mode, extra non-checker-related info will be displayed. +#verbose= + + +[BASIC] + +# Naming style matching correct argument names. +argument-naming-style=snake_case + +# Regular expression matching correct argument names. Overrides argument- +# naming-style. If left empty, argument names will be checked with the set +# naming style. +#argument-rgx= + +# Naming style matching correct attribute names. +attr-naming-style=snake_case + +# Regular expression matching correct attribute names. Overrides attr-naming- +# style. If left empty, attribute names will be checked with the set naming +# style. +#attr-rgx= + +# Bad variable names which should always be refused, separated by a comma. +bad-names=foo, + bar, + baz, + toto, + tutu, + tata + +# Bad variable names regexes, separated by a comma. If names match any regex, +# they will always be refused +bad-names-rgxs= + +# Naming style matching correct class attribute names. +class-attribute-naming-style=any + +# Regular expression matching correct class attribute names. Overrides class- +# attribute-naming-style. If left empty, class attribute names will be checked +# with the set naming style. +#class-attribute-rgx= + +# Naming style matching correct class constant names. +class-const-naming-style=UPPER_CASE + +# Regular expression matching correct class constant names. Overrides class- +# const-naming-style. If left empty, class constant names will be checked with +# the set naming style. +#class-const-rgx= + +# Naming style matching correct class names. +class-naming-style=PascalCase + +# Regular expression matching correct class names. Overrides class-naming- +# style. If left empty, class names will be checked with the set naming style. +#class-rgx= + +# Naming style matching correct constant names. +const-naming-style=UPPER_CASE + +# Regular expression matching correct constant names. Overrides const-naming- +# style. If left empty, constant names will be checked with the set naming +# style. +#const-rgx= + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=-1 + +# Naming style matching correct function names. +function-naming-style=snake_case + +# Regular expression matching correct function names. Overrides function- +# naming-style. If left empty, function names will be checked with the set +# naming style. +#function-rgx= + +# Good variable names which should always be accepted, separated by a comma. +good-names=i, + j, + k, + ex, + Run, + _ + +# Good variable names regexes, separated by a comma. If names match any regex, +# they will always be accepted +good-names-rgxs= + +# Include a hint for the correct naming format with invalid-name. +include-naming-hint=no + +# Naming style matching correct inline iteration names. +inlinevar-naming-style=any + +# Regular expression matching correct inline iteration names. Overrides +# inlinevar-naming-style. If left empty, inline iteration names will be checked +# with the set naming style. +#inlinevar-rgx= + +# Naming style matching correct method names. +method-naming-style=snake_case + +# Regular expression matching correct method names. Overrides method-naming- +# style. If left empty, method names will be checked with the set naming style. +#method-rgx= + +# Naming style matching correct module names. +module-naming-style=snake_case + +# Regular expression matching correct module names. Overrides module-naming- +# style. If left empty, module names will be checked with the set naming style. +#module-rgx= + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=^_ + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +# These decorators are taken in consideration only for invalid-name. +property-classes=abc.abstractproperty + +# Regular expression matching correct type alias names. If left empty, type +# alias names will be checked with the set naming style. +#typealias-rgx= + +# Regular expression matching correct type variable names. If left empty, type +# variable names will be checked with the set naming style. +#typevar-rgx= + +# Naming style matching correct variable names. +variable-naming-style=snake_case + +# Regular expression matching correct variable names. Overrides variable- +# naming-style. If left empty, variable names will be checked with the set +# naming style. +#variable-rgx= + + +[CLASSES] + +# Warn about protected attribute access inside special methods +check-protected-access-in-special-methods=no + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp, + asyncSetUp, + __post_init__ + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict,_fields,_replace,_source,_make,os._exit + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs + + +[DESIGN] + +# List of regular expressions of class ancestor names to ignore when counting +# public methods (see R0903) +exclude-too-few-public-methods= + +# List of qualified class names to ignore when counting class parents (see +# R0901) +ignored-parents= + +# Maximum number of arguments for function / method. +max-args=5 + +# Maximum number of attributes for a class (see R0902). +max-attributes=7 + +# Maximum number of boolean expressions in an if statement (see R0916). +max-bool-expr=5 + +# Maximum number of branch for function / method body. +max-branches=12 + +# Maximum number of locals for function / method body. +max-locals=15 + +# Maximum number of parents for a class (see R0901). +max-parents=7 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + +# Maximum number of return / yield for function / method body. +max-returns=6 + +# Maximum number of statements in function / method body. +max-statements=50 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=2 + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when caught. +overgeneral-exceptions=builtins.BaseException,builtins.Exception + + +[FORMAT] + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^\s*(# )??$ + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + +# Maximum number of characters on a single line. +max-line-length=100 + +# Maximum number of lines in a module. +max-module-lines=1000 + +# Allow the body of a class to be on the same line as the declaration if body +# contains single statement. +single-line-class-stmt=no + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=no + + +[IMPORTS] + +# List of modules that can be imported at any level, not just the top level +# one. +allow-any-import-level= + +# Allow explicit reexports by alias from a package __init__. +allow-reexport-from-package=no + +# Allow wildcard imports from modules that define __all__. +allow-wildcard-with-all=no + +# Deprecated modules which should not be used, separated by a comma. +deprecated-modules= + +# Output a graph (.gv or any supported image format) of external dependencies +# to the given file (report RP0402 must not be disabled). +ext-import-graph= + +# Output a graph (.gv or any supported image format) of all (i.e. internal and +# external) dependencies to the given file (report RP0402 must not be +# disabled). +import-graph= + +# Output a graph (.gv or any supported image format) of internal dependencies +# to the given file (report RP0402 must not be disabled). +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant + +# Couples of modules and preferred modules, separated by a comma. +preferred-modules= + + +[LOGGING] + +# The type of string formatting that logging methods do. `old` means using % +# formatting, `new` is for `{}` formatting. +logging-format-style=old + +# Logging modules to check that the string format arguments are in logging +# function parameter format. +logging-modules=logging + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE, +# UNDEFINED. +confidence=HIGH, + CONTROL_FLOW, + INFERENCE, + INFERENCE_FAILURE, + UNDEFINED + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once). You can also use "--disable=all" to +# disable everything first and then re-enable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use "--disable=all --enable=classes +# --disable=W". +disable=raw-checker-failed, + bad-inline-option, + locally-disabled, + file-ignored, + suppressed-message, + useless-suppression, + deprecated-pragma, + use-symbolic-message-instead, + use-implicit-booleaness-not-comparison-to-string, + use-implicit-booleaness-not-comparison-to-zero, + missing-module-docstring, + missing-class-docstring, + missing-function-docstring + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +enable= + + +[METHOD_ARGS] + +# List of qualified names (i.e., library.method) which require a timeout +# parameter e.g. 'requests.api.get,requests.api.post' +timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME, + XXX, + TODO + +# Regular expression of note tags to take in consideration. +notes-rgx= + + +[REFACTORING] + +# Maximum number of nested blocks for function / method body +max-nested-blocks=5 + +# Complete name of functions that never returns. When checking for +# inconsistent-return-statements if a never returning function is called then +# it will be considered as an explicit return statement and no message will be +# printed. +never-returning-functions=sys.exit,argparse.parse_error + + +[REPORTS] + +# Python expression which should return a score less than or equal to 10. You +# have access to the variables 'fatal', 'error', 'warning', 'refactor', +# 'convention', and 'info' which contain the number of messages in each +# category, as well as 'statement' which is the total number of statements +# analyzed. This score is used by the global evaluation report (RP0004). +evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details. +msg-template= + +# Set the output format. Available formats are: text, parseable, colorized, +# json2 (improved json format), json (old json format) and msvs (visual +# studio). You can also give a reporter class, e.g. +# mypackage.mymodule.MyReporterClass. +#output-format=colorized + +# Tells whether to display a full report or only the messages. +reports=yes + +# Activate the evaluation score. +score=yes + + +[SIMILARITIES] + +# Comments are removed from the similarity computation +ignore-comments=yes + +# Docstrings are removed from the similarity computation +ignore-docstrings=yes + +# Imports are removed from the similarity computation +ignore-imports=yes + +# Signatures are removed from the similarity computation +ignore-signatures=yes + +# Minimum lines number of a similarity. +min-similarity-lines=4 + + +[SPELLING] + +# Limits count of emitted suggestions for spelling mistakes. +max-spelling-suggestions=4 + +# Spelling dictionary name. No available dictionaries : You need to install +# both the python package and the system dependency for enchant to work. +spelling-dict= + +# List of comma separated words that should be considered directives if they +# appear at the beginning of a comment and should not be checked. +spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy: + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains the private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to the private dictionary (see the +# --spelling-private-dict-file option) instead of raising a message. +spelling-store-unknown-words=no + + +[STRING] + +# This flag controls whether inconsistent-quotes generates a warning when the +# character used as a quote delimiter is used inconsistently within a module. +check-quote-consistency=no + +# This flag controls whether the implicit-str-concat should generate a warning +# on implicit string concatenation in sequences defined over several lines. +check-str-concat-over-line-jumps=no + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + +# Tells whether to warn about missing members when the owner of the attribute +# is inferred to be None. +ignore-none=yes + +# This flag controls whether pylint should warn about no-member and similar +# checks whenever an opaque object is returned when inferring. The inference +# can return multiple potential results while evaluating a Python object, but +# some branches might not be evaluated, which results in partial inference. In +# that case, it might be useful to still emit no-member and other checks for +# the rest of the inferred objects. +ignore-on-opaque-inference=yes + +# List of symbolic message names to ignore for Mixin members. +ignored-checks-for-mixins=no-member, + not-async-context-manager, + not-context-manager, + attribute-defined-outside-init + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace + +# Show a hint with possible names when a member name was not found. The aspect +# of finding the hint is based on edit distance. +missing-member-hint=yes + +# The minimum edit distance a name should have in order to be considered a +# similar match for a missing member name. +missing-member-hint-distance=1 + +# The total number of similar names that should be taken in consideration when +# showing a hint for a missing member. +missing-member-max-choices=1 + +# Regex pattern to define which classes are considered mixins. +mixin-class-rgx=.*[Mm]ixin + +# List of decorators that change the signature of a decorated function. +signature-mutators= + + +[VARIABLES] + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid defining new builtins when possible. +additional-builtins= + +# Tells whether unused global variables should be treated as a violation. +allow-global-unused-variables=yes + +# List of names allowed to shadow builtins +allowed-redefined-builtins= + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_, + _cb + +# A regular expression matching the name of dummy variables (i.e. expected to +# not be used). +dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ + +# Argument names that match this expression will be ignored. +ignored-argument-names=_.*|^ignored_|^unused_ + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io diff --git a/README.md b/README.md index b34d6d6b..9f3af03f 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -Team TicTech +Team SuperSonics via TicTech Project -- Feature Development Backend: Create CRUD API's for Client @@ -22,15 +22,22 @@ This also has an API file to interact with the front end, and logic in order to -------------------------How to Use------------------------- 1. In the virtual environment you've created for this project, install all dependencies in requirements.txt (pip install -r requirements.txt) -2. Run the app (uvicorn app.main:app --reload) +2. Create a .env file with the following fields: +```markdown +SECRET_KEY = "your-secret-key-here" +ALGORITHM = "HS256" +ACCESS_TOKEN_EXPIRE_MINUTES = 30 +``` -3. Load data into database (python initialize_data.py) +3. Run the app (uvicorn app.main:app --reload) -4. Go to SwaggerUI (http://127.0.0.1:8000/docs) +4. Load data into database (python initialize_data.py) -4. Log in as admin (username: admin password: admin123) +5. Go to SwaggerUI (http://127.0.0.1:8000/docs) -5. Click on each endpoint to use +6. Log in as admin (username: admin password: admin123) + +7. Click on each endpoint to use -Create User (Only users in admin role can create new users. The role field needs to be either "admin" or "case_worker") -Get clients (Display all the clients that are in the database) diff --git a/app/auth/router.py b/app/auth/router.py index 229ee71d..edc48165 100644 --- a/app/auth/router.py +++ b/app/auth/router.py @@ -1,28 +1,35 @@ from datetime import datetime, timedelta from typing import Optional + from fastapi import APIRouter, Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from jose import JWTError, jwt +from passlib.context import CryptContext +from pydantic import BaseModel, Field, validator from sqlalchemy.orm import Session + from app.database import get_db from app.models import User, UserRole -from passlib.context import CryptContext -from pydantic import BaseModel, Field, validator + +from dotenv import load_dotenv +import os router = APIRouter(prefix="/auth", tags=["authentication"]) + class UserCreate(BaseModel): username: str = Field(..., min_length=3, max_length=50) email: str password: str role: UserRole - @validator('role') + @validator("role") def validate_role(cls, v): if v not in [UserRole.admin, UserRole.case_worker]: - raise ValueError('Role must be either admin or case_worker') + raise ValueError("Role must be either admin or case_worker") return v + class UserResponse(BaseModel): username: str email: str @@ -31,26 +38,32 @@ class UserResponse(BaseModel): class Config: from_attributes = True -# Configuration -SECRET_KEY = "your-secret-key-here" -ALGORITHM = "HS256" -ACCESS_TOKEN_EXPIRE_MINUTES = 30 + +# Load configuration from .env +load_dotenv() +SECRET_KEY = os.getenv("SECRET_KEY") +ALGORITHM = os.getenv("ALGORITHM") +ACCESS_TOKEN_EXPIRE_MINUTES = os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES") pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/token") + def verify_password(plain_password: str, hashed_password: str) -> bool: return pwd_context.verify(plain_password, hashed_password) + def get_password_hash(password: str) -> str: return pwd_context.hash(password) + def authenticate_user(db: Session, username: str, password: str) -> Optional[User]: user = db.query(User).filter(User.username == username).first() if not user or not verify_password(password, user.hashed_password): return None return user + def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): to_encode = data.copy() if expires_delta: @@ -61,9 +74,9 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) return encoded_jwt + async def get_current_user( - token: str = Depends(oauth2_scheme), - db: Session = Depends(get_db) + token: str = Depends(oauth2_scheme), db: Session = Depends(get_db) ) -> User: credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -77,24 +90,25 @@ async def get_current_user( raise credentials_exception except JWTError: raise credentials_exception - + user = db.query(User).filter(User.username == username).first() if user is None: raise credentials_exception return user + def get_admin_user(current_user: User = Depends(get_current_user)): if current_user.role != UserRole.admin: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail="Only admin users can perform this operation" + detail="Only admin users can perform this operation", ) return current_user + @router.post("/token") async def login_for_access_token( - form_data: OAuth2PasswordRequestForm = Depends(), - db: Session = Depends(get_db) + form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db) ): user = authenticate_user(db, form_data.username, form_data.password) if not user: @@ -109,25 +123,24 @@ async def login_for_access_token( ) return {"access_token": access_token, "token_type": "bearer"} + @router.post("/users", response_model=UserResponse) async def create_user( user_data: UserCreate, current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """Create a new user (admin only)""" # Check if username exists if db.query(User).filter(User.username == user_data.username).first(): raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Username already registered" + status_code=status.HTTP_400_BAD_REQUEST, detail="Username already registered" ) - + # Check if email exists if db.query(User).filter(User.email == user_data.email).first(): raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Email already registered" + status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered" ) # Create new user @@ -135,9 +148,9 @@ async def create_user( username=user_data.username, email=user_data.email, hashed_password=get_password_hash(user_data.password), - role=user_data.role + role=user_data.role, ) - + try: db.add(db_user) db.commit() @@ -145,7 +158,4 @@ async def create_user( return db_user except Exception as e: db.rollback() - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=str(e) - ) + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) diff --git a/app/clients/router.py b/app/clients/router.py index 4ecc83e4..45c385e5 100644 --- a/app/clients/router.py +++ b/app/clients/router.py @@ -3,42 +3,44 @@ Handles all HTTP requests for client operations including create, read, update, and delete. """ -from fastapi import APIRouter, Depends, HTTPException, status, Query -from sqlalchemy.orm import Session from typing import List, Optional -from app.auth.router import get_current_user, get_admin_user -from app.models import User, UserRole -from app.database import get_db -from app.clients.service.client_service import ClientService +from fastapi import APIRouter, Depends, HTTPException, Query, status +from sqlalchemy.orm import Session + +from app.auth.router import get_admin_user, get_current_user from app.clients.schema import ( - ClientResponse, - ClientUpdate, ClientListResponse, + ClientResponse, + ClientUpdate, ServiceResponse, - ServiceUpdate + ServiceUpdate, ) +from app.clients.service.client_service import ClientService +from app.database import get_db +from app.models import User, UserRole router = APIRouter(prefix="/clients", tags=["clients"]) + @router.get("/", response_model=ClientListResponse) async def get_clients( - current_user: User = Depends(get_admin_user), + current_user: User = Depends(get_admin_user), skip: int = Query(default=0, ge=0, description="Number of records to skip"), limit: int = Query(default=50, ge=1, le=150, description="Maximum number of records to return"), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): return ClientService.get_clients(db, skip, limit) + @router.get("/{client_id}", response_model=ClientResponse) async def get_client( - client_id: int, - current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) + client_id: int, current_user: User = Depends(get_admin_user), db: Session = Depends(get_db) ): """Get a specific client by ID""" return ClientService.get_client(db, client_id) + @router.get("/search/by-criteria", response_model=List[ClientResponse]) async def get_clients_by_criteria( employment_status: Optional[bool] = None, @@ -66,7 +68,7 @@ async def get_clients_by_criteria( time_unemployed: Optional[int] = Query(None, ge=0), need_mental_health_support_bool: Optional[bool] = None, current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """Search clients by any combination of criteria""" return ClientService.get_clients_by_criteria( @@ -94,9 +96,10 @@ async def get_clients_by_criteria( attending_school=attending_school, substance_use=substance_use, time_unemployed=time_unemployed, - need_mental_health_support_bool=need_mental_health_support_bool + need_mental_health_support_bool=need_mental_health_support_bool, ) + @router.get("/search/by-services", response_model=List[ClientResponse]) async def get_clients_by_services( employment_assistance: Optional[bool] = None, @@ -107,7 +110,7 @@ async def get_clients_by_services( employer_financial_supports: Optional[bool] = None, enhanced_referrals: Optional[bool] = None, current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """Get clients filtered by multiple service statuses""" return ClientService.get_clients_by_services( @@ -118,70 +121,73 @@ async def get_clients_by_services( specialized_services=specialized_services, employment_related_financial_supports=employment_related_financial_supports, employer_financial_supports=employer_financial_supports, - enhanced_referrals=enhanced_referrals + enhanced_referrals=enhanced_referrals, ) + @router.get("/{client_id}/services", response_model=List[ServiceResponse]) async def get_client_services( - client_id: int, - current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) + client_id: int, current_user: User = Depends(get_admin_user), db: Session = Depends(get_db) ): """Get all services and their status for a specific client, including case worker info""" return ClientService.get_client_services(db, client_id) + @router.get("/search/success-rate", response_model=List[ClientResponse]) async def get_clients_by_success_rate( min_rate: int = Query(70, ge=0, le=100, description="Minimum success rate percentage"), current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """Get clients with success rate above specified threshold""" return ClientService.get_clients_by_success_rate(db, min_rate) + @router.get("/case-worker/{case_worker_id}", response_model=List[ClientResponse]) async def get_clients_by_case_worker( case_worker_id: int, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ): return ClientService.get_clients_by_case_worker(db, case_worker_id) + @router.put("/{client_id}", response_model=ClientResponse) async def update_client( client_id: int, client_data: ClientUpdate, current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """Update a client's information""" return ClientService.update_client(db, client_id, client_data) + @router.put("/{client_id}/services/{user_id}", response_model=ServiceResponse) async def update_client_services( client_id: int, user_id: int, service_update: ServiceUpdate, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ): return ClientService.update_client_services(db, client_id, user_id, service_update) + @router.post("/{client_id}/case-assignment", response_model=ServiceResponse) async def create_case_assignment( client_id: int, case_worker_id: int = Query(..., description="Case worker ID to assign"), current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """Create a new case assignment for a client with a case worker""" return ClientService.create_case_assignment(db, client_id, case_worker_id) + @router.delete("/{client_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_client( - client_id: int, - current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) + client_id: int, current_user: User = Depends(get_admin_user), db: Session = Depends(get_db) ): """Delete a client""" ClientService.delete_client(db, client_id) diff --git a/app/clients/schema.py b/app/clients/schema.py index cff28897..0120d1cf 100644 --- a/app/clients/schema.py +++ b/app/clients/schema.py @@ -3,22 +3,27 @@ Defines schemas for client data, predictions, and API responses. """ +from enum import IntEnum +from typing import List, Optional + # Standard library imports from pydantic import BaseModel, Field, validator -from typing import Optional, List -from enum import IntEnum + from app.models import UserRole + # Enums for validation class Gender(IntEnum): MALE = 1 FEMALE = 2 + class PredictionInput(BaseModel): """ Schema for prediction input data containing all client assessment fields. Used for making predictions about client outcomes. """ + age: int gender: str work_experience: int @@ -44,6 +49,7 @@ class PredictionInput(BaseModel): time_unemployed: int need_mental_health_support_bool: str + class ClientBase(BaseModel): age: int = Field(ge=18, description="Age of client, must be 18 or older") gender: Gender = Field(description="Gender: 1 for male, 2 for female") @@ -96,16 +102,18 @@ class Config: "currently_employed": False, "substance_use": False, "time_unemployed": 6, - "need_mental_health_support_bool": False + "need_mental_health_support_bool": False, } } + class ClientResponse(ClientBase): id: int class Config: from_attributes = True + class ClientUpdate(BaseModel): age: Optional[int] = Field(None, ge=18) gender: Optional[Gender] = None @@ -132,6 +140,7 @@ class ClientUpdate(BaseModel): time_unemployed: Optional[int] = Field(None, ge=0) need_mental_health_support_bool: Optional[bool] = None + class ServiceResponse(BaseModel): client_id: int user_id: int @@ -147,6 +156,7 @@ class ServiceResponse(BaseModel): class Config: from_attributes = True + class ServiceUpdate(BaseModel): employment_assistance: Optional[bool] = None life_stabilization: Optional[bool] = None @@ -157,6 +167,7 @@ class ServiceUpdate(BaseModel): enhanced_referrals: Optional[bool] = None success_rate: Optional[int] = Field(None, ge=0, le=100) + class ClientListResponse(BaseModel): clients: List[ClientResponse] total: int diff --git a/app/clients/service/client_service.py b/app/clients/service/client_service.py index 86c3ef4a..16115ec1 100644 --- a/app/clients/service/client_service.py +++ b/app/clients/service/client_service.py @@ -3,14 +3,87 @@ Provides CRUD operations and business logic for client management. """ -from sqlalchemy.orm import Session -from sqlalchemy import and_ +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional + from fastapi import HTTPException, status -from typing import List, Optional, Dict, Any +from sqlalchemy import and_ +from sqlalchemy.orm import Session + +from app.clients.schema import ClientUpdate, ServiceResponse, ServiceUpdate from app.models import Client, ClientCase, User -from app.clients.schema import ClientUpdate, ServiceUpdate, ServiceResponse -class ClientService: + +class InterfaceClientQueryService(ABC): + """Interface for client query operations""" + + @abstractmethod + def get_client(self, db: Session, client_id: int) -> Client: + """Get a specific client by ID""" + pass + + @abstractmethod + def get_clients(self, db: Session, skip: int, limit: int) -> Dict[str, Any]: + """Get clients with optional pagination.""" + pass + + @abstractmethod + def get_clients_by_criteria(self, db: Session, **criteria) -> List[Client]: + """Get clients filtered by any combination of criteria""" + pass + + @abstractmethod + def get_clients_by_services(self, db: Session, **service_filters) -> List[Client]: + """Get clients filtered by multiple service statuses.""" + + pass + + @abstractmethod + def get_client_services(self, db: Session, client_id: int) -> List[ClientCase]: + pass + + @abstractmethod + def get_clients_by_success_rate(self, db: Session, min_rate: int) -> List[Client]: + pass + + @abstractmethod + def get_clients_by_case_worker(self, db: Session, case_worker_id: int) -> List[Client]: + pass + + +class InterfaceClientManagementService(ABC): + """Interface for client management operations""" + + @abstractmethod + def update_client( + self, db: Session, client_id: int, client_update: ClientUpdate + ) -> ClientUpdate: + """Update a client's information""" + pass + + @abstractmethod + def update_client_services( + self, db: Session, client_id: int, user_id: int, service_update: ServiceUpdate + ) -> ClientCase: + """Update a client's services and outcomes for a specific caseworker""" + pass + + @abstractmethod + def create_case_assignment( + self, db: Session, client_id: int, case_worker_id: int + ) -> ClientCase: + """Create a new case assignment""" + pass + + @abstractmethod + def delete_client(self, db: Session, client_id: int) -> None: + """Delete a client and their associated records""" + pass + + +class ClientQueryService(InterfaceClientQueryService): + """Implementation of client query service""" + @staticmethod def get_client(db: Session, client_id: int): """Get a specific client by ID""" @@ -18,7 +91,7 @@ def get_client(db: Session, client_id: int): if not client: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Client with id {client_id} not found" + detail=f"Client with id {client_id} not found", ) return client @@ -30,15 +103,13 @@ def get_clients(db: Session, skip: int = 0, limit: int = 50): """ if skip < 0: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Skip value cannot be negative" + status_code=status.HTTP_400_BAD_REQUEST, detail="Skip value cannot be negative" ) if limit < 1: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Limit must be greater than 0" + status_code=status.HTTP_400_BAD_REQUEST, detail="Limit must be greater than 0" ) - + clients = db.query(Client).offset(skip).limit(limit).all() total = db.query(Client).count() return {"clients": clients, "total": total} @@ -69,27 +140,25 @@ def get_clients_by_criteria( attending_school: Optional[bool] = None, substance_use: Optional[bool] = None, time_unemployed: Optional[int] = None, - need_mental_health_support_bool: Optional[bool] = None + need_mental_health_support_bool: Optional[bool] = None, ): """Get clients filtered by any combination of criteria""" query = db.query(Client) - + if education_level is not None and not (1 <= education_level <= 14): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Education level must be between 1 and 14" + detail="Education level must be between 1 and 14", ) - + if age_min is not None and age_min < 18: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Minimum age must be at least 18" + status_code=status.HTTP_400_BAD_REQUEST, detail="Minimum age must be at least 18" ) if gender is not None and gender not in [1, 2]: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Gender must be 1 or 2" + status_code=status.HTTP_400_BAD_REQUEST, detail="Gender must be 1 or 2" ) # Apply filters for non-None values @@ -140,47 +209,46 @@ def get_clients_by_criteria( if time_unemployed is not None: query = query.filter(Client.time_unemployed == time_unemployed) if need_mental_health_support_bool is not None: - query = query.filter(Client.need_mental_health_support_bool == need_mental_health_support_bool) + query = query.filter( + Client.need_mental_health_support_bool == need_mental_health_support_bool + ) try: return query.all() except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error retrieving clients: {str(e)}" + detail=f"Error retrieving clients: {str(e)}", ) @staticmethod - def get_clients_by_services( - db: Session, - **service_filters: Optional[bool] - ): + def get_clients_by_services(db: Session, **service_filters: Optional[bool]): """ Get clients filtered by multiple service statuses. """ query = db.query(Client).join(ClientCase) - + for service_name, status in service_filters.items(): if status is not None: filter_criteria = getattr(ClientCase, service_name) == status query = query.filter(filter_criteria) - + try: return query.all() except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error retrieving clients: {str(e)}" + detail=f"Error retrieving clients: {str(e)}", ) @staticmethod def get_client_services(db: Session, client_id: int): - """Get all services for a specific client with case worker info""" + """Get all services for a specific client with caseworker info""" client_cases = db.query(ClientCase).filter(ClientCase.client_id == client_id).all() if not client_cases: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"No services found for client with id {client_id}" + detail=f"No services found for client with id {client_id}", ) return client_cases @@ -190,12 +258,10 @@ def get_clients_by_success_rate(db: Session, min_rate: int = 70): if not (0 <= min_rate <= 100): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Success rate must be between 0 and 100" + detail="Success rate must be between 0 and 100", ) - - return db.query(Client).join(ClientCase).filter( - ClientCase.success_rate >= min_rate - ).all() + + return db.query(Client).join(ClientCase).filter(ClientCase.success_rate >= min_rate).all() @staticmethod def get_clients_by_case_worker(db: Session, case_worker_id: int): @@ -204,12 +270,14 @@ def get_clients_by_case_worker(db: Session, case_worker_id: int): if not case_worker: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Case worker with id {case_worker_id} not found" + detail=f"Case worker with id {case_worker_id} not found", ) - - return db.query(Client).join(ClientCase).filter( - ClientCase.user_id == case_worker_id - ).all() + + return db.query(Client).join(ClientCase).filter(ClientCase.user_id == case_worker_id).all() + + +class ClientManagementService(InterfaceClientManagementService): + """Implementation of client management service""" @staticmethod def update_client(db: Session, client_id: int, client_update: ClientUpdate): @@ -218,7 +286,7 @@ def update_client(db: Session, client_id: int, client_update: ClientUpdate): if not client: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Client with id {client_id} not found" + detail=f"Client with id {client_id} not found", ) update_data = client_update.dict(exclude_unset=True) @@ -233,27 +301,25 @@ def update_client(db: Session, client_id: int, client_update: ClientUpdate): db.rollback() raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to update client: {str(e)}" + detail=f"Failed to update client: {str(e)}", ) - + @staticmethod def update_client_services( - db: Session, - client_id: int, - user_id: int, - service_update: ServiceUpdate + db: Session, client_id: int, user_id: int, service_update: ServiceUpdate ): """Update a client's services and outcomes for a specific case worker""" - client_case = db.query(ClientCase).filter( - ClientCase.client_id == client_id, - ClientCase.user_id == user_id - ).first() - + client_case = ( + db.query(ClientCase) + .filter(ClientCase.client_id == client_id, ClientCase.user_id == user_id) + .first() + ) + if not client_case: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"No case found for client {client_id} with case worker {user_id}. " - f"Cannot update services for a non-existent case assignment." + f"Cannot update services for a non-existent case assignment.", ) update_data = service_update.dict(exclude_unset=True) @@ -268,43 +334,40 @@ def update_client_services( db.rollback() raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to update client services: {str(e)}" + detail=f"Failed to update client services: {str(e)}", ) - + @staticmethod - def create_case_assignment( - db: Session, - client_id: int, - case_worker_id: int - ): + def create_case_assignment(db: Session, client_id: int, case_worker_id: int): """Create a new case assignment""" # Check if client exists client = db.query(Client).filter(Client.id == client_id).first() if not client: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Client with id {client_id} not found" + detail=f"Client with id {client_id} not found", ) - # Check if case worker exists + # Check if caseworker exists case_worker = db.query(User).filter(User.id == case_worker_id).first() if not case_worker: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Case worker with id {case_worker_id} not found" + detail=f"Case worker with id {case_worker_id} not found", ) # Check if assignment already exists - existing_case = db.query(ClientCase).filter( - ClientCase.client_id == client_id, - ClientCase.user_id == case_worker_id - ).first() - + existing_case = ( + db.query(ClientCase) + .filter(ClientCase.client_id == client_id, ClientCase.user_id == case_worker_id) + .first() + ) + if existing_case: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Client {client_id} already has a case assigned to case worker {case_worker_id}" - ) + detail=f"Client {client_id} already has a case assigned to case worker {case_worker_id}", + ) try: # Create new case assignment with default service values @@ -318,7 +381,7 @@ def create_case_assignment( employment_related_financial_supports=False, employer_financial_supports=False, enhanced_referrals=False, - success_rate=0 + success_rate=0, ) db.add(new_case) db.commit() @@ -329,9 +392,9 @@ def create_case_assignment( db.rollback() raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to create case assignment: {str(e)}" + detail=f"Failed to create case assignment: {str(e)}", ) - + @staticmethod def delete_client(db: Session, client_id: int): """Delete a client and their associated records""" @@ -340,22 +403,77 @@ def delete_client(db: Session, client_id: int): if not client: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Client with id {client_id} not found" + detail=f"Client with id {client_id} not found", ) try: # Delete associated client_cases - db.query(ClientCase).filter( - ClientCase.client_id == client_id - ).delete() - + db.query(ClientCase).filter(ClientCase.client_id == client_id).delete() + # Delete the client db.delete(client) db.commit() - + except Exception as e: db.rollback() raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to delete client: {str(e)}" + detail=f"Failed to delete client: {str(e)}", ) + + +class ClientService: + """ + Facade that maintains backward compatibility with the existing router. + Delegates to specialized service classes. + """ + + # Query methods + @staticmethod + def get_client(db: Session, client_id: int): + return ClientQueryService.get_client(db, client_id) + + @staticmethod + def get_clients(db: Session, skip: int = 0, limit: int = 50): + return ClientQueryService.get_clients(db, skip, limit) + + @staticmethod + def get_clients_by_criteria(db: Session, **criteria): + return ClientQueryService.get_clients_by_criteria(db, **criteria) + + @staticmethod + def get_clients_by_services(db: Session, **service_filters): + return ClientQueryService.get_clients_by_services(db, **service_filters) + + @staticmethod + def get_client_services(db: Session, client_id: int): + return ClientQueryService.get_client_services(db, client_id) + + @staticmethod + def get_clients_by_success_rate(db: Session, min_rate: int = 70): + return ClientQueryService.get_clients_by_success_rate(db, min_rate) + + @staticmethod + def get_clients_by_case_worker(db: Session, case_worker_id: int): + return ClientQueryService.get_clients_by_case_worker(db, case_worker_id) + + # Modification methods + @staticmethod + def update_client(db: Session, client_id: int, client_update: ClientUpdate): + return ClientManagementService.update_client(db, client_id, client_update) + + @staticmethod + def update_client_services( + db: Session, client_id: int, user_id: int, service_update: ServiceUpdate + ): + return ClientManagementService.update_client_services( + db, client_id, user_id, service_update + ) + + @staticmethod + def create_case_assignment(db: Session, client_id: int, case_worker_id: int): + return ClientManagementService.create_case_assignment(db, client_id, case_worker_id) + + @staticmethod + def delete_client(db: Session, client_id: int): + return ClientManagementService.delete_client(db, client_id) diff --git a/app/clients/service/logic.py b/app/clients/service/logic.py index c25b4217..290e1dd9 100644 --- a/app/clients/service/logic.py +++ b/app/clients/service/logic.py @@ -5,30 +5,33 @@ # Standard library imports import os -#import json -from itertools import product # Third-party imports import pickle + +# import json +from itertools import product + import numpy as np # Constants COLUMN_INTERVENTIONS = [ - 'Life Stabilization', - 'General Employment Assistance Services', - 'Retention Services', - 'Specialized Services', - 'Employment-Related Financial Supports for Job Seekers and Employers', - 'Employer Financial Supports', - 'Enhanced Referrals for Skills Development' + "Life Stabilization", + "General Employment Assistance Services", + "Retention Services", + "Specialized Services", + "Employment-Related Financial Supports for Job Seekers and Employers", + "Employer Financial Supports", + "Enhanced Referrals for Skills Development", ] # Load model CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) -MODEL_PATH = os.path.join(CURRENT_DIR, 'model.pkl') +MODEL_PATH = os.path.join(CURRENT_DIR, "model.pkl") with open(MODEL_PATH, "rb") as model_file: MODEL = pickle.load(model_file) + def clean_input_data(input_data): """ Clean and transform input data into model-compatible format. @@ -40,13 +43,30 @@ def clean_input_data(input_data): list: Cleaned and formatted data ready for model input """ columns = [ - "age", "gender", "work_experience", "canada_workex", "dep_num", - "canada_born", "citizen_status", "level_of_schooling", "fluent_english", - "reading_english_scale", "speaking_english_scale", "writing_english_scale", - "numeracy_scale", "computer_scale", "transportation_bool", "caregiver_bool", - "housing", "income_source", "felony_bool", "attending_school", - "currently_employed", "substance_use", "time_unemployed", - "need_mental_health_support_bool" + "age", + "gender", + "work_experience", + "canada_workex", + "dep_num", + "canada_born", + "citizen_status", + "level_of_schooling", + "fluent_english", + "reading_english_scale", + "speaking_english_scale", + "writing_english_scale", + "numeracy_scale", + "computer_scale", + "transportation_bool", + "caregiver_bool", + "housing", + "income_source", + "felony_bool", + "attending_school", + "currently_employed", + "substance_use", + "time_unemployed", + "need_mental_health_support_bool", ] demographics = {key: input_data[key] for key in columns} output = [] @@ -57,6 +77,7 @@ def clean_input_data(input_data): output.append(value) return output + def convert_text(text_data: str): """ Convert text answers from front end into numerical values. @@ -68,33 +89,47 @@ def convert_text(text_data: str): int: Converted numerical value """ categorical_mappings = [ + {"": 0, "true": 1, "false": 0, "no": 0, "yes": 1, "No": 0, "Yes": 1}, { - "": 0, "true": 1, "false": 0, "no": 0, "yes": 1, - "No": 0, "Yes": 1 - }, - { - "Grade 0-8": 1, "Grade 9": 2, "Grade 10": 3, "Grade 11": 4, - "Grade 12 or equivalent": 5, "OAC or Grade 13": 6, - "Some college": 7, "Some university": 8, "Some apprenticeship": 9, - "Certificate of Apprenticeship": 10, "Journeyperson": 11, - "Certificate/Diploma": 12, "Bachelor's degree": 13, - "Post graduate": 14 + "Grade 0-8": 1, + "Grade 9": 2, + "Grade 10": 3, + "Grade 11": 4, + "Grade 12 or equivalent": 5, + "OAC or Grade 13": 6, + "Some college": 7, + "Some university": 8, + "Some apprenticeship": 9, + "Certificate of Apprenticeship": 10, + "Journeyperson": 11, + "Certificate/Diploma": 12, + "Bachelor's degree": 13, + "Post graduate": 14, }, { - "Renting-private": 1, "Renting-subsidized": 2, - "Boarding or lodging": 3, "Homeowner": 4, - "Living with family/friend": 5, "Institution": 6, - "Temporary second residence": 7, "Band-owned home": 8, - "Homeless or transient": 9, "Emergency hostel": 10 + "Renting-private": 1, + "Renting-subsidized": 2, + "Boarding or lodging": 3, + "Homeowner": 4, + "Living with family/friend": 5, + "Institution": 6, + "Temporary second residence": 7, + "Band-owned home": 8, + "Homeless or transient": 9, + "Emergency hostel": 10, }, { - "No Source of Income": 1, "Employment Insurance": 2, + "No Source of Income": 1, + "Employment Insurance": 2, "Workplace Safety and Insurance Board": 3, "Ontario Works applied or receiving": 4, "Ontario Disability Support Program applied or receiving": 5, - "Dependent of someone receiving OW or ODSP": 6, "Crown Ward": 7, - "Employment": 8, "Self-Employment": 9, "Other (specify)": 10 - } + "Dependent of someone receiving OW or ODSP": 6, + "Crown Ward": 7, + "Employment": 8, + "Self-Employment": 9, + "Other (specify)": 10, + }, ] for category in categorical_mappings: if text_data in category: @@ -102,6 +137,7 @@ def convert_text(text_data: str): return int(text_data) if text_data.isnumeric() else text_data + def create_matrix(row_data): """ Create matrix of all possible intervention combinations. @@ -116,6 +152,7 @@ def create_matrix(row_data): perms = intervention_permutations(7) return np.concatenate((np.array(data), np.array(perms)), axis=1) + def intervention_permutations(num): """ Generate all possible intervention combinations. @@ -128,6 +165,7 @@ def intervention_permutations(num): """ return np.array(list(product([0, 1], repeat=num))) + def get_baseline_row(row_data): """ Create baseline row with no interventions. @@ -141,6 +179,7 @@ def get_baseline_row(row_data): base_interventions = np.zeros(7) return np.concatenate((np.array(row_data), base_interventions)) + def intervention_row_to_names(row_data): """ Convert intervention row to list of intervention names. @@ -153,6 +192,7 @@ def intervention_row_to_names(row_data): """ return [COLUMN_INTERVENTIONS[i] for i, value in enumerate(row_data) if value == 1] + def process_results(baseline_pred, results_matrix): """ Process model results into structured output. @@ -164,14 +204,9 @@ def process_results(baseline_pred, results_matrix): Returns: dict: Processed results with baseline and interventions """ - result_list = [ - (row[-1], intervention_row_to_names(row[:-1])) - for row in results_matrix - ] - return { - "baseline": baseline_pred[-1], - "interventions": result_list - } + result_list = [(row[-1], intervention_row_to_names(row[:-1])) for row in results_matrix] + return {"baseline": baseline_pred[-1], "interventions": result_list} + def interpret_and_calculate(input_data): """ @@ -194,19 +229,33 @@ def interpret_and_calculate(input_data): top_results = result_matrix[-3:, -8:] return process_results(baseline_prediction, top_results) + if __name__ == "__main__": test_data = { - "age": "23", "gender": "1", "work_experience": "1", - "canada_workex": "1", "dep_num": "0", "canada_born": "1", - "citizen_status": "2", "level_of_schooling": "2", - "fluent_english": "3", "reading_english_scale": "2", - "speaking_english_scale": "2", "writing_english_scale": "3", - "numeracy_scale": "2", "computer_scale": "3", - "transportation_bool": "2", "caregiver_bool": "1", - "housing": "1", "income_source": "5", "felony_bool": "1", - "attending_school": "0", "currently_employed": "1", - "substance_use": "1", "time_unemployed": "1", - "need_mental_health_support_bool": "1" + "age": "23", + "gender": "1", + "work_experience": "1", + "canada_workex": "1", + "dep_num": "0", + "canada_born": "1", + "citizen_status": "2", + "level_of_schooling": "2", + "fluent_english": "3", + "reading_english_scale": "2", + "speaking_english_scale": "2", + "writing_english_scale": "3", + "numeracy_scale": "2", + "computer_scale": "3", + "transportation_bool": "2", + "caregiver_bool": "1", + "housing": "1", + "income_source": "5", + "felony_bool": "1", + "attending_school": "0", + "currently_employed": "1", + "substance_use": "1", + "time_unemployed": "1", + "need_mental_health_support_bool": "1", } results = interpret_and_calculate(test_data) print(results) diff --git a/app/clients/service/ml_models.py b/app/clients/service/ml_models.py new file mode 100644 index 00000000..50957926 --- /dev/null +++ b/app/clients/service/ml_models.py @@ -0,0 +1,147 @@ +from abc import ABC, abstractmethod +from typing import List + +import numpy as np +from sklearn.ensemble import RandomForestRegressor +from sklearn.linear_model import LinearRegression +from sklearn.svm import SVR + + +class InterfaceBaseMLModel(ABC): + """Interface of a base ML Model""" + + @abstractmethod + def fit(self, X: np.ndarray, y: np.ndarray): + pass + + @abstractmethod + def predict(self, X: np.ndarray) -> np.ndarray: + pass + + def save(self, path: str): + import pickle + + with open(path, "wb") as f: + pickle.dump(self, f) + + @staticmethod + def load(path: str): + import pickle + + with open(path, "rb") as f: + return pickle.load(f) + + @abstractmethod + def __str__(self) -> str: + """Return the name of the model""" + pass + + +class LinearRegressionModel(InterfaceBaseMLModel): + def __init__(self): + self.model = LinearRegression() + + def fit(self, X, y): + self.model.fit(X, y) + + def predict(self, X): + return self.model.predict(X) + + def __str__(self): + return "Linear Regression" + + +class RandomForestModel(InterfaceBaseMLModel): + def __init__(self, n_estimators=100, random_state=42): + self.model = RandomForestRegressor(n_estimators=n_estimators, random_state=random_state) + + def fit(self, X, y): + self.model.fit(X, y) + + def predict(self, X): + return self.model.predict(X) + + def __str__(self): + return "Random Forest Regressor" + + +class SVMModel(InterfaceBaseMLModel): + def __init__(self): + self.model = SVR() + + def fit(self, X, y): + self.model.fit(X, y) + + def predict(self, X): + return self.model.predict(X) + + def __str__(self): + return "Support Vector Machine" + + +class InterfaceMLModelRepository(ABC): + """Interface for ML Models storage""" + + @abstractmethod + def list_models(self) -> List[InterfaceBaseMLModel]: + """Get list of all available models instances""" + pass + + @abstractmethod + def is_model_available(self, model_name: str) -> bool: + """Check if a model is valid""" + pass + + @abstractmethod + def get_model_instance(self, model_name: str) -> InterfaceBaseMLModel: + """Return an instance of the requested model""" + pass + + +class InterfaceMLModelManager(ABC): + """Interface for ML model management""" + + @abstractmethod + def get_current_model(self) -> InterfaceBaseMLModel: + """Get the current active ml model""" + pass + + @abstractmethod + def switch_model(self, model_name: str) -> bool: + """Switch between models""" + pass + + +class MLModelRepository(InterfaceMLModelRepository): + def __init__(self): + self._model_map = { + "Linear Regression": LinearRegressionModel, + "Random Forest Regressor": RandomForestModel, + "Support Vector Machine": SVMModel, + } + + def list_models(self) -> List[InterfaceBaseMLModel]: + return [model_class() for model_class in self._model_map.values()] + + def is_model_available(self, model_name: str) -> bool: + return model_name in self._model_map + + def get_model_instance(self, model_name: str) -> InterfaceBaseMLModel: + if not self.is_model_available(model_name): + raise ValueError(f"Model '{model_name}' is not available.") + return self._model_map[model_name]() + + +class MLModelManager(InterfaceMLModelManager): + def __init__(self, repository: InterfaceMLModelRepository): + self._repository = repository + self._current_model = repository.get_model_instance("Random Forest Regressor") + + def get_current_model(self) -> str: + return self._current_model + + def switch_model(self, model_name: str) -> bool: + if self._repository.is_model_available(model_name): + self._current_model = self._repository.get_model_instance(model_name) + return True + return False diff --git a/app/clients/service/ml_models_router.py b/app/clients/service/ml_models_router.py new file mode 100644 index 00000000..d1070153 --- /dev/null +++ b/app/clients/service/ml_models_router.py @@ -0,0 +1,30 @@ +from fastapi import APIRouter, HTTPException + +from app.clients.service.ml_models import MLModelManager, MLModelRepository + +router = APIRouter(prefix="/ml_models") +model_repository = MLModelRepository() +model_manager = MLModelManager(model_repository) + + +@router.get("/list") +def list_models(): + """List all available ML models""" + # return {"models": model_repository.list_models()} + return {"models": [str(model) for model in model_repository.list_models()]} + + +@router.post("/switch/{model_name}") +def switch_models(model_name: str): + """Switch between ML models""" + success = model_manager.switch_model(model_name) + if not success: + raise HTTPException(status_code=400, detail="Model switch failed") + return {"message": f"Model switched to {model_name}"} + + +@router.get("/current") +def current_model(): + """Get the current ML model""" + # return {"current_model": model_manager.get_current_model()} + return {"current_model": str(model_manager.get_current_model())} diff --git a/app/clients/service/model.py b/app/clients/service/model.py index b2406370..cba85686 100644 --- a/app/clients/service/model.py +++ b/app/clients/service/model.py @@ -1,110 +1,193 @@ """ Model training module for the Common Assessment Tool. Handles the preparation, training, and saving of the prediction model. +Pass in model name via command line """ +import os + # Standard library imports import pickle +import sys # Third-party imports import numpy as np import pandas as pd -from sklearn.model_selection import train_test_split + +# Local imports +from ml_models import ( + InterfaceBaseMLModel, + LinearRegressionModel, + MLModelRepository, + RandomForestModel, + SVMModel, +) +from sklearn import svm from sklearn.ensemble import RandomForestRegressor +from sklearn.linear_model import LinearRegression +from sklearn.model_selection import train_test_split + +repo = MLModelRepository() + +default_unformatted_model_path = "pretrained_models" + os.sep + "model_{}.pkl" + + +def get_model_by_name(model_type: str, n_estimators=100, random_state=42) -> InterfaceBaseMLModel: + model_map = { + "Linear Regression": LinearRegressionModel, + "Random Forest Regressor": lambda: RandomForestModel(n_estimators, random_state), + "Support Vector Machine": SVMModel, + } + + if model_type not in model_map: + print(f"ERROR! Invalid model type '{model_type}' passed in.") + print(f"Available models: {repo.list_models()}") + sys.exit(-1) -def prepare_models(): + constructor = model_map[model_type] + return constructor() if callable(constructor) else constructor() + + +def prepare_model_data(test_size=0.2, random_state=42): """ Prepare and train the Random Forest model using the dataset. - + Args: + test_size: The percent of the dataset to use as test data (rest will be used as train data) + random_state: The random state to generate train/test split with + Returns: RandomForestRegressor: Trained model for predicting success rates """ # Load dataset - data = pd.read_csv('data_commontool.csv') + data = pd.read_csv("data_commontool.csv") # Define feature columns feature_columns = [ - 'age', # Client's age - 'gender', # Client's gender (bool) - 'work_experience', # Years of work experience - 'canada_workex', # Years of work experience in Canada - 'dep_num', # Number of dependents - 'canada_born', # Born in Canada - 'citizen_status', # Citizenship status - 'level_of_schooling', # Highest level achieved (1-14) - 'fluent_english', # English fluency scale (1-10) - 'reading_english_scale', # Reading ability scale (1-10) - 'speaking_english_scale',# Speaking ability scale (1-10) - 'writing_english_scale', # Writing ability scale (1-10) - 'numeracy_scale', # Numeracy ability scale (1-10) - 'computer_scale', # Computer proficiency scale (1-10) - 'transportation_bool', # Needs transportation support (bool) - 'caregiver_bool', # Is primary caregiver (bool) - 'housing', # Housing situation (1-10) - 'income_source', # Source of income (1-10) - 'felony_bool', # Has a felony (bool) - 'attending_school', # Currently a student (bool) - 'currently_employed', # Currently employed (bool) - 'substance_use', # Substance use disorder (bool) - 'time_unemployed', # Years unemployed - 'need_mental_health_support_bool' # Needs mental health support (bool) + "age", # Client's age + "gender", # Client's gender (bool) + "work_experience", # Years of work experience + "canada_workex", # Years of work experience in Canada + "dep_num", # Number of dependents + "canada_born", # Born in Canada + "citizen_status", # Citizenship status + "level_of_schooling", # Highest level achieved (1-14) + "fluent_english", # English fluency scale (1-10) + "reading_english_scale", # Reading ability scale (1-10) + "speaking_english_scale", # Speaking ability scale (1-10) + "writing_english_scale", # Writing ability scale (1-10) + "numeracy_scale", # Numeracy ability scale (1-10) + "computer_scale", # Computer proficiency scale (1-10) + "transportation_bool", # Needs transportation support (bool) + "caregiver_bool", # Is primary caregiver (bool) + "housing", # Housing situation (1-10) + "income_source", # Source of income (1-10) + "felony_bool", # Has a felony (bool) + "attending_school", # Currently a student (bool) + "currently_employed", # Currently employed (bool) + "substance_use", # Substance use disorder (bool) + "time_unemployed", # Years unemployed + "need_mental_health_support_bool", # Needs mental health support (bool) ] # Define intervention columns intervention_columns = [ - 'employment_assistance', - 'life_stabilization', - 'retention_services', - 'specialized_services', - 'employment_related_financial_supports', - 'employer_financial_supports', - 'enhanced_referrals' + "employment_assistance", + "life_stabilization", + "retention_services", + "specialized_services", + "employment_related_financial_supports", + "employer_financial_supports", + "enhanced_referrals", ] # Combine all feature columns all_features = feature_columns + intervention_columns # Prepare training data features = np.array(data[all_features]) # Changed from X to features - targets = np.array(data['success_rate']) # Changed from y to targets + targets = np.array(data["success_rate"]) # Changed from y to targets # Split the dataset - features_train, _, targets_train, _ = train_test_split( # Removed unused variables + X_train, x_test, Y_train, y_test = train_test_split( + # Removed unused variables features, targets, - test_size=0.2, - random_state=42 + test_size=test_size, + random_state=random_state, ) - # Initialize and train the model - model = RandomForestRegressor(n_estimators=100, random_state=42) - model.fit(features_train, targets_train) + + return X_train, x_test, Y_train, y_test + + +def train_model( + X_train, Y_train, model_type, n_estimators=100, random_state=42 +) -> InterfaceBaseMLModel: + """ + Trains the model + Args: + X_train: Training features + targets_train: Target features + Y_train: Which model to create + n_estimators: Number estimators (for random forest) + random_state: Random state to train with (for random forest) + + Returns: A trained model of the type specified + + """ + model = get_model_by_name(model_type, n_estimators, random_state) + model.fit(X_train, Y_train) return model -def save_model(model, filename="model.pkl"): + +def get_true_file_name(model_type, filename): + """ + Takes a model type and file name, formats model type, and replaces spaces with underscores + Args: + model_type: The model type as a String + filename: The file name (should follow 'model_{}.pkl' format) + + Returns: The clean file name + """ + return filename.format(model_type).replace(" ", "_") + + +def save_model(model, model_type, filename=default_unformatted_model_path): """ Save the trained model to a file. - + Args: model: Trained model to save + model_type: The type of model being saved filename (str): Name of the file to save the model to """ - with open(filename, "wb") as model_file: + true_file_name = get_true_file_name(model_type, filename) + with open(true_file_name, "wb") as model_file: pickle.dump(model, model_file) -def load_model(filename="model.pkl"): + +def load_model(model_type, filename=default_unformatted_model_path): """ Load a trained model from a file. - + Args: + model_type: The type of model being loaded filename (str): Name of the file to load the model from - + Returns: The loaded model """ - with open(filename, "rb") as model_file: + true_file_name = get_true_file_name(model_type, filename) + with open(true_file_name, "rb") as model_file: return pickle.load(model_file) -def main(): + +def main(argv): """Main function to train and save the model.""" - print("Starting model training...") - model = prepare_models() - save_model(model) + # Get the model type from the command line arguments + model_type = argv[1] + + # Train and save the model + print("Starting model training for {} model...".format(model_type)) + X_train, x_test, Y_train, y_test = prepare_model_data() + model = train_model(X_train, Y_train, model_type) + save_model(model, model_type) print("Model training completed and saved successfully.") + if __name__ == "__main__": - main() + main(sys.argv) diff --git a/app/clients/service/pretrained_models/model_Linear_Regression.pkl b/app/clients/service/pretrained_models/model_Linear_Regression.pkl new file mode 100644 index 00000000..d264f1b9 Binary files /dev/null and b/app/clients/service/pretrained_models/model_Linear_Regression.pkl differ diff --git a/app/clients/service/pretrained_models/model_Random_Forest_Regressor.pkl b/app/clients/service/pretrained_models/model_Random_Forest_Regressor.pkl new file mode 100644 index 00000000..cde4340f Binary files /dev/null and b/app/clients/service/pretrained_models/model_Random_Forest_Regressor.pkl differ diff --git a/app/clients/service/pretrained_models/model_Support_Vector_Machine.pkl b/app/clients/service/pretrained_models/model_Support_Vector_Machine.pkl new file mode 100644 index 00000000..999b824c Binary files /dev/null and b/app/clients/service/pretrained_models/model_Support_Vector_Machine.pkl differ diff --git a/app/database.py b/app/database.py index 3a489f54..b5c8b948 100644 --- a/app/database.py +++ b/app/database.py @@ -7,22 +7,23 @@ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker -#Here is where the database is located -SQLALCHEMY_DATABASE_URL = "sqlite:///./sql_app.db" +# Here is where the database is located +SQLALCHEMY_DATABASE_URL = "sqlite:///./sql_app.db" -#Open up a connection so that we are able to use the database +# Open up a connection so that we are able to use the database engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}) -#Bind the engine just created +# Bind the engine just created SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) -#Create an object of our database so as to control the database +# Create an object of our database so as to control the database Base = declarative_base() + def get_db(): """ Creates a database session and ensures it's closed after use. - + Yields: Session: SQLAlchemy database session """ diff --git a/app/main.py b/app/main.py index a8e8fa7f..66746785 100644 --- a/app/main.py +++ b/app/main.py @@ -5,27 +5,32 @@ """ from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware + from app import models -from app.database import engine -from app.clients.router import router as clients_router from app.auth.router import router as auth_router -from fastapi.middleware.cors import CORSMiddleware +from app.clients.router import router as clients_router +from app.clients.service.ml_models_router import router as ml_models_router +from app.database import engine # Initialize database tables models.Base.metadata.create_all(bind=engine) # Create FastAPI application -app = FastAPI(title="Case Management API", description="API for managing client cases", version="1.0.0") +app = FastAPI( + title="Case Management API", description="API for managing client cases", version="1.0.0" +) # Include routers app.include_router(auth_router) app.include_router(clients_router) +app.include_router(ml_models_router) # Configure CORS middleware app.add_middleware( CORSMiddleware, - allow_origins=["*"], # Allows all origins - allow_methods=["*"], # Allows all methods - allow_headers=["*"], # Allows all headers + allow_origins=["*"], # Allows all origins + allow_methods=["*"], # Allows all methods + allow_headers=["*"], # Allows all headers allow_credentials=True, ) diff --git a/app/models.py b/app/models.py index df778348..870210e5 100644 --- a/app/models.py +++ b/app/models.py @@ -3,11 +3,14 @@ Contains the Client model for storing client information in the database. """ -from app.database import Base -from sqlalchemy import Column, Integer, String, Boolean, ForeignKey, CheckConstraint, Enum -from sqlalchemy.orm import relationship import enum +from sqlalchemy import Boolean, CheckConstraint, Column, Enum, ForeignKey, Integer, String +from sqlalchemy.orm import relationship + +from app.database import Base + + class UserRole(str, enum.Enum): admin = "admin" case_worker = "case_worker" @@ -24,46 +27,61 @@ class User(Base): cases = relationship("ClientCase", back_populates="user") + class Client(Base): """ Client model representing client data in the database. """ + __tablename__ = "clients" id = Column(Integer, primary_key=True, autoincrement=True) - age = Column(Integer, CheckConstraint('age >= 18')) + age = Column(Integer, CheckConstraint("age >= 18")) gender = Column(Integer, CheckConstraint("gender = 1 OR gender = 2")) - work_experience = Column(Integer, CheckConstraint('work_experience >= 0')) - canada_workex = Column(Integer, CheckConstraint('canada_workex >= 0')) - dep_num = Column(Integer, CheckConstraint('dep_num >= 0')) + work_experience = Column(Integer, CheckConstraint("work_experience >= 0")) + canada_workex = Column(Integer, CheckConstraint("canada_workex >= 0")) + dep_num = Column(Integer, CheckConstraint("dep_num >= 0")) canada_born = Column(Boolean) citizen_status = Column(Boolean) - level_of_schooling = Column(Integer, CheckConstraint('level_of_schooling >= 1 AND level_of_schooling <= 14')) + level_of_schooling = Column( + Integer, CheckConstraint("level_of_schooling >= 1 AND level_of_schooling <= 14") + ) fluent_english = Column(Boolean) - reading_english_scale = Column(Integer, CheckConstraint('reading_english_scale >= 0 AND reading_english_scale <= 10')) - speaking_english_scale = Column(Integer, CheckConstraint('speaking_english_scale >= 0 AND speaking_english_scale <= 10')) - writing_english_scale = Column(Integer, CheckConstraint('writing_english_scale >= 0 AND writing_english_scale <= 10')) - numeracy_scale = Column(Integer, CheckConstraint('numeracy_scale >= 0 AND numeracy_scale <= 10')) - computer_scale = Column(Integer, CheckConstraint('computer_scale >= 0 AND computer_scale <= 10')) + reading_english_scale = Column( + Integer, CheckConstraint("reading_english_scale >= 0 AND reading_english_scale <= 10") + ) + speaking_english_scale = Column( + Integer, CheckConstraint("speaking_english_scale >= 0 AND speaking_english_scale <= 10") + ) + writing_english_scale = Column( + Integer, CheckConstraint("writing_english_scale >= 0 AND writing_english_scale <= 10") + ) + numeracy_scale = Column( + Integer, CheckConstraint("numeracy_scale >= 0 AND numeracy_scale <= 10") + ) + computer_scale = Column( + Integer, CheckConstraint("computer_scale >= 0 AND computer_scale <= 10") + ) transportation_bool = Column(Boolean) caregiver_bool = Column(Boolean) - housing = Column(Integer, CheckConstraint('housing >= 1 AND housing <= 10')) - income_source = Column(Integer, CheckConstraint('income_source >= 1 AND income_source <= 11')) + housing = Column(Integer, CheckConstraint("housing >= 1 AND housing <= 10")) + income_source = Column(Integer, CheckConstraint("income_source >= 1 AND income_source <= 11")) felony_bool = Column(Boolean) attending_school = Column(Boolean) currently_employed = Column(Boolean) substance_use = Column(Boolean) - time_unemployed = Column(Integer, CheckConstraint('time_unemployed >= 0')) + time_unemployed = Column(Integer, CheckConstraint("time_unemployed >= 0")) need_mental_health_support_bool = Column(Boolean) cases = relationship("ClientCase", back_populates="client") + class ClientCase(Base): __tablename__ = "client_cases" client_id = Column(Integer, ForeignKey("clients.id"), primary_key=True) user_id = Column(Integer, ForeignKey("users.id"), primary_key=True) - + employment_assistance = Column(Boolean) life_stabilization = Column(Boolean) retention_services = Column(Boolean) @@ -71,7 +89,7 @@ class ClientCase(Base): employment_related_financial_supports = Column(Boolean) employer_financial_supports = Column(Boolean) enhanced_referrals = Column(Boolean) - success_rate = Column(Integer, CheckConstraint('success_rate >= 0 AND success_rate <= 100')) + success_rate = Column(Integer, CheckConstraint("success_rate >= 0 AND success_rate <= 100")) client = relationship("Client", back_populates="cases") user = relationship("User", back_populates="cases") diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 00000000..6dd3a8a7 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,40 @@ +# Global options: +[mypy] +# Enable strict mode for better type checking +strict = false + +# Encourage but don't require type annotations +disallow_untyped_defs = false +# At least annotate return types +disallow_incomplete_defs = true +# Type check the body of functions without annotations +check_untyped_defs = true + +# Allow calling functions without type hints +disallow_untyped_calls = false + +# Be more permissive with 'Any' usage +disallow_any_unimported = false +disallow_any_explicit = false +disallow_any_generics = false +# Still warn about returning Any +warn_return_any = true + +# Don't enforce strict subclassing +disallow_subclassing_any = false + +# Still catch obvious issues +warn_redundant_casts = true +warn_unused_ignores = true +warn_no_return = false +warn_unreachable = true + +# Allow redefinitions in some contexts +allow_redefinition = true + +# Module import settings +ignore_missing_imports = true +follow_imports = silent + +# Performance improvements +cache_dir = .mypy_cache \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..755e2239 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,89 @@ +# Overall Project Configuration +[project] +name = "Common_Assessment_Tool" +version = "1.0.0" +authors = [{name = "David Treadwell", email = "treadwell.d@northeastern.edu"}, {name = "Fran Li", email = "li.fengr@northeastern.edu"}, {name = "Steve Chen", email = "chen.steve2@northeastern.edu"}] +readme = "README.md" +license = "MIT" +dependencies = [ + 'annotated-types==0.7.0', + 'anyio==4.4.0', + 'certifi==2024.7.4', + 'click==8.1.7', + 'dnspython==2.6.1', + 'email_validator==2.2.0', + 'fastapi==0.112.2', + 'fastapi-cli==0.0.5', + 'h11==0.14.0', + 'httpcore==1.0.5', + 'httptools==0.6.1', + 'httpx==0.27.2', + 'idna==3.8', + 'Jinja2==3.1.4', + 'joblib==1.4.2', + 'markdown-it-py==3.0.0', + 'MarkupSafe==2.1.5', + 'mdurl==0.1.2', + 'numpy==2.1.0', + 'pandas==2.2.2', + 'pydantic==2.8.2', + 'pydantic_core==2.20.1', + 'Pygments==2.18.0', + 'python-dateutil==2.9.0.post0', + 'python-dotenv==1.0.1', + 'python-multipart==0.0.9', + 'pytz==2024.1', + 'PyYAML==6.0.2', + 'rich==13.8.0', + 'scikit-learn==1.5.1', + 'scipy==1.14.1', + 'shellingham==1.5.4', + 'six==1.16.0', + 'sniffio==1.3.1', + 'starlette==0.38.2', + 'threadpoolctl==3.5.0', + 'typer==0.12.5', + 'typing_extensions==4.12.2', + 'tzdata==2024.1', + 'uvicorn==0.30.6', + 'watchfiles==0.23.0', + 'websockets==13.0' +] +requires-python = ">=3.10" + +# Project urls +[project.urls] +Repository = "https://github.com/dtread4/CommonAssessmentTool" + +# Build system +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +# Packages finder +[tool.setuptools.packages.find] +where = ["."] + +# Optional dependency configuration +[project.optional-dependencies] +dev = ["black", "isort"] +extra = ['uvloop==0.20.0'] + +# Black Configuration +[tool.black] +line-length = 100 +include = '\.pyi?$' +skip-magic-trailing-comma = true +target-version = ['py310'] + +# isort Configuration +[tool.isort] +profile = "black" +line_length = 100 +known_first_party = ["app"] +multi_line_output = 3 +force_grid_wrap = 0 +combine_as_imports = true +include_trailing_comma = true +force_single_line = false +skip = ["venv", ".venv", "migrations"] \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 93d35fbf..c5910a3b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -141,4 +141,4 @@ webencodings==0.5.1 websocket-client==1.5.1 websockets==11.0.3 widgetsnbextension==4.0.7 - +mypy=1.15.1