Skip to content
This repository was archived by the owner on Sep 3, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions src/dispatch/auth/permissions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from abc import ABC, abstractmethod
import json

from fastapi import HTTPException
from starlette.requests import Request
Expand All @@ -16,6 +17,7 @@
from dispatch.organization import service as organization_service
from dispatch.organization.models import OrganizationRead
from dispatch.participant_role.enums import ParticipantRoleType
from dispatch.task import service as task_service

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -335,6 +337,51 @@ def has_required_permissions(
)


class IncidentTaskCreateEditPermission(BasePermission):
"""
Permissions dependency to apply incident edit permissions to task-based requests.
"""

def has_required_permissions(self, request: Request) -> bool:
incident_id = None
# for task creation, retrieve the incident id from the payload
if request.method == "POST" and hasattr(request, "_body"):
try:
body = json.loads(request._body.decode())
incident_id = body["incident"]["id"]
except (json.JSONDecodeError, KeyError, AttributeError):
log.error(
"Encountered create_task request without expected incident ID. Cannot properly ascertain incident permissions."
)
return False
else: # otherwise, retrieve via the task id
pk = PrimaryKeyModel(id=request.path_params["task_id"])
current_task = task_service.get(db_session=request.state.db, task_id=pk.id)
if not current_task or not current_task.incident:
return False
incident_id = current_task.incident.id

# minimal object with the attributes required for IncidentViewPermission
incident_request = type(
"IncidentRequest",
(),
{
"path_params": {**request.path_params, "incident_id": incident_id},
"state": request.state,
},
)()

# copy necessary request attributes
for attr in ["headers", "method", "url", "query_params"]:
if hasattr(request, attr):
setattr(incident_request, attr, getattr(request, attr))

return any_permission(
permissions=[IncidentEditPermission],
request=incident_request,
)


class IncidentReporterPermission(BasePermission):
def has_required_permissions(
self,
Expand Down
41 changes: 30 additions & 11 deletions src/dispatch/task/views.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import json
from fastapi import APIRouter, HTTPException, Query, status
from fastapi import APIRouter, HTTPException, Query, status, Depends


from dispatch.auth.service import CurrentUser
from dispatch.auth.permissions import PermissionsDependency, IncidentTaskCreateEditPermission
from dispatch.common.utils.views import create_pydantic_include
from dispatch.database.core import DbSession
from dispatch.database.service import CommonParameters, search_filter_sort_paginate
Expand Down Expand Up @@ -43,7 +44,12 @@ def get_tasks(common: CommonParameters, include: list[str] = Query([], alias="in
return json.loads(TaskPagination(**pagination).json())


@router.post("", response_model=TaskRead, tags=["tasks"])
@router.post(
"",
response_model=TaskRead,
tags=["tasks"],
dependencies=[Depends(PermissionsDependency([IncidentTaskCreateEditPermission]))],
)
def create_task(
db_session: DbSession,
task_in: TaskCreate,
Expand All @@ -64,11 +70,12 @@ def create_task(
return task


@router.post("/ticket/{task_id}", tags=["tasks"])
def create_ticket(
db_session: DbSession,
task_id: PrimaryKey,
):
@router.post(
"/ticket/{task_id}",
tags=["tasks"],
dependencies=[Depends(PermissionsDependency([IncidentTaskCreateEditPermission]))],
)
def create_ticket(db_session: DbSession, task_id: PrimaryKey, current_user: CurrentUser):
"""Creates a ticket for an existing task."""
task = get(db_session=db_session, task_id=task_id)
if not task:
Expand All @@ -79,8 +86,15 @@ def create_ticket(
return create_task_ticket(task=task, db_session=db_session)


@router.put("/{task_id}", response_model=TaskRead, tags=["tasks"])
def update_task(db_session: DbSession, task_id: PrimaryKey, task_in: TaskUpdate):
@router.put(
"/{task_id}",
response_model=TaskRead,
tags=["tasks"],
dependencies=[Depends(PermissionsDependency([IncidentTaskCreateEditPermission]))],
)
def update_task(
db_session: DbSession, task_id: PrimaryKey, task_in: TaskUpdate, current_user: CurrentUser
):
"""Updates an existing task."""
task = get(db_session=db_session, task_id=task_id)
if not task:
Expand All @@ -104,8 +118,13 @@ def update_task(db_session: DbSession, task_id: PrimaryKey, task_in: TaskUpdate)
return task


@router.delete("/{task_id}", response_model=None, tags=["tasks"])
def delete_task(db_session: DbSession, task_id: PrimaryKey):
@router.delete(
"/{task_id}",
response_model=None,
tags=["tasks"],
dependencies=[Depends(PermissionsDependency([IncidentTaskCreateEditPermission]))],
)
def delete_task(db_session: DbSession, task_id: PrimaryKey, current_user: CurrentUser):
"""Deletes an existing task."""
task = get(db_session=db_session, task_id=task_id)
if not task:
Expand Down
Loading