diff --git a/aws_lambda_powertools/utilities/batch/base.py b/aws_lambda_powertools/utilities/batch/base.py index 4cb1337cafd..f924bbf3c42 100644 --- a/aws_lambda_powertools/utilities/batch/base.py +++ b/aws_lambda_powertools/utilities/batch/base.py @@ -335,18 +335,22 @@ def _get_messages_to_report(self) -> list[PartialItemFailures]: # Event Source Data Classes follow python idioms for fields # while Parser/Pydantic follows the event field names to the latter + def _sqs_failure_item_identifier(self, msg) -> str: + # If a message failed due to model validation (e.g., poison pill) + # we convert to an event source data class...but self.model is still true + # therefore, we do an additional check on whether the failed message is still a model + # see https://github.com/aws-powertools/powertools-lambda-python/issues/2091 + if self.model and getattr(msg, "model_validate", None): + return msg.messageId + data = msg._data if hasattr(msg, "_data") else msg + if isinstance(data, dict): + return data.get("messageId", "") + return msg.message_id + def _collect_sqs_failures(self): failures = [] for msg in self.fail_messages: - # If a message failed due to model validation (e.g., poison pill) - # we convert to an event source data class...but self.model is still true - # therefore, we do an additional check on whether the failed message is still a model - # see https://github.com/aws-powertools/powertools-lambda-python/issues/2091 - if self.model and getattr(msg, "model_validate", None): - msg_id = msg.messageId - else: - msg_id = msg.message_id - failures.append({"itemIdentifier": msg_id}) + failures.append({"itemIdentifier": self._sqs_failure_item_identifier(msg)}) return failures def _collect_kinesis_failures(self): diff --git a/tests/functional/batch/required_dependencies/test_utilities_batch.py b/tests/functional/batch/required_dependencies/test_utilities_batch.py index 43c2aa16191..98edda23b55 100644 --- a/tests/functional/batch/required_dependencies/test_utilities_batch.py +++ b/tests/functional/batch/required_dependencies/test_utilities_batch.py @@ -861,3 +861,27 @@ async def simple_async_handler(record: SQSRecord): # THEN record is processed successfully using asyncio.run() assert result == {"batchItemFailures": []} assert result == {"batchItemFailures": []} + + +def test_sqs_batch_processor_missing_message_id_does_not_crash_on_handler_failure(): + processor = BatchProcessor(event_type=EventType.SQS, raise_on_entire_batch_failure=False) + + def record_handler(record): + raise RuntimeError("boom") + + malformed_record = { + "body": "{}", + "receiptHandle": "rh", + "attributes": {"ApproximateReceiveCount": "1"}, + "messageAttributes": {}, + "md5OfBody": "abc", + "eventSource": "aws:sqs", + "eventSourceARN": "arn:aws:sqs:us-east-1:123456789012:my-queue", + "awsRegion": "us-east-1", + } + + with processor(records=[malformed_record], handler=record_handler): + processor.process() + + response = processor.response() + assert response == {"batchItemFailures": [{"itemIdentifier": ""}]}