Skip to content

Commit 50541fe

Browse files
Sujit-1509AI BotleandrodamascenaSujit
authored
fix(batch): add optional logger injection for BatchProcessors (#7553) (#8272)
* chore: fix minor typos and grammar in docs and comments * fix(batch): add optional logger injection for BatchProcessors (#7553) * fix: satisfy batch logger mypy checks * chore: address batch logger review feedback * chore: simplify batch logger warning guard --------- Co-authored-by: AI Bot <bot@example.com> Co-authored-by: Leandro Damascena <lcdama@amazon.pt> Co-authored-by: Sujit <sujit@example.com>
1 parent f7239b8 commit 50541fe

3 files changed

Lines changed: 102 additions & 5 deletions

File tree

aws_lambda_powertools/utilities/batch/base.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import sys
1515
from abc import ABC, abstractmethod
1616
from enum import Enum
17-
from typing import TYPE_CHECKING, Any, Tuple, Union, overload
17+
from typing import TYPE_CHECKING, Any, Tuple, TypeGuard, Union, overload
1818

1919
from aws_lambda_powertools.shared import constants
2020
from aws_lambda_powertools.utilities.batch.exceptions import (
@@ -35,6 +35,7 @@
3535

3636
if TYPE_CHECKING:
3737
from collections.abc import Callable
38+
from types import TracebackType
3839

3940
from aws_lambda_powertools.utilities.batch.types import (
4041
PartialItemFailureResponse,
@@ -61,17 +62,22 @@ class EventType(Enum):
6162
FailureResponse = Tuple[str, str, BatchEventTypes]
6263

6364

65+
def _has_traceback(exception: ExceptionInfo) -> TypeGuard[tuple[type[BaseException], BaseException, TracebackType]]:
66+
return exception[0] is not None and exception[1] is not None and exception[2] is not None
67+
68+
6469
class BasePartialProcessor(ABC):
6570
"""
6671
Abstract class for batch processors.
6772
"""
6873

6974
lambda_context: LambdaContext
7075

71-
def __init__(self):
76+
def __init__(self, logger: logging.Logger | None = None):
7277
self.success_messages: list[BatchEventTypes] = []
7378
self.fail_messages: list[BatchEventTypes] = []
7479
self.exceptions: list[ExceptionInfo] = []
80+
self.logger = logger
7581

7682
@abstractmethod
7783
def _prepare(self):
@@ -237,6 +243,13 @@ def failure_handler(self, record, exception: ExceptionInfo) -> FailureResponse:
237243
exception_string = f"{exception[0]}:{exception[1]}"
238244
entry = ("fail", exception_string, record)
239245
logger.debug(f"Record processing exception: {exception_string}")
246+
247+
if self.logger is not None and _has_traceback(exception):
248+
self.logger.warning(
249+
"Record processing exception; skipping this record",
250+
exc_info=exception,
251+
)
252+
240253
self.exceptions.append(exception)
241254
self.fail_messages.append(record)
242255
return entry
@@ -250,6 +263,7 @@ def __init__(
250263
event_type: EventType,
251264
model: BatchTypeModels | None = None,
252265
raise_on_entire_batch_failure: bool = True,
266+
logger: logging.Logger | None = None,
253267
):
254268
"""Process batch and partially report failed items
255269
@@ -262,6 +276,8 @@ def __init__(
262276
raise_on_entire_batch_failure: bool
263277
Raise an exception when the entire batch has failed processing.
264278
When set to False, partial failures are reported in the response
279+
logger: logging.Logger | None
280+
Optional Logger instance to output warnings with tracebacks for failed records.
265281
266282
Exceptions
267283
----------
@@ -285,7 +301,7 @@ def __init__(
285301
EventType.Kafka: KafkaEventRecord,
286302
}
287303

288-
super().__init__()
304+
super().__init__(logger=logger)
289305

290306
def response(self) -> PartialItemFailureResponse:
291307
"""Batch items that failed processing, if any"""

aws_lambda_powertools/utilities/batch/sqs_fifo_partial_processor.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,12 @@ def lambda_handler(event, context: LambdaContext):
6666
None,
6767
)
6868

69-
def __init__(self, model: BatchSqsTypeModel | None = None, skip_group_on_error: bool = False):
69+
def __init__(
70+
self,
71+
model: BatchSqsTypeModel | None = None,
72+
skip_group_on_error: bool = False,
73+
logger: logging.Logger | None = None,
74+
):
7075
"""
7176
Initialize the SqsFifoProcessor.
7277
@@ -77,12 +82,14 @@ def __init__(self, model: BatchSqsTypeModel | None = None, skip_group_on_error:
7782
skip_group_on_error: bool
7883
Determines whether to exclusively skip messages from the MessageGroupID that encountered processing failures
7984
Default is False.
85+
logger: logging.Logger | None
86+
Optional Logger instance to output warnings with tracebacks for failed records.
8087
8188
"""
8289
self._skip_group_on_error: bool = skip_group_on_error
8390
self._current_group_id = None
8491
self._failed_group_ids: set[str] = set()
85-
super().__init__(EventType.SQS, model)
92+
super().__init__(EventType.SQS, model, logger=logger)
8693

8794
def _process_record(self, record):
8895
self._current_group_id = record.get("attributes", {}).get("MessageGroupId")

tests/functional/batch/required_dependencies/test_utilities_batch.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import json
4+
import logging
45
import uuid
56
from random import randint
67
from typing import TYPE_CHECKING, Any
@@ -861,3 +862,76 @@ async def simple_async_handler(record: SQSRecord):
861862
# THEN record is processed successfully using asyncio.run()
862863
assert result == {"batchItemFailures": []}
863864
assert result == {"batchItemFailures": []}
865+
866+
867+
def test_batch_processor_logs_exception_with_injected_logger(sqs_event_factory, caplog):
868+
fail_record = sqs_event_factory("fail")
869+
success_record = sqs_event_factory("success")
870+
871+
def handler(record):
872+
if "fail" in record["body"]:
873+
raise ValueError("intentional failure")
874+
return record["body"]
875+
876+
test_logger = logging.getLogger("test_logger")
877+
processor = BatchProcessor(event_type=EventType.SQS, logger=test_logger)
878+
879+
with caplog.at_level(logging.WARNING, logger="test_logger"):
880+
process_partial_response(
881+
event={"Records": [fail_record, success_record]},
882+
record_handler=handler,
883+
processor=processor,
884+
)
885+
886+
warning_records = [r for r in caplog.records if r.levelno == logging.WARNING]
887+
assert len(warning_records) == 1, f"Expected 1 WARNING log, got {len(warning_records)}"
888+
assert "intentional failure" in warning_records[0].getMessage() or warning_records[0].exc_info is not None
889+
assert warning_records[0].exc_info is not None, "Expected exc_info (traceback) in log record"
890+
assert warning_records[0].exc_info[0] is ValueError
891+
892+
893+
def test_batch_processor_does_not_log_without_injected_logger(sqs_event_factory, caplog):
894+
fail_record = sqs_event_factory("fail")
895+
896+
def handler(record):
897+
raise ValueError("intentional failure")
898+
899+
processor = BatchProcessor(event_type=EventType.SQS, raise_on_entire_batch_failure=False, logger=None)
900+
901+
with caplog.at_level(logging.WARNING, logger="aws_lambda_powertools.utilities.batch.base"):
902+
process_partial_response(
903+
event={"Records": [fail_record]},
904+
record_handler=handler,
905+
processor=processor,
906+
)
907+
908+
warning_records = [r for r in caplog.records if r.levelno == logging.WARNING]
909+
assert len(warning_records) == 0, "Expected no WARNING logs when logger is None"
910+
911+
912+
def test_sqs_fifo_circuit_breaker_does_not_log(sqs_event_fifo_factory, caplog):
913+
failing_record = sqs_event_fifo_factory("fail", "group-1")
914+
short_circuited_record = sqs_event_fifo_factory("would-succeed", "group-1")
915+
916+
def handler(record):
917+
if "fail" in record["body"]:
918+
raise ValueError("first record failure")
919+
return record["body"]
920+
921+
test_logger = logging.getLogger("test_logger")
922+
processor = SqsFifoPartialProcessor(logger=test_logger)
923+
processor.raise_on_entire_batch_failure = False
924+
925+
with caplog.at_level(logging.WARNING, logger="test_logger"):
926+
process_partial_response(
927+
event={"Records": [failing_record, short_circuited_record]},
928+
record_handler=handler,
929+
processor=processor,
930+
)
931+
932+
warning_records = [r for r in caplog.records if r.levelno == logging.WARNING]
933+
assert len(warning_records) == 1, (
934+
f"Expected exactly 1 WARNING (real exception only), got {len(warning_records)}: "
935+
+ str([r.getMessage() for r in warning_records])
936+
)
937+
assert warning_records[0].exc_info[0] is ValueError

0 commit comments

Comments
 (0)