diff --git a/infrastructure/instance/sqs_id_sync.tf b/infrastructure/instance/sqs_id_sync.tf index 6cf8add56b..e545c299a1 100644 --- a/infrastructure/instance/sqs_id_sync.tf +++ b/infrastructure/instance/sqs_id_sync.tf @@ -1,7 +1,7 @@ resource "aws_sqs_queue" "id_sync_queue" { name = "imms-${local.resource_scope}-id-sync-queue" kms_master_key_id = data.aws_kms_key.existing_id_sync_sqs_encryption_key.arn - visibility_timeout_seconds = 360 + visibility_timeout_seconds = 1080 # as per AWS docs to be 6 times the Lambda function timeout but kept to 3 times redrive_policy = jsonencode({ deadLetterTargetArn = aws_sqs_queue.id_sync_dlq.arn maxReceiveCount = 4 diff --git a/lambdas/id_sync/src/id_sync.py b/lambdas/id_sync/src/id_sync.py index 9441458472..ba905bc63a 100644 --- a/lambdas/id_sync/src/id_sync.py +++ b/lambdas/id_sync/src/id_sync.py @@ -1,8 +1,8 @@ """ -- Parses the incoming AWS event into `AwsLambdaEvent` and iterate its `records`. -- Delegate each record to `process_record` and collect `nhs_number` from each result. -- If any record has status == "error" raise `IdSyncException` with aggregated nhs_numbers. -- Any unexpected error is wrapped into `IdSyncException(message="Error processing id_sync event")`. +- Parses the incoming AWS event into `AwsLambdaEvent` and iterates its `records`. +- Delegates each record to `process_record` with per-record exception isolation. +- Returns {"batchItemFailures": [...]} for any failed records so SQS only re-drives the failing messages. +- A handler-level exception (bad event schema etc.) re-raises to trigger full batch retry. """ from typing import Any @@ -10,7 +10,6 @@ from common.aws_lambda_event import AwsLambdaEvent from common.clients import STREAM_NAME, logger from common.log_decorator import logging_decorator -from exceptions.id_sync_exception import IdSyncException from record_processor import process_record @@ -25,28 +24,32 @@ def handler(event_data: dict[str, Any], _context) -> dict[str, Any]: logger.info("id_sync processing event with %d records", len(records)) - error_count = 0 + batch_item_failures = [] for record in records: - result = process_record(record) - - if result.get("status") == "error": - error_count += 1 - - if error_count > 0: - raise IdSyncException( - message=f"Processed {len(records)} records with {error_count} errors", - ) + try: + result = process_record(record) + if result.get("status") == "error": + message_id = record.get("messageId") + logger.error( + "id_sync record processing failed for messageId: %s — %s", + message_id, + result.get("message"), + ) + batch_item_failures.append({"itemIdentifier": message_id}) + except Exception: + message_id = record.get("messageId") + logger.exception("Unexpected error processing messageId: %s", message_id) + batch_item_failures.append({"itemIdentifier": message_id}) + + if batch_item_failures: + logger.error("id_sync completed with %d/%d failures", len(batch_item_failures), len(records)) + return {"batchItemFailures": batch_item_failures} response = {"status": "success", "message": f"Successfully processed {len(records)} records"} - logger.info("id_sync handler completed: %s", response) return response - except IdSyncException as e: - logger.exception(f"id_sync error: {e.message}") - raise except Exception: - msg = "Error processing id_sync event" - logger.exception(msg) - raise IdSyncException(message=msg) + logger.exception("Unexpected error processing id_sync event") + raise diff --git a/lambdas/id_sync/tests/test_id_sync.py b/lambdas/id_sync/tests/test_id_sync.py index 619ffd9818..55c41f8476 100644 --- a/lambdas/id_sync/tests/test_id_sync.py +++ b/lambdas/id_sync/tests/test_id_sync.py @@ -2,15 +2,12 @@ from unittest.mock import MagicMock, patch with patch("common.log_decorator.logging_decorator") as mock_decorator: - mock_decorator.return_value = lambda f: f # Pass-through decorator - from exceptions.id_sync_exception import IdSyncException + mock_decorator.return_value = lambda f: f from id_sync import handler class TestIdSyncHandler(unittest.TestCase): def setUp(self): - """Set up all patches and test fixtures""" - # Patch all dependencies self.aws_lambda_event_patcher = patch("id_sync.AwsLambdaEvent") self.mock_aws_lambda_event = self.aws_lambda_event_patcher.start() @@ -19,266 +16,161 @@ def setUp(self): self.logger_patcher = patch("id_sync.logger") self.mock_logger = self.logger_patcher.start() - # Set up test data - self.single_sqs_event = {"Records": [{"body": '{"source":"aws:sqs","data":"test-data"}'}]} + self.single_sqs_event = {"Records": [{"messageId": "msg-1", "body": '{"subject":"9000000001"}'}]} self.multi_sqs_event = { "Records": [ - { - "body": ('{"source":"aws:sqs","data":"a"}'), - }, - { - "body": ('{"source":"aws:sqs","data":"b"}'), - }, + {"messageId": "msg-1", "body": '{"subject":"9000000001"}'}, + {"messageId": "msg-2", "body": '{"subject":"9000000002"}'}, + {"messageId": "msg-3", "body": '{"subject":"9000000003"}'}, ] } - self.empty_event = {"Records": []} self.no_records_event = {"someOtherKey": "value"} def tearDown(self): - """Stop all patches""" patch.stopall() - def test_handler_success_single_record(self): - """Test handler with single successful record""" - # Setup mocks + def test_single_record_success(self): mock_event = MagicMock() - mock_event.records = [MagicMock()] + mock_event.records = [{"messageId": "msg-1"}] self.mock_aws_lambda_event.return_value = mock_event + self.mock_process_record.return_value = {"status": "success"} - self.mock_process_record.return_value = { - "status": "success", - } - - # Call handler result = handler(self.single_sqs_event, None) - # Assertions - self.mock_aws_lambda_event.assert_called_once_with(self.single_sqs_event) - self.mock_process_record.assert_called_once_with(mock_event.records[0]) - self.assertEqual(result["status"], "success") self.assertEqual(result["message"], "Successfully processed 1 records") + self.assertNotIn("batchItemFailures", result) - def test_handler_success_multiple_records(self): - """Test handler with multiple successful records""" - # Setup mocks + def test_multiple_records_all_success(self): mock_event = MagicMock() - mock_event.records = [MagicMock(), MagicMock()] + mock_event.records = [{"messageId": "msg-1"}, {"messageId": "msg-2"}, {"messageId": "msg-3"}] self.mock_aws_lambda_event.return_value = mock_event - self.mock_process_record.side_effect = [ {"status": "success"}, {"status": "success"}, + {"status": "success"}, ] - # Call handler result = handler(self.multi_sqs_event, None) - # Assertions - self.assertEqual(self.mock_process_record.call_count, 2) self.assertEqual(result["status"], "success") - self.assertEqual(result["message"], "Successfully processed 2 records") + self.assertEqual(result["message"], "Successfully processed 3 records") + self.assertNotIn("batchItemFailures", result) - def test_handler_error_single_record(self): - """Test handler with single failed record""" - # Setup mocks + def test_single_record_error_returns_batch_item_failure(self): + """A record returning status=error must appear in batchItemFailures — not raise.""" mock_event = MagicMock() - mock_event.records = [MagicMock()] + mock_event.records = [{"messageId": "msg-1"}] self.mock_aws_lambda_event.return_value = mock_event + self.mock_process_record.return_value = {"status": "error", "message": "PDS timeout"} - self.mock_process_record.return_value = { - "status": "error", - } - - # Call handler - with self.assertRaises(IdSyncException) as exception_context: - handler(self.single_sqs_event, None) - - exception = exception_context.exception - # Assertions - self.mock_process_record.assert_called_once_with(mock_event.records[0]) - self.mock_logger.info.assert_any_call("id_sync processing event with %d records", 1) + result = handler(self.single_sqs_event, None) - self.assertEqual(exception.message, "Processed 1 records with 1 errors") + self.assertIn("batchItemFailures", result) + self.assertEqual(result["batchItemFailures"], [{"itemIdentifier": "msg-1"}]) - def test_handler_mixed_success_error(self): - """Test handler with mix of successful and failed records""" - # Setup mocks + def test_mixed_batch_only_failures_in_response(self): + """Only the failing messageId appears in batchItemFailures; successes are not listed.""" mock_event = MagicMock() - mock_event.records = [MagicMock(), MagicMock(), MagicMock()] + mock_event.records = [ + {"messageId": "msg-1"}, + {"messageId": "msg-2"}, + {"messageId": "msg-3"}, + ] self.mock_aws_lambda_event.return_value = mock_event - self.mock_process_record.side_effect = [ {"status": "success"}, - {"status": "error"}, + {"status": "error", "message": "PDS returned 404"}, {"status": "success"}, ] - # Call handler - with self.assertRaises(IdSyncException) as exception_context: - handler(self.multi_sqs_event, None) - - error = exception_context.exception - # Assertions - self.assertEqual(self.mock_process_record.call_count, 3) + result = handler(self.multi_sqs_event, None) - self.assertEqual(error.message, "Processed 3 records with 1 errors") + self.assertIn("batchItemFailures", result) + self.assertEqual(result["batchItemFailures"], [{"itemIdentifier": "msg-2"}]) - def test_handler_all_records_fail(self): - """Test handler when all records fail""" - # Setup mocks + def test_all_records_fail_all_in_batch_item_failures(self): mock_event = MagicMock() - mock_event.records = [MagicMock(), MagicMock()] + mock_event.records = [{"messageId": "msg-1"}, {"messageId": "msg-2"}] self.mock_aws_lambda_event.return_value = mock_event + self.mock_process_record.side_effect = [ + {"status": "error", "message": "err"}, + {"status": "error", "message": "err"}, + ] + result = handler(self.multi_sqs_event, None) + + self.assertEqual( + result["batchItemFailures"], + [{"itemIdentifier": "msg-1"}, {"itemIdentifier": "msg-2"}], + ) + + def test_process_record_raises_exception_is_isolated_per_record(self): + """ + Core regression test for the alarm incident. + If process_record throws for one record, only that messageId is in batchItemFailures. + The other records still process normally — no full-batch failure. + """ + mock_event = MagicMock() + mock_event.records = [ + {"messageId": "msg-1"}, + {"messageId": "msg-2"}, + {"messageId": "msg-3"}, + ] + self.mock_aws_lambda_event.return_value = mock_event self.mock_process_record.side_effect = [ - {"status": "error"}, - {"status": "error"}, + {"status": "success"}, + RuntimeError("Unexpected crash in process_record"), + {"status": "success"}, ] - # Call handler - with self.assertRaises(IdSyncException) as exception_context: - handler(self.multi_sqs_event, None) - exception = exception_context.exception - # Assertions - self.assertEqual(self.mock_process_record.call_count, 2) + result = handler(self.multi_sqs_event, None) - self.assertEqual(exception.message, "Processed 2 records with 2 errors") + self.assertIn("batchItemFailures", result) + self.assertEqual(result["batchItemFailures"], [{"itemIdentifier": "msg-2"}]) + # Verify the other two records were still processed + self.assertEqual(self.mock_process_record.call_count, 3) - def test_handler_empty_records(self): - """Test handler with empty records""" - # Setup mocks + def test_process_record_raises_logs_exception(self): + """Unexpected exception must be logged at ERROR level.""" mock_event = MagicMock() - mock_event.records = [] + mock_event.records = [{"messageId": "msg-1"}] self.mock_aws_lambda_event.return_value = mock_event + self.mock_process_record.side_effect = RuntimeError("boom") - # Call handler - result = handler(self.empty_event, None) - - # Assertions - self.mock_aws_lambda_event.assert_called_once_with(self.empty_event) - self.mock_process_record.assert_not_called() + handler(self.single_sqs_event, None) - self.assertEqual(result["status"], "success") - self.assertEqual(result["message"], "No records found in event") + self.mock_logger.exception.assert_called_once_with("Unexpected error processing messageId: %s", "msg-1") - def test_handler_no_records_key(self): - """Test handler with no Records key in event""" - # Setup mocks + def test_empty_records_returns_success(self): mock_event = MagicMock() mock_event.records = [] self.mock_aws_lambda_event.return_value = mock_event - # Call handler - result = handler(self.no_records_event, None) - - # Assertions - self.mock_aws_lambda_event.assert_called_once_with(self.no_records_event) - self.mock_process_record.assert_not_called() + result = handler(self.empty_event, None) self.assertEqual(result["status"], "success") self.assertEqual(result["message"], "No records found in event") - - def test_handler_aws_lambda_event_exception(self): - """Test handler when AwsLambdaEvent raises exception""" - # Setup mock to raise exception - self.mock_aws_lambda_event.side_effect = Exception("AwsLambdaEvent creation failed") - - # Call handler - with self.assertRaises(IdSyncException) as exception_context: - handler(self.single_sqs_event, None) - - result = exception_context.exception - # Assertions - self.mock_aws_lambda_event.assert_called_once_with(self.single_sqs_event) - self.mock_logger.exception.assert_called_once_with("Error processing id_sync event") self.mock_process_record.assert_not_called() - self.assertEqual(result.message, "Error processing id_sync event") - - def test_handler_process_record_exception(self): - """Test handler when process_record raises exception""" - # Setup mocks - mock_event = MagicMock() - mock_event.records = [MagicMock()] - self.mock_aws_lambda_event.return_value = mock_event - - self.mock_process_record.side_effect = Exception("Process record failed") - - # Call handler - with self.assertRaises(IdSyncException) as exception_context: - handler(self.single_sqs_event, None) - exception = exception_context.exception - # Assertions - self.mock_process_record.assert_called_once_with(mock_event.records[0]) - self.mock_logger.exception.assert_called_once_with("Error processing id_sync event") - - self.assertEqual(exception.message, "Error processing id_sync event") - - def test_handler_process_record_missing_nhs_number(self): - """Test handler when process_record returns error and missing NHS number""" - - # Setup mocks - mock_event = MagicMock() - mock_event.records = [MagicMock()] - self.mock_aws_lambda_event.return_value = mock_event - - # Return result without 'nhs_number' but with an 'error' status - self.mock_process_record.return_value = { - "status": "error", - "message": "Missing NHS number", - # No 'nhs_number' - } - - # Call handler and expect exception - with self.assertRaises(IdSyncException) as exception_context: - handler(self.single_sqs_event, None) - - exception = exception_context.exception - - self.assertIsInstance(exception, IdSyncException) - self.assertEqual(exception.message, "Processed 1 records with 1 errors") - self.mock_logger.exception.assert_called_once_with(f"id_sync error: {exception.message}") - - def test_handler_context_parameter_ignored(self): - """Test that context parameter is properly ignored""" - # Setup mocks + def test_no_records_key_returns_success(self): mock_event = MagicMock() - mock_event.records = [MagicMock()] + mock_event.records = [] self.mock_aws_lambda_event.return_value = mock_event - self.mock_process_record.return_value = { - "status": "success", - } - - # Call handler with mock context - mock_context = MagicMock() - result = handler(self.single_sqs_event, mock_context) + result = handler(self.no_records_event, None) - # Should work normally regardless of context + self.mock_process_record.assert_not_called() self.assertEqual(result["status"], "success") - def test_handler_error_count_tracking(self): - """Test that error count is properly tracked""" - # Setup mocks - mock_event = MagicMock() - mock_event.records = [MagicMock(), MagicMock(), MagicMock(), MagicMock()] - self.mock_aws_lambda_event.return_value = mock_event + def test_aws_lambda_event_raises_propagates_as_exception(self): + """ + A crash before the record loop (e.g. malformed event schema) must re-raise + so SQS retries the full batch — nothing was processed yet. + """ + self.mock_aws_lambda_event.side_effect = Exception("malformed event") - self.mock_process_record.side_effect = [ - {"status": "success"}, - {"status": "error"}, - {"status": "error"}, - {"status": "success"}, - ] - - # Call handler - with self.assertRaises(IdSyncException) as exception_context: - handler(self.multi_sqs_event, None) - exception = exception_context.exception - # Assertions - should track 2 errors out of 4 records - self.assertEqual(self.mock_process_record.call_count, 4) - - self.assertEqual(exception.message, "Processed 4 records with 2 errors") + with self.assertRaises(Exception, msg="malformed event"): + handler(self.single_sqs_event, None) diff --git a/lambdas/shared/src/common/api_clients/retry.py b/lambdas/shared/src/common/api_clients/retry.py index 1951fa50f7..5694aebabf 100644 --- a/lambdas/shared/src/common/api_clients/retry.py +++ b/lambdas/shared/src/common/api_clients/retry.py @@ -40,7 +40,21 @@ def request_with_retry_backoff( if data is not None: api_request_kwargs["data"] = data - response = requests.request(**api_request_kwargs) + try: + response = requests.request(**api_request_kwargs) + except (requests.exceptions.Timeout, requests.exceptions.ConnectionError) as e: + if request_attempt < max_retries: + logger.warning( + "Network error on attempt %d/%d: %s. Retrying...", + request_attempt + 1, + max_retries + 1, + e, + ) + time.sleep(Constants.API_CLIENTS_BACKOFF_SECONDS * (2**request_attempt)) + continue + logger.error("Network error after %d attempts: %s", max_retries + 1, e) + raise + if response.status_code not in Constants.RETRYABLE_STATUS_CODES: break @@ -49,7 +63,6 @@ def request_with_retry_backoff( f"Retryable response. Status={response.status_code}. " f"Attempt={request_attempt + 1}/{max_retries + 1}. Retrying..." ) - time.sleep(Constants.API_CLIENTS_BACKOFF_SECONDS * (2**request_attempt)) return response diff --git a/lambdas/shared/tests/test_common/api_clients/test_retry.py b/lambdas/shared/tests/test_common/api_clients/test_retry.py index 4ae1291880..d784a81088 100644 --- a/lambdas/shared/tests/test_common/api_clients/test_retry.py +++ b/lambdas/shared/tests/test_common/api_clients/test_retry.py @@ -1,6 +1,8 @@ import unittest from unittest.mock import MagicMock, call, patch +import requests + from common.api_clients.constants import Constants from common.api_clients.errors import ( BadRequestError, @@ -137,3 +139,85 @@ def test_backoff_values_are_exponential(self, mock_get, mock_sleep): mock_sleep.assert_has_calls( [call(Constants.API_CLIENTS_BACKOFF_SECONDS), call(Constants.API_CLIENTS_BACKOFF_SECONDS * 2)] ) + + +class TestRequestWithRetryBackoffNetworkErrors(unittest.TestCase): + """ + Regression tests for the ReadTimeout incident (imms-blue-id-sync-lambda-error, 17 Mar 2026). + Verifies that network-level exceptions are retried identically to retryable HTTP status codes. + """ + + @patch("time.sleep") + @patch("requests.request") + def test_read_timeout_is_retried(self, mock_request, mock_sleep): + """ReadTimeout on attempt 1 then success on attempt 2 — must not raise.""" + mock_request.side_effect = [ + requests.exceptions.ReadTimeout("read timeout=5"), + _make_response(200), + ] + + resp = request_with_retry_backoff("GET", "http://example.com", {}) + + self.assertEqual(resp.status_code, 200) + self.assertEqual(mock_request.call_count, 2) + self.assertEqual(mock_sleep.call_count, 1) + + @patch("time.sleep") + @patch("requests.request") + def test_connection_error_is_retried(self, mock_request, mock_sleep): + """ConnectionError on attempt 1 then success on attempt 2 — must not raise.""" + mock_request.side_effect = [ + requests.exceptions.ConnectionError("connection refused"), + _make_response(200), + ] + + resp = request_with_retry_backoff("GET", "http://example.com", {}) + + self.assertEqual(resp.status_code, 200) + self.assertEqual(mock_request.call_count, 2) + + @patch("time.sleep") + @patch("requests.request") + def test_timeout_exhausted_raises_after_max_retries(self, mock_request, mock_sleep): + """ + ReadTimeout on every attempt — must raise after max_retries+1 total attempts. + With API_CLIENTS_MAX_RETRIES=2, that is 3 attempts total. + """ + mock_request.side_effect = requests.exceptions.ReadTimeout("read timeout=5") + + with self.assertRaises(requests.exceptions.ReadTimeout): + request_with_retry_backoff("GET", "http://example.com", {}) + + self.assertEqual(mock_request.call_count, Constants.API_CLIENTS_MAX_RETRIES + 1) + + @patch("time.sleep") + @patch("requests.request") + def test_timeout_retry_backoff_is_exponential(self, mock_request, mock_sleep): + """Sleep intervals between network-error retries must be identical to HTTP retryable backoff.""" + mock_request.side_effect = requests.exceptions.ReadTimeout("read timeout=5") + + with self.assertRaises(requests.exceptions.ReadTimeout): + request_with_retry_backoff("GET", "http://example.com", {}) + + mock_sleep.assert_has_calls( + [ + call(Constants.API_CLIENTS_BACKOFF_SECONDS), # after attempt 1: 0.5s + call(Constants.API_CLIENTS_BACKOFF_SECONDS * 2), # after attempt 2: 1.0s + ] + ) + + @patch("time.sleep") + @patch("requests.request") + def test_timeout_then_retryable_status_then_success(self, mock_request, mock_sleep): + """Network error on attempt 1, HTTP 503 on attempt 2, success on attempt 3 — full coverage.""" + mock_request.side_effect = [ + requests.exceptions.ReadTimeout("read timeout=5"), + _make_response(503), + _make_response(200), + ] + + resp = request_with_retry_backoff("GET", "http://example.com", {}) + + self.assertEqual(resp.status_code, 200) + self.assertEqual(mock_request.call_count, 3) + self.assertEqual(mock_sleep.call_count, 2)