Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 93 additions & 16 deletions fia_api/routers/jobs.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
"""Jobs API Router"""
"""Jobs API Router."""

import io
import json
import os
import zipfile
from http import HTTPStatus
from typing import Annotated, Literal
from typing import Annotated, Any, Literal

from fastapi import APIRouter, Depends, Query, Response
from fastapi.responses import FileResponse, StreamingResponse
from fastapi.security import HTTPAuthorizationCredentials
from sqlalchemy.orm import Session

from fia_api.core.auth.tokens import JWTAPIBearer, get_user_from_token
from fia_api.core.cache import cache_get_json, cache_set_json, hash_key
from fia_api.core.exceptions import (
AuthError,
NoFilesAddedError,
Expand All @@ -34,6 +35,8 @@

JobsRouter = APIRouter(tags=["jobs"])
jwt_api_security = JWTAPIBearer()
JOB_LIST_CACHE_TTL_SECONDS = int(os.environ.get("JOB_LIST_CACHE_TTL_SECONDS", "15"))
JOB_COUNT_CACHE_TTL_SECONDS = int(os.environ.get("JOB_COUNT_CACHE_TTL_SECONDS", "15"))

OrderField = Literal[
"start",
Expand All @@ -49,6 +52,11 @@
]


def _jobs_cache_key(scope: str, payload: dict[str, Any]) -> str:
digest = hash_key(json.dumps(payload, sort_keys=True, separators=(",", ":")))
return f"fia_api:jobs:{scope}:{digest}"


@JobsRouter.get("/jobs", tags=["jobs"])
async def get_jobs(
credentials: Annotated[HTTPAuthorizationCredentials, Depends(jwt_api_security)],
Expand Down Expand Up @@ -87,6 +95,24 @@ async def get_jobs(
else:
user_number = user.user_number

cache_key = None
if JOB_LIST_CACHE_TTL_SECONDS > 0:
cache_key = _jobs_cache_key(
"list:all",
{
"user_number": user_number,
"include_run": include_run,
"limit": limit,
"offset": offset,
"order_by": order_by,
"order_direction": order_direction,
"filters": filters,
},
)
cached = cache_get_json(cache_key)
if isinstance(cached, list):
return cached

jobs = get_all_jobs(
session,
limit=limit,
Expand All @@ -98,8 +124,14 @@ async def get_jobs(
)

if include_run:
return [JobWithRunResponse.from_job(j) for j in jobs]
return [JobResponse.from_job(j) for j in jobs]
payload = [JobWithRunResponse.from_job(j).model_dump(mode="json") for j in jobs]
else:
payload = [JobResponse.from_job(j).model_dump(mode="json") for j in jobs]

if cache_key:
cache_set_json(cache_key, payload, JOB_LIST_CACHE_TTL_SECONDS)

return payload # type: ignore[return-value]


@JobsRouter.get("/instrument/{instrument}/jobs", tags=["jobs"])
Expand Down Expand Up @@ -144,6 +176,25 @@ async def get_jobs_by_instrument(
else:
user_number = user.user_number

cache_key = None
if JOB_LIST_CACHE_TTL_SECONDS > 0:
cache_key = _jobs_cache_key(
"list:instrument",
{
"instrument": instrument,
"user_number": user_number,
"include_run": include_run,
"limit": limit,
"offset": offset,
"order_by": order_by,
"order_direction": order_direction,
"filters": filters,
},
)
cached = cache_get_json(cache_key)
if isinstance(cached, list):
return cached

jobs = get_job_by_instrument(
instrument,
session,
Expand All @@ -156,8 +207,14 @@ async def get_jobs_by_instrument(
)

if include_run:
return [JobWithRunResponse.from_job(j) for j in jobs]
return [JobResponse.from_job(j) for j in jobs]
payload = [JobWithRunResponse.from_job(j).model_dump(mode="json") for j in jobs]
else:
payload = [JobResponse.from_job(j).model_dump(mode="json") for j in jobs]

if cache_key:
cache_set_json(cache_key, payload, JOB_LIST_CACHE_TTL_SECONDS)

return payload # type: ignore[return-value]


@JobsRouter.get("/instrument/{instrument}/jobs/count", tags=["jobs"])
Expand All @@ -174,9 +231,20 @@ async def count_jobs_for_instrument(
:return: CountResponse containing the count
"""
instrument = instrument.upper()
return CountResponse(
count=count_jobs_by_instrument(instrument, session, filters=json.loads(filters) if filters else None)
)
parsed_filters = json.loads(filters) if filters else None

cache_key = None
if JOB_COUNT_CACHE_TTL_SECONDS > 0:
cache_key = _jobs_cache_key("count:instrument", {"instrument": instrument, "filters": parsed_filters})
cached = cache_get_json(cache_key)
if isinstance(cached, dict) and "count" in cached:
return CountResponse.model_validate(cached)

count = count_jobs_by_instrument(instrument, session, filters=parsed_filters)
payload = {"count": count}
if cache_key:
cache_set_json(cache_key, payload, JOB_COUNT_CACHE_TTL_SECONDS)
return CountResponse.model_validate(payload)


@JobsRouter.get("/job/{job_id}", tags=["jobs"])
Expand Down Expand Up @@ -229,13 +297,22 @@ async def count_all_jobs(
session: Annotated[Session, Depends(get_db_session)],
filters: Annotated[str | None, Query(description="json string of filters")] = None,
) -> CountResponse:
"""
Count all jobs
\f
:param filters: json string of filters
:return: CountResponse containing the count
"""
return CountResponse(count=count_jobs(session, filters=json.loads(filters) if filters else None))
"""Count all jobs \f :param filters: json string of filters :return:
CountResponse containing the count."""
parsed_filters = json.loads(filters) if filters else None

cache_key = None
if JOB_COUNT_CACHE_TTL_SECONDS > 0:
cache_key = _jobs_cache_key("count:all", {"filters": parsed_filters})
cached = cache_get_json(cache_key)
if isinstance(cached, dict) and "count" in cached:
return CountResponse.model_validate(cached)

count = count_jobs(session, filters=parsed_filters)
payload = {"count": count}
if cache_key:
cache_set_json(cache_key, payload, JOB_COUNT_CACHE_TTL_SECONDS)
return CountResponse.model_validate(payload)


@JobsRouter.get("/job/{job_id}/filename/{filename}", tags=["jobs"])
Expand Down
77 changes: 77 additions & 0 deletions test/e2e/test_jobs_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""Cache behavior tests for jobs endpoints."""

from http import HTTPStatus
from unittest.mock import patch

from starlette.testclient import TestClient

from fia_api.fia_api import app

from .constants import STAFF_HEADER

client = TestClient(app)


@patch("fia_api.routers.jobs.JOB_LIST_CACHE_TTL_SECONDS", 15)
@patch("fia_api.core.auth.tokens.requests.post")
@patch("fia_api.routers.jobs.cache_set_json")
@patch("fia_api.routers.jobs.get_all_jobs")
@patch("fia_api.routers.jobs.cache_get_json")
def test_jobs_list_cache_hit_returns_cached_payload(
mock_cache_get,
mock_get_all_jobs,
mock_cache_set,
mock_post,
):
cached_payload = [
{
"id": 1,
"start": None,
"end": None,
"state": "NOT_STARTED",
"status_message": None,
"inputs": {},
"outputs": None,
"stacktrace": None,
"script": None,
"runner_image": None,
"type": "JobType.AUTOREDUCTION",
}
]
mock_cache_get.return_value = cached_payload
mock_post.return_value.status_code = HTTPStatus.OK

response = client.get("/jobs?limit=1", headers=STAFF_HEADER)

assert response.status_code == HTTPStatus.OK
assert response.json() == cached_payload
mock_get_all_jobs.assert_not_called()
mock_cache_set.assert_not_called()


@patch("fia_api.routers.jobs.JOB_COUNT_CACHE_TTL_SECONDS", 15)
@patch("fia_api.routers.jobs.count_jobs")
@patch("fia_api.routers.jobs.cache_get_json")
def test_jobs_count_cache_hit_returns_cached_payload(mock_cache_get, mock_count_jobs):
cached_payload = {"count": 42}
mock_cache_get.return_value = cached_payload

response = client.get("/jobs/count")

assert response.status_code == HTTPStatus.OK
assert response.json() == cached_payload
mock_count_jobs.assert_not_called()


@patch("fia_api.routers.jobs.JOB_COUNT_CACHE_TTL_SECONDS", 15)
@patch("fia_api.routers.jobs.count_jobs_by_instrument")
@patch("fia_api.routers.jobs.cache_get_json")
def test_jobs_count_by_instrument_cache_hit_returns_cached_payload(mock_cache_get, mock_count_jobs):
cached_payload = {"count": 7}
mock_cache_get.return_value = cached_payload

response = client.get("/instrument/TEST/jobs/count")

assert response.status_code == HTTPStatus.OK
assert response.json() == cached_payload
mock_count_jobs.assert_not_called()