Skip to content

Commit 84872d0

Browse files
[CDAPI-118]: Add x-correlation-id
1 parent 15c9cbe commit 84872d0

File tree

6 files changed

+179
-4
lines changed

6 files changed

+179
-4
lines changed

pathology-api/lambda_handler.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from pathology_api.fhir.r4.resources import Bundle, OperationOutcome
1414
from pathology_api.handler import handle_request
1515
from pathology_api.logging import get_logger
16+
from pathology_api.request_context import set_correlation_id
1617

1718
_logger = get_logger(__name__)
1819

@@ -102,8 +103,16 @@ def status() -> Response[str]:
102103
return Response(status_code=200, body="OK", headers={"Content-Type": "text/plain"})
103104

104105

106+
_CORRELATION_ID_HEADER = "nhsd-correlation-id"
107+
108+
105109
@app.post("/FHIR/R4/Bundle")
106110
def post_result() -> Response[str]:
111+
correlation_id = app.current_event.headers.get(_CORRELATION_ID_HEADER)
112+
if not correlation_id:
113+
raise ValidationError(f"Missing required header: {_CORRELATION_ID_HEADER}")
114+
set_correlation_id(correlation_id)
115+
107116
_logger.debug("Post result endpoint called.")
108117

109118
try:

pathology-api/src/pathology_api/logging.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,18 @@
1+
import logging
12
from typing import Any, Protocol
23

34
from aws_lambda_powertools import Logger
45

6+
from pathology_api.request_context import get_correlation_id
7+
8+
9+
class _CorrelationIdFilter(logging.Filter):
10+
"""Injects the current correlation ID into every log record."""
11+
12+
def filter(self, record: logging.LogRecord) -> bool:
13+
record.correlation_id = get_correlation_id()
14+
return True
15+
516

617
class LogProvider(Protocol):
718
"""Protocol defining required contract for a logger."""
@@ -19,4 +30,6 @@ def exception(self, msg: str, *args: Any, **kwargs: Any) -> None: ...
1930

2031
def get_logger(service: str) -> LogProvider:
2132
"""Get a configured logger instance."""
22-
return Logger(service=service, level="DEBUG", serialize_stacktrace=True)
33+
logger = Logger(service=service, level="DEBUG", serialize_stacktrace=True)
34+
logger.addFilter(_CorrelationIdFilter())
35+
return logger
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from contextvars import ContextVar
2+
3+
_correlation_id: ContextVar[str] = ContextVar("correlation_id", default="")
4+
5+
6+
def set_correlation_id(value: str) -> None:
7+
"""Set the correlation ID for the current request context."""
8+
_correlation_id.set(value)
9+
10+
11+
def get_correlation_id() -> str:
12+
"""Get the correlation ID for the current request context."""
13+
return _correlation_id.get()
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import logging
2+
3+
from pathology_api.logging import (
4+
_CorrelationIdFilter,
5+
get_logger,
6+
)
7+
from pathology_api.request_context import set_correlation_id
8+
9+
10+
class TestCorrelationIdFilter:
11+
def test_filter_injects_correlation_id_into_log_record(self) -> None:
12+
set_correlation_id("test-abc-123")
13+
14+
record = logging.LogRecord(
15+
name="test",
16+
level=logging.INFO,
17+
pathname="test.py",
18+
lineno=1,
19+
msg="test message",
20+
args=None,
21+
exc_info=None,
22+
)
23+
24+
f = _CorrelationIdFilter()
25+
result = f.filter(record)
26+
27+
assert result is True
28+
assert record.correlation_id == "test-abc-123" # type: ignore[attr-defined]
29+
30+
def test_filter_uses_empty_default_when_no_correlation_id_set(self) -> None:
31+
set_correlation_id("")
32+
33+
record = logging.LogRecord(
34+
name="test",
35+
level=logging.INFO,
36+
pathname="test.py",
37+
lineno=1,
38+
msg="test message",
39+
args=None,
40+
exc_info=None,
41+
)
42+
43+
f = _CorrelationIdFilter()
44+
f.filter(record)
45+
46+
assert record.correlation_id == "" # type: ignore[attr-defined]
47+
48+
49+
class TestGetLogger:
50+
def test_get_logger_attaches_correlation_id_filter(self) -> None:
51+
logger = get_logger("test-service")
52+
53+
filters = getattr(logger, "filters", [])
54+
assert any(isinstance(f, _CorrelationIdFilter) for f in filters)
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from pathology_api.request_context import get_correlation_id, set_correlation_id
2+
3+
4+
class TestSetAndGetCorrelationId:
5+
def test_set_and_get_correlation_id(self) -> None:
6+
set_correlation_id("round-trip-test-123")
7+
assert get_correlation_id() == "round-trip-test-123"
8+
9+
def test_default_correlation_id_is_empty(self) -> None:
10+
set_correlation_id("")
11+
assert get_correlation_id() == ""

pathology-api/test_lambda_handler.py

Lines changed: 78 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pathology_api.exception import ValidationError
99
from pathology_api.fhir.r4.elements import LogicalReference, PatientIdentifier
1010
from pathology_api.fhir.r4.resources import Bundle, Composition, OperationOutcome
11+
from pathology_api.request_context import get_correlation_id
1112

1213

1314
class TestHandler:
@@ -16,9 +17,11 @@ def _create_test_event(
1617
body: str | None = None,
1718
path_params: str | None = None,
1819
request_method: str | None = None,
20+
headers: dict[str, str] | None = None,
1921
) -> dict[str, Any]:
2022
return {
2123
"body": body,
24+
"headers": headers or {},
2225
"requestContext": {
2326
"http": {
2427
"path": f"/{path_params}",
@@ -58,6 +61,7 @@ def test_create_test_result_success(self) -> None:
5861
body=bundle.model_dump_json(by_alias=True),
5962
path_params="FHIR/R4/Bundle",
6063
request_method="POST",
64+
headers={"nhsd-correlation-id": "test-correlation-id"},
6165
)
6266
context = LambdaContext()
6367

@@ -76,9 +80,72 @@ def test_create_test_result_success(self) -> None:
7680
# A UUID value so can only check its presence.
7781
assert response_bundle.id is not None
7882

83+
def test_correlation_id_is_set_from_request_header(self) -> None:
84+
correlation_id = "test-correlation-id-abc-123"
85+
bundle = Bundle.create(
86+
type="document",
87+
entry=[
88+
Bundle.Entry(
89+
fullUrl="composition",
90+
resource=Composition.create(
91+
subject=LogicalReference(
92+
PatientIdentifier.from_nhs_number("nhs_number")
93+
)
94+
),
95+
)
96+
],
97+
)
98+
event = self._create_test_event(
99+
body=bundle.model_dump_json(by_alias=True),
100+
path_params="FHIR/R4/Bundle",
101+
request_method="POST",
102+
headers={"nhsd-correlation-id": correlation_id},
103+
)
104+
context = LambdaContext()
105+
106+
handler(event, context)
107+
108+
assert get_correlation_id() == correlation_id
109+
110+
def test_missing_correlation_id_header_returns_400(self) -> None:
111+
bundle = Bundle.create(
112+
type="document",
113+
entry=[
114+
Bundle.Entry(
115+
fullUrl="composition",
116+
resource=Composition.create(
117+
subject=LogicalReference(
118+
PatientIdentifier.from_nhs_number("nhs_number")
119+
)
120+
),
121+
)
122+
],
123+
)
124+
event = self._create_test_event(
125+
body=bundle.model_dump_json(by_alias=True),
126+
path_params="FHIR/R4/Bundle",
127+
request_method="POST",
128+
)
129+
context = LambdaContext()
130+
131+
response = handler(event, context)
132+
133+
assert response["statusCode"] == 400
134+
assert response["headers"] == {"Content-Type": "application/fhir+json"}
135+
136+
returned_issue = self._parse_returned_issue(response["body"])
137+
assert returned_issue["severity"] == "error"
138+
assert returned_issue["code"] == "invalid"
139+
assert (
140+
returned_issue["diagnostics"]
141+
== "Missing required header: nhsd-correlation-id"
142+
)
143+
79144
def test_create_test_result_no_payload(self) -> None:
80145
event = self._create_test_event(
81-
path_params="FHIR/R4/Bundle", request_method="POST"
146+
path_params="FHIR/R4/Bundle",
147+
request_method="POST",
148+
headers={"nhsd-correlation-id": "test-correlation-id"},
82149
)
83150
context = LambdaContext()
84151

@@ -98,7 +165,10 @@ def test_create_test_result_no_payload(self) -> None:
98165

99166
def test_create_test_result_empty_payload(self) -> None:
100167
event = self._create_test_event(
101-
body="{}", path_params="FHIR/R4/Bundle", request_method="POST"
168+
body="{}",
169+
path_params="FHIR/R4/Bundle",
170+
request_method="POST",
171+
headers={"nhsd-correlation-id": "test-correlation-id"},
102172
)
103173
context = LambdaContext()
104174

@@ -118,7 +188,10 @@ def test_create_test_result_empty_payload(self) -> None:
118188

119189
def test_create_test_result_invalid_json(self) -> None:
120190
event = self._create_test_event(
121-
body="invalid json", path_params="FHIR/R4/Bundle", request_method="POST"
191+
body="invalid json",
192+
path_params="FHIR/R4/Bundle",
193+
request_method="POST",
194+
headers={"nhsd-correlation-id": "test-correlation-id"},
122195
)
123196
context = LambdaContext()
124197

@@ -169,6 +242,7 @@ def test_create_test_result_processing_error(
169242
body=bundle.model_dump_json(by_alias=True),
170243
path_params="FHIR/R4/Bundle",
171244
request_method="POST",
245+
headers={"nhsd-correlation-id": "test-correlation-id"},
172246
)
173247
context = LambdaContext()
174248

@@ -207,6 +281,7 @@ def test_create_test_result_model_validate_error(
207281
body=bundle.model_dump_json(by_alias=True),
208282
path_params="FHIR/R4/Bundle",
209283
request_method="POST",
284+
headers={"nhsd-correlation-id": "test-correlation-id"},
210285
)
211286
context = LambdaContext()
212287

0 commit comments

Comments
 (0)