diff --git a/src/dispatch/auth/permissions.py b/src/dispatch/auth/permissions.py index bd0b1a95da46..9f68dbb1e06f 100644 --- a/src/dispatch/auth/permissions.py +++ b/src/dispatch/auth/permissions.py @@ -1,5 +1,6 @@ import logging from abc import ABC, abstractmethod +import json from fastapi import HTTPException from starlette.requests import Request @@ -16,6 +17,7 @@ from dispatch.organization import service as organization_service from dispatch.organization.models import OrganizationRead from dispatch.participant_role.enums import ParticipantRoleType +from dispatch.task import service as task_service log = logging.getLogger(__name__) @@ -335,6 +337,51 @@ def has_required_permissions( ) +class IncidentTaskCreateEditPermission(BasePermission): + """ + Permissions dependency to apply incident edit permissions to task-based requests. + """ + + def has_required_permissions(self, request: Request) -> bool: + incident_id = None + # for task creation, retrieve the incident id from the payload + if request.method == "POST" and hasattr(request, "_body"): + try: + body = json.loads(request._body.decode()) + incident_id = body["incident"]["id"] + except (json.JSONDecodeError, KeyError, AttributeError): + log.error( + "Encountered create_task request without expected incident ID. Cannot properly ascertain incident permissions." + ) + return False + else: # otherwise, retrieve via the task id + pk = PrimaryKeyModel(id=request.path_params["task_id"]) + current_task = task_service.get(db_session=request.state.db, task_id=pk.id) + if not current_task or not current_task.incident: + return False + incident_id = current_task.incident.id + + # minimal object with the attributes required for IncidentViewPermission + incident_request = type( + "IncidentRequest", + (), + { + "path_params": {**request.path_params, "incident_id": incident_id}, + "state": request.state, + }, + )() + + # copy necessary request attributes + for attr in ["headers", "method", "url", "query_params"]: + if hasattr(request, attr): + setattr(incident_request, attr, getattr(request, attr)) + + return any_permission( + permissions=[IncidentEditPermission], + request=incident_request, + ) + + class IncidentReporterPermission(BasePermission): def has_required_permissions( self, diff --git a/src/dispatch/task/views.py b/src/dispatch/task/views.py index 3a7eedddf01a..843bb95b81c2 100644 --- a/src/dispatch/task/views.py +++ b/src/dispatch/task/views.py @@ -1,8 +1,9 @@ import json -from fastapi import APIRouter, HTTPException, Query, status +from fastapi import APIRouter, HTTPException, Query, status, Depends from dispatch.auth.service import CurrentUser +from dispatch.auth.permissions import PermissionsDependency, IncidentTaskCreateEditPermission from dispatch.common.utils.views import create_pydantic_include from dispatch.database.core import DbSession from dispatch.database.service import CommonParameters, search_filter_sort_paginate @@ -43,7 +44,12 @@ def get_tasks(common: CommonParameters, include: list[str] = Query([], alias="in return json.loads(TaskPagination(**pagination).json()) -@router.post("", response_model=TaskRead, tags=["tasks"]) +@router.post( + "", + response_model=TaskRead, + tags=["tasks"], + dependencies=[Depends(PermissionsDependency([IncidentTaskCreateEditPermission]))], +) def create_task( db_session: DbSession, task_in: TaskCreate, @@ -64,11 +70,12 @@ def create_task( return task -@router.post("/ticket/{task_id}", tags=["tasks"]) -def create_ticket( - db_session: DbSession, - task_id: PrimaryKey, -): +@router.post( + "/ticket/{task_id}", + tags=["tasks"], + dependencies=[Depends(PermissionsDependency([IncidentTaskCreateEditPermission]))], +) +def create_ticket(db_session: DbSession, task_id: PrimaryKey, current_user: CurrentUser): """Creates a ticket for an existing task.""" task = get(db_session=db_session, task_id=task_id) if not task: @@ -79,8 +86,15 @@ def create_ticket( return create_task_ticket(task=task, db_session=db_session) -@router.put("/{task_id}", response_model=TaskRead, tags=["tasks"]) -def update_task(db_session: DbSession, task_id: PrimaryKey, task_in: TaskUpdate): +@router.put( + "/{task_id}", + response_model=TaskRead, + tags=["tasks"], + dependencies=[Depends(PermissionsDependency([IncidentTaskCreateEditPermission]))], +) +def update_task( + db_session: DbSession, task_id: PrimaryKey, task_in: TaskUpdate, current_user: CurrentUser +): """Updates an existing task.""" task = get(db_session=db_session, task_id=task_id) if not task: @@ -104,8 +118,13 @@ def update_task(db_session: DbSession, task_id: PrimaryKey, task_in: TaskUpdate) return task -@router.delete("/{task_id}", response_model=None, tags=["tasks"]) -def delete_task(db_session: DbSession, task_id: PrimaryKey): +@router.delete( + "/{task_id}", + response_model=None, + tags=["tasks"], + dependencies=[Depends(PermissionsDependency([IncidentTaskCreateEditPermission]))], +) +def delete_task(db_session: DbSession, task_id: PrimaryKey, current_user: CurrentUser): """Deletes an existing task.""" task = get(db_session=db_session, task_id=task_id) if not task: