Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions aws_lambda_powertools/utilities/batch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,18 +335,22 @@ def _get_messages_to_report(self) -> list[PartialItemFailures]:

# Event Source Data Classes follow python idioms for fields
# while Parser/Pydantic follows the event field names to the latter
def _sqs_failure_item_identifier(self, msg) -> str:
# If a message failed due to model validation (e.g., poison pill)
# we convert to an event source data class...but self.model is still true
# therefore, we do an additional check on whether the failed message is still a model
# see https://github.com/aws-powertools/powertools-lambda-python/issues/2091
if self.model and getattr(msg, "model_validate", None):
return msg.messageId
data = msg._data if hasattr(msg, "_data") else msg
if isinstance(data, dict):
return data.get("messageId", "")
return msg.message_id

def _collect_sqs_failures(self):
failures = []
for msg in self.fail_messages:
# If a message failed due to model validation (e.g., poison pill)
# we convert to an event source data class...but self.model is still true
# therefore, we do an additional check on whether the failed message is still a model
# see https://github.com/aws-powertools/powertools-lambda-python/issues/2091
if self.model and getattr(msg, "model_validate", None):
msg_id = msg.messageId
else:
msg_id = msg.message_id
failures.append({"itemIdentifier": msg_id})
failures.append({"itemIdentifier": self._sqs_failure_item_identifier(msg)})
return failures

def _collect_kinesis_failures(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -861,3 +861,27 @@ async def simple_async_handler(record: SQSRecord):
# THEN record is processed successfully using asyncio.run()
assert result == {"batchItemFailures": []}
assert result == {"batchItemFailures": []}


def test_sqs_batch_processor_missing_message_id_does_not_crash_on_handler_failure():
processor = BatchProcessor(event_type=EventType.SQS, raise_on_entire_batch_failure=False)

def record_handler(record):
raise RuntimeError("boom")

malformed_record = {
"body": "{}",
"receiptHandle": "rh",
"attributes": {"ApproximateReceiveCount": "1"},
"messageAttributes": {},
"md5OfBody": "abc",
"eventSource": "aws:sqs",
"eventSourceARN": "arn:aws:sqs:us-east-1:123456789012:my-queue",
"awsRegion": "us-east-1",
}

with processor(records=[malformed_record], handler=record_handler):
processor.process()

response = processor.response()
assert response == {"batchItemFailures": [{"itemIdentifier": ""}]}