diff --git a/aws_lambda_powertools/utilities/batch/base.py b/aws_lambda_powertools/utilities/batch/base.py index 4cb1337cafd..f51b17db6f5 100644 --- a/aws_lambda_powertools/utilities/batch/base.py +++ b/aws_lambda_powertools/utilities/batch/base.py @@ -36,6 +36,7 @@ if TYPE_CHECKING: from collections.abc import Callable + from aws_lambda_powertools.logging import Logger from aws_lambda_powertools.utilities.batch.types import ( PartialItemFailureResponse, PartialItemFailures, @@ -68,10 +69,11 @@ class BasePartialProcessor(ABC): lambda_context: LambdaContext - def __init__(self): + def __init__(self, logger: logging.Logger | Logger | None = None): self.success_messages: list[BatchEventTypes] = [] self.fail_messages: list[BatchEventTypes] = [] self.exceptions: list[ExceptionInfo] = [] + self.logger = logger @abstractmethod def _prepare(self): @@ -237,6 +239,13 @@ def failure_handler(self, record, exception: ExceptionInfo) -> FailureResponse: exception_string = f"{exception[0]}:{exception[1]}" entry = ("fail", exception_string, record) logger.debug(f"Record processing exception: {exception_string}") + + if getattr(self, "logger", None) and exception[2] is not None: + self.logger.warning( + "Record processing exception; skipping this record", + exc_info=exception, + ) + self.exceptions.append(exception) self.fail_messages.append(record) return entry @@ -250,6 +259,7 @@ def __init__( event_type: EventType, model: BatchTypeModels | None = None, raise_on_entire_batch_failure: bool = True, + logger: logging.Logger | Logger | None = None, ): """Process batch and partially report failed items @@ -262,6 +272,8 @@ def __init__( raise_on_entire_batch_failure: bool Raise an exception when the entire batch has failed processing. When set to False, partial failures are reported in the response + logger: logging.Logger | Logger | None + Optional Logger instance to output warnings with tracebacks for failed records. Exceptions ---------- @@ -285,7 +297,7 @@ def __init__( EventType.Kafka: KafkaEventRecord, } - super().__init__() + super().__init__(logger=logger) def response(self) -> PartialItemFailureResponse: """Batch items that failed processing, if any""" diff --git a/aws_lambda_powertools/utilities/batch/sqs_fifo_partial_processor.py b/aws_lambda_powertools/utilities/batch/sqs_fifo_partial_processor.py index 2e680e2f04e..441efc1d288 100644 --- a/aws_lambda_powertools/utilities/batch/sqs_fifo_partial_processor.py +++ b/aws_lambda_powertools/utilities/batch/sqs_fifo_partial_processor.py @@ -10,6 +10,7 @@ ) if TYPE_CHECKING: + from aws_lambda_powertools.logging import Logger from aws_lambda_powertools.utilities.batch.types import BatchSqsTypeModel logger = logging.getLogger(__name__) @@ -66,7 +67,12 @@ def lambda_handler(event, context: LambdaContext): None, ) - def __init__(self, model: BatchSqsTypeModel | None = None, skip_group_on_error: bool = False): + def __init__( + self, + model: BatchSqsTypeModel | None = None, + skip_group_on_error: bool = False, + logger: logging.Logger | Logger | None = None, + ): """ Initialize the SqsFifoProcessor. @@ -77,12 +83,14 @@ def __init__(self, model: BatchSqsTypeModel | None = None, skip_group_on_error: skip_group_on_error: bool Determines whether to exclusively skip messages from the MessageGroupID that encountered processing failures Default is False. + logger: logging.Logger | Logger | None + Optional Logger instance to output warnings with tracebacks for failed records. """ self._skip_group_on_error: bool = skip_group_on_error self._current_group_id = None self._failed_group_ids: set[str] = set() - super().__init__(EventType.SQS, model) + super().__init__(EventType.SQS, model, logger=logger) def _process_record(self, record): self._current_group_id = record.get("attributes", {}).get("MessageGroupId") diff --git a/tests/functional/batch/required_dependencies/test_utilities_batch.py b/tests/functional/batch/required_dependencies/test_utilities_batch.py index 43c2aa16191..91672d46620 100644 --- a/tests/functional/batch/required_dependencies/test_utilities_batch.py +++ b/tests/functional/batch/required_dependencies/test_utilities_batch.py @@ -861,3 +861,85 @@ async def simple_async_handler(record: SQSRecord): # THEN record is processed successfully using asyncio.run() assert result == {"batchItemFailures": []} assert result == {"batchItemFailures": []} + + +def test_batch_processor_logs_exception_with_injected_logger(sqs_event_factory, caplog): + import logging + from aws_lambda_powertools.utilities.batch import BatchProcessor, EventType, process_partial_response + + fail_record = sqs_event_factory("fail") + success_record = sqs_event_factory("success") + + def handler(record): + if "fail" in record["body"]: + raise ValueError("intentional failure") + return record["body"] + + test_logger = logging.getLogger("test_logger") + processor = BatchProcessor(event_type=EventType.SQS, logger=test_logger) + + with caplog.at_level(logging.WARNING, logger="test_logger"): + process_partial_response( + event={"Records": [fail_record, success_record]}, + record_handler=handler, + processor=processor, + ) + + warning_records = [r for r in caplog.records if r.levelno == logging.WARNING] + assert len(warning_records) == 1, f"Expected 1 WARNING log, got {len(warning_records)}" + assert "intentional failure" in warning_records[0].getMessage() or warning_records[0].exc_info is not None + assert warning_records[0].exc_info is not None, "Expected exc_info (traceback) in log record" + assert warning_records[0].exc_info[0] is ValueError + + +def test_batch_processor_does_not_log_without_injected_logger(sqs_event_factory, caplog): + import logging + from aws_lambda_powertools.utilities.batch import BatchProcessor, EventType, process_partial_response + + fail_record = sqs_event_factory("fail") + + def handler(record): + raise ValueError("intentional failure") + + processor = BatchProcessor(event_type=EventType.SQS, raise_on_entire_batch_failure=False, logger=None) + + with caplog.at_level(logging.WARNING, logger="aws_lambda_powertools.utilities.batch.base"): + process_partial_response( + event={"Records": [fail_record]}, + record_handler=handler, + processor=processor, + ) + + warning_records = [r for r in caplog.records if r.levelno == logging.WARNING] + assert len(warning_records) == 0, "Expected no WARNING logs when logger is None" + + +def test_sqs_fifo_circuit_breaker_does_not_log(sqs_event_fifo_factory, caplog): + import logging + from aws_lambda_powertools.utilities.batch import SqsFifoPartialProcessor, process_partial_response + + failing_record = sqs_event_fifo_factory("fail", "group-1") + short_circuited_record = sqs_event_fifo_factory("would-succeed", "group-1") + + def handler(record): + if "fail" in record["body"]: + raise ValueError("first record failure") + return record["body"] + + test_logger = logging.getLogger("test_logger") + processor = SqsFifoPartialProcessor(logger=test_logger) + processor.raise_on_entire_batch_failure = False + + with caplog.at_level(logging.WARNING, logger="test_logger"): + process_partial_response( + event={"Records": [failing_record, short_circuited_record]}, + record_handler=handler, + processor=processor, + ) + + warning_records = [r for r in caplog.records if r.levelno == logging.WARNING] + assert len(warning_records) == 1, ( + f"Expected exactly 1 WARNING (real exception only), got {len(warning_records)}: " + + str([r.getMessage() for r in warning_records]) + ) + assert warning_records[0].exc_info[0] is ValueError