diff --git a/.env b/.env new file mode 100644 index 00000000..f8f43722 --- /dev/null +++ b/.env @@ -0,0 +1,16 @@ +# Database Configuration +DATABASE_URL="sqlite:///./sql_app.db" +# For PostgreSQL: DATABASE_URL="postgresql://username:password@localhost:5432/dbname" + +# Security +SECRET_KEY="your-secret-key-here" +ALGORITHM="HS256" +ACCESS_TOKEN_EXPIRE_MINUTES=30 + +# Application Settings +DEBUG=1 +ENVIRONMENT="development" + +# Default Admin User +ADMIN_USERNAME="admin" +ADMIN_PASSWORD="admin123" \ No newline at end of file diff --git a/.github/workflows/cd.yml b/.github/workflows/cd.yml index b801c2d3..0fa53b22 100644 --- a/.github/workflows/cd.yml +++ b/.github/workflows/cd.yml @@ -2,9 +2,13 @@ name: CI/CD Pipeline on: push: - branches: [master, main] + branches: + - master + - main pull_request: - branches: [master, main] + branches: + - master + - main jobs: test: @@ -28,20 +32,29 @@ jobs: python -m pytest tests/ deploy: - needs: test # This ensures deploy only runs if tests pass runs-on: ubuntu-latest + needs: test steps: - - uses: actions/checkout@v4 - - - name: Build Docker image - run: docker build -t common-assessment-tool . - - - name: Run Docker container - run: | - docker run -d -p 8000:8000 common-assessment-tool - sleep 10 # Wait for container to start - - - name: Test Docker container - run: | - curl http://localhost:8000/docs + - name: Checkout Code + uses: actions/checkout@v3 + + - name: Set Up SSH Key + shell: bash + run: | + echo "${{ secrets.EC2_KEY }}" > ~/my-common-assessment-tool.pem + chmod 600 ~/my-common-assessment-tool.pem + + - name: Deploy Code + shell: bash + run: | + # Copy all project files to the EC2 instance + scp -o StrictHostKeyChecking=no -i ~/my-common-assessment-tool.pem -r ./* ${{ secrets.EC2_USER }}@${{ secrets.EC2_HOST }}:/home/${{ secrets.EC2_USER }}/CommonAssessmentTool/ + + # SSH into the EC2 instance and deploy with cleanup + ssh -o StrictHostKeyChecking=no -i ~/my-common-assessment-tool.pem ${{ secrets.EC2_USER }}@${{ secrets.EC2_HOST }} " + cd /home/${{ secrets.EC2_USER }}/CommonAssessmentTool && + sudo docker-compose down -v && + sudo docker system prune -f --volumes && + sudo docker-compose up --build -d + " \ No newline at end of file diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 30c81bdb..abfd1b88 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,41 +2,84 @@ name: Python CI Pipeline on: push: - branches: [master, main] + branches: [main, master] pull_request: branches: [master, main] + jobs: - test: + build: runs-on: ubuntu-latest # Use the latest Ubuntu runner steps: + # Step 1: Checkout the code from the repository - name: Checkout Code uses: actions/checkout@v4 # Checkout the repository + # Step 2: Set up Python environment - name: Set up Python uses: actions/setup-python@v5 # Set up Python environment with: python-version: "3.11" + # Step 3: Install project dependencies, linters, formatters, and testing tools - name: Install dependencies run: | python -m pip install --upgrade pip # Upgrade pip to the latest version pip install setuptools wheel pip install -r requirements.txt # Install dependencies from requirements.txt - pip install pylint pytest + pip install pylint pytest black + + # Step 4: Run code quality checks with pylint + - name: Lint with pylint + run: | + pylint app/ tests/ --exit-zero + # Step 5: Check code formatting with black + - name: Check code formatting with black + run: | + black --check app/ tests/ + + # Step 6: Run tests with pytest - name: Run Tests run: | - python -m pytest tests/ + pytest tests/ + + # Step 7: Lint Dockerfile syntax (optional) + - name: Lint Dockerfile syntax + uses: hadolint/hadolint-action@v3.1.0 + with: + dockerfile: ./Dockerfile + + # Step 8: Build Docker Image + - name: Build Docker Image + run: docker build -t case-management-api . + + # Step 9: Run Docker Container + - name: Run Docker Container + run: | + docker run -d -p 8000:8000 --name test-container case-management-api + sleep 5 + + # Step 10: Test API Endpoint + - name: Test API Endpoint + run: curl --fail http://localhost:8000/docs + + # Step 11: Cleanup Docker Container + - name: Cleanup Docker Container + run: | + docker stop test-container + docker rm test-container + # Step 12: Print Success Message - name: Print Success Message + if: success() # Only runs if previous steps are successful run: | echo "CI Pipeline completed successfully!" echo "========================" echo "✓ Code checked out" echo "✓ Python environment set up" echo "✓ Dependencies installed" + echo "✓ Linting and formatting completed" echo "✓ Tests executed" - echo "✓ Linting completed" echo "========================" diff --git a/.gitignore b/.gitignore index 371e45c1..a3672bff 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,7 @@ .venv venv/ env/ -.env + # IDE and System .idea diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 00000000..e06dc82a --- /dev/null +++ b/.pylintrc @@ -0,0 +1,634 @@ +[MAIN] + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + +# Clear in-memory caches upon conclusion of linting. Useful if running pylint +# in a server-like mode. +clear-cache-post-run=no + +# Load and enable all available extensions. Use --list-extensions to see a list +# all available extensions. +#enable-all-extensions= + +# In error mode, messages with a category besides ERROR or FATAL are +# suppressed, and no reports are done by default. Error mode is compatible with +# disabling specific errors. +#errors-only= + +# Always return a 0 (non-error) status code, even if lint errors are found. +# This is primarily useful in continuous integration scripts. +#exit-zero= + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. +extension-pkg-allow-list= + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. (This is an alternative name to extension-pkg-allow-list +# for backward compatibility.) +extension-pkg-whitelist= + +# Return non-zero exit code if any of these messages/categories are detected, +# even if score is above --fail-under value. Syntax same as enable. Messages +# specified are enabled, while categories only check already-enabled messages. +fail-on= + +# Specify a score threshold under which the program will exit with error. +fail-under=10 + +# Interpret the stdin as a python script, whose filename needs to be passed as +# the module_or_package argument. +#from-stdin= + +# Files or directories to be skipped. They should be base names, not paths. +ignore=CVS + +# Add files or directories matching the regular expressions patterns to the +# ignore-list. The regex matches against paths and can be in Posix or Windows +# format. Because '\\' represents the directory delimiter on Windows systems, +# it can't be used as an escape character. +ignore-paths= + +# Files or directories matching the regular expression patterns are skipped. +# The regex matches against base names, not paths. The default value ignores +# Emacs file locks +ignore-patterns=^\.# + +# List of module names for which member attributes should not be checked and +# will not be imported (useful for modules/projects where namespaces are +# manipulated during runtime and thus existing member attributes cannot be +# deduced by static analysis). It supports qualified module names, as well as +# Unix pattern matching. +ignored-modules= + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the +# number of processors available to use, and will cap the count on Windows to +# avoid hangs. +jobs=1 + +# Control the amount of potential inferred values when inferring a single +# object. This can help the performance when dealing with large functions or +# complex, nested conditions. +limit-inference-results=100 + +# List of plugins (as comma separated values of python module names) to load, +# usually to register additional checkers. +load-plugins= + +# Pickle collected data for later comparisons. +persistent=yes + +# Minimum Python version to use for version dependent checks. Will default to +# the version used to run pylint. +py-version=3.13 + +# Discover python modules and packages in the file system subtree. +recursive=no + +# Add paths to the list of the source roots. Supports globbing patterns. The +# source root is an absolute path or a path relative to the current working +# directory used to determine a package namespace for modules located under the +# source root. +source-roots= + +# When enabled, pylint would attempt to guess common misconfiguration and emit +# user-friendly hints instead of false-positive error messages. +suggestion-mode=yes + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + +# In verbose mode, extra non-checker-related info will be displayed. +#verbose= + + +[BASIC] + +# Naming style matching correct argument names. +argument-naming-style=snake_case + +# Regular expression matching correct argument names. Overrides argument- +# naming-style. If left empty, argument names will be checked with the set +# naming style. +#argument-rgx= + +# Naming style matching correct attribute names. +attr-naming-style=snake_case + +# Regular expression matching correct attribute names. Overrides attr-naming- +# style. If left empty, attribute names will be checked with the set naming +# style. +#attr-rgx= + +# Bad variable names which should always be refused, separated by a comma. +bad-names=foo, + bar, + baz, + toto, + tutu, + tata + +# Bad variable names regexes, separated by a comma. If names match any regex, +# they will always be refused +bad-names-rgxs= + +# Naming style matching correct class attribute names. +class-attribute-naming-style=any + +# Regular expression matching correct class attribute names. Overrides class- +# attribute-naming-style. If left empty, class attribute names will be checked +# with the set naming style. +#class-attribute-rgx= + +# Naming style matching correct class constant names. +class-const-naming-style=UPPER_CASE + +# Regular expression matching correct class constant names. Overrides class- +# const-naming-style. If left empty, class constant names will be checked with +# the set naming style. +#class-const-rgx= + +# Naming style matching correct class names. +class-naming-style=PascalCase + +# Regular expression matching correct class names. Overrides class-naming- +# style. If left empty, class names will be checked with the set naming style. +#class-rgx= + +# Naming style matching correct constant names. +const-naming-style=UPPER_CASE + +# Regular expression matching correct constant names. Overrides const-naming- +# style. If left empty, constant names will be checked with the set naming +# style. +#const-rgx= + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=-1 + +# Naming style matching correct function names. +function-naming-style=snake_case + +# Regular expression matching correct function names. Overrides function- +# naming-style. If left empty, function names will be checked with the set +# naming style. +#function-rgx= + +# Good variable names which should always be accepted, separated by a comma. +good-names=i, + j, + k, + ex, + Run, + _ + +# Good variable names regexes, separated by a comma. If names match any regex, +# they will always be accepted +good-names-rgxs= + +# Include a hint for the correct naming format with invalid-name. +include-naming-hint=no + +# Naming style matching correct inline iteration names. +inlinevar-naming-style=any + +# Regular expression matching correct inline iteration names. Overrides +# inlinevar-naming-style. If left empty, inline iteration names will be checked +# with the set naming style. +#inlinevar-rgx= + +# Naming style matching correct method names. +method-naming-style=snake_case + +# Regular expression matching correct method names. Overrides method-naming- +# style. If left empty, method names will be checked with the set naming style. +#method-rgx= + +# Naming style matching correct module names. +module-naming-style=snake_case + +# Regular expression matching correct module names. Overrides module-naming- +# style. If left empty, module names will be checked with the set naming style. +#module-rgx= + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=^_ + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +# These decorators are taken in consideration only for invalid-name. +property-classes=abc.abstractproperty + +# Regular expression matching correct type alias names. If left empty, type +# alias names will be checked with the set naming style. +#typealias-rgx= + +# Regular expression matching correct type variable names. If left empty, type +# variable names will be checked with the set naming style. +#typevar-rgx= + +# Naming style matching correct variable names. +variable-naming-style=snake_case + +# Regular expression matching correct variable names. Overrides variable- +# naming-style. If left empty, variable names will be checked with the set +# naming style. +#variable-rgx= + + +[CLASSES] + +# Warn about protected attribute access inside special methods +check-protected-access-in-special-methods=no + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp, + asyncSetUp, + __post_init__ + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict,_fields,_replace,_source,_make,os._exit + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs + + +[DESIGN] + +# List of regular expressions of class ancestor names to ignore when counting +# public methods (see R0903) +exclude-too-few-public-methods= + +# List of qualified class names to ignore when counting class parents (see +# R0901) +ignored-parents= + +# Maximum number of arguments for function / method. +max-args=5 + +# Maximum number of attributes for a class (see R0902). +max-attributes=7 + +# Maximum number of boolean expressions in an if statement (see R0916). +max-bool-expr=5 + +# Maximum number of branch for function / method body. +max-branches=12 + +# Maximum number of locals for function / method body. +max-locals=15 + +# Maximum number of parents for a class (see R0901). +max-parents=7 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + +# Maximum number of return / yield for function / method body. +max-returns=6 + +# Maximum number of statements in function / method body. +max-statements=50 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=2 + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when caught. +overgeneral-exceptions=builtins.BaseException,builtins.Exception + + +[FORMAT] + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^\s*(# )??$ + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + +# Maximum number of characters on a single line. +max-line-length=100 + +# Maximum number of lines in a module. +max-module-lines=1000 + +# Allow the body of a class to be on the same line as the declaration if body +# contains single statement. +single-line-class-stmt=no + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=no + + +[IMPORTS] + +# List of modules that can be imported at any level, not just the top level +# one. +allow-any-import-level= + +# Allow explicit reexports by alias from a package __init__. +allow-reexport-from-package=no + +# Allow wildcard imports from modules that define __all__. +allow-wildcard-with-all=no + +# Deprecated modules which should not be used, separated by a comma. +deprecated-modules= + +# Output a graph (.gv or any supported image format) of external dependencies +# to the given file (report RP0402 must not be disabled). +ext-import-graph= + +# Output a graph (.gv or any supported image format) of all (i.e. internal and +# external) dependencies to the given file (report RP0402 must not be +# disabled). +import-graph= + +# Output a graph (.gv or any supported image format) of internal dependencies +# to the given file (report RP0402 must not be disabled). +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant + +# Couples of modules and preferred modules, separated by a comma. +preferred-modules= + + +[LOGGING] + +# The type of string formatting that logging methods do. `old` means using % +# formatting, `new` is for `{}` formatting. +logging-format-style=old + +# Logging modules to check that the string format arguments are in logging +# function parameter format. +logging-modules=logging + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE, +# UNDEFINED. +confidence=HIGH, + CONTROL_FLOW, + INFERENCE, + INFERENCE_FAILURE, + UNDEFINED + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once). You can also use "--disable=all" to +# disable everything first and then re-enable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use "--disable=all --enable=classes +# --disable=W". +disable=raw-checker-failed, + bad-inline-option, + locally-disabled, + file-ignored, + suppressed-message, + useless-suppression, + deprecated-pragma, + use-symbolic-message-instead, + use-implicit-booleaness-not-comparison-to-string, + use-implicit-booleaness-not-comparison-to-zero + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +enable= + + +[METHOD_ARGS] + +# List of qualified names (i.e., library.method) which require a timeout +# parameter e.g. 'requests.api.get,requests.api.post' +timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME, + XXX, + TODO + +# Regular expression of note tags to take in consideration. +notes-rgx= + + +[REFACTORING] + +# Maximum number of nested blocks for function / method body +max-nested-blocks=5 + +# Complete name of functions that never returns. When checking for +# inconsistent-return-statements if a never returning function is called then +# it will be considered as an explicit return statement and no message will be +# printed. +never-returning-functions=sys.exit,argparse.parse_error + +[REPORTS] + +# Python expression which should return a score less than or equal to 10. You +# have access to the variables 'fatal', 'error', 'warning', 'refactor', +# 'convention', and 'info' which contain the number of messages in each +# category, as well as 'statement' which is the total number of statements +# analyzed. This score is used by the global evaluation report (RP0004). +evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details. +msg-template= + +# Set the output format. Available formats are: text, parseable, colorized, +# json2 (improved json format), json (old json format) and msvs (visual +# studio). You can also give a reporter class, e.g. +# mypackage.mymodule.MyReporterClass. +#output-format= + +# Tells whether to display a full report or only the messages. +reports=no + +# Activate the evaluation score. +score=yes + + +[SIMILARITIES] + +# Comments are removed from the similarity computation +ignore-comments=yes + +# Docstrings are removed from the similarity computation +ignore-docstrings=yes + +# Imports are removed from the similarity computation +ignore-imports=yes + +# Signatures are removed from the similarity computation +ignore-signatures=yes + +# Minimum lines number of a similarity. +min-similarity-lines=4 + + +[SPELLING] + +# Limits count of emitted suggestions for spelling mistakes. +max-spelling-suggestions=4 + +# Spelling dictionary name. No available dictionaries : You need to install +# both the python package and the system dependency for enchant to work. +spelling-dict= + +# List of comma separated words that should be considered directives if they +# appear at the beginning of a comment and should not be checked. +spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy: + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains the private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to the private dictionary (see the +# --spelling-private-dict-file option) instead of raising a message. +spelling-store-unknown-words=no + + +[STRING] + +# This flag controls whether inconsistent-quotes generates a warning when the +# character used as a quote delimiter is used inconsistently within a module. +check-quote-consistency=no + +# This flag controls whether the implicit-str-concat should generate a warning +# on implicit string concatenation in sequences defined over several lines. +check-str-concat-over-line-jumps=no + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + +# Tells whether to warn about missing members when the owner of the attribute +# is inferred to be None. +ignore-none=yes + +# This flag controls whether pylint should warn about no-member and similar +# checks whenever an opaque object is returned when inferring. The inference +# can return multiple potential results while evaluating a Python object, but +# some branches might not be evaluated, which results in partial inference. In +# that case, it might be useful to still emit no-member and other checks for +# the rest of the inferred objects. +ignore-on-opaque-inference=yes + +# List of symbolic message names to ignore for Mixin members. +ignored-checks-for-mixins=no-member, + not-async-context-manager, + not-context-manager, + attribute-defined-outside-init + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace + +# Show a hint with possible names when a member name was not found. The aspect +# of finding the hint is based on edit distance. +missing-member-hint=yes + +# The minimum edit distance a name should have in order to be considered a +# similar match for a missing member name. +missing-member-hint-distance=1 + +# The total number of similar names that should be taken in consideration when +# showing a hint for a missing member. +missing-member-max-choices=1 + +# Regex pattern to define which classes are considered mixins. +mixin-class-rgx=.*[Mm]ixin + +# List of decorators that change the signature of a decorated function. +signature-mutators= + + +[VARIABLES] + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid defining new builtins when possible. +additional-builtins= + +# Tells whether unused global variables should be treated as a violation. +allow-global-unused-variables=yes + +# List of names allowed to shadow builtins +allowed-redefined-builtins= + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_, + _cb + +# A regular expression matching the name of dummy variables (i.e. expected to +# not be used). +dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ + +# Argument names that match this expression will be ignored. +ignored-argument-names=_.*|^ignored_|^unused_ + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io diff --git a/.python-version b/.python-version new file mode 100644 index 00000000..b6d8b761 --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.11.8 diff --git a/README.md b/README.md index b34d6d6b..2132fbb7 100644 --- a/README.md +++ b/README.md @@ -55,3 +55,23 @@ This also has an API file to interact with the front end, and logic in order to -Create case assignment (Allow authorized users to create a new case assignment.) +-------------------------How to Run with Doker------------------------- +- Option 1: Using Docker +1. Build the Docker image: docker build -t common-assessment-tool . + +2. Run your container in detached mode: docker run -d -p 8000:8000 --name assessment-tool common-assessment-tool + +3. Then run the initialization script: docker exec -it assessment-tool python initialize_data.py + +4. Access the application at http://localhost:8000/docs + +5. Log in as admin (username: admin password: admin123) + +- Option 2: Using Docker Compose +1. Start the application in background mode: docker-compose up -d + +2. Run the initialization script: docker-compose exec backend python initialize_data.py + +3. Access the application at http://localhost:8000/docs + +4. Log in as admin (username: admin password: admin123) \ No newline at end of file diff --git a/app/auth/dependencies.py b/app/auth/dependencies.py new file mode 100644 index 00000000..0e941072 --- /dev/null +++ b/app/auth/dependencies.py @@ -0,0 +1,85 @@ +# app/auth/dependencies.py +""" +Authentication dependencies for FastAPI dependency injection. +Provides injectable dependencies for current user, admin user, etc. +""" +from fastapi import Depends, HTTPException, status +from fastapi.security import OAuth2PasswordBearer +from sqlalchemy.orm import Session + +from app.database import get_db +from app.models import User +from app.auth.repository import SQLAlchemyUserRepository +from app.auth.service import AuthorizationService +from app.auth.security import TokenService + +# OAuth2 scheme for token extraction +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/token") + + +def get_user_repository(db: Session = Depends(get_db)): + """ + Get the user repository + + Args: + db: The database session + + Returns: + SQLAlchemyUserRepository: The user repository + """ + return SQLAlchemyUserRepository(db) + + +def get_authorization_service(repository=Depends(get_user_repository)): + """ + Get the authorization service + + Args: + repository: The user repository + + Returns: + AuthorizationService: The authorization service + """ + return AuthorizationService(repository) + + +async def get_current_user( + token: str = Depends(oauth2_scheme), + auth_service: AuthorizationService = Depends(get_authorization_service), +) -> User: + """ + Get the current user from the token + + Args: + token: The JWT token + auth_service: The authorization service + + Returns: + User: The current user + + Raises: + HTTPException: If token validation fails + """ + token_data = TokenService.decode_token(token) + return auth_service.get_current_user(token_data) + + +async def get_admin_user( + current_user: User = Depends(get_current_user), + auth_service: AuthorizationService = Depends(get_authorization_service), +) -> User: + """ + Ensure the current user is an admin + + Args: + current_user: The current user + auth_service: The authorization service + + Returns: + User: The current admin user + + Raises: + HTTPException: If user is not an admin + """ + auth_service.check_admin_role(current_user) + return current_user diff --git a/app/auth/repository.py b/app/auth/repository.py new file mode 100644 index 00000000..dc04ca49 --- /dev/null +++ b/app/auth/repository.py @@ -0,0 +1,85 @@ +""" +Repository for user data access operations. +Implements the repository pattern for user-related database operations. +""" + +from typing import Optional, Protocol, List +from sqlalchemy.orm import Session +from app.models import User, UserRole + + +class UserRepositoryProtocol(Protocol): + """Protocol defining the interface for user repositories""" + + def get_by_username(self, username: str) -> Optional[User]: ... + def get_by_email(self, email: str) -> Optional[User]: ... + def create(self, username: str, email: str, hashed_password: str, role: UserRole) -> User: ... + def get_all(self) -> List[User]: ... + + +class SQLAlchemyUserRepository: + """SQLAlchemy implementation of the user repository""" + + def __init__(self, db: Session): + self.db = db + + def get_by_username(self, username: str) -> Optional[User]: + """ + Get a user by username + + Args: + username: The username to search for + + Returns: + Optional[User]: The user if found, None otherwise + """ + return self.db.query(User).filter(User.username == username).first() + + def get_by_email(self, email: str) -> Optional[User]: + """ + Get a user by email + + Args: + email: The email to search for + + Returns: + Optional[User]: The user if found, None otherwise + """ + return self.db.query(User).filter(User.email == email).first() + + def create(self, username: str, email: str, hashed_password: str, role: UserRole) -> User: + """ + Create a new user + + Args: + username: The username + email: The email + hashed_password: The hashed password + role: The user role + + Returns: + User: The created user + + Raises: + Exception: If user creation fails + """ + + db_user = User(username=username, email=email, hashed_password=hashed_password, role=role) + + try: + self.db.add(db_user) + self.db.commit() + self.db.refresh(db_user) + return db_user + except Exception as e: + self.db.rollback() + raise e + + def get_all(self) -> List[User]: + """ + Get all users + + Returns: + List[User]: All users in the database + """ + return self.db.query(User).all() diff --git a/app/auth/router.py b/app/auth/router.py index 229ee71d..c25aff2b 100644 --- a/app/auth/router.py +++ b/app/auth/router.py @@ -1,151 +1,68 @@ -from datetime import datetime, timedelta -from typing import Optional +# app/auth/router.py +""" +Router for authentication endpoints. +Handles login, user creation, and other auth-related routes. +""" from fastapi import APIRouter, Depends, HTTPException, status -from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm -from jose import JWTError, jwt -from sqlalchemy.orm import Session -from app.database import get_db -from app.models import User, UserRole -from passlib.context import CryptContext -from pydantic import BaseModel, Field, validator +from fastapi.security import OAuth2PasswordRequestForm -router = APIRouter(prefix="/auth", tags=["authentication"]) - -class UserCreate(BaseModel): - username: str = Field(..., min_length=3, max_length=50) - email: str - password: str - role: UserRole - - @validator('role') - def validate_role(cls, v): - if v not in [UserRole.admin, UserRole.case_worker]: - raise ValueError('Role must be either admin or case_worker') - return v - -class UserResponse(BaseModel): - username: str - email: str - role: UserRole - - class Config: - from_attributes = True - -# Configuration -SECRET_KEY = "your-secret-key-here" -ALGORITHM = "HS256" -ACCESS_TOKEN_EXPIRE_MINUTES = 30 - -pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") -oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/token") +from app.auth.service import AuthenticationService, UserCreate, UserResponse +from app.auth.dependencies import get_user_repository, get_admin_user +from app.models import User -def verify_password(plain_password: str, hashed_password: str) -> bool: - return pwd_context.verify(plain_password, hashed_password) - -def get_password_hash(password: str) -> str: - return pwd_context.hash(password) - -def authenticate_user(db: Session, username: str, password: str) -> Optional[User]: - user = db.query(User).filter(User.username == username).first() - if not user or not verify_password(password, user.hashed_password): - return None - return user - -def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): - to_encode = data.copy() - if expires_delta: - expire = datetime.utcnow() + expires_delta - else: - expire = datetime.utcnow() + timedelta(minutes=15) - to_encode.update({"exp": expire}) - encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) - return encoded_jwt - -async def get_current_user( - token: str = Depends(oauth2_scheme), - db: Session = Depends(get_db) -) -> User: - credentials_exception = HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Could not validate credentials", - headers={"WWW-Authenticate": "Bearer"}, - ) - try: - payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) - username: str = payload.get("sub") - if username is None: - raise credentials_exception - except JWTError: - raise credentials_exception - - user = db.query(User).filter(User.username == username).first() - if user is None: - raise credentials_exception - return user +router = APIRouter(prefix="/auth", tags=["authentication"]) -def get_admin_user(current_user: User = Depends(get_current_user)): - if current_user.role != UserRole.admin: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Only admin users can perform this operation" - ) - return current_user @router.post("/token") async def login_for_access_token( form_data: OAuth2PasswordRequestForm = Depends(), - db: Session = Depends(get_db) + auth_service: AuthenticationService = Depends( + lambda repo=Depends(get_user_repository): AuthenticationService(repo) + ), ): - user = authenticate_user(db, form_data.username, form_data.password) + """ + Login endpoint to get access token + + Args: + form_data: The login form data + auth_service: The authentication service + + Returns: + dict: Access token response + + Raises: + HTTPException: If authentication fails + """ + user = auth_service.authenticate_user(form_data.username, form_data.password) if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect username or password", headers={"WWW-Authenticate": "Bearer"}, ) - access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) - access_token = create_access_token( - data={"sub": user.username}, expires_delta=access_token_expires - ) - return {"access_token": access_token, "token_type": "bearer"} + return auth_service.create_access_token(user.username) + @router.post("/users", response_model=UserResponse) async def create_user( user_data: UserCreate, current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) + auth_service: AuthenticationService = Depends( + lambda repo=Depends(get_user_repository): AuthenticationService(repo) + ), ): - """Create a new user (admin only)""" - # Check if username exists - if db.query(User).filter(User.username == user_data.username).first(): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Username already registered" - ) - - # Check if email exists - if db.query(User).filter(User.email == user_data.email).first(): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Email already registered" - ) + """ + Create a new user (admin only) - # Create new user - db_user = User( - username=user_data.username, - email=user_data.email, - hashed_password=get_password_hash(user_data.password), - role=user_data.role - ) - - try: - db.add(db_user) - db.commit() - db.refresh(db_user) - return db_user - except Exception as e: - db.rollback() - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=str(e) - ) + Args: + user_data: The user data + current_user: The current admin user + auth_service: The authentication service + + Returns: + UserResponse: The created user + + Raises: + HTTPException: If user creation fails + """ + return auth_service.create_user(user_data) diff --git a/app/auth/security.py b/app/auth/security.py new file mode 100644 index 00000000..38dcc038 --- /dev/null +++ b/app/auth/security.py @@ -0,0 +1,108 @@ +""" +Security utilities for authentication handling. +Implements password hashing, verification, and JWT token creation/validation. +""" + +from datetime import datetime, timedelta +from typing import Optional +from fastapi import HTTPException, status +from jose import JWTError, jwt +from passlib.context import CryptContext +from pydantic import BaseModel + +# Configuration +SECRET_KEY = "your-secret-key-here" +ALGORITHM = "HS256" +ACCESS_TOKEN_EXPIRE_MINUTES = 30 + +# Password context for hashing +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + + +# Token validation and data extraction +class TokenData(BaseModel): + username: str + + +class TokenService: + """Service for JWT token operations""" + + @staticmethod + def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: + """ + Create a new JWT access token + + Args: + data: The data to encode in the token + expires_delta: Optional expiration time delta + + Returns: + str: The encoded JWT token + """ + to_encode = data.copy() + if expires_delta: + expire = datetime.utcnow() + expires_delta + else: + expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + to_encode.update({"exp": expire}) + encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) + return encoded_jwt + + @staticmethod + def decode_token(token: str) -> TokenData: + """ + Decode and validate a JWT token + + Args: + token: The JWT token to decode + + Returns: + TokenData: The decoded token data + + Raises: + HTTPException: If token validation fails + """ + credentials_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + username: str = payload.get("sub") + if username is None: + raise credentials_exception + return TokenData(username=username) + except JWTError: + raise credentials_exception + + +class PasswordService: + """Service for password operations""" + + @staticmethod + def verify_password(plain_password: str, hashed_password: str) -> bool: + """ + Verify a password against its hash + + Args: + plain_password: The plain text password + hashed_password: The hashed password to compare against + + Returns: + bool: True if password matches, False otherwise + """ + return pwd_context.verify(plain_password, hashed_password) + + @staticmethod + def get_password_hash(password: str) -> str: + """ + Hash a password + + Args: + password: The password to hash + + Returns: + str: The hashed password + """ + return pwd_context.hash(password) diff --git a/app/auth/service.py b/app/auth/service.py new file mode 100644 index 00000000..b272f1fb --- /dev/null +++ b/app/auth/service.py @@ -0,0 +1,158 @@ +# app/auth/service.py +""" +Authentication service for user authentication and authorization. +Handles user authentication, creation, and token management. +""" +from datetime import timedelta +from typing import Optional, Tuple, Dict, Any + +from fastapi import HTTPException, status +from pydantic import BaseModel, Field, validator + +from app.models import User, UserRole +from app.auth.security import PasswordService, TokenService +from app.auth.repository import UserRepositoryProtocol + + +class UserCreate(BaseModel): + username: str = Field(..., min_length=3, max_length=50) + email: str + password: str + role: UserRole + + @validator("role") + def validate_role(cls, v): + if v not in [UserRole.admin, UserRole.case_worker]: + raise ValueError("Role must be either admin or case_worker") + return v + + +class UserResponse(BaseModel): + username: str + email: str + role: UserRole + + class Config: + from_attributes = True + + +class AuthenticationService: + """Service for user authentication and authorization""" + + def __init__(self, user_repository: UserRepositoryProtocol): + self.user_repository = user_repository + + def authenticate_user(self, username: str, password: str) -> Optional[User]: + """ + Authenticate a user with username and password + + Args: + username: The username + password: The plain text password + + Returns: + Optional[User]: The authenticated user if successful, None otherwise + """ + user = self.user_repository.get_by_username(username) + if not user or not PasswordService.verify_password(password, user.hashed_password): + return None + return user + + def create_user(self, user_data: UserCreate) -> User: + """ + Create a new user + + Args: + user_data: The user data + + Returns: + User: The created user + + Raises: + HTTPException: If username or email already exists + """ + # Check if username exists + if self.user_repository.get_by_username(user_data.username): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail="Username already registered" + ) + + # Check if email exists + if self.user_repository.get_by_email(user_data.email): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered" + ) + + # Create new user + hashed_password = PasswordService.get_password_hash(user_data.password) + + try: + return self.user_repository.create( + username=user_data.username, + email=user_data.email, + hashed_password=hashed_password, + role=user_data.role, + ) + except Exception as e: + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) + + def create_access_token(self, username: str) -> Dict[str, Any]: + """ + Create access token for a user + + Args: + username: The username + + Returns: + Dict[str, Any]: Access token response + """ + access_token_expires = timedelta(minutes=30) # Could be configurable + access_token = TokenService.create_access_token( + data={"sub": username}, expires_delta=access_token_expires + ) + return {"access_token": access_token, "token_type": "bearer"} + + +class AuthorizationService: + """Service for user authorization""" + + def __init__(self, user_repository: UserRepositoryProtocol): + self.user_repository = user_repository + + def get_current_user(self, token_data: str) -> User: + """ + Get the current user from a token + + Args: + token_data: The token data with username + + Returns: + User: The current user + + Raises: + HTTPException: If user not found + """ + user = self.user_repository.get_by_username(token_data.username) + if user is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + return user + + def check_admin_role(self, user: User) -> None: + """ + Check if user has admin role + + Args: + user: The user to check + + Raises: + HTTPException: If user is not an admin + """ + if user.role != UserRole.admin: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Only admin users can perform this operation", + ) diff --git a/app/clients/dependencies.py b/app/clients/dependencies.py new file mode 100644 index 00000000..aaef6cd7 --- /dev/null +++ b/app/clients/dependencies.py @@ -0,0 +1,54 @@ +""" +Client dependencies for FastAPI dependency injection. +Provides injectable dependencies for repositories and services. +""" + +from fastapi import Depends +from sqlalchemy.orm import Session + +from app.database import get_db +from app.clients.repository import SQLAlchemyClientRepository, SQLAlchemyClientCaseRepository +from app.clients.service.client_service import ClientService + + +def get_client_repository(db: Session = Depends(get_db)): + """ + Get client repository + + Args: + db: Database session + + Returns: + SQLAlchemyClientRepository: Client repository + """ + return SQLAlchemyClientRepository(db) + + +def get_client_case_repository(db: Session = Depends(get_db)): + """ + Get client case repository + + Args: + db: Database session + + Returns: + SQLAlchemyClientCaseRepository: Client case repository + """ + return SQLAlchemyClientCaseRepository(db) + + +def get_client_service( + client_repo: SQLAlchemyClientRepository = Depends(get_client_repository), + client_case_repo: SQLAlchemyClientCaseRepository = Depends(get_client_case_repository), +): + """ + Get client service + + Args: + client_repo: Client repository + client_case_repo: Client case repository + + Returns: + ClientService: Client service + """ + return ClientService(client_repo, client_case_repo) diff --git a/app/clients/repository.py b/app/clients/repository.py new file mode 100644 index 00000000..db7e9ad0 --- /dev/null +++ b/app/clients/repository.py @@ -0,0 +1,396 @@ +""" +Repository for client data access operations. +Implements the repository pattern for client-related database operations. +Single Responsibility Principle (SRP): Create a separate repository layer for database operations, leaving higher-level business logic in the service class. +""" + +from typing import Optional, Protocol, List, Dict, Any, Tuple +from sqlalchemy.orm import Session +from sqlalchemy import and_ +from fastapi import HTTPException, status + +from app.models import Client, ClientCase, User + + +class ClientRepositoryProtocol(Protocol): + """Protocol defining the interface for client repositories""" + + def get_by_id(self, client_id: int) -> Optional[Client]: ... + def get_all(self, skip: int, limit: int) -> Tuple[List[Client], int]: ... + def filter_by_criteria(self, **criteria) -> List[Client]: ... + def filter_by_services(self, **service_filters) -> List[Client]: ... + def get_clients_by_success_rate(self, min_rate: int) -> List[Client]: ... + def get_clients_by_case_worker(self, case_worker_id: int) -> List[Client]: ... + def update(self, client_id: int, update_data: Dict[str, Any]) -> Client: ... + def delete(self, client_id: int) -> None: ... + + +class ClientCaseRepositoryProtocol(Protocol): + """Protocol defining the interface for client case repositories""" + + def get_by_client(self, client_id: int) -> List[ClientCase]: ... + def get_by_client_and_user(self, client_id: int, user_id: int) -> Optional[ClientCase]: ... + def create(self, client_id: int, user_id: int) -> ClientCase: ... + def update(self, client_id: int, user_id: int, update_data: Dict[str, Any]) -> ClientCase: ... + + +class SQLAlchemyClientRepository: + """SQLAlchemy implementation of the client repository""" + + def __init__(self, db: Session): + self.db = db + + def get_by_id(self, client_id: int) -> Optional[Client]: + """ + Get a client by ID + + Args: + client_id: The client ID + + Returns: + Optional[Client]: The client if found, None otherwise + """ + client = self.db.query(Client).filter(Client.id == client_id).first() + if not client: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Client with id {client_id} not found", + ) + return client + + def get_all(self, skip: int, limit: int) -> Tuple[List[Client], int]: + """ + Get all clients with pagination + + Args: + skip: Number of records to skip + limit: Maximum number of records to return + + Returns: + Tuple[List[Client], int]: List of clients and total count + """ + if skip < 0: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail="Skip value cannot be negative" + ) + if limit < 1: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail="Limit must be greater than 0" + ) + + clients = self.db.query(Client).offset(skip).limit(limit).all() + total = self.db.query(Client).count() + return clients, total + + def filter_by_criteria(self, **criteria) -> List[Client]: + """ + Filter clients by criteria + + Args: + **criteria: Filter criteria as keyword arguments + + Returns: + List[Client]: Filtered clients + """ + query = self.db.query(Client) + + range_fields = { + "age_min": ("age", ">="), + "age_max": ("age", "<="), + "time_unemployed": ("time_unemployed", "=="), + } + + for field, value in criteria.items(): + if value is None: + continue + + if field in range_fields: + real_field, op = range_fields[field] + column = getattr(Client, real_field) + if op == ">=": + query = query.filter(column >= value) + elif op == "<=": + query = query.filter(column <= value) + elif op == "==": + query = query.filter(column == value) + + elif hasattr(Client, field): + column = getattr(Client, field) + query = query.filter(column == value) + + else: + print(f"avoid unknown: {field}") + + try: + return query.all() + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error retrieving clients: {str(e)}", + ) + + def filter_by_services(self, **service_filters) -> List[Client]: + """ + Filter clients by service statuses + + Args: + **service_filters: Service filters as keyword arguments + + Returns: + List[Client]: Filtered clients + """ + query = self.db.query(Client).join(ClientCase) + + for service_name, status in service_filters.items(): + if status is not None: + filter_criteria = getattr(ClientCase, service_name) == status + query = query.filter(filter_criteria) + + try: + return query.all() + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error retrieving clients: {str(e)}", + ) + + def get_clients_by_success_rate(self, min_rate: int) -> List[Client]: + """ + Get clients with success rate at or above the specified percentage + + Args: + min_rate: Minimum success rate percentage + + Returns: + List[Client]: Filtered clients + """ + if not (0 <= min_rate <= 100): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Success rate must be between 0 and 100", + ) + + return ( + self.db.query(Client).join(ClientCase).filter(ClientCase.success_rate >= min_rate).all() + ) + + def get_clients_by_case_worker(self, case_worker_id: int) -> List[Client]: + """ + Get all clients assigned to a specific case worker + + Args: + case_worker_id: The case worker ID + + Returns: + List[Client]: Filtered clients + """ + case_worker = self.db.query(User).filter(User.id == case_worker_id).first() + if not case_worker: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Case worker with id {case_worker_id} not found", + ) + + return ( + self.db.query(Client) + .join(ClientCase) + .filter(ClientCase.user_id == case_worker_id) + .all() + ) + + def update(self, client_id: int, update_data: Dict[str, Any]) -> Client: + """ + Update a client + + Args: + client_id: The client ID + update_data: The update data + + Returns: + Client: The updated client + """ + client = self.db.query(Client).filter(Client.id == client_id).first() + if not client: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Client with id {client_id} not found", + ) + + for field, value in update_data.items(): + setattr(client, field, value) + + try: + self.db.commit() + self.db.refresh(client) + return client + except Exception as e: + self.db.rollback() + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to update client: {str(e)}", + ) + + def delete(self, client_id: int) -> None: + """ + Delete a client + + Args: + client_id: The client ID + """ + client = self.db.query(Client).filter(Client.id == client_id).first() + if not client: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Client with id {client_id} not found", + ) + + try: + # Delete associated client_cases + self.db.query(ClientCase).filter(ClientCase.client_id == client_id).delete() + + # Delete the client + self.db.delete(client) + self.db.commit() + except Exception as e: + self.db.rollback() + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to delete client: {str(e)}", + ) + + +class SQLAlchemyClientCaseRepository: + """SQLAlchemy implementation of the client case repository""" + + def __init__(self, db: Session): + self.db = db + + def get_by_client(self, client_id: int) -> List[ClientCase]: + """ + Get all cases for a client + + Args: + client_id: The client ID + + Returns: + List[ClientCase]: The client cases + """ + client_cases = self.db.query(ClientCase).filter(ClientCase.client_id == client_id).all() + if not client_cases: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"No services found for client with id {client_id}", + ) + return client_cases + + def get_by_client_and_user(self, client_id: int, user_id: int) -> Optional[ClientCase]: + """ + Get a case by client and user + + Args: + client_id: The client ID + user_id: The user ID + + Returns: + Optional[ClientCase]: The client case if found, None otherwise + """ + return ( + self.db.query(ClientCase) + .filter(ClientCase.client_id == client_id, ClientCase.user_id == user_id) + .first() + ) + + def create(self, client_id: int, user_id: int) -> ClientCase: + """ + Create a new case assignment + + Args: + client_id: The client ID + user_id: The user ID + + Returns: + ClientCase: The created client case + """ + # Check if client exists + client = self.db.query(Client).filter(Client.id == client_id).first() + if not client: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Client with id {client_id} not found", + ) + + # Check if case worker exists + case_worker = self.db.query(User).filter(User.id == user_id).first() + if not case_worker: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Case worker with id {user_id} not found", + ) + + # Check if assignment already exists + existing_case = self.get_by_client_and_user(client_id, user_id) + if existing_case: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Client {client_id} already has a case assigned to case worker {user_id}", + ) + + try: + # Create new case assignment with default service values + new_case = ClientCase( + client_id=client_id, + user_id=user_id, + employment_assistance=False, + life_stabilization=False, + retention_services=False, + specialized_services=False, + employment_related_financial_supports=False, + employer_financial_supports=False, + enhanced_referrals=False, + success_rate=0, + ) + self.db.add(new_case) + self.db.commit() + self.db.refresh(new_case) + return new_case + + except Exception as e: + self.db.rollback() + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to create case assignment: {str(e)}", + ) + + def update(self, client_id: int, user_id: int, update_data: Dict[str, Any]) -> ClientCase: + """ + Update a client case + + Args: + client_id: The client ID + user_id: The user ID + update_data: The update data + + Returns: + ClientCase: The updated client case + """ + client_case = self.get_by_client_and_user(client_id, user_id) + if not client_case: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"No case found for client {client_id} with case worker {user_id}. " + f"Cannot update services for a non-existent case assignment.", + ) + + for field, value in update_data.items(): + setattr(client_case, field, value) + + try: + self.db.commit() + self.db.refresh(client_case) + return client_case + except Exception as e: + self.db.rollback() + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to update client services: {str(e)}", + ) diff --git a/app/clients/router.py b/app/clients/router.py index 4ecc83e4..28bcbae6 100644 --- a/app/clients/router.py +++ b/app/clients/router.py @@ -1,43 +1,114 @@ """ -Router module for client-related endpoints. -Handles all HTTP requests for client operations including create, read, update, and delete. +Router for client endpoints. +Handles HTTP requests for client-related operations. """ -from fastapi import APIRouter, Depends, HTTPException, status, Query -from sqlalchemy.orm import Session +from fastapi import HTTPException +from fastapi import APIRouter, Depends, status, Query from typing import List, Optional -from app.auth.router import get_current_user, get_admin_user +from app.auth.dependencies import get_current_user, get_admin_user from app.models import User, UserRole -from app.database import get_db -from app.clients.service.client_service import ClientService +from app.auth.dependencies import get_current_user, get_admin_user +from app.models import User +from app.clients.dependencies import get_client_service from app.clients.schema import ( - ClientResponse, - ClientUpdate, + ClientResponse, + ClientUpdate, ClientListResponse, ServiceResponse, - ServiceUpdate + ServiceUpdate, + PredictionInput, ) +from app.clients.service.logic import interpret_and_calculate + +# Add the Code to see the Prediction API and Test It per Piazza Post +from app.clients.service.logic import interpret_and_calculate +from app.clients.schema import PredictionInput + +# Add the Code to see the Prediction API and Test It per Piazza Post +from app.clients.service.logic import interpret_and_calculate +from app.clients.schema import PredictionInput + +router = APIRouter(tags=["clients"]) + + +@router.post("/predictions") +async def predict(data: PredictionInput): + print("HERE") + print(data.model_dump()) + return interpret_and_calculate(data.model_dump()) + + +# Import the model_manager functions for switching models, getting the current model, and listing all available models +from app.clients.service.model_manager import list_models, get_current_model_name, switch_model + +model_router = APIRouter(prefix="/models", tags=["models"]) + + +# API Endpoint for listing all available models +@model_router.get("/", response_model=list) +def get_available_models(): + return list_models() + + +# API Endpoint for getting the name of the currently active model +@model_router.get("/current", response_model=str) +def get_active_model(): + return get_current_model_name() + + +# API Endpoint for switching to a different model by name +@model_router.post("/switch/{model_name}") +def change_model(model_name: str): + try: + switch_model(model_name) + return {"message": f"Switched to model: {model_name}"} + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) -router = APIRouter(prefix="/clients", tags=["clients"]) @router.get("/", response_model=ClientListResponse) async def get_clients( - current_user: User = Depends(get_admin_user), + client_service=Depends(get_client_service), + current_user: User = Depends(get_admin_user), skip: int = Query(default=0, ge=0, description="Number of records to skip"), limit: int = Query(default=50, ge=1, le=150, description="Maximum number of records to return"), - db: Session = Depends(get_db) ): - return ClientService.get_clients(db, skip, limit) + """ + Get all clients with pagination (admin only) + + Args: + client_service: Client service + current_user: Current admin user + skip: Number of records to skip + limit: Maximum number of records + + Returns: + ClientListResponse: Clients and total count + """ + return client_service.get_clients(skip, limit) + @router.get("/{client_id}", response_model=ClientResponse) async def get_client( client_id: int, + client_service=Depends(get_client_service), current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) ): - """Get a specific client by ID""" - return ClientService.get_client(db, client_id) + """ + Get a specific client by ID (admin only) + + Args: + client_id: The client ID + client_service: Client service + current_user: Current admin user + + Returns: + ClientResponse: The client + """ + return client_service.get_client(client_id) + @router.get("/search/by-criteria", response_model=List[ClientResponse]) async def get_clients_by_criteria( @@ -65,12 +136,21 @@ async def get_clients_by_criteria( substance_use: Optional[bool] = None, time_unemployed: Optional[int] = Query(None, ge=0), need_mental_health_support_bool: Optional[bool] = None, + client_service=Depends(get_client_service), current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) ): - """Search clients by any combination of criteria""" - return ClientService.get_clients_by_criteria( - db, + """ + Search clients by criteria (admin only) + + Args: + Multiple filter criteria as query parameters + client_service: Client service + current_user: Current admin user + + Returns: + List[ClientResponse]: Filtered clients + """ + return client_service.get_clients_by_criteria( employment_status=employment_status, education_level=education_level, age_min=age_min, @@ -94,9 +174,10 @@ async def get_clients_by_criteria( attending_school=attending_school, substance_use=substance_use, time_unemployed=time_unemployed, - need_mental_health_support_bool=need_mental_health_support_bool + need_mental_health_support_bool=need_mental_health_support_bool, ) + @router.get("/search/by-services", response_model=List[ClientResponse]) async def get_clients_by_services( employment_assistance: Optional[bool] = None, @@ -106,83 +187,189 @@ async def get_clients_by_services( employment_related_financial_supports: Optional[bool] = None, employer_financial_supports: Optional[bool] = None, enhanced_referrals: Optional[bool] = None, + client_service=Depends(get_client_service), current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) ): - """Get clients filtered by multiple service statuses""" - return ClientService.get_clients_by_services( - db, + """ + Get clients filtered by service statuses (admin only) + + Args: + Multiple service filters as query parameters + client_service: Client service + current_user: Current admin user + + Returns: + List[ClientResponse]: Filtered clients + """ + return client_service.get_clients_by_services( employment_assistance=employment_assistance, life_stabilization=life_stabilization, retention_services=retention_services, specialized_services=specialized_services, employment_related_financial_supports=employment_related_financial_supports, employer_financial_supports=employer_financial_supports, - enhanced_referrals=enhanced_referrals + enhanced_referrals=enhanced_referrals, ) + @router.get("/{client_id}/services", response_model=List[ServiceResponse]) async def get_client_services( client_id: int, + client_service=Depends(get_client_service), current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) ): - """Get all services and their status for a specific client, including case worker info""" - return ClientService.get_client_services(db, client_id) + """ + Get all services for a specific client (admin only) + + Args: + client_id: The client ID + client_service: Client service + current_user: Current admin user + + Returns: + List[ServiceResponse]: Client services + """ + return client_service.get_client_services(client_id) + @router.get("/search/success-rate", response_model=List[ClientResponse]) async def get_clients_by_success_rate( min_rate: int = Query(70, ge=0, le=100, description="Minimum success rate percentage"), + client_service=Depends(get_client_service), current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) ): - """Get clients with success rate above specified threshold""" - return ClientService.get_clients_by_success_rate(db, min_rate) + """ + Get clients with success rate above threshold (admin only) + + Args: + min_rate: Minimum success rate + client_service: Client service + current_user: Current admin user + + Returns: + List[ClientResponse]: Filtered clients + """ + return client_service.get_clients_by_success_rate(min_rate) + @router.get("/case-worker/{case_worker_id}", response_model=List[ClientResponse]) async def get_clients_by_case_worker( case_worker_id: int, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + client_service=Depends(get_client_service), + current_user: User = Depends(get_current_user), ): - return ClientService.get_clients_by_case_worker(db, case_worker_id) + """ + Get clients by case worker + + Args: + case_worker_id: The case worker ID + client_service: Client service + current_user: Current user + + Returns: + List[ClientResponse]: Filtered clients + """ + return client_service.get_clients_by_case_worker(case_worker_id) + @router.put("/{client_id}", response_model=ClientResponse) async def update_client( client_id: int, client_data: ClientUpdate, + client_service=Depends(get_client_service), current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) ): - """Update a client's information""" - return ClientService.update_client(db, client_id, client_data) + """ + Update a client (admin only) + + Args: + client_id: The client ID + client_data: Client update data + client_service: Client service + current_user: Current admin user + + Returns: + ClientResponse: Updated client + """ + return client_service.update_client(client_id, client_data) + @router.put("/{client_id}/services/{user_id}", response_model=ServiceResponse) async def update_client_services( client_id: int, user_id: int, service_update: ServiceUpdate, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + client_service=Depends(get_client_service), + current_user: User = Depends(get_current_user), ): - return ClientService.update_client_services(db, client_id, user_id, service_update) + """ + Update client services + + Args: + client_id: The client ID + user_id: The user ID + service_update: Service update data + client_service: Client service + current_user: Current user + + Returns: + ServiceResponse: Updated service + """ + return client_service.update_client_services(client_id, user_id, service_update) + @router.post("/{client_id}/case-assignment", response_model=ServiceResponse) async def create_case_assignment( client_id: int, case_worker_id: int = Query(..., description="Case worker ID to assign"), + client_service=Depends(get_client_service), current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) ): - """Create a new case assignment for a client with a case worker""" - return ClientService.create_case_assignment(db, client_id, case_worker_id) + """ + Create a new case assignment (admin only) + + Args: + client_id: The client ID + case_worker_id: The case worker ID + client_service: Client service + current_user: Current admin user + + Returns: + ServiceResponse: Created case assignment + """ + return client_service.create_case_assignment(client_id, case_worker_id) + @router.delete("/{client_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_client( client_id: int, + client_service=Depends(get_client_service), current_user: User = Depends(get_admin_user), - db: Session = Depends(get_db) ): - """Delete a client""" - ClientService.delete_client(db, client_id) + """ + Delete a client (admin only) + + Args: + client_id: The client ID + client_service: Client service + current_user: Current admin user + """ + client_service.delete_client(client_id) return None + + +@router.get( + "/test-new-endpoint", + tags=["test-tag"], + summary="Brief description", + description="Detailed description", + response_description="Description of the response", +) +async def new_endpoint(): + """ + This docstring will appear in the Swagger documentation + + Returns: + dict: Description of what the endpoint returns + """ + return {"message": "Hello World"} diff --git a/app/clients/schema.py b/app/clients/schema.py index cff28897..fa5f0e04 100644 --- a/app/clients/schema.py +++ b/app/clients/schema.py @@ -9,16 +9,19 @@ from enum import IntEnum from app.models import UserRole + # Enums for validation class Gender(IntEnum): MALE = 1 FEMALE = 2 + class PredictionInput(BaseModel): """ Schema for prediction input data containing all client assessment fields. Used for making predictions about client outcomes. """ + age: int gender: str work_experience: int @@ -44,6 +47,7 @@ class PredictionInput(BaseModel): time_unemployed: int need_mental_health_support_bool: str + class ClientBase(BaseModel): age: int = Field(ge=18, description="Age of client, must be 18 or older") gender: Gender = Field(description="Gender: 1 for male, 2 for female") @@ -96,16 +100,18 @@ class Config: "currently_employed": False, "substance_use": False, "time_unemployed": 6, - "need_mental_health_support_bool": False + "need_mental_health_support_bool": False, } } + class ClientResponse(ClientBase): id: int class Config: from_attributes = True + class ClientUpdate(BaseModel): age: Optional[int] = Field(None, ge=18) gender: Optional[Gender] = None @@ -132,6 +138,7 @@ class ClientUpdate(BaseModel): time_unemployed: Optional[int] = Field(None, ge=0) need_mental_health_support_bool: Optional[bool] = None + class ServiceResponse(BaseModel): client_id: int user_id: int @@ -147,6 +154,7 @@ class ServiceResponse(BaseModel): class Config: from_attributes = True + class ServiceUpdate(BaseModel): employment_assistance: Optional[bool] = None life_stabilization: Optional[bool] = None @@ -157,6 +165,7 @@ class ServiceUpdate(BaseModel): enhanced_referrals: Optional[bool] = None success_rate: Optional[int] = Field(None, ge=0, le=100) + class ClientListResponse(BaseModel): clients: List[ClientResponse] total: int diff --git a/app/clients/service/client_service.py b/app/clients/service/client_service.py index 86c3ef4a..4048e884 100644 --- a/app/clients/service/client_service.py +++ b/app/clients/service/client_service.py @@ -1,361 +1,168 @@ +# app/clients/service/client_service.py """ -Client service module handling all database operations for clients. -Provides CRUD operations and business logic for client management. +Client service for client-related business logic. +Encapsulates business rules and coordinates with repositories. """ +from typing import Dict, Any, List, Optional, Tuple + +from app.models import Client, ClientCase +from app.clients.repository import ClientRepositoryProtocol, ClientCaseRepositoryProtocol +from app.clients.schema import ClientUpdate, ServiceUpdate -from sqlalchemy.orm import Session -from sqlalchemy import and_ -from fastapi import HTTPException, status -from typing import List, Optional, Dict, Any -from app.models import Client, ClientCase, User -from app.clients.schema import ClientUpdate, ServiceUpdate, ServiceResponse class ClientService: - @staticmethod - def get_client(db: Session, client_id: int): - """Get a specific client by ID""" - client = db.query(Client).filter(Client.id == client_id).first() - if not client: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Client with id {client_id} not found" - ) - return client - - @staticmethod - def get_clients(db: Session, skip: int = 0, limit: int = 50): + """Service for client-related operations""" + + def __init__( + self, + client_repository: ClientRepositoryProtocol, + client_case_repository: ClientCaseRepositoryProtocol, + ): + """ + Initialize with repositories + + Args: + client_repository: Client repository + client_case_repository: Client case repository """ - Get clients with optional pagination. - Default shows first 50 clients, which means you'd need 3 pages for 150 records. + self.client_repository = client_repository + self.client_case_repository = client_case_repository + + def get_client(self, client_id: int) -> Client: """ - if skip < 0: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Skip value cannot be negative" - ) - if limit < 1: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Limit must be greater than 0" - ) - - clients = db.query(Client).offset(skip).limit(limit).all() - total = db.query(Client).count() + Get a client by ID + + Args: + client_id: The client ID + + Returns: + Client: The client + """ + return self.client_repository.get_by_id(client_id) + + def get_clients(self, skip: int = 0, limit: int = 50) -> Dict[str, Any]: + """ + Get clients with pagination + + Args: + skip: Number of records to skip + limit: Maximum number of records + + Returns: + Dict[str, Any]: Clients and total count + """ + clients, total = self.client_repository.get_all(skip, limit) return {"clients": clients, "total": total} - @staticmethod - def get_clients_by_criteria( - db: Session, - employment_status: Optional[bool] = None, - education_level: Optional[int] = None, - age_min: Optional[int] = None, - gender: Optional[int] = None, - work_experience: Optional[int] = None, - canada_workex: Optional[int] = None, - dep_num: Optional[int] = None, - canada_born: Optional[bool] = None, - citizen_status: Optional[bool] = None, - fluent_english: Optional[bool] = None, - reading_english_scale: Optional[int] = None, - speaking_english_scale: Optional[int] = None, - writing_english_scale: Optional[int] = None, - numeracy_scale: Optional[int] = None, - computer_scale: Optional[int] = None, - transportation_bool: Optional[bool] = None, - caregiver_bool: Optional[bool] = None, - housing: Optional[int] = None, - income_source: Optional[int] = None, - felony_bool: Optional[bool] = None, - attending_school: Optional[bool] = None, - substance_use: Optional[bool] = None, - time_unemployed: Optional[int] = None, - need_mental_health_support_bool: Optional[bool] = None - ): - """Get clients filtered by any combination of criteria""" - query = db.query(Client) - - if education_level is not None and not (1 <= education_level <= 14): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Education level must be between 1 and 14" - ) - - if age_min is not None and age_min < 18: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Minimum age must be at least 18" - ) - - if gender is not None and gender not in [1, 2]: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Gender must be 1 or 2" - ) - - # Apply filters for non-None values - if employment_status is not None: - query = query.filter(Client.currently_employed == employment_status) - if age_min is not None: - query = query.filter(Client.age >= age_min) - if gender is not None: - query = query.filter(Client.gender == gender) - if education_level is not None: - query = query.filter(Client.level_of_schooling == education_level) - if work_experience is not None: - query = query.filter(Client.work_experience == work_experience) - if canada_workex is not None: - query = query.filter(Client.canada_workex == canada_workex) - if dep_num is not None: - query = query.filter(Client.dep_num == dep_num) - if canada_born is not None: - query = query.filter(Client.canada_born == canada_born) - if citizen_status is not None: - query = query.filter(Client.citizen_status == citizen_status) - if fluent_english is not None: - query = query.filter(Client.fluent_english == fluent_english) - if reading_english_scale is not None: - query = query.filter(Client.reading_english_scale == reading_english_scale) - if speaking_english_scale is not None: - query = query.filter(Client.speaking_english_scale == speaking_english_scale) - if writing_english_scale is not None: - query = query.filter(Client.writing_english_scale == writing_english_scale) - if numeracy_scale is not None: - query = query.filter(Client.numeracy_scale == numeracy_scale) - if computer_scale is not None: - query = query.filter(Client.computer_scale == computer_scale) - if transportation_bool is not None: - query = query.filter(Client.transportation_bool == transportation_bool) - if caregiver_bool is not None: - query = query.filter(Client.caregiver_bool == caregiver_bool) - if housing is not None: - query = query.filter(Client.housing == housing) - if income_source is not None: - query = query.filter(Client.income_source == income_source) - if felony_bool is not None: - query = query.filter(Client.felony_bool == felony_bool) - if attending_school is not None: - query = query.filter(Client.attending_school == attending_school) - if substance_use is not None: - query = query.filter(Client.substance_use == substance_use) - if time_unemployed is not None: - query = query.filter(Client.time_unemployed == time_unemployed) - if need_mental_health_support_bool is not None: - query = query.filter(Client.need_mental_health_support_bool == need_mental_health_support_bool) - - try: - return query.all() - except Exception as e: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error retrieving clients: {str(e)}" - ) - - @staticmethod - def get_clients_by_services( - db: Session, - **service_filters: Optional[bool] - ): + def get_clients_by_criteria(self, **criteria) -> List[Client]: + """ + Get clients by criteria + + Args: + **criteria: Filter criteria + + Returns: + List[Client]: Filtered clients + """ + return self.client_repository.filter_by_criteria(**criteria) + + def get_clients_by_services(self, **service_filters) -> List[Client]: + """ + Get clients by service filters + + Args: + **service_filters: Service filters + + Returns: + List[Client]: Filtered clients """ - Get clients filtered by multiple service statuses. + return self.client_repository.filter_by_services(**service_filters) + + def get_client_services(self, client_id: int) -> List[ClientCase]: """ - query = db.query(Client).join(ClientCase) - - for service_name, status in service_filters.items(): - if status is not None: - filter_criteria = getattr(ClientCase, service_name) == status - query = query.filter(filter_criteria) - - try: - return query.all() - except Exception as e: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error retrieving clients: {str(e)}" - ) - - @staticmethod - def get_client_services(db: Session, client_id: int): - """Get all services for a specific client with case worker info""" - client_cases = db.query(ClientCase).filter(ClientCase.client_id == client_id).all() - if not client_cases: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"No services found for client with id {client_id}" - ) - return client_cases - - @staticmethod - def get_clients_by_success_rate(db: Session, min_rate: int = 70): - """Get clients with success rate at or above the specified percentage""" - if not (0 <= min_rate <= 100): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Success rate must be between 0 and 100" - ) - - return db.query(Client).join(ClientCase).filter( - ClientCase.success_rate >= min_rate - ).all() - - @staticmethod - def get_clients_by_case_worker(db: Session, case_worker_id: int): - """Get all clients assigned to a specific case worker""" - case_worker = db.query(User).filter(User.id == case_worker_id).first() - if not case_worker: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Case worker with id {case_worker_id} not found" - ) - - return db.query(Client).join(ClientCase).filter( - ClientCase.user_id == case_worker_id - ).all() - - @staticmethod - def update_client(db: Session, client_id: int, client_update: ClientUpdate): - """Update a client's information""" - client = db.query(Client).filter(Client.id == client_id).first() - if not client: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Client with id {client_id} not found" - ) + Get services for a client + + Args: + client_id: The client ID + Returns: + List[ClientCase]: Client services + """ + return self.client_case_repository.get_by_client(client_id) + + def get_clients_by_success_rate(self, min_rate: int = 70) -> List[Client]: + """ + Get clients by success rate + + Args: + min_rate: Minimum success rate + + Returns: + List[Client]: Filtered clients + """ + return self.client_repository.get_clients_by_success_rate(min_rate) + + def get_clients_by_case_worker(self, case_worker_id: int) -> List[Client]: + """ + Get clients by case worker + + Args: + case_worker_id: The case worker ID + + Returns: + List[Client]: Filtered clients + """ + return self.client_repository.get_clients_by_case_worker(case_worker_id) + + def update_client(self, client_id: int, client_update: ClientUpdate) -> Client: + """ + Update a client + + Args: + client_id: The client ID + client_update: The update data + + Returns: + Client: The updated client + """ update_data = client_update.dict(exclude_unset=True) - for field, value in update_data.items(): - setattr(client, field, value) - - try: - db.commit() - db.refresh(client) - return client - except Exception as e: - db.rollback() - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to update client: {str(e)}" - ) - - @staticmethod + return self.client_repository.update(client_id, update_data) + def update_client_services( - db: Session, - client_id: int, - user_id: int, - service_update: ServiceUpdate - ): - """Update a client's services and outcomes for a specific case worker""" - client_case = db.query(ClientCase).filter( - ClientCase.client_id == client_id, - ClientCase.user_id == user_id - ).first() - - if not client_case: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"No case found for client {client_id} with case worker {user_id}. " - f"Cannot update services for a non-existent case assignment." - ) + self, client_id: int, user_id: int, service_update: ServiceUpdate + ) -> ClientCase: + """ + Update client services + Args: + client_id: The client ID + user_id: The user ID + service_update: The service update data + + Returns: + ClientCase: The updated client case + """ update_data = service_update.dict(exclude_unset=True) - for field, value in update_data.items(): - setattr(client_case, field, value) - - try: - db.commit() - db.refresh(client_case) - return client_case - except Exception as e: - db.rollback() - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to update client services: {str(e)}" - ) - - @staticmethod - def create_case_assignment( - db: Session, - client_id: int, - case_worker_id: int - ): - """Create a new case assignment""" - # Check if client exists - client = db.query(Client).filter(Client.id == client_id).first() - if not client: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Client with id {client_id} not found" - ) - - # Check if case worker exists - case_worker = db.query(User).filter(User.id == case_worker_id).first() - if not case_worker: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Case worker with id {case_worker_id} not found" - ) - - # Check if assignment already exists - existing_case = db.query(ClientCase).filter( - ClientCase.client_id == client_id, - ClientCase.user_id == case_worker_id - ).first() - - if existing_case: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Client {client_id} already has a case assigned to case worker {case_worker_id}" - ) - - try: - # Create new case assignment with default service values - new_case = ClientCase( - client_id=client_id, - user_id=case_worker_id, - employment_assistance=False, - life_stabilization=False, - retention_services=False, - specialized_services=False, - employment_related_financial_supports=False, - employer_financial_supports=False, - enhanced_referrals=False, - success_rate=0 - ) - db.add(new_case) - db.commit() - db.refresh(new_case) - return new_case - - except Exception as e: - db.rollback() - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to create case assignment: {str(e)}" - ) - - @staticmethod - def delete_client(db: Session, client_id: int): - """Delete a client and their associated records""" - # First check if client exists - client = db.query(Client).filter(Client.id == client_id).first() - if not client: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Client with id {client_id} not found" - ) - - try: - # Delete associated client_cases - db.query(ClientCase).filter( - ClientCase.client_id == client_id - ).delete() - - # Delete the client - db.delete(client) - db.commit() - - except Exception as e: - db.rollback() - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to delete client: {str(e)}" - ) + return self.client_case_repository.update(client_id, user_id, update_data) + + def create_case_assignment(self, client_id: int, case_worker_id: int) -> ClientCase: + """ + Create a case assignment + + Args: + client_id: The client ID + case_worker_id: The case worker ID + + Returns: + ClientCase: The created client case + """ + return self.client_case_repository.create(client_id, case_worker_id) + + def delete_client(self, client_id: int) -> None: + """ + Delete a client + + Args: + client_id: The client ID + """ + self.client_repository.delete(client_id) diff --git a/app/clients/service/logic.py b/app/clients/service/logic.py index c25b4217..95764c51 100644 --- a/app/clients/service/logic.py +++ b/app/clients/service/logic.py @@ -5,7 +5,8 @@ # Standard library imports import os -#import json + +# import json from itertools import product # Third-party imports @@ -14,21 +15,22 @@ # Constants COLUMN_INTERVENTIONS = [ - 'Life Stabilization', - 'General Employment Assistance Services', - 'Retention Services', - 'Specialized Services', - 'Employment-Related Financial Supports for Job Seekers and Employers', - 'Employer Financial Supports', - 'Enhanced Referrals for Skills Development' + "Life Stabilization", + "General Employment Assistance Services", + "Retention Services", + "Specialized Services", + "Employment-Related Financial Supports for Job Seekers and Employers", + "Employer Financial Supports", + "Enhanced Referrals for Skills Development", ] # Load model CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) -MODEL_PATH = os.path.join(CURRENT_DIR, 'model.pkl') +MODEL_PATH = os.path.join(CURRENT_DIR, "model.pkl") with open(MODEL_PATH, "rb") as model_file: MODEL = pickle.load(model_file) + def clean_input_data(input_data): """ Clean and transform input data into model-compatible format. @@ -40,13 +42,30 @@ def clean_input_data(input_data): list: Cleaned and formatted data ready for model input """ columns = [ - "age", "gender", "work_experience", "canada_workex", "dep_num", - "canada_born", "citizen_status", "level_of_schooling", "fluent_english", - "reading_english_scale", "speaking_english_scale", "writing_english_scale", - "numeracy_scale", "computer_scale", "transportation_bool", "caregiver_bool", - "housing", "income_source", "felony_bool", "attending_school", - "currently_employed", "substance_use", "time_unemployed", - "need_mental_health_support_bool" + "age", + "gender", + "work_experience", + "canada_workex", + "dep_num", + "canada_born", + "citizen_status", + "level_of_schooling", + "fluent_english", + "reading_english_scale", + "speaking_english_scale", + "writing_english_scale", + "numeracy_scale", + "computer_scale", + "transportation_bool", + "caregiver_bool", + "housing", + "income_source", + "felony_bool", + "attending_school", + "currently_employed", + "substance_use", + "time_unemployed", + "need_mental_health_support_bool", ] demographics = {key: input_data[key] for key in columns} output = [] @@ -57,6 +76,7 @@ def clean_input_data(input_data): output.append(value) return output + def convert_text(text_data: str): """ Convert text answers from front end into numerical values. @@ -68,33 +88,47 @@ def convert_text(text_data: str): int: Converted numerical value """ categorical_mappings = [ + {"": 0, "true": 1, "false": 0, "no": 0, "yes": 1, "No": 0, "Yes": 1}, { - "": 0, "true": 1, "false": 0, "no": 0, "yes": 1, - "No": 0, "Yes": 1 - }, - { - "Grade 0-8": 1, "Grade 9": 2, "Grade 10": 3, "Grade 11": 4, - "Grade 12 or equivalent": 5, "OAC or Grade 13": 6, - "Some college": 7, "Some university": 8, "Some apprenticeship": 9, - "Certificate of Apprenticeship": 10, "Journeyperson": 11, - "Certificate/Diploma": 12, "Bachelor's degree": 13, - "Post graduate": 14 + "Grade 0-8": 1, + "Grade 9": 2, + "Grade 10": 3, + "Grade 11": 4, + "Grade 12 or equivalent": 5, + "OAC or Grade 13": 6, + "Some college": 7, + "Some university": 8, + "Some apprenticeship": 9, + "Certificate of Apprenticeship": 10, + "Journeyperson": 11, + "Certificate/Diploma": 12, + "Bachelor's degree": 13, + "Post graduate": 14, }, { - "Renting-private": 1, "Renting-subsidized": 2, - "Boarding or lodging": 3, "Homeowner": 4, - "Living with family/friend": 5, "Institution": 6, - "Temporary second residence": 7, "Band-owned home": 8, - "Homeless or transient": 9, "Emergency hostel": 10 + "Renting-private": 1, + "Renting-subsidized": 2, + "Boarding or lodging": 3, + "Homeowner": 4, + "Living with family/friend": 5, + "Institution": 6, + "Temporary second residence": 7, + "Band-owned home": 8, + "Homeless or transient": 9, + "Emergency hostel": 10, }, { - "No Source of Income": 1, "Employment Insurance": 2, + "No Source of Income": 1, + "Employment Insurance": 2, "Workplace Safety and Insurance Board": 3, "Ontario Works applied or receiving": 4, "Ontario Disability Support Program applied or receiving": 5, - "Dependent of someone receiving OW or ODSP": 6, "Crown Ward": 7, - "Employment": 8, "Self-Employment": 9, "Other (specify)": 10 - } + "Dependent of someone receiving OW or ODSP": 6, + "Crown Ward": 7, + "Employment": 8, + "Self-Employment": 9, + "Other (specify)": 10, + }, ] for category in categorical_mappings: if text_data in category: @@ -102,6 +136,7 @@ def convert_text(text_data: str): return int(text_data) if text_data.isnumeric() else text_data + def create_matrix(row_data): """ Create matrix of all possible intervention combinations. @@ -116,6 +151,7 @@ def create_matrix(row_data): perms = intervention_permutations(7) return np.concatenate((np.array(data), np.array(perms)), axis=1) + def intervention_permutations(num): """ Generate all possible intervention combinations. @@ -128,6 +164,7 @@ def intervention_permutations(num): """ return np.array(list(product([0, 1], repeat=num))) + def get_baseline_row(row_data): """ Create baseline row with no interventions. @@ -141,6 +178,7 @@ def get_baseline_row(row_data): base_interventions = np.zeros(7) return np.concatenate((np.array(row_data), base_interventions)) + def intervention_row_to_names(row_data): """ Convert intervention row to list of intervention names. @@ -153,6 +191,7 @@ def intervention_row_to_names(row_data): """ return [COLUMN_INTERVENTIONS[i] for i, value in enumerate(row_data) if value == 1] + def process_results(baseline_pred, results_matrix): """ Process model results into structured output. @@ -164,14 +203,9 @@ def process_results(baseline_pred, results_matrix): Returns: dict: Processed results with baseline and interventions """ - result_list = [ - (row[-1], intervention_row_to_names(row[:-1])) - for row in results_matrix - ] - return { - "baseline": baseline_pred[-1], - "interventions": result_list - } + result_list = [(row[-1], intervention_row_to_names(row[:-1])) for row in results_matrix] + return {"baseline": baseline_pred[-1], "interventions": result_list} + def interpret_and_calculate(input_data): """ @@ -194,19 +228,33 @@ def interpret_and_calculate(input_data): top_results = result_matrix[-3:, -8:] return process_results(baseline_prediction, top_results) + if __name__ == "__main__": test_data = { - "age": "23", "gender": "1", "work_experience": "1", - "canada_workex": "1", "dep_num": "0", "canada_born": "1", - "citizen_status": "2", "level_of_schooling": "2", - "fluent_english": "3", "reading_english_scale": "2", - "speaking_english_scale": "2", "writing_english_scale": "3", - "numeracy_scale": "2", "computer_scale": "3", - "transportation_bool": "2", "caregiver_bool": "1", - "housing": "1", "income_source": "5", "felony_bool": "1", - "attending_school": "0", "currently_employed": "1", - "substance_use": "1", "time_unemployed": "1", - "need_mental_health_support_bool": "1" + "age": "23", + "gender": "1", + "work_experience": "1", + "canada_workex": "1", + "dep_num": "0", + "canada_born": "1", + "citizen_status": "2", + "level_of_schooling": "2", + "fluent_english": "3", + "reading_english_scale": "2", + "speaking_english_scale": "2", + "writing_english_scale": "3", + "numeracy_scale": "2", + "computer_scale": "3", + "transportation_bool": "2", + "caregiver_bool": "1", + "housing": "1", + "income_source": "5", + "felony_bool": "1", + "attending_school": "0", + "currently_employed": "1", + "substance_use": "1", + "time_unemployed": "1", + "need_mental_health_support_bool": "1", } results = interpret_and_calculate(test_data) print(results) diff --git a/app/clients/service/model.py b/app/clients/service/model.py index b2406370..4809ba69 100644 --- a/app/clients/service/model.py +++ b/app/clients/service/model.py @@ -1,6 +1,9 @@ """ Model training module for the Common Assessment Tool. -Handles the preparation, training, and saving of the prediction model. +Trains and saves three models: +- Random Forest Regressor +- Linear Regression +- Decision Tree Regressor """ # Standard library imports @@ -11,74 +14,202 @@ import pandas as pd from sklearn.model_selection import train_test_split from sklearn.ensemble import RandomForestRegressor +from sklearn.linear_model import LinearRegression +from sklearn.tree import DecisionTreeRegressor + +# === RANDOM FOREST === + def prepare_models(): """ Prepare and train the Random Forest model using the dataset. - + Returns: RandomForestRegressor: Trained model for predicting success rates """ - # Load dataset - data = pd.read_csv('data_commontool.csv') - # Define feature columns + data = pd.read_csv("app/clients/service/data_commontool.csv") + feature_columns = [ - 'age', # Client's age - 'gender', # Client's gender (bool) - 'work_experience', # Years of work experience - 'canada_workex', # Years of work experience in Canada - 'dep_num', # Number of dependents - 'canada_born', # Born in Canada - 'citizen_status', # Citizenship status - 'level_of_schooling', # Highest level achieved (1-14) - 'fluent_english', # English fluency scale (1-10) - 'reading_english_scale', # Reading ability scale (1-10) - 'speaking_english_scale',# Speaking ability scale (1-10) - 'writing_english_scale', # Writing ability scale (1-10) - 'numeracy_scale', # Numeracy ability scale (1-10) - 'computer_scale', # Computer proficiency scale (1-10) - 'transportation_bool', # Needs transportation support (bool) - 'caregiver_bool', # Is primary caregiver (bool) - 'housing', # Housing situation (1-10) - 'income_source', # Source of income (1-10) - 'felony_bool', # Has a felony (bool) - 'attending_school', # Currently a student (bool) - 'currently_employed', # Currently employed (bool) - 'substance_use', # Substance use disorder (bool) - 'time_unemployed', # Years unemployed - 'need_mental_health_support_bool' # Needs mental health support (bool) + "age", + "gender", + "work_experience", + "canada_workex", + "dep_num", + "canada_born", + "citizen_status", + "level_of_schooling", + "fluent_english", + "reading_english_scale", + "speaking_english_scale", + "writing_english_scale", + "numeracy_scale", + "computer_scale", + "transportation_bool", + "caregiver_bool", + "housing", + "income_source", + "felony_bool", + "attending_school", + "currently_employed", + "substance_use", + "time_unemployed", + "need_mental_health_support_bool", ] - # Define intervention columns + intervention_columns = [ - 'employment_assistance', - 'life_stabilization', - 'retention_services', - 'specialized_services', - 'employment_related_financial_supports', - 'employer_financial_supports', - 'enhanced_referrals' + "employment_assistance", + "life_stabilization", + "retention_services", + "specialized_services", + "employment_related_financial_supports", + "employer_financial_supports", + "enhanced_referrals", ] - # Combine all feature columns + all_features = feature_columns + intervention_columns - # Prepare training data - features = np.array(data[all_features]) # Changed from X to features - targets = np.array(data['success_rate']) # Changed from y to targets - # Split the dataset - features_train, _, targets_train, _ = train_test_split( # Removed unused variables - features, - targets, - test_size=0.2, - random_state=42 + features = np.array(data[all_features]) + targets = np.array(data["success_rate"]) + + features_train, _, targets_train, _ = train_test_split( + features, targets, test_size=0.2, random_state=42 ) - # Initialize and train the model + model = RandomForestRegressor(n_estimators=100, random_state=42) model.fit(features_train, targets_train) return model + +# === LINEAR REGRESSION === + + +def prepare_linear_regression_model(): + """ + Prepare and train the Linear Regression model using the dataset. + + Returns: + LinearRegression: Trained model + """ + data = pd.read_csv("app/clients/service/data_commontool.csv") + + feature_columns = [ + "age", + "gender", + "work_experience", + "canada_workex", + "dep_num", + "canada_born", + "citizen_status", + "level_of_schooling", + "fluent_english", + "reading_english_scale", + "speaking_english_scale", + "writing_english_scale", + "numeracy_scale", + "computer_scale", + "transportation_bool", + "caregiver_bool", + "housing", + "income_source", + "felony_bool", + "attending_school", + "currently_employed", + "substance_use", + "time_unemployed", + "need_mental_health_support_bool", + ] + + intervention_columns = [ + "employment_assistance", + "life_stabilization", + "retention_services", + "specialized_services", + "employment_related_financial_supports", + "employer_financial_supports", + "enhanced_referrals", + ] + + all_features = feature_columns + intervention_columns + features = np.array(data[all_features]) + targets = np.array(data["success_rate"]) + + features_train, _, targets_train, _ = train_test_split( + features, targets, test_size=0.2, random_state=42 + ) + + model = LinearRegression() + model.fit(features_train, targets_train) + return model + + +# === DECISION TREE === + + +def prepare_decision_tree_model(): + """ + Prepare and train the Decision Tree model using the dataset. + + Returns: + DecisionTreeRegressor: Trained model + """ + data = pd.read_csv("app/clients/service/data_commontool.csv") + + feature_columns = [ + "age", + "gender", + "work_experience", + "canada_workex", + "dep_num", + "canada_born", + "citizen_status", + "level_of_schooling", + "fluent_english", + "reading_english_scale", + "speaking_english_scale", + "writing_english_scale", + "numeracy_scale", + "computer_scale", + "transportation_bool", + "caregiver_bool", + "housing", + "income_source", + "felony_bool", + "attending_school", + "currently_employed", + "substance_use", + "time_unemployed", + "need_mental_health_support_bool", + ] + + intervention_columns = [ + "employment_assistance", + "life_stabilization", + "retention_services", + "specialized_services", + "employment_related_financial_supports", + "employer_financial_supports", + "enhanced_referrals", + ] + + all_features = feature_columns + intervention_columns + features = np.array(data[all_features]) + targets = np.array(data["success_rate"]) + + features_train, _, targets_train, _ = train_test_split( + features, targets, test_size=0.2, random_state=42 + ) + + model = DecisionTreeRegressor(random_state=42) + model.fit(features_train, targets_train) + return model + + +# === SAVE FUNCTION (shared) === + + def save_model(model, filename="model.pkl"): """ Save the trained model to a file. - + Args: model: Trained model to save filename (str): Name of the file to save the model to @@ -86,25 +217,27 @@ def save_model(model, filename="model.pkl"): with open(filename, "wb") as model_file: pickle.dump(model, model_file) -def load_model(filename="model.pkl"): + +# === MAIN: train all models === + + +def train_all_models(): """ - Load a trained model from a file. - - Args: - filename (str): Name of the file to load the model from - - Returns: - The loaded model + Trains and saves all three models. """ - with open(filename, "rb") as model_file: - return pickle.load(model_file) + print("Training and saving all models...") + + rf = prepare_models() + save_model(rf, "model_rf.pkl") + + lr = prepare_linear_regression_model() + save_model(lr, "model_lr.pkl") + + dt = prepare_decision_tree_model() + save_model(dt, "model_dt.pkl") + + print("All models saved successfully!") -def main(): - """Main function to train and save the model.""" - print("Starting model training...") - model = prepare_models() - save_model(model) - print("Model training completed and saved successfully.") if __name__ == "__main__": - main() + train_all_models() diff --git a/app/clients/service/model_dt.pkl b/app/clients/service/model_dt.pkl new file mode 100644 index 00000000..b20adfda Binary files /dev/null and b/app/clients/service/model_dt.pkl differ diff --git a/app/clients/service/model_lr.pkl b/app/clients/service/model_lr.pkl new file mode 100644 index 00000000..a7de1a1b Binary files /dev/null and b/app/clients/service/model_lr.pkl differ diff --git a/app/clients/service/model_manager.py b/app/clients/service/model_manager.py new file mode 100644 index 00000000..f8ba8447 --- /dev/null +++ b/app/clients/service/model_manager.py @@ -0,0 +1,83 @@ +""" +Model Manager + +Loads and manages machine learning models saved as .pkl files. +Provides functions to switch between them and retrieve current model info. +""" + +import os +import pickle + +# Get absolute path of current file +BASE_DIR = os.path.dirname(__file__) + +# Map of model names to .pkl filenames +model_files = { + "random_forest": "model_rf.pkl", + "linear_regression": "model_lr.pkl", + "decision_tree": "model_dt.pkl", +} + +# Load all models on startup +models = {} +try: + for name, path in model_files.items(): + try: + full_path = os.path.join(BASE_DIR, path) + with open(full_path, "rb") as f: + models[name] = pickle.load(f) + except (ModuleNotFoundError, ImportError, FileNotFoundError) as e: + print( + f"Warning: Could not load model '{name}' from {full_path}. Using a placeholder model. Error: {e}" + ) + from sklearn.ensemble import RandomForestRegressor + + models[name] = RandomForestRegressor() +except Exception as e: + print(f"Error loading models: {e}") + from sklearn.ensemble import RandomForestRegressor + + models = {"default": RandomForestRegressor()} + + +# === Public functions === + + +def list_models(): + """ + Returns a list of all available model names. + """ + return list(models.keys()) + + +def get_current_model_name(): + """ + Returns the name of the currently active model. + """ + return current_model_name + + +def get_current_model(): + """ + Returns the actual model object currently in use. + """ + return current_model + + +def switch_model(model_name: str): + """ + Switches the currently active model to the given one. + + Args: + model_name (str): One of the keys in model_files + + Raises: + ValueError: If the model_name is not available + """ + global current_model_name, current_model + + if model_name not in models: + raise ValueError(f"Model '{model_name}' not found. Available: {list_models()}") + + current_model_name = model_name + current_model = models[model_name] diff --git a/app/clients/service/model_rf.pkl b/app/clients/service/model_rf.pkl new file mode 100644 index 00000000..a7cc0f61 Binary files /dev/null and b/app/clients/service/model_rf.pkl differ diff --git a/app/database.py b/app/database.py index 3a489f54..b5c8b948 100644 --- a/app/database.py +++ b/app/database.py @@ -7,22 +7,23 @@ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker -#Here is where the database is located -SQLALCHEMY_DATABASE_URL = "sqlite:///./sql_app.db" +# Here is where the database is located +SQLALCHEMY_DATABASE_URL = "sqlite:///./sql_app.db" -#Open up a connection so that we are able to use the database +# Open up a connection so that we are able to use the database engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}) -#Bind the engine just created +# Bind the engine just created SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) -#Create an object of our database so as to control the database +# Create an object of our database so as to control the database Base = declarative_base() + def get_db(): """ Creates a database session and ensures it's closed after use. - + Yields: Session: SQLAlchemy database session """ diff --git a/app/main.py b/app/main.py index a8e8fa7f..d75f8f8a 100644 --- a/app/main.py +++ b/app/main.py @@ -4,28 +4,71 @@ Handles database initialization and CORS middleware configuration. """ +"Just For Testing" from fastapi import FastAPI from app import models from app.database import engine -from app.clients.router import router as clients_router +from app.clients.router import router as clients_router, model_router from app.auth.router import router as auth_router from fastapi.middleware.cors import CORSMiddleware +import os +from dotenv import load_dotenv + +# Load environment variables +load_dotenv() # Initialize database tables models.Base.metadata.create_all(bind=engine) # Create FastAPI application -app = FastAPI(title="Case Management API", description="API for managing client cases", version="1.0.0") +app = FastAPI( + title="Case Management API", description="API for managing client cases", version="1.0.0" +) # Include routers app.include_router(auth_router) -app.include_router(clients_router) +app.include_router(clients_router, prefix="/clients", tags=["Clients"]) +app.include_router(model_router) # Configure CORS middleware app.add_middleware( CORSMiddleware, - allow_origins=["*"], # Allows all origins - allow_methods=["*"], # Allows all methods - allow_headers=["*"], # Allows all headers + allow_origins=["*"], # Allows all origins + allow_methods=["*"], # Allows all methods + allow_headers=["*"], # Allows all headers allow_credentials=True, ) + + +@app.on_event("startup") +async def show_routes_on_startup(): + print("✅ LOADED ROUTES:") + for route in app.routes: + print(f" {route.path}") + + +# Health check endpoint +@app.get("/health", tags=["health"]) +async def health_check(): + """ + Health check endpoint for monitoring + + Returns: + dict: Status message + """ + return {"status": "ok", "version": app.version} + + +if __name__ == "__main__": + import uvicorn + + # Get port from environment variable or default to 8000 + port = int(os.getenv("PORT", 8000)) + + # Start the application with uvicorn + uvicorn.run( + "app.main:app", + host="0.0.0.0", + port=port, + reload=os.getenv("ENVIRONMENT", "production").lower() == "development", + ) diff --git a/app/models.py b/app/models.py index df778348..101744d3 100644 --- a/app/models.py +++ b/app/models.py @@ -8,6 +8,7 @@ from sqlalchemy.orm import relationship import enum + class UserRole(str, enum.Enum): admin = "admin" case_worker = "case_worker" @@ -24,46 +25,61 @@ class User(Base): cases = relationship("ClientCase", back_populates="user") + class Client(Base): """ Client model representing client data in the database. """ + __tablename__ = "clients" id = Column(Integer, primary_key=True, autoincrement=True) - age = Column(Integer, CheckConstraint('age >= 18')) + age = Column(Integer, CheckConstraint("age >= 18")) gender = Column(Integer, CheckConstraint("gender = 1 OR gender = 2")) - work_experience = Column(Integer, CheckConstraint('work_experience >= 0')) - canada_workex = Column(Integer, CheckConstraint('canada_workex >= 0')) - dep_num = Column(Integer, CheckConstraint('dep_num >= 0')) + work_experience = Column(Integer, CheckConstraint("work_experience >= 0")) + canada_workex = Column(Integer, CheckConstraint("canada_workex >= 0")) + dep_num = Column(Integer, CheckConstraint("dep_num >= 0")) canada_born = Column(Boolean) citizen_status = Column(Boolean) - level_of_schooling = Column(Integer, CheckConstraint('level_of_schooling >= 1 AND level_of_schooling <= 14')) + level_of_schooling = Column( + Integer, CheckConstraint("level_of_schooling >= 1 AND level_of_schooling <= 14") + ) fluent_english = Column(Boolean) - reading_english_scale = Column(Integer, CheckConstraint('reading_english_scale >= 0 AND reading_english_scale <= 10')) - speaking_english_scale = Column(Integer, CheckConstraint('speaking_english_scale >= 0 AND speaking_english_scale <= 10')) - writing_english_scale = Column(Integer, CheckConstraint('writing_english_scale >= 0 AND writing_english_scale <= 10')) - numeracy_scale = Column(Integer, CheckConstraint('numeracy_scale >= 0 AND numeracy_scale <= 10')) - computer_scale = Column(Integer, CheckConstraint('computer_scale >= 0 AND computer_scale <= 10')) + reading_english_scale = Column( + Integer, CheckConstraint("reading_english_scale >= 0 AND reading_english_scale <= 10") + ) + speaking_english_scale = Column( + Integer, CheckConstraint("speaking_english_scale >= 0 AND speaking_english_scale <= 10") + ) + writing_english_scale = Column( + Integer, CheckConstraint("writing_english_scale >= 0 AND writing_english_scale <= 10") + ) + numeracy_scale = Column( + Integer, CheckConstraint("numeracy_scale >= 0 AND numeracy_scale <= 10") + ) + computer_scale = Column( + Integer, CheckConstraint("computer_scale >= 0 AND computer_scale <= 10") + ) transportation_bool = Column(Boolean) caregiver_bool = Column(Boolean) - housing = Column(Integer, CheckConstraint('housing >= 1 AND housing <= 10')) - income_source = Column(Integer, CheckConstraint('income_source >= 1 AND income_source <= 11')) + housing = Column(Integer, CheckConstraint("housing >= 1 AND housing <= 10")) + income_source = Column(Integer, CheckConstraint("income_source >= 1 AND income_source <= 11")) felony_bool = Column(Boolean) attending_school = Column(Boolean) currently_employed = Column(Boolean) substance_use = Column(Boolean) - time_unemployed = Column(Integer, CheckConstraint('time_unemployed >= 0')) + time_unemployed = Column(Integer, CheckConstraint("time_unemployed >= 0")) need_mental_health_support_bool = Column(Boolean) cases = relationship("ClientCase", back_populates="client") + class ClientCase(Base): __tablename__ = "client_cases" client_id = Column(Integer, ForeignKey("clients.id"), primary_key=True) user_id = Column(Integer, ForeignKey("users.id"), primary_key=True) - + employment_assistance = Column(Boolean) life_stabilization = Column(Boolean) retention_services = Column(Boolean) @@ -71,7 +87,7 @@ class ClientCase(Base): employment_related_financial_supports = Column(Boolean) employer_financial_supports = Column(Boolean) enhanced_referrals = Column(Boolean) - success_rate = Column(Integer, CheckConstraint('success_rate >= 0 AND success_rate <= 100')) + success_rate = Column(Integer, CheckConstraint("success_rate >= 0 AND success_rate <= 100")) client = relationship("Client", back_populates="cases") user = relationship("User", back_populates="cases") diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 00000000..1ca903bd --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,11 @@ +version: '3.8' + +services: + backend: + build: . + ports: + - "8000:8000" + environment: + - ENV_VAR=value + volumes: + - .:/app \ No newline at end of file diff --git a/initialize_data.py b/initialize_data.py index 1444bf41..7eba5a24 100644 --- a/initialize_data.py +++ b/initialize_data.py @@ -2,19 +2,19 @@ from sqlalchemy.orm import Session from app.database import SessionLocal from app.models import Client, User, ClientCase, UserRole -from app.auth.router import get_password_hash +from app.auth.security import PasswordService def initialize_database(): print("Starting database initialization...") db = SessionLocal() try: # Create admin user if doesn't exist - admin = db.query(User).filter(User.username == "admin").first() - if not admin: + admin_user = db.query(User).filter(User.username == "admin").first() + if not admin_user: admin_user = User( username="admin", email="admin@example.com", - hashed_password=get_password_hash("admin123"), + hashed_password=PasswordService.get_password_hash("admin123"), role=UserRole.admin ) db.add(admin_user) @@ -29,7 +29,7 @@ def initialize_database(): case_worker = User( username="case_worker1", email="caseworker1@example.com", - hashed_password=get_password_hash("worker123"), + hashed_password=PasswordService.get_password_hash("worker123"), role=UserRole.case_worker ) db.add(case_worker) diff --git a/my-common-assessment-tool.pem b/my-common-assessment-tool.pem new file mode 100644 index 00000000..d830b7b4 --- /dev/null +++ b/my-common-assessment-tool.pem @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEpQIBAAKCAQEAyFCsaaa+reGDHqdXqEoayp7dmaFJBrkIiBFf/Lzt/spXC9YI +JDXPP9nLUIOLzv7Ay3vwfB/LYVp/IYlpzAyixJnUdIcwGQbKK7AmgpWCtmVeB9Jx +SmmxEWPoTcVVOputGuReVDS2/CLUGTjqj5epP4Ugdwc8XsJjNuJYf3K1hTPSj4VF +1/GEu9fEF7vEZLR8QzEtaal9D2odmxSNJhLuNKhzqsWcBhTnEapYTxbOl+mQtxcm +4J8oCS4iAacCydnPyUcGhIRcuNz1xn6gYAEMYwDMcfpLFBlafnRYI9a1UuaU9Wz1 +8IDrjQi5va0p/WxPfttBkvDZlBz72uQVxNDGpwIDAQABAoIBAGFogX7a2+x4Nieo +3oJyjrarLD1x5a4EOnbYZCHlyaHVySBzUwAwvnhhM3ISleDxltUcjuP9HgxYUmv/ +g1f7aQdLermzp5rz50n5XbCwfaCuiFwrZHX4EWfQen2fEQPwAeyK0qgF/ll7okIl +oEJ1UJMX7KKU/TFjO5XL2ZcYM9bycDyyncio0ImfaFVpSoG3JEhEKGGnaLlxjlNg +hHClgOswg/hFYSvOtHVXir8VFlfWQq+oQDftLXDf6s6ozj4lTyTOb2TILf7y+1Iy +4ITd3HrSvnLeA3O40NWDMrrKo/IGXyoJvUGz/sBOFVQTn65MEVwCv06KuXOIj6e6 +TYvxkwECgYEA94CB4Nd3fOr3AYxp0AiMaSVbZyFX+pfABOWBFZGGxR/uzsfb7S8Y +qRkXJwksI3qRU7BZoAGp/7oX5HskNh0dUMiUQKQ2ocWEM5qFneh20lwCeo+9UfuL +bYYFz30KJxTJSlxZhHyG1SXfj1Lly8+7kcR3RlU1e7r0aSYfoy3IaoECgYEAzzFn +OS4lmGV0Y34Xzh2I9ZWaEaYGD86meb0xRkwVU1IBKuijMNvKPDgt6cRztsuIbnYb +14klrmCajXKwWtM3qXi+waQf66clnnTD3chnaMwcT/25Fik4cXpvi0Mn+uHOVawI +ygk//nqHh+spJWcWXi8wFh/4RaCGzBtW8O3gDScCgYEAyxcO/AGyUbW4g/PFK+in +1uvJieGpgL6e2SW9+4XTsdOXMOR8ya6YrMEi52w2ZNKBh8uwb4SOC4KXcmu9dg4D +7TL5u+VD0xDxfyqvs7h6L/lCK3HhZvFjIrcT84NmHlWHKtaGuhk4xpRyUvgyCkDm +aCFvwi3PWj05q0KWOV8rEoECgYEAsafIvIzHC68iZxT9UGyevQTzwGI9HFyy/fut +PnuKZZEREzu6gfBTreL161XZakmGyEBZiyw7tRN8MgC/GoG1Xoj794nFHQiLBx1T +vN1TXdZ2CFij1U6u6Q50ilKg+0uW4nrKZoIb7xYdE/wdocaMtWF8t9vdw8XrDyP6 +Hke5L00CgYEA20Z7IP5QhaT5stl9avNu/pvr9MI5PV9X9tM+cHXdolMiIwJhFqx5 +YBRG3IkOHRdTf2DOdyUDuqjQbYBwq/S/rLQGpTDeiqG9riPPQK62eIvjECb56TWM +qdGs+a530ePKDHcM0eL9DudIfDGXLJ4rccPbPm9drXhqD1f68/FgGNs= +-----END RSA PRIVATE KEY----- \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..19a3e527 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,46 @@ +[project] +name = "CommonAssessmentTool" +version = "0.1.0" +authors = [ + {name = "Richeng Yang"}, + {name = "Jiayi Liu"}, + {name = "Yanyue Wang"}, + {name = "Qilin Zeng"}, +] +readme = "README.md" +requires-python = ">=3.10" + +dependencies = [ + "fastapi>=0.103.2", + "uvicorn>=0.23.2", + "sqlalchemy>=2.0.21", + "pydantic>=2.4.2", + "python-dotenv>=1.0.0", + "pandas>=2.0.0", + "psycopg2-binary>=2.9.9", + "python-jose>=3.3.0", + "passlib>=1.7.4", + "bcrypt>=4.0.1", + "numpy>=1.24.2", + "scikit-learn>=1.4.2", +] + +[project.optional-dependencies] +dev = [ + "pytest>=7.2.0", + "pylint>=3.0.1", + "black>=23.10.0", + "httpx>=0.24.1", +] + +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +[tool.black] +line-length = 100 +target-version = ["py310"] + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = "test_*.py" \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 93d35fbf..f04b5e50 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,7 +12,7 @@ attrs==22.1.0 backcall==0.2.0 bcrypt==4.0.1 beautifulsoup4==4.12.2 -black==23.10.0 +black==25.1.0 bleach==6.0.0 branca==0.6.0 certifi==2023.7.22 diff --git a/tests/conftest.py b/tests/conftest.py index aa30d094..d5c29371 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,7 +4,7 @@ from sqlalchemy.orm import sessionmaker from app.database import Base, get_db from app.main import app -from app.auth.router import get_password_hash +from app.auth.security import PasswordService from app.models import User, UserRole, Client, ClientCase # Create test database @@ -12,19 +12,20 @@ engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}) TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + @pytest.fixture def test_db(): # Create tables Base.metadata.create_all(bind=engine) - + db = TestingSessionLocal() try: # Create test admin user admin_user = User( username="testadmin", email="testadmin@example.com", - hashed_password=get_password_hash("testpass123"), - role=UserRole.admin + hashed_password=PasswordService.get_password_hash("testpass123"), + role=UserRole.admin, ) db.add(admin_user) @@ -32,11 +33,11 @@ def test_db(): case_worker = User( username="testworker", email="worker@example.com", - hashed_password=get_password_hash("workerpass123"), - role=UserRole.case_worker + hashed_password=PasswordService.get_password_hash("workerpass123"), + role=UserRole.case_worker, ) db.add(case_worker) - + # Create test clients client1 = Client( age=25, @@ -62,9 +63,9 @@ def test_db(): currently_employed=False, substance_use=False, time_unemployed=6, - need_mental_health_support_bool=False + need_mental_health_support_bool=False, ) - + client2 = Client( age=30, gender=2, @@ -89,13 +90,13 @@ def test_db(): currently_employed=True, substance_use=False, time_unemployed=0, - need_mental_health_support_bool=False + need_mental_health_support_bool=False, ) - + db.add(client1) db.add(client2) db.commit() - + # Create test client cases client_case1 = ClientCase( client_id=1, @@ -107,9 +108,9 @@ def test_db(): employment_related_financial_supports=True, employer_financial_supports=False, enhanced_referrals=True, - success_rate=75 + success_rate=75, ) - + client_case2 = ClientCase( client_id=2, user_id=2, # Assigned to case worker @@ -120,18 +121,19 @@ def test_db(): employment_related_financial_supports=False, employer_financial_supports=True, enhanced_referrals=False, - success_rate=85 + success_rate=85, ) - + db.add(client_case1) db.add(client_case2) db.commit() - + yield db finally: db.close() Base.metadata.drop_all(bind=engine) + @pytest.fixture def client(test_db): def override_get_db(): @@ -139,32 +141,31 @@ def override_get_db(): yield test_db finally: test_db.close() - + app.dependency_overrides[get_db] = override_get_db yield TestClient(app) app.dependency_overrides.clear() + @pytest.fixture def admin_token(client): - response = client.post( - "/auth/token", - data={"username": "testadmin", "password": "testpass123"} - ) + response = client.post("/auth/token", data={"username": "testadmin", "password": "testpass123"}) return response.json()["access_token"] + @pytest.fixture def case_worker_token(client): response = client.post( - "/auth/token", - data={"username": "testworker", "password": "workerpass123"} + "/auth/token", data={"username": "testworker", "password": "workerpass123"} ) return response.json()["access_token"] + @pytest.fixture def admin_headers(admin_token): return {"Authorization": f"Bearer {admin_token}"} + @pytest.fixture def case_worker_headers(case_worker_token): return {"Authorization": f"Bearer {case_worker_token}"} - \ No newline at end of file diff --git a/tests/test_auth.py b/tests/test_auth.py index 1d4692e4..d919ca99 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -1,123 +1,111 @@ import pytest from fastapi import status + def test_create_user_success(client, admin_headers): """Test successful user creation by admin""" user_data = { "username": "newuser", "email": "new@test.com", "password": "testpass123", - "role": "case_worker" + "role": "case_worker", } - response = client.post( - "/auth/users", - headers=admin_headers, - json=user_data - ) + response = client.post("/auth/users", headers=admin_headers, json=user_data) assert response.status_code == status.HTTP_200_OK data = response.json() assert data["username"] == "newuser" assert data["role"] == "case_worker" assert "password" not in data # Password should not be in response + def test_create_user_duplicate_username(client, admin_headers): """Test creating user with existing username""" user_data = { "username": "testadmin", # This username exists in test database "email": "another@test.com", "password": "testpass123", - "role": "case_worker" + "role": "case_worker", } - response = client.post( - "/auth/users", - headers=admin_headers, - json=user_data - ) + response = client.post("/auth/users", headers=admin_headers, json=user_data) assert response.status_code == status.HTTP_400_BAD_REQUEST assert "Username already registered" in response.json()["detail"] + def test_create_user_duplicate_email(client, admin_headers): """Test creating user with existing email""" user_data = { "username": "uniqueuser", "email": "testadmin@example.com", # This email exists in test database "password": "testpass123", - "role": "case_worker" + "role": "case_worker", } - response = client.post( - "/auth/users", - headers=admin_headers, - json=user_data - ) + response = client.post("/auth/users", headers=admin_headers, json=user_data) assert response.status_code == status.HTTP_400_BAD_REQUEST assert "Email already registered" in response.json()["detail"] + def test_create_user_invalid_role(client, admin_headers): """Test creating user with invalid role""" user_data = { "username": "newuser", "email": "new@test.com", "password": "testpass123", - "role": "invalid_role" # Invalid role + "role": "invalid_role", # Invalid role } - response = client.post( - "/auth/users", - headers=admin_headers, - json=user_data - ) + response = client.post("/auth/users", headers=admin_headers, json=user_data) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + def test_create_user_unauthorized(client): """Test user creation without authentication""" user_data = { "username": "newuser", "email": "new@test.com", "password": "testpass123", - "role": "case_worker" + "role": "case_worker", } response = client.post("/auth/users", json=user_data) assert response.status_code == status.HTTP_401_UNAUTHORIZED + def test_login_success_admin(client): """Test successful login for admin""" - response = client.post( - "/auth/token", - data={"username": "testadmin", "password": "testpass123"} - ) + response = client.post("/auth/token", data={"username": "testadmin", "password": "testpass123"}) assert response.status_code == status.HTTP_200_OK data = response.json() assert "access_token" in data assert data["token_type"] == "bearer" + def test_login_success_case_worker(client): """Test successful login for case worker""" response = client.post( - "/auth/token", - data={"username": "testworker", "password": "workerpass123"} + "/auth/token", data={"username": "testworker", "password": "workerpass123"} ) assert response.status_code == status.HTTP_200_OK data = response.json() assert "access_token" in data assert data["token_type"] == "bearer" + def test_login_wrong_password(client): """Test login with incorrect password""" response = client.post( - "/auth/token", - data={"username": "testadmin", "password": "wrongpassword"} + "/auth/token", data={"username": "testadmin", "password": "wrongpassword"} ) assert response.status_code == status.HTTP_401_UNAUTHORIZED assert "Incorrect username or password" in response.json()["detail"] + def test_login_nonexistent_user(client): """Test login with non-existent username""" response = client.post( - "/auth/token", - data={"username": "nonexistent", "password": "testpass123"} + "/auth/token", data={"username": "nonexistent", "password": "testpass123"} ) assert response.status_code == status.HTTP_401_UNAUTHORIZED assert "Incorrect username or password" in response.json()["detail"] + def test_invalid_token(client): """Test using invalid token""" headers = {"Authorization": "Bearer invalid_token_here"} @@ -125,12 +113,14 @@ def test_invalid_token(client): assert response.status_code == status.HTTP_401_UNAUTHORIZED assert "Could not validate credentials" in response.json()["detail"] + def test_missing_token(client): """Test accessing protected endpoint without token""" response = client.get("/clients/") assert response.status_code == status.HTTP_401_UNAUTHORIZED assert "Not authenticated" in response.json()["detail"] + def test_token_user_deleted(client, admin_headers): """Test using token of deleted user""" # First create a new user as admin @@ -138,22 +128,15 @@ def test_token_user_deleted(client, admin_headers): "username": "temporary", "email": "temp@test.com", "password": "temppass123", - "role": "admin" # Changed to admin so they can access /clients/ + "role": "admin", # Changed to admin so they can access /clients/ } - response = client.post( - "/auth/users", - headers=admin_headers, - json=user_data - ) + response = client.post("/auth/users", headers=admin_headers, json=user_data) assert response.status_code == status.HTTP_200_OK # Get token for new user - response = client.post( - "/auth/token", - data={"username": "temporary", "password": "temppass123"} - ) + response = client.post("/auth/token", data={"username": "temporary", "password": "temppass123"}) token = response.json()["access_token"] - + # Try using the token headers = {"Authorization": f"Bearer {token}"} response = client.get("/clients/", headers=headers) diff --git a/tests/test_clients.py b/tests/test_clients.py index 611a5b34..f7462d30 100644 --- a/tests/test_clients.py +++ b/tests/test_clients.py @@ -1,12 +1,14 @@ import pytest from fastapi import status + # Test GET Operations def test_get_clients_unauthorized(client): """Test that unauthorized access is prevented""" response = client.get("/clients/") assert response.status_code == status.HTTP_401_UNAUTHORIZED + def test_get_clients_as_admin(client, admin_headers): """Test getting all clients as admin""" response = client.get("/clients/", headers=admin_headers) @@ -16,24 +18,24 @@ def test_get_clients_as_admin(client, admin_headers): assert "total" in data assert len(data["clients"]) > 0 + def test_get_client_by_id(client, admin_headers): """Test getting specific client""" # Test existing client response = client.get("/clients/1", headers=admin_headers) assert response.status_code == status.HTTP_200_OK assert response.json()["id"] == 1 - + # Test non-existent client response = client.get("/clients/999", headers=admin_headers) assert response.status_code == status.HTTP_404_NOT_FOUND + def test_get_clients_by_criteria(client, admin_headers): """Test searching clients by various criteria""" # Test single criterion response = client.get( - "/clients/search/by-criteria", - params={"age_min": 25}, - headers=admin_headers + "/clients/search/by-criteria", params={"age_min": 25}, headers=admin_headers ) assert response.status_code == status.HTTP_200_OK assert len(response.json()) > 0 @@ -41,12 +43,8 @@ def test_get_clients_by_criteria(client, admin_headers): # Test multiple criteria response = client.get( "/clients/search/by-criteria", - params={ - "age_min": 25, - "currently_employed": True, - "gender": 2 - }, - headers=admin_headers + params={"age_min": 25, "currently_employed": True, "gender": 2}, + headers=admin_headers, ) assert response.status_code == status.HTTP_200_OK @@ -54,23 +52,22 @@ def test_get_clients_by_criteria(client, admin_headers): response = client.get( "/clients/search/by-criteria", params={"age_min": 15}, # Below minimum age - headers=admin_headers + headers=admin_headers, ) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY # Changed from 400 + def test_get_clients_by_services(client, admin_headers): """Test getting clients by service status""" response = client.get( "/clients/search/by-services", - params={ - "employment_assistance": True, - "life_stabilization": True - }, - headers=admin_headers + params={"employment_assistance": True, "life_stabilization": True}, + headers=admin_headers, ) assert response.status_code == status.HTTP_200_OK assert len(response.json()) > 0 + def test_get_client_services(client, admin_headers): """Test getting services for a specific client""" response = client.get("/clients/1/services", headers=admin_headers) @@ -81,63 +78,54 @@ def test_get_client_services(client, admin_headers): assert "employment_assistance" in services[0] assert "success_rate" in services[0] + def test_get_clients_by_success_rate(client, admin_headers): """Test getting clients by success rate threshold""" response = client.get( - "/clients/search/success-rate", - params={"min_rate": 70}, - headers=admin_headers + "/clients/search/success-rate", params={"min_rate": 70}, headers=admin_headers ) assert response.status_code == status.HTTP_200_OK assert len(response.json()) > 0 + def test_get_clients_by_case_worker(client, admin_headers, case_worker_headers): """Test getting clients assigned to a case worker""" # Test as admin response = client.get("/clients/case-worker/2", headers=admin_headers) assert response.status_code == status.HTTP_200_OK - + # Test as case worker response = client.get("/clients/case-worker/2", headers=case_worker_headers) assert response.status_code == status.HTTP_200_OK + # Test UPDATE Operations def test_update_client(client, admin_headers): """Test updating client information""" - update_data = { - "age": 26, - "currently_employed": True, - "time_unemployed": 0 - } - response = client.put( - "/clients/1", - json=update_data, - headers=admin_headers - ) + update_data = {"age": 26, "currently_employed": True, "time_unemployed": 0} + response = client.put("/clients/1", json=update_data, headers=admin_headers) assert response.status_code == status.HTTP_200_OK updated_client = response.json() assert updated_client["age"] == 26 assert updated_client["currently_employed"] == True assert updated_client["time_unemployed"] == 0 + # Test Create Case Assignment def test_create_case_assignment(client, admin_headers): """Test creating new case assignment""" response = client.post( - "/clients/1/case-assignment", - params={"case_worker_id": 2}, - headers=admin_headers + "/clients/1/case-assignment", params={"case_worker_id": 2}, headers=admin_headers ) assert response.status_code == status.HTTP_200_OK # Test duplicate assignment response = client.post( - "/clients/1/case-assignment", - params={"case_worker_id": 2}, - headers=admin_headers + "/clients/1/case-assignment", params={"case_worker_id": 2}, headers=admin_headers ) assert response.status_code == status.HTTP_400_BAD_REQUEST + # Test DELETE Operation def test_delete_client(client, admin_headers): """Test deleting a client"""