diff --git a/.env b/.env new file mode 100644 index 00000000..14d7b937 --- /dev/null +++ b/.env @@ -0,0 +1,3 @@ +SECRET_KEY = "" +ALGORITHM = "HS256" +ACCESS_TOKEN_EXPIRE_MINUTES = 30 \ No newline at end of file diff --git a/.github/workflows/cd.yml b/.github/workflows/cd.yml deleted file mode 100644 index b801c2d3..00000000 --- a/.github/workflows/cd.yml +++ /dev/null @@ -1,47 +0,0 @@ -name: CI/CD Pipeline - -on: - push: - branches: [master, main] - pull_request: - branches: [master, main] - -jobs: - test: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v4 - - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: "3.11" - - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install -r requirements.txt - - - name: Run Tests - run: | - python -m pytest tests/ - - deploy: - needs: test # This ensures deploy only runs if tests pass - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v4 - - - name: Build Docker image - run: docker build -t common-assessment-tool . - - - name: Run Docker container - run: | - docker run -d -p 8000:8000 common-assessment-tool - sleep 10 # Wait for container to start - - - name: Test Docker container - run: | - curl http://localhost:8000/docs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml deleted file mode 100644 index 30c81bdb..00000000 --- a/.github/workflows/ci.yml +++ /dev/null @@ -1,42 +0,0 @@ -name: Python CI Pipeline - -on: - push: - branches: [master, main] - pull_request: - branches: [master, main] - -jobs: - test: - runs-on: ubuntu-latest # Use the latest Ubuntu runner - - steps: - - name: Checkout Code - uses: actions/checkout@v4 # Checkout the repository - - - name: Set up Python - uses: actions/setup-python@v5 # Set up Python environment - with: - python-version: "3.11" - - - name: Install dependencies - run: | - python -m pip install --upgrade pip # Upgrade pip to the latest version - pip install setuptools wheel - pip install -r requirements.txt # Install dependencies from requirements.txt - pip install pylint pytest - - - name: Run Tests - run: | - python -m pytest tests/ - - - name: Print Success Message - run: | - echo "CI Pipeline completed successfully!" - echo "========================" - echo "✓ Code checked out" - echo "✓ Python environment set up" - echo "✓ Dependencies installed" - echo "✓ Tests executed" - echo "✓ Linting completed" - echo "========================" diff --git a/.github/workflows/ci_cd.yml b/.github/workflows/ci_cd.yml new file mode 100644 index 00000000..bbc9c21b --- /dev/null +++ b/.github/workflows/ci_cd.yml @@ -0,0 +1,159 @@ +name: CI-CD Pipeline + +on: + push: + branches: [ master, main, dev ] + + pull_request: + branches: [ master, main ] + +jobs: + test_code: + runs-on: ubuntu-latest # Use the latest Ubuntu runner + + steps: + - name: Checkout Code + uses: actions/checkout@v4 # Checkout the repository + + - name: Set up Python + uses: actions/setup-python@v5 # Set up Python environment + with: + python-version: "3.11" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip # Upgrade pip to the latest version + pip install setuptools wheel + pip install -r requirements.txt # Install dependencies from requirements.txt + pip install pylint pytest black isort + + - name: Run Code Formatting with Black # Format the entire repo + run: | + black . + + - name: Run Code Formatting with isort + run: | + isort . + + - name: Run Linter + run: | + pylint ./app ./tests + + - name: Run Tests + run: | + python -m pytest tests/ + + - name: Print Success Message + if: success() + run: | + echo "Confirmed code formatting and functionality successfully!" + echo "========================" + echo "✓ Code checked out" + echo "✓ Python environment set up" + echo "✓ Dependencies installed" + echo "✓ Code formatting checked with Black" + echo "✓ Code formatting checked with isort" + echo "✓ Linting completed" + echo "✓ Tests completed" + echo "========================" + + test_docker_setup: + runs-on: ubuntu-latest # Use the latest Ubuntu runner + + steps: + - name: Checkout Code + uses: actions/checkout@v4 # Checkout the repository + + - name: Set up Python + uses: actions/setup-python@v5 # Set up Python environment + with: + python-version: "3.11" + + - name: Hadolint Action Check Dockerfile Syntax + uses: hadolint/hadolint-action@v3.1.0 + with: + dockerfile: ./Dockerfile + + - name: Build Docker Image + run: | + docker build -t common-assessment-tool . + + - name: Run Docker container + run: | + docker run -d --name common-assessment-container -p 8000:8000 common-assessment-tool + sleep 10 + + - name: Test Docker container + run: | + curl --fail http://localhost:8000/docs || { + echo "Health check failed" + docker logs common-assessment-tool + exit 1 + } + + - name: Stop Docker container + run: docker stop common-assessment-container + + - name: Print Success Message + if: success() + run: | + echo "Confirmed Docker image can be built and run successfully!" + echo "========================" + echo "✓ Code checked out" + echo "✓ Python environment set up" + echo "✓ Docker file syntax checked" + echo "✓ Docker container built" + echo "✓ Docker container run" + echo "✓ Docker container tested with good health" + echo "✓ Docker container stopped" + echo "========================" + + # Source for how deployment was set up + # https://dev.to/s3cloudhub/automate-docker-deployments-push-your-images-to-ec2-with-github-actions-3a3j + deploy: + needs: [test_code, test_docker_setup] # This ensures deploy only runs if tests pass + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Login to DockerHub + run: | + echo "${{ secrets.DOCKERHUB_TOKEN }}" | docker login -u ${{ secrets.DOCKERHUB_USERNAME }} --password-stdin + + - name: Build Docker image + run: | + docker build -t common_assessment_tool . + + - name: Push image to Docker Hub + run: | + docker tag common_assessment_tool ${{ secrets.DOCKERHUB_USERNAME }}/common_assessment_tool:latest + docker push ${{ secrets.DOCKERHUB_USERNAME }}/common_assessment_tool:latest + + - name: Install SSH Key + uses: webfactory/ssh-agent@v0.9.1 + with: + ssh-private-key: ${{ secrets.SSH_PRIVATE_KEY }} + + - name: Deploy Docker image to EC2 + run: | + ssh -o StrictHostKeyChecking=no ${{ secrets.EC2_USER }}@${{ secrets.EC2_INSTANCE_IP }} << 'EOF' + docker pull ${{ secrets.DOCKERHUB_USERNAME }}/common_assessment_tool:latest + docker stop $(docker ps -a -q) || true + docker rm $(docker ps -a -q) || true + docker run -d -p 8000:8000 ${{ secrets.DOCKERHUB_USERNAME }}/common_assessment_tool:latest + EOF + + - name: Print Success Message + if: success() + run: | + echo "Deployed updates to EC2 instance successfully!" + echo "========================" + echo "✓ Code checked out" + echo "✓ Logged in to Docker Hub" + echo "✓ Docker image built" + echo "✓ Docker image pushed to Docker Hub" + echo "✓ Installed EC2 SSH key" + echo "✓ Docker image deployed successfully" + echo "========================" \ No newline at end of file diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 00000000..dee81e31 --- /dev/null +++ b/.pylintrc @@ -0,0 +1,637 @@ +[MAIN] + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + +# Clear in-memory caches upon conclusion of linting. Useful if running pylint +# in a server-like mode. +clear-cache-post-run=no + +# Load and enable all available extensions. Use --list-extensions to see a list +# all available extensions. +#enable-all-extensions= + +# In error mode, messages with a category besides ERROR or FATAL are +# suppressed, and no reports are done by default. Error mode is compatible with +# disabling specific errors. +#errors-only= + +# Always return a 0 (non-error) status code, even if lint errors are found. +# This is primarily useful in continuous integration scripts. +#exit-zero= + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. +extension-pkg-allow-list= + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. (This is an alternative name to extension-pkg-allow-list +# for backward compatibility.) +extension-pkg-whitelist= + +# Return non-zero exit code if any of these messages/categories are detected, +# even if score is above --fail-under value. Syntax same as enable. Messages +# specified are enabled, while categories only check already-enabled messages. +fail-on= + +# Specify a score threshold under which the program will exit with error. +fail-under=10 + +# Interpret the stdin as a python script, whose filename needs to be passed as +# the module_or_package argument. +#from-stdin= + +# Files or directories to be skipped. They should be base names, not paths. +ignore=CVS + +# Add files or directories matching the regular expressions patterns to the +# ignore-list. The regex matches against paths and can be in Posix or Windows +# format. Because '\\' represents the directory delimiter on Windows systems, +# it can't be used as an escape character. +ignore-paths= + +# Files or directories matching the regular expression patterns are skipped. +# The regex matches against base names, not paths. The default value ignores +# Emacs file locks +ignore-patterns=^\.# + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis). It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules= + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the +# number of processors available to use, and will cap the count on Windows to +# avoid hangs. +jobs=1 + +# Control the amount of potential inferred values when inferring a single +# object. This can help the performance when dealing with large functions or +# complex, nested conditions. +limit-inference-results=100 + +# List of plugins (as comma separated values of python module names) to load, +# usually to register additional checkers. +load-plugins= + +# Pickle collected data for later comparisons. +persistent=yes + +# Minimum Python version to use for version dependent checks. Will default to +# the version used to run pylint. +py-version=3.10 + +# Discover python modules and packages in the file system subtree. +recursive=no + +# Add paths to the list of the source roots. Supports globbing patterns. The +# source root is an absolute path or a path relative to the current working +# directory used to determine a package namespace for modules located under the +# source root. +source-roots= + +# When enabled, pylint would attempt to guess common misconfiguration and emit +# user-friendly hints instead of false-positive error messages. +suggestion-mode=yes + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + +# In verbose mode, extra non-checker-related info will be displayed. +#verbose= + + +[BASIC] + +# Naming style matching correct argument names. +argument-naming-style=snake_case + +# Regular expression matching correct argument names. Overrides argument- +# naming-style. If left empty, argument names will be checked with the set +# naming style. +#argument-rgx= + +# Naming style matching correct attribute names. +attr-naming-style=snake_case + +# Regular expression matching correct attribute names. Overrides attr-naming- +# style. If left empty, attribute names will be checked with the set naming +# style. +#attr-rgx= + +# Bad variable names which should always be refused, separated by a comma. +bad-names=foo, + bar, + baz, + toto, + tutu, + tata + +# Bad variable names regexes, separated by a comma. If names match any regex, +# they will always be refused +bad-names-rgxs= + +# Naming style matching correct class attribute names. +class-attribute-naming-style=any + +# Regular expression matching correct class attribute names. Overrides class- +# attribute-naming-style. If left empty, class attribute names will be checked +# with the set naming style. +#class-attribute-rgx= + +# Naming style matching correct class constant names. +class-const-naming-style=UPPER_CASE + +# Regular expression matching correct class constant names. Overrides class- +# const-naming-style. If left empty, class constant names will be checked with +# the set naming style. +#class-const-rgx= + +# Naming style matching correct class names. +class-naming-style=PascalCase + +# Regular expression matching correct class names. Overrides class-naming- +# style. If left empty, class names will be checked with the set naming style. +#class-rgx= + +# Naming style matching correct constant names. +const-naming-style=UPPER_CASE + +# Regular expression matching correct constant names. Overrides const-naming- +# style. If left empty, constant names will be checked with the set naming +# style. +#const-rgx= + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=-1 + +# Naming style matching correct function names. +function-naming-style=snake_case + +# Regular expression matching correct function names. Overrides function- +# naming-style. If left empty, function names will be checked with the set +# naming style. +#function-rgx= + +# Good variable names which should always be accepted, separated by a comma. +good-names=i, + j, + k, + ex, + Run, + _ + +# Good variable names regexes, separated by a comma. If names match any regex, +# they will always be accepted +good-names-rgxs= + +# Include a hint for the correct naming format with invalid-name. +include-naming-hint=no + +# Naming style matching correct inline iteration names. +inlinevar-naming-style=any + +# Regular expression matching correct inline iteration names. Overrides +# inlinevar-naming-style. If left empty, inline iteration names will be checked +# with the set naming style. +#inlinevar-rgx= + +# Naming style matching correct method names. +method-naming-style=snake_case + +# Regular expression matching correct method names. Overrides method-naming- +# style. If left empty, method names will be checked with the set naming style. +#method-rgx= + +# Naming style matching correct module names. +module-naming-style=snake_case + +# Regular expression matching correct module names. Overrides module-naming- +# style. If left empty, module names will be checked with the set naming style. +#module-rgx= + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=^_ + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +# These decorators are taken in consideration only for invalid-name. +property-classes=abc.abstractproperty + +# Regular expression matching correct type alias names. If left empty, type +# alias names will be checked with the set naming style. +#typealias-rgx= + +# Regular expression matching correct type variable names. If left empty, type +# variable names will be checked with the set naming style. +#typevar-rgx= + +# Naming style matching correct variable names. +variable-naming-style=snake_case + +# Regular expression matching correct variable names. Overrides variable- +# naming-style. If left empty, variable names will be checked with the set +# naming style. +#variable-rgx= + + +[CLASSES] + +# Warn about protected attribute access inside special methods +check-protected-access-in-special-methods=no + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp, + asyncSetUp, + __post_init__ + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict,_fields,_replace,_source,_make,os._exit + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs + + +[DESIGN] + +# List of regular expressions of class ancestor names to ignore when counting +# public methods (see R0903) +exclude-too-few-public-methods= + +# List of qualified class names to ignore when counting class parents (see +# R0901) +ignored-parents= + +# Maximum number of arguments for function / method. +max-args=5 + +# Maximum number of attributes for a class (see R0902). +max-attributes=7 + +# Maximum number of boolean expressions in an if statement (see R0916). +max-bool-expr=5 + +# Maximum number of branch for function / method body. +max-branches=12 + +# Maximum number of locals for function / method body. +max-locals=15 + +# Maximum number of parents for a class (see R0901). +max-parents=7 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + +# Maximum number of return / yield for function / method body. +max-returns=6 + +# Maximum number of statements in function / method body. +max-statements=50 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=2 + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when caught. +overgeneral-exceptions=builtins.BaseException,builtins.Exception + + +[FORMAT] + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^\s*(# )??$ + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + +# Maximum number of characters on a single line. +max-line-length=100 + +# Maximum number of lines in a module. +max-module-lines=1000 + +# Allow the body of a class to be on the same line as the declaration if body +# contains single statement. +single-line-class-stmt=no + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=no + + +[IMPORTS] + +# List of modules that can be imported at any level, not just the top level +# one. +allow-any-import-level= + +# Allow explicit reexports by alias from a package __init__. +allow-reexport-from-package=no + +# Allow wildcard imports from modules that define __all__. +allow-wildcard-with-all=no + +# Deprecated modules which should not be used, separated by a comma. +deprecated-modules= + +# Output a graph (.gv or any supported image format) of external dependencies +# to the given file (report RP0402 must not be disabled). +ext-import-graph= + +# Output a graph (.gv or any supported image format) of all (i.e. internal and +# external) dependencies to the given file (report RP0402 must not be +# disabled). +import-graph= + +# Output a graph (.gv or any supported image format) of internal dependencies +# to the given file (report RP0402 must not be disabled). +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant + +# Couples of modules and preferred modules, separated by a comma. +preferred-modules= + + +[LOGGING] + +# The type of string formatting that logging methods do. `old` means using % +# formatting, `new` is for `{}` formatting. +logging-format-style=old + +# Logging modules to check that the string format arguments are in logging +# function parameter format. +logging-modules=logging + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE, +# UNDEFINED. +confidence=HIGH, + CONTROL_FLOW, + INFERENCE, + INFERENCE_FAILURE, + UNDEFINED + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once). You can also use "--disable=all" to +# disable everything first and then re-enable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use "--disable=all --enable=classes +# --disable=W". +disable=raw-checker-failed, + bad-inline-option, + locally-disabled, + file-ignored, + suppressed-message, + useless-suppression, + deprecated-pragma, + use-symbolic-message-instead, + use-implicit-booleaness-not-comparison-to-string, + use-implicit-booleaness-not-comparison-to-zero, + missing-module-docstring, + missing-class-docstring, + missing-function-docstring + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +enable= + + +[METHOD_ARGS] + +# List of qualified names (i.e., library.method) which require a timeout +# parameter e.g. 'requests.api.get,requests.api.post' +timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME, + XXX, + TODO + +# Regular expression of note tags to take in consideration. +notes-rgx= + + +[REFACTORING] + +# Maximum number of nested blocks for function / method body +max-nested-blocks=5 + +# Complete name of functions that never returns. When checking for +# inconsistent-return-statements if a never returning function is called then +# it will be considered as an explicit return statement and no message will be +# printed. +never-returning-functions=sys.exit,argparse.parse_error + + +[REPORTS] + +# Python expression which should return a score less than or equal to 10. You +# have access to the variables 'fatal', 'error', 'warning', 'refactor', +# 'convention', and 'info' which contain the number of messages in each +# category, as well as 'statement' which is the total number of statements +# analyzed. This score is used by the global evaluation report (RP0004). +evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details. +msg-template= + +# Set the output format. Available formats are: text, parseable, colorized, +# json2 (improved json format), json (old json format) and msvs (visual +# studio). You can also give a reporter class, e.g. +# mypackage.mymodule.MyReporterClass. +#output-format=colorized + +# Tells whether to display a full report or only the messages. +reports=yes + +# Activate the evaluation score. +score=yes + + +[SIMILARITIES] + +# Comments are removed from the similarity computation +ignore-comments=yes + +# Docstrings are removed from the similarity computation +ignore-docstrings=yes + +# Imports are removed from the similarity computation +ignore-imports=yes + +# Signatures are removed from the similarity computation +ignore-signatures=yes + +# Minimum lines number of a similarity. +min-similarity-lines=4 + + +[SPELLING] + +# Limits count of emitted suggestions for spelling mistakes. +max-spelling-suggestions=4 + +# Spelling dictionary name. No available dictionaries : You need to install +# both the python package and the system dependency for enchant to work. +spelling-dict= + +# List of comma separated words that should be considered directives if they +# appear at the beginning of a comment and should not be checked. +spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy: + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains the private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to the private dictionary (see the +# --spelling-private-dict-file option) instead of raising a message. +spelling-store-unknown-words=no + + +[STRING] + +# This flag controls whether inconsistent-quotes generates a warning when the +# character used as a quote delimiter is used inconsistently within a module. +check-quote-consistency=no + +# This flag controls whether the implicit-str-concat should generate a warning +# on implicit string concatenation in sequences defined over several lines. +check-str-concat-over-line-jumps=no + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + +# Tells whether to warn about missing members when the owner of the attribute +# is inferred to be None. +ignore-none=yes + +# This flag controls whether pylint should warn about no-member and similar +# checks whenever an opaque object is returned when inferring. The inference +# can return multiple potential results while evaluating a Python object, but +# some branches might not be evaluated, which results in partial inference. In +# that case, it might be useful to still emit no-member and other checks for +# the rest of the inferred objects. +ignore-on-opaque-inference=yes + +# List of symbolic message names to ignore for Mixin members. +ignored-checks-for-mixins=no-member, + not-async-context-manager, + not-context-manager, + attribute-defined-outside-init + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace + +# Show a hint with possible names when a member name was not found. The aspect +# of finding the hint is based on edit distance. +missing-member-hint=yes + +# The minimum edit distance a name should have in order to be considered a +# similar match for a missing member name. +missing-member-hint-distance=1 + +# The total number of similar names that should be taken in consideration when +# showing a hint for a missing member. +missing-member-max-choices=1 + +# Regex pattern to define which classes are considered mixins. +mixin-class-rgx=.*[Mm]ixin + +# List of decorators that change the signature of a decorated function. +signature-mutators= + + +[VARIABLES] + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid defining new builtins when possible. +additional-builtins= + +# Tells whether unused global variables should be treated as a violation. +allow-global-unused-variables=yes + +# List of names allowed to shadow builtins +allowed-redefined-builtins= + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_, + _cb + +# A regular expression matching the name of dummy variables (i.e. expected to +# not be used). +dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ + +# Argument names that match this expression will be ignored. +ignored-argument-names=_.*|^ignored_|^unused_ + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io diff --git a/README.md b/README.md index b34d6d6b..99eb6714 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -Team TicTech +Team SuperSonics via TicTech Project -- Feature Development Backend: Create CRUD API's for Client @@ -22,15 +22,22 @@ This also has an API file to interact with the front end, and logic in order to -------------------------How to Use------------------------- 1. In the virtual environment you've created for this project, install all dependencies in requirements.txt (pip install -r requirements.txt) -2. Run the app (uvicorn app.main:app --reload) +2. Create a .env file with the following fields: +```markdown +SECRET_KEY = "" +ALGORITHM = "HS256" +ACCESS_TOKEN_EXPIRE_MINUTES = 30 +``` -3. Load data into database (python initialize_data.py) +3. Run the app (uvicorn app.main:app --reload) -4. Go to SwaggerUI (http://127.0.0.1:8000/docs) +4. Go to SwaggerUI (http://127.0.0.1:8000/docs) -4. Log in as admin (username: admin password: admin123) +5. Load data into database (python initialize_data.py) (if receiving an error, make sure the app is running and open, then try again) -5. Click on each endpoint to use +6. Log in as admin (username: admin password: admin123) + +7. Click on each endpoint to use -Create User (Only users in admin role can create new users. The role field needs to be either "admin" or "case_worker") -Get clients (Display all the clients that are in the database) @@ -55,3 +62,32 @@ This also has an API file to interact with the front end, and logic in order to -Create case assignment (Allow authorized users to create a new case assignment.) +## Docker Instructions +1. Follow installation guide from Docker: https://www.docker.com/blog/how-to-dockerize-your-python-applications/ +2. WINDOWS-SPECIFIC: Ensure virtualization is enabled in your system BIOS, or Docker cannot run +3. Open the Docker Desktop application +4. In a command prompt, navigate to the CommonAssessmentTool repo's directory on your machine (assumes you already cloned from GitHub) and run the command below (make sure the period at the end is included!): +``` +docker build -t common_assessment_tool . +``` +5. Now run with the following Docker command: +``` +docker run --rm -p 8000:8000 common_assessment_tool +``` +6. Follow the steps to run the Swagger UI as described above (clicking link in step 5 should take you to the UI) +7. To run using Docker-Compose in the foreground, run the command below in the CommonAssessmentTool repo's directory +``` +docker compose up +``` +8. To run using Docker-Compose in the background, run the command below in the CommonAssessmentTool repo's directory +``` +docker compose up -d +``` +9. If running using the background command, you can stop the container gracefully with the following command: +``` +docker compose stop +``` + +## Access public address +Backend application is now deployed to the AWS Cloud. +Access the backend application from the endpoint by clicking: http://ec2-54-165-172-227.compute-1.amazonaws.com:8000/docs diff --git a/app/auth/router.py b/app/auth/router.py index 229ee71d..ed66c7a0 100644 --- a/app/auth/router.py +++ b/app/auth/router.py @@ -1,56 +1,72 @@ +# pylint: disable=unused-argument, no-self-argument, too-few-public-methods +import os from datetime import datetime, timedelta from typing import Optional + +from dotenv import load_dotenv + from fastapi import APIRouter, Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from jose import JWTError, jwt +from passlib.context import CryptContext +from pydantic import BaseModel, Field, field_validator, ConfigDict from sqlalchemy.orm import Session + +from dotenv import load_dotenv + from app.database import get_db from app.models import User, UserRole -from passlib.context import CryptContext -from pydantic import BaseModel, Field, validator router = APIRouter(prefix="/auth", tags=["authentication"]) + class UserCreate(BaseModel): username: str = Field(..., min_length=3, max_length=50) email: str password: str role: UserRole - @validator('role') + @field_validator("role") + @classmethod def validate_role(cls, v): - if v not in [UserRole.admin, UserRole.case_worker]: - raise ValueError('Role must be either admin or case_worker') + if v not in [UserRole.ADMIN, UserRole.CASE_WORKER]: + raise ValueError("Role must be either admin or case_worker") return v + class UserResponse(BaseModel): username: str email: str role: UserRole - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) + -# Configuration -SECRET_KEY = "your-secret-key-here" -ALGORITHM = "HS256" -ACCESS_TOKEN_EXPIRE_MINUTES = 30 +# Load configuration from .env +load_dotenv() +SECRET_KEY = os.getenv("SECRET_KEY") +ALGORITHM = os.getenv("ALGORITHM") +ACCESS_TOKEN_EXPIRE_MINUTES = os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES") pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/token") + def verify_password(plain_password: str, hashed_password: str) -> bool: return pwd_context.verify(plain_password, hashed_password) + def get_password_hash(password: str) -> str: return pwd_context.hash(password) + def authenticate_user(db: Session, username: str, password: str) -> Optional[User]: user = db.query(User).filter(User.username == username).first() if not user or not verify_password(password, user.hashed_password): return None return user + def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): to_encode = data.copy() if expires_delta: @@ -61,9 +77,9 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) return encoded_jwt + async def get_current_user( - token: str = Depends(oauth2_scheme), - db: Session = Depends(get_db) + token: str = Depends(oauth2_scheme), db: Session = Depends(get_db) ) -> User: credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -75,26 +91,27 @@ async def get_current_user( username: str = payload.get("sub") if username is None: raise credentials_exception - except JWTError: - raise credentials_exception - + except JWTError as exc: + raise credentials_exception from exc + user = db.query(User).filter(User.username == username).first() if user is None: raise credentials_exception return user + def get_admin_user(current_user: User = Depends(get_current_user)): - if current_user.role != UserRole.admin: + if current_user.role != UserRole.ADMIN: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail="Only admin users can perform this operation" + detail="Only admin users can perform this operation", ) return current_user + @router.post("/token") async def login_for_access_token( - form_data: OAuth2PasswordRequestForm = Depends(), - db: Session = Depends(get_db) + form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db) ): user = authenticate_user(db, form_data.username, form_data.password) if not user: @@ -103,31 +120,30 @@ async def login_for_access_token( detail="Incorrect username or password", headers={"WWW-Authenticate": "Bearer"}, ) - access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + access_token_expires = timedelta(minutes=int(ACCESS_TOKEN_EXPIRE_MINUTES)) access_token = create_access_token( data={"sub": user.username}, expires_delta=access_token_expires ) return {"access_token": access_token, "token_type": "bearer"} + @router.post("/users", response_model=UserResponse) async def create_user( user_data: UserCreate, current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """Create a new user (admin only)""" # Check if username exists if db.query(User).filter(User.username == user_data.username).first(): raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Username already registered" + status_code=status.HTTP_400_BAD_REQUEST, detail="Username already registered" ) - + # Check if email exists if db.query(User).filter(User.email == user_data.email).first(): raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Email already registered" + status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered" ) # Create new user @@ -135,9 +151,9 @@ async def create_user( username=user_data.username, email=user_data.email, hashed_password=get_password_hash(user_data.password), - role=user_data.role + role=user_data.role, ) - + try: db.add(db_user) db.commit() @@ -145,7 +161,4 @@ async def create_user( return db_user except Exception as e: db.rollback() - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=str(e) - ) + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) from e diff --git a/app/clients/router.py b/app/clients/router.py index 4ecc83e4..b08db093 100644 --- a/app/clients/router.py +++ b/app/clients/router.py @@ -2,43 +2,45 @@ Router module for client-related endpoints. Handles all HTTP requests for client operations including create, read, update, and delete. """ +# pylint: disable=unused-argument, too-many-arguments, too-many-positional-arguments, too-many-locals +from typing import List, Optional -from fastapi import APIRouter, Depends, HTTPException, status, Query +from fastapi import APIRouter, Depends, Query, status from sqlalchemy.orm import Session -from typing import List, Optional -from app.auth.router import get_current_user, get_admin_user -from app.models import User, UserRole -from app.database import get_db -from app.clients.service.client_service import ClientService +from app.auth.router import get_admin_user, get_current_user from app.clients.schema import ( - ClientResponse, - ClientUpdate, ClientListResponse, + ClientResponse, + ClientUpdate, ServiceResponse, - ServiceUpdate + ServiceUpdate, ) +from app.clients.service.client_service import ClientService +from app.database import get_db +from app.models import User router = APIRouter(prefix="/clients", tags=["clients"]) + @router.get("/", response_model=ClientListResponse) async def get_clients( - current_user: User = Depends(get_admin_user), + current_user: User = Depends(get_admin_user), skip: int = Query(default=0, ge=0, description="Number of records to skip"), limit: int = Query(default=50, ge=1, le=150, description="Maximum number of records to return"), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): return ClientService.get_clients(db, skip, limit) + @router.get("/{client_id}", response_model=ClientResponse) async def get_client( - client_id: int, - current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) + client_id: int, current_user: User = Depends(get_admin_user), db: Session = Depends(get_db) ): """Get a specific client by ID""" return ClientService.get_client(db, client_id) + @router.get("/search/by-criteria", response_model=List[ClientResponse]) async def get_clients_by_criteria( employment_status: Optional[bool] = None, @@ -66,7 +68,7 @@ async def get_clients_by_criteria( time_unemployed: Optional[int] = Query(None, ge=0), need_mental_health_support_bool: Optional[bool] = None, current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """Search clients by any combination of criteria""" return ClientService.get_clients_by_criteria( @@ -94,9 +96,10 @@ async def get_clients_by_criteria( attending_school=attending_school, substance_use=substance_use, time_unemployed=time_unemployed, - need_mental_health_support_bool=need_mental_health_support_bool + need_mental_health_support_bool=need_mental_health_support_bool, ) + @router.get("/search/by-services", response_model=List[ClientResponse]) async def get_clients_by_services( employment_assistance: Optional[bool] = None, @@ -107,7 +110,7 @@ async def get_clients_by_services( employer_financial_supports: Optional[bool] = None, enhanced_referrals: Optional[bool] = None, current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """Get clients filtered by multiple service statuses""" return ClientService.get_clients_by_services( @@ -118,70 +121,73 @@ async def get_clients_by_services( specialized_services=specialized_services, employment_related_financial_supports=employment_related_financial_supports, employer_financial_supports=employer_financial_supports, - enhanced_referrals=enhanced_referrals + enhanced_referrals=enhanced_referrals, ) + @router.get("/{client_id}/services", response_model=List[ServiceResponse]) async def get_client_services( - client_id: int, - current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) + client_id: int, current_user: User = Depends(get_admin_user), db: Session = Depends(get_db) ): """Get all services and their status for a specific client, including case worker info""" return ClientService.get_client_services(db, client_id) + @router.get("/search/success-rate", response_model=List[ClientResponse]) async def get_clients_by_success_rate( min_rate: int = Query(70, ge=0, le=100, description="Minimum success rate percentage"), current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """Get clients with success rate above specified threshold""" return ClientService.get_clients_by_success_rate(db, min_rate) + @router.get("/case-worker/{case_worker_id}", response_model=List[ClientResponse]) async def get_clients_by_case_worker( case_worker_id: int, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ): return ClientService.get_clients_by_case_worker(db, case_worker_id) + @router.put("/{client_id}", response_model=ClientResponse) async def update_client( client_id: int, client_data: ClientUpdate, current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """Update a client's information""" return ClientService.update_client(db, client_id, client_data) + @router.put("/{client_id}/services/{user_id}", response_model=ServiceResponse) async def update_client_services( client_id: int, user_id: int, service_update: ServiceUpdate, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ): return ClientService.update_client_services(db, client_id, user_id, service_update) + @router.post("/{client_id}/case-assignment", response_model=ServiceResponse) async def create_case_assignment( client_id: int, case_worker_id: int = Query(..., description="Case worker ID to assign"), current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """Create a new case assignment for a client with a case worker""" return ClientService.create_case_assignment(db, client_id, case_worker_id) + @router.delete("/{client_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_client( - client_id: int, - current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) + client_id: int, current_user: User = Depends(get_admin_user), db: Session = Depends(get_db) ): """Delete a client""" ClientService.delete_client(db, client_id) diff --git a/app/clients/schema.py b/app/clients/schema.py index cff28897..6160ad47 100644 --- a/app/clients/schema.py +++ b/app/clients/schema.py @@ -2,23 +2,26 @@ Pydantic models for data validation and serialization. Defines schemas for client data, predictions, and API responses. """ +# pylint: disable=too-few-public-methods +from enum import IntEnum +from typing import List, Optional # Standard library imports -from pydantic import BaseModel, Field, validator -from typing import Optional, List -from enum import IntEnum -from app.models import UserRole +from pydantic import BaseModel, Field, ConfigDict + # Enums for validation class Gender(IntEnum): MALE = 1 FEMALE = 2 + class PredictionInput(BaseModel): """ Schema for prediction input data containing all client assessment fields. Used for making predictions about client outcomes. """ + age: int gender: str work_experience: int @@ -44,6 +47,7 @@ class PredictionInput(BaseModel): time_unemployed: int need_mental_health_support_bool: str + class ClientBase(BaseModel): age: int = Field(ge=18, description="Age of client, must be 18 or older") gender: Gender = Field(description="Gender: 1 for male, 2 for female") @@ -70,7 +74,7 @@ class ClientBase(BaseModel): time_unemployed: int = Field(ge=0, description="Time unemployed in months") need_mental_health_support_bool: bool = Field(description="Needs mental health support") - class Config: + model_config = ConfigDict( json_schema_extra = { "example": { "age": 25, @@ -96,15 +100,17 @@ class Config: "currently_employed": False, "substance_use": False, "time_unemployed": 6, - "need_mental_health_support_bool": False + "need_mental_health_support_bool": False, } } + ) + class ClientResponse(ClientBase): id: int - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) + class ClientUpdate(BaseModel): age: Optional[int] = Field(None, ge=18) @@ -132,6 +138,7 @@ class ClientUpdate(BaseModel): time_unemployed: Optional[int] = Field(None, ge=0) need_mental_health_support_bool: Optional[bool] = None + class ServiceResponse(BaseModel): client_id: int user_id: int @@ -144,8 +151,8 @@ class ServiceResponse(BaseModel): enhanced_referrals: bool success_rate: int = Field(ge=0, le=100) - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) + class ServiceUpdate(BaseModel): employment_assistance: Optional[bool] = None @@ -157,6 +164,7 @@ class ServiceUpdate(BaseModel): enhanced_referrals: Optional[bool] = None success_rate: Optional[int] = Field(None, ge=0, le=100) + class ClientListResponse(BaseModel): clients: List[ClientResponse] total: int diff --git a/app/clients/service/client_service.py b/app/clients/service/client_service.py index 86c3ef4a..be6c98f2 100644 --- a/app/clients/service/client_service.py +++ b/app/clients/service/client_service.py @@ -2,15 +2,78 @@ Client service module handling all database operations for clients. Provides CRUD operations and business logic for client management. """ +# pylint: disable=arguments-differ, arguments-renamed, too-many-arguments, too-many-positional-arguments, too-many-locals +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional -from sqlalchemy.orm import Session -from sqlalchemy import and_ from fastapi import HTTPException, status -from typing import List, Optional, Dict, Any +from sqlalchemy.orm import Session + +from app.clients.schema import ClientUpdate, ServiceUpdate from app.models import Client, ClientCase, User -from app.clients.schema import ClientUpdate, ServiceUpdate, ServiceResponse -class ClientService: + +class InterfaceClientQueryService(ABC): + """Interface for client query operations""" + + @abstractmethod + def get_client(self, db: Session, client_id: int) -> Client: + """Get a specific client by ID""" + + @abstractmethod + def get_clients(self, db: Session, skip: int, limit: int) -> Dict[str, Any]: + """Get clients with optional pagination.""" + + @abstractmethod + def get_clients_by_criteria(self, db: Session, **criteria) -> List[Client]: + """Get clients filtered by any combination of criteria""" + + @abstractmethod + def get_clients_by_services(self, db: Session, **service_filters) -> List[Client]: + """Get clients filtered by multiple service statuses.""" + + @abstractmethod + def get_client_services(self, db: Session, client_id: int) -> List[ClientCase]: + """Get client's services""" + + @abstractmethod + def get_clients_by_success_rate(self, db: Session, min_rate: int) -> List[Client]: + "Get clients filtered by success rate" + + @abstractmethod + def get_clients_by_case_worker(self, db: Session, case_worker_id: int) -> List[Client]: + "Get clients filtered by case worker" + + +class InterfaceClientManagementService(ABC): + """Interface for client management operations""" + + @abstractmethod + def update_client( + self, db: Session, client_id: int, client_update: ClientUpdate + ) -> ClientUpdate: + """Update a client's information""" + + @abstractmethod + def update_client_services( + self, db: Session, client_id: int, user_id: int, service_update: ServiceUpdate + ) -> ClientCase: + """Update a client's services and outcomes for a specific caseworker""" + + @abstractmethod + def create_case_assignment( + self, db: Session, client_id: int, case_worker_id: int + ) -> ClientCase: + """Create a new case assignment""" + + @abstractmethod + def delete_client(self, db: Session, client_id: int) -> None: + """Delete a client and their associated records""" + + +class ClientQueryService(InterfaceClientQueryService): + """Implementation of client query service""" + @staticmethod def get_client(db: Session, client_id: int): """Get a specific client by ID""" @@ -18,7 +81,7 @@ def get_client(db: Session, client_id: int): if not client: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Client with id {client_id} not found" + detail=f"Client with id {client_id} not found", ) return client @@ -30,15 +93,13 @@ def get_clients(db: Session, skip: int = 0, limit: int = 50): """ if skip < 0: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Skip value cannot be negative" + status_code=status.HTTP_400_BAD_REQUEST, detail="Skip value cannot be negative" ) if limit < 1: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Limit must be greater than 0" + status_code=status.HTTP_400_BAD_REQUEST, detail="Limit must be greater than 0" ) - + clients = db.query(Client).offset(skip).limit(limit).all() total = db.query(Client).count() return {"clients": clients, "total": total} @@ -69,133 +130,111 @@ def get_clients_by_criteria( attending_school: Optional[bool] = None, substance_use: Optional[bool] = None, time_unemployed: Optional[int] = None, - need_mental_health_support_bool: Optional[bool] = None + need_mental_health_support_bool: Optional[bool] = None, ): """Get clients filtered by any combination of criteria""" - query = db.query(Client) - - if education_level is not None and not (1 <= education_level <= 14): + if education_level is not None and not 1 <= education_level <= 14: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Education level must be between 1 and 14" + detail="Education level must be between 1 and 14", ) - + if age_min is not None and age_min < 18: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Minimum age must be at least 18" + status_code=status.HTTP_400_BAD_REQUEST, detail="Minimum age must be at least 18" ) if gender is not None and gender not in [1, 2]: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Gender must be 1 or 2" + status_code=status.HTTP_400_BAD_REQUEST, detail="Gender must be 1 or 2" ) # Apply filters for non-None values - if employment_status is not None: - query = query.filter(Client.currently_employed == employment_status) - if age_min is not None: - query = query.filter(Client.age >= age_min) - if gender is not None: - query = query.filter(Client.gender == gender) - if education_level is not None: - query = query.filter(Client.level_of_schooling == education_level) - if work_experience is not None: - query = query.filter(Client.work_experience == work_experience) - if canada_workex is not None: - query = query.filter(Client.canada_workex == canada_workex) - if dep_num is not None: - query = query.filter(Client.dep_num == dep_num) - if canada_born is not None: - query = query.filter(Client.canada_born == canada_born) - if citizen_status is not None: - query = query.filter(Client.citizen_status == citizen_status) - if fluent_english is not None: - query = query.filter(Client.fluent_english == fluent_english) - if reading_english_scale is not None: - query = query.filter(Client.reading_english_scale == reading_english_scale) - if speaking_english_scale is not None: - query = query.filter(Client.speaking_english_scale == speaking_english_scale) - if writing_english_scale is not None: - query = query.filter(Client.writing_english_scale == writing_english_scale) - if numeracy_scale is not None: - query = query.filter(Client.numeracy_scale == numeracy_scale) - if computer_scale is not None: - query = query.filter(Client.computer_scale == computer_scale) - if transportation_bool is not None: - query = query.filter(Client.transportation_bool == transportation_bool) - if caregiver_bool is not None: - query = query.filter(Client.caregiver_bool == caregiver_bool) - if housing is not None: - query = query.filter(Client.housing == housing) - if income_source is not None: - query = query.filter(Client.income_source == income_source) - if felony_bool is not None: - query = query.filter(Client.felony_bool == felony_bool) - if attending_school is not None: - query = query.filter(Client.attending_school == attending_school) - if substance_use is not None: - query = query.filter(Client.substance_use == substance_use) - if time_unemployed is not None: - query = query.filter(Client.time_unemployed == time_unemployed) - if need_mental_health_support_bool is not None: - query = query.filter(Client.need_mental_health_support_bool == need_mental_health_support_bool) + filters = [] + + criteria_map = { + Client.currently_employed: employment_status, + Client.age: age_min, + Client.gender: gender, + Client.level_of_schooling: education_level, + Client.work_experience: work_experience, + Client.canada_workex: canada_workex, + Client.dep_num: dep_num, + Client.canada_born: canada_born, + Client.citizen_status: citizen_status, + Client.fluent_english: fluent_english, + Client.reading_english_scale: reading_english_scale, + Client.speaking_english_scale: speaking_english_scale, + Client.writing_english_scale: writing_english_scale, + Client.numeracy_scale: numeracy_scale, + Client.computer_scale: computer_scale, + Client.transportation_bool: transportation_bool, + Client.caregiver_bool: caregiver_bool, + Client.housing: housing, + Client.income_source: income_source, + Client.felony_bool: felony_bool, + Client.attending_school: attending_school, + Client.substance_use: substance_use, + Client.time_unemployed: time_unemployed, + Client.need_mental_health_support_bool: need_mental_health_support_bool, + } + + for column, value in criteria_map.items(): + if value is not None: + if column == Client.age: + filters.append(column >= value) + else: + filters.append(column == value) try: - return query.all() + return db.query(Client).filter(*filters).all() except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error retrieving clients: {str(e)}" - ) + detail=f"Error retrieving clients: {str(e)}", + ) from e @staticmethod - def get_clients_by_services( - db: Session, - **service_filters: Optional[bool] - ): + def get_clients_by_services(db: Session, **service_filters: Optional[bool]): """ Get clients filtered by multiple service statuses. """ query = db.query(Client).join(ClientCase) - - for service_name, status in service_filters.items(): - if status is not None: - filter_criteria = getattr(ClientCase, service_name) == status + + for service_name, service_status in service_filters.items(): + if service_status is not None: + filter_criteria = getattr(ClientCase, service_name) == service_status query = query.filter(filter_criteria) - + try: return query.all() except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error retrieving clients: {str(e)}" - ) + detail=f"Error retrieving clients: {str(e)}", + ) from e @staticmethod def get_client_services(db: Session, client_id: int): - """Get all services for a specific client with case worker info""" + """Get all services for a specific client with caseworker info""" client_cases = db.query(ClientCase).filter(ClientCase.client_id == client_id).all() if not client_cases: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"No services found for client with id {client_id}" + detail=f"No services found for client with id {client_id}", ) return client_cases @staticmethod def get_clients_by_success_rate(db: Session, min_rate: int = 70): """Get clients with success rate at or above the specified percentage""" - if not (0 <= min_rate <= 100): + if not 0 <= min_rate <= 100: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Success rate must be between 0 and 100" + detail="Success rate must be between 0 and 100", ) - - return db.query(Client).join(ClientCase).filter( - ClientCase.success_rate >= min_rate - ).all() + + return db.query(Client).join(ClientCase).filter(ClientCase.success_rate >= min_rate).all() @staticmethod def get_clients_by_case_worker(db: Session, case_worker_id: int): @@ -204,12 +243,14 @@ def get_clients_by_case_worker(db: Session, case_worker_id: int): if not case_worker: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Case worker with id {case_worker_id} not found" + detail=f"Case worker with id {case_worker_id} not found", ) - - return db.query(Client).join(ClientCase).filter( - ClientCase.user_id == case_worker_id - ).all() + + return db.query(Client).join(ClientCase).filter(ClientCase.user_id == case_worker_id).all() + + +class ClientManagementService(InterfaceClientManagementService): + """Implementation of client management service""" @staticmethod def update_client(db: Session, client_id: int, client_update: ClientUpdate): @@ -218,10 +259,10 @@ def update_client(db: Session, client_id: int, client_update: ClientUpdate): if not client: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Client with id {client_id} not found" + detail=f"Client with id {client_id} not found", ) - update_data = client_update.dict(exclude_unset=True) + update_data = client_update.model_dump(exclude_unset=True) for field, value in update_data.items(): setattr(client, field, value) @@ -233,27 +274,25 @@ def update_client(db: Session, client_id: int, client_update: ClientUpdate): db.rollback() raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to update client: {str(e)}" - ) - + detail=f"Failed to update client: {str(e)}", + ) from e + @staticmethod def update_client_services( - db: Session, - client_id: int, - user_id: int, - service_update: ServiceUpdate + db: Session, client_id: int, user_id: int, service_update: ServiceUpdate ): """Update a client's services and outcomes for a specific case worker""" - client_case = db.query(ClientCase).filter( - ClientCase.client_id == client_id, - ClientCase.user_id == user_id - ).first() - + client_case = ( + db.query(ClientCase) + .filter(ClientCase.client_id == client_id, ClientCase.user_id == user_id) + .first() + ) + if not client_case: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"No case found for client {client_id} with case worker {user_id}. " - f"Cannot update services for a non-existent case assignment." + f"Cannot update services for a non-existent case assignment.", ) update_data = service_update.dict(exclude_unset=True) @@ -268,43 +307,43 @@ def update_client_services( db.rollback() raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to update client services: {str(e)}" - ) - + detail=f"Failed to update client services: {str(e)}", + ) from e + @staticmethod - def create_case_assignment( - db: Session, - client_id: int, - case_worker_id: int - ): + def create_case_assignment(db: Session, client_id: int, case_worker_id: int): """Create a new case assignment""" # Check if client exists client = db.query(Client).filter(Client.id == client_id).first() if not client: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Client with id {client_id} not found" + detail=f"Client with id {client_id} not found", ) - # Check if case worker exists + # Check if caseworker exists case_worker = db.query(User).filter(User.id == case_worker_id).first() if not case_worker: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Case worker with id {case_worker_id} not found" + detail=f"Case worker with id {case_worker_id} not found", ) # Check if assignment already exists - existing_case = db.query(ClientCase).filter( - ClientCase.client_id == client_id, - ClientCase.user_id == case_worker_id - ).first() - + existing_case = ( + db.query(ClientCase) + .filter(ClientCase.client_id == client_id, ClientCase.user_id == case_worker_id) + .first() + ) + if existing_case: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Client {client_id} already has a case assigned to case worker {case_worker_id}" - ) + detail=( + f"Client {client_id} already has a case assigned " + f"to case worker {case_worker_id}" + ), + ) try: # Create new case assignment with default service values @@ -318,7 +357,7 @@ def create_case_assignment( employment_related_financial_supports=False, employer_financial_supports=False, enhanced_referrals=False, - success_rate=0 + success_rate=0, ) db.add(new_case) db.commit() @@ -329,9 +368,9 @@ def create_case_assignment( db.rollback() raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to create case assignment: {str(e)}" - ) - + detail=f"Failed to create case assignment: {str(e)}", + ) from e + @staticmethod def delete_client(db: Session, client_id: int): """Delete a client and their associated records""" @@ -340,22 +379,77 @@ def delete_client(db: Session, client_id: int): if not client: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Client with id {client_id} not found" + detail=f"Client with id {client_id} not found", ) try: # Delete associated client_cases - db.query(ClientCase).filter( - ClientCase.client_id == client_id - ).delete() - + db.query(ClientCase).filter(ClientCase.client_id == client_id).delete() + # Delete the client db.delete(client) db.commit() - + except Exception as e: db.rollback() raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to delete client: {str(e)}" - ) + detail=f"Failed to delete client: {str(e)}", + ) from e + + +class ClientService: + """ + Facade that maintains backward compatibility with the existing router. + Delegates to specialized service classes. + """ + + # Query methods + @staticmethod + def get_client(db: Session, client_id: int): + return ClientQueryService.get_client(db, client_id) + + @staticmethod + def get_clients(db: Session, skip: int = 0, limit: int = 50): + return ClientQueryService.get_clients(db, skip, limit) + + @staticmethod + def get_clients_by_criteria(db: Session, **criteria): + return ClientQueryService.get_clients_by_criteria(db, **criteria) + + @staticmethod + def get_clients_by_services(db: Session, **service_filters): + return ClientQueryService.get_clients_by_services(db, **service_filters) + + @staticmethod + def get_client_services(db: Session, client_id: int): + return ClientQueryService.get_client_services(db, client_id) + + @staticmethod + def get_clients_by_success_rate(db: Session, min_rate: int = 70): + return ClientQueryService.get_clients_by_success_rate(db, min_rate) + + @staticmethod + def get_clients_by_case_worker(db: Session, case_worker_id: int): + return ClientQueryService.get_clients_by_case_worker(db, case_worker_id) + + # Modification methods + @staticmethod + def update_client(db: Session, client_id: int, client_update: ClientUpdate): + return ClientManagementService.update_client(db, client_id, client_update) + + @staticmethod + def update_client_services( + db: Session, client_id: int, user_id: int, service_update: ServiceUpdate + ): + return ClientManagementService.update_client_services( + db, client_id, user_id, service_update + ) + + @staticmethod + def create_case_assignment(db: Session, client_id: int, case_worker_id: int): + return ClientManagementService.create_case_assignment(db, client_id, case_worker_id) + + @staticmethod + def delete_client(db: Session, client_id: int): + return ClientManagementService.delete_client(db, client_id) diff --git a/app/clients/service/constants.py b/app/clients/service/constants.py new file mode 100644 index 00000000..01d02094 --- /dev/null +++ b/app/clients/service/constants.py @@ -0,0 +1,36 @@ +COLUMNS_FIELDS = [ + "age", # Client's age + "gender", # Client's gender (bool) + "work_experience", # Years of work experience + "canada_workex", # Years of work experience in Canada + "dep_num", # Number of dependents + "canada_born", # Born in Canada + "citizen_status", # Citizenship status + "level_of_schooling", # Highest level achieved (1-14) + "fluent_english", # English fluency scale (1-10) + "reading_english_scale", # Reading ability scale (1-10) + "speaking_english_scale", # Speaking ability scale (1-10) + "writing_english_scale", # Writing ability scale (1-10) + "numeracy_scale", # Numeracy ability scale (1-10) + "computer_scale", # Computer proficiency scale (1-10) + "transportation_bool", # Needs transportation support (bool) + "caregiver_bool", # Is primary caregiver (bool) + "housing", # Housing situation (1-10) + "income_source", # Source of income (1-10) + "felony_bool", # Has a felony (bool) + "attending_school", # Currently a student (bool) + "currently_employed", # Currently employed (bool) + "substance_use", # Substance use disorder (bool) + "time_unemployed", # Years unemployed + "need_mental_health_support_bool", # Needs mental health support (bool) +] + +INTERVENTION_FIELDS = [ + "employment_assistance", + "life_stabilization", + "retention_services", + "specialized_services", + "employment_related_financial_supports", + "employer_financial_supports", + "enhanced_referrals" +] diff --git a/app/clients/service/logic.py b/app/clients/service/logic.py index c25b4217..c2968ceb 100644 --- a/app/clients/service/logic.py +++ b/app/clients/service/logic.py @@ -5,30 +5,35 @@ # Standard library imports import os -#import json -from itertools import product # Third-party imports import pickle + +# import json +from itertools import product + import numpy as np +from app.clients.service.constants import COLUMNS_FIELDS + # Constants COLUMN_INTERVENTIONS = [ - 'Life Stabilization', - 'General Employment Assistance Services', - 'Retention Services', - 'Specialized Services', - 'Employment-Related Financial Supports for Job Seekers and Employers', - 'Employer Financial Supports', - 'Enhanced Referrals for Skills Development' + "Life Stabilization", + "General Employment Assistance Services", + "Retention Services", + "Specialized Services", + "Employment-Related Financial Supports for Job Seekers and Employers", + "Employer Financial Supports", + "Enhanced Referrals for Skills Development", ] # Load model CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) -MODEL_PATH = os.path.join(CURRENT_DIR, 'model.pkl') +MODEL_PATH = os.path.join(CURRENT_DIR, "model.pkl") with open(MODEL_PATH, "rb") as model_file: MODEL = pickle.load(model_file) + def clean_input_data(input_data): """ Clean and transform input data into model-compatible format. @@ -39,24 +44,17 @@ def clean_input_data(input_data): Returns: list: Cleaned and formatted data ready for model input """ - columns = [ - "age", "gender", "work_experience", "canada_workex", "dep_num", - "canada_born", "citizen_status", "level_of_schooling", "fluent_english", - "reading_english_scale", "speaking_english_scale", "writing_english_scale", - "numeracy_scale", "computer_scale", "transportation_bool", "caregiver_bool", - "housing", "income_source", "felony_bool", "attending_school", - "currently_employed", "substance_use", "time_unemployed", - "need_mental_health_support_bool" - ] - demographics = {key: input_data[key] for key in columns} + + demographics = {key: input_data[key] for key in COLUMNS_FIELDS} output = [] - for column in columns: + for column in COLUMNS_FIELDS: value = demographics.get(column, None) if isinstance(value, str): value = convert_text(value) # Removed 'column' from here as it wasn't used output.append(value) return output + def convert_text(text_data: str): """ Convert text answers from front end into numerical values. @@ -68,33 +66,47 @@ def convert_text(text_data: str): int: Converted numerical value """ categorical_mappings = [ + {"": 0, "true": 1, "false": 0, "no": 0, "yes": 1, "No": 0, "Yes": 1}, { - "": 0, "true": 1, "false": 0, "no": 0, "yes": 1, - "No": 0, "Yes": 1 - }, - { - "Grade 0-8": 1, "Grade 9": 2, "Grade 10": 3, "Grade 11": 4, - "Grade 12 or equivalent": 5, "OAC or Grade 13": 6, - "Some college": 7, "Some university": 8, "Some apprenticeship": 9, - "Certificate of Apprenticeship": 10, "Journeyperson": 11, - "Certificate/Diploma": 12, "Bachelor's degree": 13, - "Post graduate": 14 + "Grade 0-8": 1, + "Grade 9": 2, + "Grade 10": 3, + "Grade 11": 4, + "Grade 12 or equivalent": 5, + "OAC or Grade 13": 6, + "Some college": 7, + "Some university": 8, + "Some apprenticeship": 9, + "Certificate of Apprenticeship": 10, + "Journeyperson": 11, + "Certificate/Diploma": 12, + "Bachelor's degree": 13, + "Post graduate": 14, }, { - "Renting-private": 1, "Renting-subsidized": 2, - "Boarding or lodging": 3, "Homeowner": 4, - "Living with family/friend": 5, "Institution": 6, - "Temporary second residence": 7, "Band-owned home": 8, - "Homeless or transient": 9, "Emergency hostel": 10 + "Renting-private": 1, + "Renting-subsidized": 2, + "Boarding or lodging": 3, + "Homeowner": 4, + "Living with family/friend": 5, + "Institution": 6, + "Temporary second residence": 7, + "Band-owned home": 8, + "Homeless or transient": 9, + "Emergency hostel": 10, }, { - "No Source of Income": 1, "Employment Insurance": 2, + "No Source of Income": 1, + "Employment Insurance": 2, "Workplace Safety and Insurance Board": 3, "Ontario Works applied or receiving": 4, "Ontario Disability Support Program applied or receiving": 5, - "Dependent of someone receiving OW or ODSP": 6, "Crown Ward": 7, - "Employment": 8, "Self-Employment": 9, "Other (specify)": 10 - } + "Dependent of someone receiving OW or ODSP": 6, + "Crown Ward": 7, + "Employment": 8, + "Self-Employment": 9, + "Other (specify)": 10, + }, ] for category in categorical_mappings: if text_data in category: @@ -102,6 +114,7 @@ def convert_text(text_data: str): return int(text_data) if text_data.isnumeric() else text_data + def create_matrix(row_data): """ Create matrix of all possible intervention combinations. @@ -116,6 +129,7 @@ def create_matrix(row_data): perms = intervention_permutations(7) return np.concatenate((np.array(data), np.array(perms)), axis=1) + def intervention_permutations(num): """ Generate all possible intervention combinations. @@ -128,6 +142,7 @@ def intervention_permutations(num): """ return np.array(list(product([0, 1], repeat=num))) + def get_baseline_row(row_data): """ Create baseline row with no interventions. @@ -141,6 +156,7 @@ def get_baseline_row(row_data): base_interventions = np.zeros(7) return np.concatenate((np.array(row_data), base_interventions)) + def intervention_row_to_names(row_data): """ Convert intervention row to list of intervention names. @@ -153,6 +169,7 @@ def intervention_row_to_names(row_data): """ return [COLUMN_INTERVENTIONS[i] for i, value in enumerate(row_data) if value == 1] + def process_results(baseline_pred, results_matrix): """ Process model results into structured output. @@ -164,14 +181,9 @@ def process_results(baseline_pred, results_matrix): Returns: dict: Processed results with baseline and interventions """ - result_list = [ - (row[-1], intervention_row_to_names(row[:-1])) - for row in results_matrix - ] - return { - "baseline": baseline_pred[-1], - "interventions": result_list - } + result_list = [(row[-1], intervention_row_to_names(row[:-1])) for row in results_matrix] + return {"baseline": baseline_pred[-1], "interventions": result_list} + def interpret_and_calculate(input_data): """ @@ -194,19 +206,33 @@ def interpret_and_calculate(input_data): top_results = result_matrix[-3:, -8:] return process_results(baseline_prediction, top_results) + if __name__ == "__main__": test_data = { - "age": "23", "gender": "1", "work_experience": "1", - "canada_workex": "1", "dep_num": "0", "canada_born": "1", - "citizen_status": "2", "level_of_schooling": "2", - "fluent_english": "3", "reading_english_scale": "2", - "speaking_english_scale": "2", "writing_english_scale": "3", - "numeracy_scale": "2", "computer_scale": "3", - "transportation_bool": "2", "caregiver_bool": "1", - "housing": "1", "income_source": "5", "felony_bool": "1", - "attending_school": "0", "currently_employed": "1", - "substance_use": "1", "time_unemployed": "1", - "need_mental_health_support_bool": "1" + "age": "23", + "gender": "1", + "work_experience": "1", + "canada_workex": "1", + "dep_num": "0", + "canada_born": "1", + "citizen_status": "2", + "level_of_schooling": "2", + "fluent_english": "3", + "reading_english_scale": "2", + "speaking_english_scale": "2", + "writing_english_scale": "3", + "numeracy_scale": "2", + "computer_scale": "3", + "transportation_bool": "2", + "caregiver_bool": "1", + "housing": "1", + "income_source": "5", + "felony_bool": "1", + "attending_school": "0", + "currently_employed": "1", + "substance_use": "1", + "time_unemployed": "1", + "need_mental_health_support_bool": "1", } results = interpret_and_calculate(test_data) print(results) diff --git a/app/clients/service/ml_models.py b/app/clients/service/ml_models.py new file mode 100644 index 00000000..ca59d4f8 --- /dev/null +++ b/app/clients/service/ml_models.py @@ -0,0 +1,180 @@ +import os +from abc import ABC, abstractmethod +from typing import List + +import pickle +import numpy as np +from sklearn.ensemble import RandomForestRegressor +from sklearn.linear_model import LinearRegression +from sklearn.svm import SVR + +from app.clients.service.model_helper import get_all_feature_columns, get_true_file_name + +default_unformatted_model_path = os.path.join( + os.path.dirname(__file__), "pretrained_models", "model_{}.pkl" +) + + +class InterfaceBaseMLModel(ABC): + """Interface of a base ML Model""" + + def __init__(self): + self.feature_columns = get_all_feature_columns() + + @abstractmethod + def fit(self, features: np.ndarray, targets: np.ndarray): + """Fit the model to provided data""" + + @abstractmethod + def predict(self, features: np.ndarray) -> np.ndarray: + """Predict using the fitted model""" + + def save(self, path: str): + with open(path, "wb") as f: + pickle.dump(self, f) + + @staticmethod + def load(path: str): + with open(path, "rb") as f: + return pickle.load(f) + + @abstractmethod + def __str__(self) -> str: + """Return the name of the model""" + + def load_if_trained(self): + pass + + +class LinearRegressionModel(InterfaceBaseMLModel): + def __init__(self): + super().__init__() + self.model = LinearRegression() + + def fit(self, features, targets): + self.model.fit(features, targets) + + def predict(self, features): + return self.model.predict(features) + + def __str__(self): + return "Linear Regression" + + def load_if_trained(self): + path = get_true_file_name(str(self), default_unformatted_model_path) + print(f"Attempting to load model from: {path}") + if os.path.exists(path): + print("Model file exists, loading...") + self.model = InterfaceBaseMLModel.load(path) + else: + print(f"Model file not found at {path}") + + +class RandomForestModel(InterfaceBaseMLModel): + def __init__(self, n_estimators=100, random_state=42): + super().__init__() + self.model = RandomForestRegressor(n_estimators=n_estimators, random_state=random_state) + + def fit(self, features, targets): + self.model.fit(features, targets) + + def predict(self, features): + return self.model.predict(features) + + def __str__(self): + return "Random Forest Regressor" + + def load_if_trained(self): + path = get_true_file_name(str(self), default_unformatted_model_path) + print(f"Attempting to load model from: {path}") + if os.path.exists(path): + print("Model file exists, loading...") + self.model = InterfaceBaseMLModel.load(path) + else: + print(f"Model file not found at {path}") + + +class SVMModel(InterfaceBaseMLModel): + def __init__(self): + super().__init__() + self.model = SVR() + + def fit(self, features, targets): + self.model.fit(features, targets) + + def predict(self, features): + return self.model.predict(features) + + def __str__(self): + return "SVM" + + def load_if_trained(self): + path = get_true_file_name(str(self), default_unformatted_model_path) + print(f"Attempting to load model from: {path}") + if os.path.exists(path): + print("Model file exists, loading...") + self.model = InterfaceBaseMLModel.load(path) + else: + print(f"Model file not found at {path}") + + +class InterfaceMLModelRepository(ABC): + """Interface for ML Models storage""" + @abstractmethod + def list_models(self) -> List[InterfaceBaseMLModel]: + """Get list of all available models instances""" + + @abstractmethod + def is_model_available(self, model_name: str) -> bool: + """Check if a model is valid""" + + @abstractmethod + def get_model_instance(self, model_name: str) -> InterfaceBaseMLModel: + """Return an instance of the requested model""" + + +class InterfaceMLModelManager(ABC): + """Interface for ML model management""" + + @abstractmethod + def get_current_model(self) -> InterfaceBaseMLModel: + """Get the current active ml model""" + + @abstractmethod + def switch_model(self, model_name: str) -> bool: + """Switch between models""" + + +class MLModelRepository(InterfaceMLModelRepository): + def __init__(self): + self._model_map = { + "Linear Regression": LinearRegressionModel, + "Random Forest Regressor": RandomForestModel, + "Support Vector Machine": SVMModel, + } + + def list_models(self) -> List[InterfaceBaseMLModel]: + return [model_class() for model_class in self._model_map.values()] + + def is_model_available(self, model_name: str) -> bool: + return model_name in self._model_map + + def get_model_instance(self, model_name: str) -> InterfaceBaseMLModel: + if not self.is_model_available(model_name): + raise ValueError(f"Model '{model_name}' is not available.") + return self._model_map[model_name]() + + +class MLModelManager(InterfaceMLModelManager): + def __init__(self, repository: InterfaceMLModelRepository): + self._repository = repository + self._current_model = repository.get_model_instance("Random Forest Regressor") + + def get_current_model(self) -> InterfaceBaseMLModel: + return self._current_model + + def switch_model(self, model_name: str) -> bool: + if self._repository.is_model_available(model_name): + self._current_model = self._repository.get_model_instance(model_name) + return True + return False diff --git a/app/clients/service/ml_models_router.py b/app/clients/service/ml_models_router.py new file mode 100644 index 00000000..7d250d61 --- /dev/null +++ b/app/clients/service/ml_models_router.py @@ -0,0 +1,65 @@ +import numpy as np +from fastapi import APIRouter, HTTPException + +from app.clients.service.ml_models import MLModelManager, MLModelRepository, \ + InterfaceBaseMLModel +from app.clients.service.models import PredictionFeatures, PredictionRequest + +router = APIRouter(prefix="/ml_models", tags=["model"]) +model_repository = MLModelRepository() +model_manager = MLModelManager(model_repository) + + +@router.get("/list") +def list_models(): + """List all available ML models""" + # return {"models": model_repository.list_models()} + return {"models": [str(model) for model in model_repository.list_models()]} + + +@router.post("/switch/{model_name}") +def switch_models(model_name: str): + """Switch between ML models""" + success = model_manager.switch_model(model_name) + if not success: + raise HTTPException(status_code=400, detail="Model switch failed") + return {"message": f"Model switched to {model_name}"} + + +@router.get("/current") +def current_model(): + """Get the current ML model""" + # return {"current_model": model_manager.get_current_model()} + return {"current_model": str(model_manager.get_current_model())} + + +@router.post("/predict/{model_name}") +def predict_with_model_name(features: PredictionFeatures, model_name: str): + """Predict based on a given ML model name""" + model = model_repository.get_model_instance(model_name) + # model.load_if_trained() + return predict_model(model, features) + + +@router.post("/predict") +def predict_with_current_model(features: PredictionFeatures): + """Predict based on current ML model""" + model = model_manager.get_current_model() + # model.load_if_trained() + return predict_model(model, features) + + +def predict_model(model: InterfaceBaseMLModel, features: PredictionFeatures): + """Predict based on given ML model""" + model.load_if_trained() + prediction_request = PredictionRequest.from_structured_features(features) + try: + prediction = model.predict(np.array([prediction_request.features])) + return { + "model": str(model), + "input": prediction_request.features, + "prediction": prediction.tolist(), + } + except Exception as e: + raise HTTPException(status_code=500, + detail=f"Prediction failed: {str(e)}") from e diff --git a/app/clients/service/model.py b/app/clients/service/model.py index b2406370..86bd782a 100644 --- a/app/clients/service/model.py +++ b/app/clients/service/model.py @@ -1,110 +1,159 @@ """ Model training module for the Common Assessment Tool. Handles the preparation, training, and saving of the prediction model. +Pass in model name via command line """ +import os + # Standard library imports import pickle +import sys # Third-party imports import numpy as np import pandas as pd + +# Local imports +# from sklearn import svm +# from sklearn.ensemble import RandomForestRegressor +# from sklearn.linear_model import LinearRegression from sklearn.model_selection import train_test_split -from sklearn.ensemble import RandomForestRegressor +from app.clients.service.constants import COLUMNS_FIELDS, INTERVENTION_FIELDS +from .ml_models import ( + InterfaceBaseMLModel, + LinearRegressionModel, + MLModelRepository, + RandomForestModel, + SVMModel, +) + +repo = MLModelRepository() + +DEFAULT_UNFORMATTED_MODEL_PATH = "pretrained_models" + os.sep + "model_{}.pkl" + + +def get_model_by_name(model_type: str, n_estimators=100, random_state=42) -> InterfaceBaseMLModel: + model_map = { + "Linear Regression": LinearRegressionModel, + "Random Forest Regressor": lambda: RandomForestModel(n_estimators, random_state), + "Support Vector Machine": SVMModel, + } + + if model_type not in model_map: + print(f"ERROR! Invalid model type '{model_type}' passed in.") + print(f"Available models: {repo.list_models()}") + sys.exit(-1) + + constructor = model_map[model_type] + return constructor() if callable(constructor) else constructor() + -def prepare_models(): +def prepare_model_data(test_size=0.2, random_state=42): """ Prepare and train the Random Forest model using the dataset. - + Args: + test_size: The percent of the dataset to use as test data (rest will be used as train data) + random_state: The random state to generate train/test split with + Returns: RandomForestRegressor: Trained model for predicting success rates """ # Load dataset - data = pd.read_csv('data_commontool.csv') - # Define feature columns - feature_columns = [ - 'age', # Client's age - 'gender', # Client's gender (bool) - 'work_experience', # Years of work experience - 'canada_workex', # Years of work experience in Canada - 'dep_num', # Number of dependents - 'canada_born', # Born in Canada - 'citizen_status', # Citizenship status - 'level_of_schooling', # Highest level achieved (1-14) - 'fluent_english', # English fluency scale (1-10) - 'reading_english_scale', # Reading ability scale (1-10) - 'speaking_english_scale',# Speaking ability scale (1-10) - 'writing_english_scale', # Writing ability scale (1-10) - 'numeracy_scale', # Numeracy ability scale (1-10) - 'computer_scale', # Computer proficiency scale (1-10) - 'transportation_bool', # Needs transportation support (bool) - 'caregiver_bool', # Is primary caregiver (bool) - 'housing', # Housing situation (1-10) - 'income_source', # Source of income (1-10) - 'felony_bool', # Has a felony (bool) - 'attending_school', # Currently a student (bool) - 'currently_employed', # Currently employed (bool) - 'substance_use', # Substance use disorder (bool) - 'time_unemployed', # Years unemployed - 'need_mental_health_support_bool' # Needs mental health support (bool) - ] - # Define intervention columns - intervention_columns = [ - 'employment_assistance', - 'life_stabilization', - 'retention_services', - 'specialized_services', - 'employment_related_financial_supports', - 'employer_financial_supports', - 'enhanced_referrals' - ] + data = pd.read_csv("data_commontool.csv") + # Combine all feature columns - all_features = feature_columns + intervention_columns + all_features = COLUMNS_FIELDS + INTERVENTION_FIELDS # Prepare training data - features = np.array(data[all_features]) # Changed from X to features - targets = np.array(data['success_rate']) # Changed from y to targets + features = np.array(data[all_features]) # Input features for the model + targets = np.array(data["success_rate"]) # Target variable # Split the dataset - features_train, _, targets_train, _ = train_test_split( # Removed unused variables + feature_train, feature_test, target_train, target_test = train_test_split( + # Removed unused variables features, targets, - test_size=0.2, - random_state=42 + test_size=test_size, + random_state=random_state, ) - # Initialize and train the model - model = RandomForestRegressor(n_estimators=100, random_state=42) - model.fit(features_train, targets_train) + + return feature_train, feature_test, target_train, target_test + + +def train_model( + feature_train, target_train, model_type, n_estimators=100, random_state=42 +) -> InterfaceBaseMLModel: + """ + Trains the model + Args: + feature_train: Training features + target_train: Target features + model_type: Which model to create + n_estimators: Number estimators (for random forest) + random_state: Random state to train with (for random forest) + + Returns: A trained model of the type specified + + """ + model = get_model_by_name(model_type, n_estimators, random_state) + model.fit(feature_train, target_train) return model -def save_model(model, filename="model.pkl"): + +def get_true_file_name(model_type, filename): + """ + Takes a model type and file name, formats model type, and replaces spaces with underscores + Args: + model_type: The model type as a String + filename: The file name (should follow 'model_{}.pkl' format) + + Returns: The clean file name + """ + return filename.format(model_type).replace(" ", "_") + + +def save_model(model, model_type, filename=DEFAULT_UNFORMATTED_MODEL_PATH): """ Save the trained model to a file. - + Args: model: Trained model to save + model_type: The type of model being saved filename (str): Name of the file to save the model to """ - with open(filename, "wb") as model_file: + true_file_name = get_true_file_name(model_type, filename) + with open(true_file_name, "wb") as model_file: pickle.dump(model, model_file) -def load_model(filename="model.pkl"): + +def load_model(model_type, filename=DEFAULT_UNFORMATTED_MODEL_PATH): """ Load a trained model from a file. - + Args: + model_type: The type of model being loaded filename (str): Name of the file to load the model from - + Returns: The loaded model """ - with open(filename, "rb") as model_file: + true_file_name = get_true_file_name(model_type, filename) + with open(true_file_name, "rb") as model_file: return pickle.load(model_file) -def main(): + +def main(argv): """Main function to train and save the model.""" - print("Starting model training...") - model = prepare_models() - save_model(model) + # Get the model type from the command line arguments + model_type = argv[1] + + # Train and save the model + print(f"Starting model training for {model_type} model...") + # feature_train, feature_test, target_train, target_test = prepare_model_data() + feature_train, _, target_train, _ = prepare_model_data() + model = train_model(feature_train, target_train, model_type) + save_model(model, model_type) print("Model training completed and saved successfully.") + if __name__ == "__main__": - main() + main(sys.argv) diff --git a/app/clients/service/model_helper.py b/app/clients/service/model_helper.py new file mode 100644 index 00000000..125cec76 --- /dev/null +++ b/app/clients/service/model_helper.py @@ -0,0 +1,21 @@ +from app.clients.service.constants import COLUMNS_FIELDS, INTERVENTION_FIELDS + + +def get_feature_columns(): + """Get all feature columns""" + return COLUMNS_FIELDS + + +def get_intervention_columns(): + """Get all intervention columns""" + return INTERVENTION_FIELDS + + +def get_all_feature_columns(): + """Get all feature columns""" + return get_feature_columns() + get_intervention_columns() + + +def get_true_file_name(model_type, filename): + """Format pickle file name""" + return filename.format(model_type).replace(" ", "_") diff --git a/app/clients/service/models.py b/app/clients/service/models.py new file mode 100644 index 00000000..904ec86d --- /dev/null +++ b/app/clients/service/models.py @@ -0,0 +1,107 @@ +from typing import List + +from pydantic import BaseModel, Field + + +class PredictionFeatures(BaseModel): + """Template class prediction class""" + + age: float = Field(..., description="Client's age", example=30) + gender: float = Field(..., description="Client's gender (1 for male, 0 for female)", example=1) + work_experience: float = Field(..., description="Years of work experience", example=5) + canada_workex: float = Field(..., description="Years of work experience in Canada", example=2) + dep_num: float = Field(..., description="Number of dependents", example=1) + canada_born: float = Field(..., description="Born in Canada (1 for yes, 0 for no)", example=0) + citizen_status: float = Field(..., description="Citizenship status", example=1) + level_of_schooling: float = Field(..., description="Highest level achieved (1-14)", example=8) + fluent_english: float = Field(..., description="English fluency scale (1-10)", example=7) + reading_english_scale: float = Field(..., description="Reading ability scale (1-10)", example=6) + speaking_english_scale: float = Field( + ..., description="Speaking ability scale (1-10)", example=6 + ) + writing_english_scale: float = Field(..., description="Writing ability scale (1-10)", example=5) + numeracy_scale: float = Field(..., description="Numeracy ability scale (1-10)", example=7) + computer_scale: float = Field(..., description="Computer proficiency scale (1-10)", example=6) + transportation_bool: float = Field( + ..., description="Needs transportation support (1 for yes, 0 for no)", example=0 + ) + caregiver_bool: float = Field( + ..., description="Is primary caregiver (1 for yes, 0 for no)", example=0 + ) + housing: float = Field(..., description="Housing situation (1-10)", example=3) + income_source: float = Field(..., description="Source of income (1-10)", example=2) + felony_bool: float = Field(..., description="Has a felony (1 for yes, 0 for no)", example=0) + attending_school: float = Field( + ..., description="Currently a student (1 for yes, 0 for no)", example=0 + ) + currently_employed: float = Field( + ..., description="Currently employed (1 for yes, 0 for no)", example=0 + ) + substance_use: float = Field( + ..., description="Substance use disorder (1 for yes, 0 for no)", example=0 + ) + time_unemployed: float = Field(..., description="Years unemployed", example=1) + need_mental_health_support_bool: float = Field( + ..., description="Needs mental health support (1 for yes, 0 for no)", example=0 + ) + # Intervention columns + employment_assistance: float = Field( + ..., description="Employment assistance intervention", example=1 + ) + life_stabilization: float = Field(..., description="Life stabilization intervention", example=1) + retention_services: float = Field(..., description="Retention services intervention", example=0) + specialized_services: float = Field( + ..., description="Specialized services intervention", example=0 + ) + employment_related_financial_supports: float = Field( + ..., description="Employment related financial supports", example=1 + ) + employer_financial_supports: float = Field( + ..., description="Employer financial supports", example=0 + ) + enhanced_referrals: float = Field(..., description="Enhanced referrals", example=0) + + +class PredictionRequest(BaseModel): + """Template class for prediction request""" + + features: List[float] = Field( + ..., description="List of 31 features in specific order for model prediction" + ) + + @classmethod + def from_structured_features(cls, structured_features: PredictionFeatures): + features = [ + structured_features.age, + structured_features.gender, + structured_features.work_experience, + structured_features.canada_workex, + structured_features.dep_num, + structured_features.canada_born, + structured_features.citizen_status, + structured_features.level_of_schooling, + structured_features.fluent_english, + structured_features.reading_english_scale, + structured_features.speaking_english_scale, + structured_features.writing_english_scale, + structured_features.numeracy_scale, + structured_features.computer_scale, + structured_features.transportation_bool, + structured_features.caregiver_bool, + structured_features.housing, + structured_features.income_source, + structured_features.felony_bool, + structured_features.attending_school, + structured_features.currently_employed, + structured_features.substance_use, + structured_features.time_unemployed, + structured_features.need_mental_health_support_bool, + structured_features.employment_assistance, + structured_features.life_stabilization, + structured_features.retention_services, + structured_features.specialized_services, + structured_features.employment_related_financial_supports, + structured_features.employer_financial_supports, + structured_features.enhanced_referrals, + ] + return cls(features=features) diff --git a/app/clients/service/pretrained_models/model_Linear_Regression.pkl b/app/clients/service/pretrained_models/model_Linear_Regression.pkl new file mode 100644 index 00000000..d264f1b9 Binary files /dev/null and b/app/clients/service/pretrained_models/model_Linear_Regression.pkl differ diff --git a/app/clients/service/pretrained_models/model_Random_Forest_Regressor.pkl b/app/clients/service/pretrained_models/model_Random_Forest_Regressor.pkl new file mode 100644 index 00000000..cde4340f Binary files /dev/null and b/app/clients/service/pretrained_models/model_Random_Forest_Regressor.pkl differ diff --git a/app/clients/service/pretrained_models/model_Support_Vector_Machine.pkl b/app/clients/service/pretrained_models/model_Support_Vector_Machine.pkl new file mode 100644 index 00000000..999b824c Binary files /dev/null and b/app/clients/service/pretrained_models/model_Support_Vector_Machine.pkl differ diff --git a/app/database.py b/app/database.py index 3a489f54..cb932579 100644 --- a/app/database.py +++ b/app/database.py @@ -4,25 +4,26 @@ """ from sqlalchemy import create_engine -from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import declarative_base from sqlalchemy.orm import sessionmaker -#Here is where the database is located -SQLALCHEMY_DATABASE_URL = "sqlite:///./sql_app.db" +# Here is where the database is located +SQLALCHEMY_DATABASE_URL = "sqlite:///./sql_app.db" -#Open up a connection so that we are able to use the database +# Open up a connection so that we are able to use the database engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}) -#Bind the engine just created +# Bind the engine just created SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) -#Create an object of our database so as to control the database +# Create an object of our database so as to control the database Base = declarative_base() + def get_db(): """ Creates a database session and ensures it's closed after use. - + Yields: Session: SQLAlchemy database session """ diff --git a/app/main.py b/app/main.py index a8e8fa7f..66746785 100644 --- a/app/main.py +++ b/app/main.py @@ -5,27 +5,32 @@ """ from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware + from app import models -from app.database import engine -from app.clients.router import router as clients_router from app.auth.router import router as auth_router -from fastapi.middleware.cors import CORSMiddleware +from app.clients.router import router as clients_router +from app.clients.service.ml_models_router import router as ml_models_router +from app.database import engine # Initialize database tables models.Base.metadata.create_all(bind=engine) # Create FastAPI application -app = FastAPI(title="Case Management API", description="API for managing client cases", version="1.0.0") +app = FastAPI( + title="Case Management API", description="API for managing client cases", version="1.0.0" +) # Include routers app.include_router(auth_router) app.include_router(clients_router) +app.include_router(ml_models_router) # Configure CORS middleware app.add_middleware( CORSMiddleware, - allow_origins=["*"], # Allows all origins - allow_methods=["*"], # Allows all methods - allow_headers=["*"], # Allows all headers + allow_origins=["*"], # Allows all origins + allow_methods=["*"], # Allows all methods + allow_headers=["*"], # Allows all headers allow_credentials=True, ) diff --git a/app/models.py b/app/models.py index df778348..4fc4b8ba 100644 --- a/app/models.py +++ b/app/models.py @@ -2,15 +2,18 @@ Database models module defining SQLAlchemy ORM models for the Common Assessment Tool. Contains the Client model for storing client information in the database. """ +# pylint: disable=too-few-public-methods +import enum -from app.database import Base -from sqlalchemy import Column, Integer, String, Boolean, ForeignKey, CheckConstraint, Enum +from sqlalchemy import Boolean, CheckConstraint, Column, Enum, ForeignKey, Integer, String from sqlalchemy.orm import relationship -import enum + +from app.database import Base + class UserRole(str, enum.Enum): - admin = "admin" - case_worker = "case_worker" + ADMIN = "admin" + CASE_WORKER = "case_worker" class User(Base): @@ -24,46 +27,61 @@ class User(Base): cases = relationship("ClientCase", back_populates="user") + class Client(Base): """ Client model representing client data in the database. """ + __tablename__ = "clients" id = Column(Integer, primary_key=True, autoincrement=True) - age = Column(Integer, CheckConstraint('age >= 18')) + age = Column(Integer, CheckConstraint("age >= 18")) gender = Column(Integer, CheckConstraint("gender = 1 OR gender = 2")) - work_experience = Column(Integer, CheckConstraint('work_experience >= 0')) - canada_workex = Column(Integer, CheckConstraint('canada_workex >= 0')) - dep_num = Column(Integer, CheckConstraint('dep_num >= 0')) + work_experience = Column(Integer, CheckConstraint("work_experience >= 0")) + canada_workex = Column(Integer, CheckConstraint("canada_workex >= 0")) + dep_num = Column(Integer, CheckConstraint("dep_num >= 0")) canada_born = Column(Boolean) citizen_status = Column(Boolean) - level_of_schooling = Column(Integer, CheckConstraint('level_of_schooling >= 1 AND level_of_schooling <= 14')) + level_of_schooling = Column( + Integer, CheckConstraint("level_of_schooling >= 1 AND level_of_schooling <= 14") + ) fluent_english = Column(Boolean) - reading_english_scale = Column(Integer, CheckConstraint('reading_english_scale >= 0 AND reading_english_scale <= 10')) - speaking_english_scale = Column(Integer, CheckConstraint('speaking_english_scale >= 0 AND speaking_english_scale <= 10')) - writing_english_scale = Column(Integer, CheckConstraint('writing_english_scale >= 0 AND writing_english_scale <= 10')) - numeracy_scale = Column(Integer, CheckConstraint('numeracy_scale >= 0 AND numeracy_scale <= 10')) - computer_scale = Column(Integer, CheckConstraint('computer_scale >= 0 AND computer_scale <= 10')) + reading_english_scale = Column( + Integer, CheckConstraint("reading_english_scale >= 0 AND reading_english_scale <= 10") + ) + speaking_english_scale = Column( + Integer, CheckConstraint("speaking_english_scale >= 0 AND speaking_english_scale <= 10") + ) + writing_english_scale = Column( + Integer, CheckConstraint("writing_english_scale >= 0 AND writing_english_scale <= 10") + ) + numeracy_scale = Column( + Integer, CheckConstraint("numeracy_scale >= 0 AND numeracy_scale <= 10") + ) + computer_scale = Column( + Integer, CheckConstraint("computer_scale >= 0 AND computer_scale <= 10") + ) transportation_bool = Column(Boolean) caregiver_bool = Column(Boolean) - housing = Column(Integer, CheckConstraint('housing >= 1 AND housing <= 10')) - income_source = Column(Integer, CheckConstraint('income_source >= 1 AND income_source <= 11')) + housing = Column(Integer, CheckConstraint("housing >= 1 AND housing <= 10")) + income_source = Column(Integer, CheckConstraint("income_source >= 1 AND income_source <= 11")) felony_bool = Column(Boolean) attending_school = Column(Boolean) currently_employed = Column(Boolean) substance_use = Column(Boolean) - time_unemployed = Column(Integer, CheckConstraint('time_unemployed >= 0')) + time_unemployed = Column(Integer, CheckConstraint("time_unemployed >= 0")) need_mental_health_support_bool = Column(Boolean) cases = relationship("ClientCase", back_populates="client") + class ClientCase(Base): __tablename__ = "client_cases" client_id = Column(Integer, ForeignKey("clients.id"), primary_key=True) user_id = Column(Integer, ForeignKey("users.id"), primary_key=True) - + employment_assistance = Column(Boolean) life_stabilization = Column(Boolean) retention_services = Column(Boolean) @@ -71,7 +89,7 @@ class ClientCase(Base): employment_related_financial_supports = Column(Boolean) employer_financial_supports = Column(Boolean) enhanced_referrals = Column(Boolean) - success_rate = Column(Integer, CheckConstraint('success_rate >= 0 AND success_rate <= 100')) + success_rate = Column(Integer, CheckConstraint("success_rate >= 0 AND success_rate <= 100")) client = relationship("Client", back_populates="cases") user = relationship("User", back_populates="cases") diff --git a/app/requirements.txt b/app/requirements.txt index 57fc8e6e..ddcf118a 100644 --- a/app/requirements.txt +++ b/app/requirements.txt @@ -38,6 +38,5 @@ typer==0.12.5 typing_extensions==4.12.2 tzdata==2024.1 uvicorn==0.30.6 -uvloop==0.20.0 watchfiles==0.23.0 websockets==13.0 diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 00000000..1f70f5fb --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,12 @@ +services: + web: + build: . + ports: + - "8000:8000" + develop: + watch: + - action: sync + path: . + target: /code + app: + env_file: .env \ No newline at end of file diff --git a/initialize_data.py b/initialize_data.py index 1444bf41..34c77a68 100644 --- a/initialize_data.py +++ b/initialize_data.py @@ -4,6 +4,7 @@ from app.models import Client, User, ClientCase, UserRole from app.auth.router import get_password_hash + def initialize_database(): print("Starting database initialization...") db = SessionLocal() @@ -15,7 +16,7 @@ def initialize_database(): username="admin", email="admin@example.com", hashed_password=get_password_hash("admin123"), - role=UserRole.admin + role=UserRole.ADMIN, ) db.add(admin_user) db.commit() @@ -30,7 +31,7 @@ def initialize_database(): username="case_worker1", email="caseworker1@example.com", hashed_password=get_password_hash("worker123"), - role=UserRole.case_worker + role=UserRole.CASE_WORKER, ) db.add(case_worker) db.commit() @@ -40,46 +41,57 @@ def initialize_database(): # Load CSV data print("Loading CSV data...") - df = pd.read_csv('app/clients/service/data_commontool.csv') - + df = pd.read_csv("app/clients/service/data_commontool.csv") + # Convert data types integer_columns = [ - 'age', 'gender', 'work_experience', 'canada_workex', 'dep_num', - 'level_of_schooling', 'reading_english_scale', 'speaking_english_scale', - 'writing_english_scale', 'numeracy_scale', 'computer_scale', - 'housing', 'income_source', 'time_unemployed', 'success_rate' + "age", + "gender", + "work_experience", + "canada_workex", + "dep_num", + "level_of_schooling", + "reading_english_scale", + "speaking_english_scale", + "writing_english_scale", + "numeracy_scale", + "computer_scale", + "housing", + "income_source", + "time_unemployed", + "success_rate", ] for col in integer_columns: - df[col] = pd.to_numeric(df[col], errors='raise') + df[col] = pd.to_numeric(df[col], errors="raise") # Process each row in CSV for index, row in df.iterrows(): # Create client client = Client( - age=int(row['age']), - gender=int(row['gender']), - work_experience=int(row['work_experience']), - canada_workex=int(row['canada_workex']), - dep_num=int(row['dep_num']), - canada_born=bool(row['canada_born']), - citizen_status=bool(row['citizen_status']), - level_of_schooling=int(row['level_of_schooling']), - fluent_english=bool(row['fluent_english']), - reading_english_scale=int(row['reading_english_scale']), - speaking_english_scale=int(row['speaking_english_scale']), - writing_english_scale=int(row['writing_english_scale']), - numeracy_scale=int(row['numeracy_scale']), - computer_scale=int(row['computer_scale']), - transportation_bool=bool(row['transportation_bool']), - caregiver_bool=bool(row['caregiver_bool']), - housing=int(row['housing']), - income_source=int(row['income_source']), - felony_bool=bool(row['felony_bool']), - attending_school=bool(row['attending_school']), - currently_employed=bool(row['currently_employed']), - substance_use=bool(row['substance_use']), - time_unemployed=int(row['time_unemployed']), - need_mental_health_support_bool=bool(row['need_mental_health_support_bool']) + age=int(row["age"]), + gender=int(row["gender"]), + work_experience=int(row["work_experience"]), + canada_workex=int(row["canada_workex"]), + dep_num=int(row["dep_num"]), + canada_born=bool(row["canada_born"]), + citizen_status=bool(row["citizen_status"]), + level_of_schooling=int(row["level_of_schooling"]), + fluent_english=bool(row["fluent_english"]), + reading_english_scale=int(row["reading_english_scale"]), + speaking_english_scale=int(row["speaking_english_scale"]), + writing_english_scale=int(row["writing_english_scale"]), + numeracy_scale=int(row["numeracy_scale"]), + computer_scale=int(row["computer_scale"]), + transportation_bool=bool(row["transportation_bool"]), + caregiver_bool=bool(row["caregiver_bool"]), + housing=int(row["housing"]), + income_source=int(row["income_source"]), + felony_bool=bool(row["felony_bool"]), + attending_school=bool(row["attending_school"]), + currently_employed=bool(row["currently_employed"]), + substance_use=bool(row["substance_use"]), + time_unemployed=int(row["time_unemployed"]), + need_mental_health_support_bool=bool(row["need_mental_health_support_bool"]), ) db.add(client) db.commit() @@ -88,14 +100,16 @@ def initialize_database(): client_case = ClientCase( client_id=client.id, user_id=admin_user.id, # Assign to admin - employment_assistance=bool(row['employment_assistance']), - life_stabilization=bool(row['life_stabilization']), - retention_services=bool(row['retention_services']), - specialized_services=bool(row['specialized_services']), - employment_related_financial_supports=bool(row['employment_related_financial_supports']), - employer_financial_supports=bool(row['employer_financial_supports']), - enhanced_referrals=bool(row['enhanced_referrals']), - success_rate=int(row['success_rate']) + employment_assistance=bool(row["employment_assistance"]), + life_stabilization=bool(row["life_stabilization"]), + retention_services=bool(row["retention_services"]), + specialized_services=bool(row["specialized_services"]), + employment_related_financial_supports=bool( + row["employment_related_financial_supports"] + ), + employer_financial_supports=bool(row["employer_financial_supports"]), + enhanced_referrals=bool(row["enhanced_referrals"]), + success_rate=int(row["success_rate"]), ) db.add(client_case) db.commit() @@ -108,5 +122,6 @@ def initialize_database(): finally: db.close() + if __name__ == "__main__": - initialize_database() \ No newline at end of file + initialize_database() diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 00000000..6dd3a8a7 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,40 @@ +# Global options: +[mypy] +# Enable strict mode for better type checking +strict = false + +# Encourage but don't require type annotations +disallow_untyped_defs = false +# At least annotate return types +disallow_incomplete_defs = true +# Type check the body of functions without annotations +check_untyped_defs = true + +# Allow calling functions without type hints +disallow_untyped_calls = false + +# Be more permissive with 'Any' usage +disallow_any_unimported = false +disallow_any_explicit = false +disallow_any_generics = false +# Still warn about returning Any +warn_return_any = true + +# Don't enforce strict subclassing +disallow_subclassing_any = false + +# Still catch obvious issues +warn_redundant_casts = true +warn_unused_ignores = true +warn_no_return = false +warn_unreachable = true + +# Allow redefinitions in some contexts +allow_redefinition = true + +# Module import settings +ignore_missing_imports = true +follow_imports = silent + +# Performance improvements +cache_dir = .mypy_cache \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..820b4f6f --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,50 @@ +# Overall Project Configuration +[project] +name = "Common_Assessment_Tool" +version = "1.0.0" +authors = [{name = "David Treadwell", email = "treadwell.d@northeastern.edu"}, {name = "Fran Li", email = "li.fengr@northeastern.edu"}, {name = "Steve Chen", email = "chen.steve2@northeastern.edu"}] +readme = "README.md" +license = "MIT" +dynamic = ["dependencies"] +requires-python = ">=3.10" + +# Set up dependencies from requirements.txt file +[tool.setuptools.dynamic] +dependencies = {file = ["requirements.txt"]} + +# Project urls +[project.urls] +Repository = "https://github.com/dtread4/CommonAssessmentTool" + +# Build system +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +# Packages finder +[tool.setuptools.packages.find] +where = ["."] + +# Optional dependency configuration +[project.optional-dependencies] +dev = ["black", "isort"] +extra = ['uvloop==0.20.0'] + +# Black Configuration +[tool.black] +line-length = 100 +include = '\.pyi?$' +skip-magic-trailing-comma = true +target-version = ['py310'] + +# isort Configuration +[tool.isort] +profile = "black" +line_length = 100 +known_first_party = ["app"] +multi_line_output = 3 +force_grid_wrap = 0 +combine_as_imports = true +include_trailing_comma = true +force_single_line = false +skip = ["venv", ".venv", "migrations"] \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 93d35fbf..b09bed08 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ appnope==0.1.3 argon2-cffi==21.3.0 argon2-cffi-bindings==21.2.0 arrow==1.2.3 -astroid==3.0.1 +astroid==3.3.9 asttokens==2.4.0 attrs==22.1.0 backcall==0.2.0 @@ -95,7 +95,7 @@ pydantic==2.4.2 pydantic-settings==2.0.3 pydantic_core==2.10.1 Pygments==2.16.1 -pylint==3.0.1 +pylint==3.3.6 pyrsistent==0.19.3 pytest==7.2.0 python-dateutil==2.8.2 @@ -133,7 +133,6 @@ tzdata==2023.3 uri-template==1.2.0 urllib3==2.0.7 uvicorn==0.23.2 -uvloop==0.17.0 watchfiles==0.20.0 wcwidth==0.2.8 webcolors==1.13 @@ -141,4 +140,3 @@ webencodings==0.5.1 websocket-client==1.5.1 websockets==11.0.3 widgetsnbextension==4.0.7 - diff --git a/tests/conftest.py b/tests/conftest.py index aa30d094..fffaf232 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +# pylint: disable=redefined-outer-name import pytest from fastapi.testclient import TestClient from sqlalchemy import create_engine @@ -12,11 +13,12 @@ engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}) TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + @pytest.fixture def test_db(): # Create tables Base.metadata.create_all(bind=engine) - + db = TestingSessionLocal() try: # Create test admin user @@ -24,7 +26,7 @@ def test_db(): username="testadmin", email="testadmin@example.com", hashed_password=get_password_hash("testpass123"), - role=UserRole.admin + role=UserRole.ADMIN, ) db.add(admin_user) @@ -33,10 +35,10 @@ def test_db(): username="testworker", email="worker@example.com", hashed_password=get_password_hash("workerpass123"), - role=UserRole.case_worker + role=UserRole.CASE_WORKER, ) db.add(case_worker) - + # Create test clients client1 = Client( age=25, @@ -62,9 +64,9 @@ def test_db(): currently_employed=False, substance_use=False, time_unemployed=6, - need_mental_health_support_bool=False + need_mental_health_support_bool=False, ) - + client2 = Client( age=30, gender=2, @@ -89,13 +91,13 @@ def test_db(): currently_employed=True, substance_use=False, time_unemployed=0, - need_mental_health_support_bool=False + need_mental_health_support_bool=False, ) - + db.add(client1) db.add(client2) db.commit() - + # Create test client cases client_case1 = ClientCase( client_id=1, @@ -107,9 +109,9 @@ def test_db(): employment_related_financial_supports=True, employer_financial_supports=False, enhanced_referrals=True, - success_rate=75 + success_rate=75, ) - + client_case2 = ClientCase( client_id=2, user_id=2, # Assigned to case worker @@ -120,18 +122,19 @@ def test_db(): employment_related_financial_supports=False, employer_financial_supports=True, enhanced_referrals=False, - success_rate=85 + success_rate=85, ) - + db.add(client_case1) db.add(client_case2) db.commit() - + yield db finally: db.close() Base.metadata.drop_all(bind=engine) + @pytest.fixture def client(test_db): def override_get_db(): @@ -139,32 +142,31 @@ def override_get_db(): yield test_db finally: test_db.close() - + app.dependency_overrides[get_db] = override_get_db yield TestClient(app) app.dependency_overrides.clear() + @pytest.fixture def admin_token(client): - response = client.post( - "/auth/token", - data={"username": "testadmin", "password": "testpass123"} - ) + response = client.post("/auth/token", data={"username": "testadmin", "password": "testpass123"}) return response.json()["access_token"] + @pytest.fixture def case_worker_token(client): response = client.post( - "/auth/token", - data={"username": "testworker", "password": "workerpass123"} + "/auth/token", data={"username": "testworker", "password": "workerpass123"} ) return response.json()["access_token"] + @pytest.fixture def admin_headers(admin_token): return {"Authorization": f"Bearer {admin_token}"} + @pytest.fixture def case_worker_headers(case_worker_token): return {"Authorization": f"Bearer {case_worker_token}"} - \ No newline at end of file diff --git a/tests/test_auth.py b/tests/test_auth.py index 1d4692e4..1c24a497 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -1,123 +1,110 @@ -import pytest from fastapi import status + def test_create_user_success(client, admin_headers): """Test successful user creation by admin""" user_data = { "username": "newuser", "email": "new@test.com", "password": "testpass123", - "role": "case_worker" + "role": "case_worker", } - response = client.post( - "/auth/users", - headers=admin_headers, - json=user_data - ) + response = client.post("/auth/users", headers=admin_headers, json=user_data) assert response.status_code == status.HTTP_200_OK data = response.json() assert data["username"] == "newuser" assert data["role"] == "case_worker" assert "password" not in data # Password should not be in response + def test_create_user_duplicate_username(client, admin_headers): """Test creating user with existing username""" user_data = { "username": "testadmin", # This username exists in test database "email": "another@test.com", "password": "testpass123", - "role": "case_worker" + "role": "case_worker", } - response = client.post( - "/auth/users", - headers=admin_headers, - json=user_data - ) + response = client.post("/auth/users", headers=admin_headers, json=user_data) assert response.status_code == status.HTTP_400_BAD_REQUEST assert "Username already registered" in response.json()["detail"] + def test_create_user_duplicate_email(client, admin_headers): """Test creating user with existing email""" user_data = { "username": "uniqueuser", "email": "testadmin@example.com", # This email exists in test database "password": "testpass123", - "role": "case_worker" + "role": "case_worker", } - response = client.post( - "/auth/users", - headers=admin_headers, - json=user_data - ) + response = client.post("/auth/users", headers=admin_headers, json=user_data) assert response.status_code == status.HTTP_400_BAD_REQUEST assert "Email already registered" in response.json()["detail"] + def test_create_user_invalid_role(client, admin_headers): """Test creating user with invalid role""" user_data = { "username": "newuser", "email": "new@test.com", "password": "testpass123", - "role": "invalid_role" # Invalid role + "role": "invalid_role", # Invalid role } - response = client.post( - "/auth/users", - headers=admin_headers, - json=user_data - ) + response = client.post("/auth/users", headers=admin_headers, json=user_data) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + def test_create_user_unauthorized(client): """Test user creation without authentication""" user_data = { "username": "newuser", "email": "new@test.com", "password": "testpass123", - "role": "case_worker" + "role": "case_worker", } response = client.post("/auth/users", json=user_data) assert response.status_code == status.HTTP_401_UNAUTHORIZED + def test_login_success_admin(client): """Test successful login for admin""" - response = client.post( - "/auth/token", - data={"username": "testadmin", "password": "testpass123"} - ) + response = client.post("/auth/token", data={"username": "testadmin", "password": "testpass123"}) assert response.status_code == status.HTTP_200_OK data = response.json() assert "access_token" in data assert data["token_type"] == "bearer" + def test_login_success_case_worker(client): """Test successful login for case worker""" response = client.post( - "/auth/token", - data={"username": "testworker", "password": "workerpass123"} + "/auth/token", data={"username": "testworker", "password": "workerpass123"} ) assert response.status_code == status.HTTP_200_OK data = response.json() assert "access_token" in data assert data["token_type"] == "bearer" + def test_login_wrong_password(client): """Test login with incorrect password""" response = client.post( - "/auth/token", - data={"username": "testadmin", "password": "wrongpassword"} + "/auth/token", data={"username": "testadmin", "password": "wrongpassword"} ) assert response.status_code == status.HTTP_401_UNAUTHORIZED assert "Incorrect username or password" in response.json()["detail"] + def test_login_nonexistent_user(client): """Test login with non-existent username""" response = client.post( - "/auth/token", - data={"username": "nonexistent", "password": "testpass123"} + "/auth/token", data={"username": "nonexistent", "password": "testpass123"} ) assert response.status_code == status.HTTP_401_UNAUTHORIZED assert "Incorrect username or password" in response.json()["detail"] + def test_invalid_token(client): """Test using invalid token""" headers = {"Authorization": "Bearer invalid_token_here"} @@ -125,12 +112,14 @@ def test_invalid_token(client): assert response.status_code == status.HTTP_401_UNAUTHORIZED assert "Could not validate credentials" in response.json()["detail"] + def test_missing_token(client): """Test accessing protected endpoint without token""" response = client.get("/clients/") assert response.status_code == status.HTTP_401_UNAUTHORIZED assert "Not authenticated" in response.json()["detail"] + def test_token_user_deleted(client, admin_headers): """Test using token of deleted user""" # First create a new user as admin @@ -138,22 +127,15 @@ def test_token_user_deleted(client, admin_headers): "username": "temporary", "email": "temp@test.com", "password": "temppass123", - "role": "admin" # Changed to admin so they can access /clients/ + "role": "admin", # Changed to admin so they can access /clients/ } - response = client.post( - "/auth/users", - headers=admin_headers, - json=user_data - ) + response = client.post("/auth/users", headers=admin_headers, json=user_data) assert response.status_code == status.HTTP_200_OK # Get token for new user - response = client.post( - "/auth/token", - data={"username": "temporary", "password": "temppass123"} - ) + response = client.post("/auth/token", data={"username": "temporary", "password": "temppass123"}) token = response.json()["access_token"] - + # Try using the token headers = {"Authorization": f"Bearer {token}"} response = client.get("/clients/", headers=headers) diff --git a/tests/test_clients.py b/tests/test_clients.py index 611a5b34..e6bd1cf4 100644 --- a/tests/test_clients.py +++ b/tests/test_clients.py @@ -1,12 +1,13 @@ -import pytest from fastapi import status + # Test GET Operations def test_get_clients_unauthorized(client): """Test that unauthorized access is prevented""" response = client.get("/clients/") assert response.status_code == status.HTTP_401_UNAUTHORIZED + def test_get_clients_as_admin(client, admin_headers): """Test getting all clients as admin""" response = client.get("/clients/", headers=admin_headers) @@ -16,24 +17,24 @@ def test_get_clients_as_admin(client, admin_headers): assert "total" in data assert len(data["clients"]) > 0 + def test_get_client_by_id(client, admin_headers): """Test getting specific client""" # Test existing client response = client.get("/clients/1", headers=admin_headers) assert response.status_code == status.HTTP_200_OK assert response.json()["id"] == 1 - + # Test non-existent client response = client.get("/clients/999", headers=admin_headers) assert response.status_code == status.HTTP_404_NOT_FOUND + def test_get_clients_by_criteria(client, admin_headers): """Test searching clients by various criteria""" # Test single criterion response = client.get( - "/clients/search/by-criteria", - params={"age_min": 25}, - headers=admin_headers + "/clients/search/by-criteria", params={"age_min": 25}, headers=admin_headers ) assert response.status_code == status.HTTP_200_OK assert len(response.json()) > 0 @@ -41,12 +42,8 @@ def test_get_clients_by_criteria(client, admin_headers): # Test multiple criteria response = client.get( "/clients/search/by-criteria", - params={ - "age_min": 25, - "currently_employed": True, - "gender": 2 - }, - headers=admin_headers + params={"age_min": 25, "currently_employed": True, "gender": 2}, + headers=admin_headers, ) assert response.status_code == status.HTTP_200_OK @@ -54,23 +51,22 @@ def test_get_clients_by_criteria(client, admin_headers): response = client.get( "/clients/search/by-criteria", params={"age_min": 15}, # Below minimum age - headers=admin_headers + headers=admin_headers, ) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY # Changed from 400 + def test_get_clients_by_services(client, admin_headers): """Test getting clients by service status""" response = client.get( "/clients/search/by-services", - params={ - "employment_assistance": True, - "life_stabilization": True - }, - headers=admin_headers + params={"employment_assistance": True, "life_stabilization": True}, + headers=admin_headers, ) assert response.status_code == status.HTTP_200_OK assert len(response.json()) > 0 + def test_get_client_services(client, admin_headers): """Test getting services for a specific client""" response = client.get("/clients/1/services", headers=admin_headers) @@ -81,63 +77,54 @@ def test_get_client_services(client, admin_headers): assert "employment_assistance" in services[0] assert "success_rate" in services[0] + def test_get_clients_by_success_rate(client, admin_headers): """Test getting clients by success rate threshold""" response = client.get( - "/clients/search/success-rate", - params={"min_rate": 70}, - headers=admin_headers + "/clients/search/success-rate", params={"min_rate": 70}, headers=admin_headers ) assert response.status_code == status.HTTP_200_OK assert len(response.json()) > 0 + def test_get_clients_by_case_worker(client, admin_headers, case_worker_headers): """Test getting clients assigned to a case worker""" # Test as admin response = client.get("/clients/case-worker/2", headers=admin_headers) assert response.status_code == status.HTTP_200_OK - + # Test as case worker response = client.get("/clients/case-worker/2", headers=case_worker_headers) assert response.status_code == status.HTTP_200_OK + # Test UPDATE Operations def test_update_client(client, admin_headers): """Test updating client information""" - update_data = { - "age": 26, - "currently_employed": True, - "time_unemployed": 0 - } - response = client.put( - "/clients/1", - json=update_data, - headers=admin_headers - ) + update_data = {"age": 26, "currently_employed": True, "time_unemployed": 0} + response = client.put("/clients/1", json=update_data, headers=admin_headers) assert response.status_code == status.HTTP_200_OK updated_client = response.json() assert updated_client["age"] == 26 - assert updated_client["currently_employed"] == True + assert updated_client["currently_employed"] assert updated_client["time_unemployed"] == 0 + # Test Create Case Assignment def test_create_case_assignment(client, admin_headers): """Test creating new case assignment""" response = client.post( - "/clients/1/case-assignment", - params={"case_worker_id": 2}, - headers=admin_headers + "/clients/1/case-assignment", params={"case_worker_id": 2}, headers=admin_headers ) assert response.status_code == status.HTTP_200_OK # Test duplicate assignment response = client.post( - "/clients/1/case-assignment", - params={"case_worker_id": 2}, - headers=admin_headers + "/clients/1/case-assignment", params={"case_worker_id": 2}, headers=admin_headers ) assert response.status_code == status.HTTP_400_BAD_REQUEST + # Test DELETE Operation def test_delete_client(client, admin_headers): """Test deleting a client"""