Skip to content
22 changes: 19 additions & 3 deletions aws_lambda_powertools/utilities/batch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import sys
from abc import ABC, abstractmethod
from enum import Enum
from typing import TYPE_CHECKING, Any, Tuple, Union, overload
from typing import TYPE_CHECKING, Any, Tuple, TypeGuard, Union, overload

from aws_lambda_powertools.shared import constants
from aws_lambda_powertools.utilities.batch.exceptions import (
Expand All @@ -35,6 +35,7 @@

if TYPE_CHECKING:
from collections.abc import Callable
from types import TracebackType

from aws_lambda_powertools.utilities.batch.types import (
PartialItemFailureResponse,
Expand All @@ -61,17 +62,22 @@ class EventType(Enum):
FailureResponse = Tuple[str, str, BatchEventTypes]


def _has_traceback(exception: ExceptionInfo) -> TypeGuard[tuple[type[BaseException], BaseException, TracebackType]]:
return exception[0] is not None and exception[1] is not None and exception[2] is not None


class BasePartialProcessor(ABC):
"""
Abstract class for batch processors.
"""

lambda_context: LambdaContext

def __init__(self):
def __init__(self, logger: logging.Logger | None = None):
self.success_messages: list[BatchEventTypes] = []
self.fail_messages: list[BatchEventTypes] = []
self.exceptions: list[ExceptionInfo] = []
self.logger = logger

@abstractmethod
def _prepare(self):
Expand Down Expand Up @@ -237,6 +243,13 @@ def failure_handler(self, record, exception: ExceptionInfo) -> FailureResponse:
exception_string = f"{exception[0]}:{exception[1]}"
entry = ("fail", exception_string, record)
logger.debug(f"Record processing exception: {exception_string}")

if self.logger is not None and _has_traceback(exception):
self.logger.warning(
"Record processing exception; skipping this record",
exc_info=exception,
)
Comment on lines +246 to +251

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Log with full traceback when a customer-provided logger is present
# and the exception carries a real traceback (e.g. not a synthetic FIFO circuit-breaker)
batch_logger = self.logger
if batch_logger is not None and exception[2] is not None:
# ExceptionInfo allows None on every slot, but logging.warning's exc_info
# requires a fully populated tuple. We already excluded synthetic exceptions
# (no traceback) above, so the type and value are guaranteed to be set.
assert exception[0] is not None
assert exception[1] is not None
exc_info = cast("tuple[type[BaseException], BaseException, TracebackType]", exception)
batch_logger.warning(
"Record processing exception; skipping this record",
exc_info=exc_info,
)
if self.logger is not None and exception[2] is not None:
self.logger.warning(
"Record processing exception; skipping this record",
exc_info=exception,
)

Can you try this pls? I think we can simplify this block of code.


self.exceptions.append(exception)
self.fail_messages.append(record)
return entry
Expand All @@ -250,6 +263,7 @@ def __init__(
event_type: EventType,
model: BatchTypeModels | None = None,
raise_on_entire_batch_failure: bool = True,
logger: logging.Logger | None = None,
):
"""Process batch and partially report failed items

Expand All @@ -262,6 +276,8 @@ def __init__(
raise_on_entire_batch_failure: bool
Raise an exception when the entire batch has failed processing.
When set to False, partial failures are reported in the response
logger: logging.Logger | None
Optional Logger instance to output warnings with tracebacks for failed records.

Exceptions
----------
Expand All @@ -285,7 +301,7 @@ def __init__(
EventType.Kafka: KafkaEventRecord,
}

super().__init__()
super().__init__(logger=logger)

def response(self) -> PartialItemFailureResponse:
"""Batch items that failed processing, if any"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,12 @@ def lambda_handler(event, context: LambdaContext):
None,
)

def __init__(self, model: BatchSqsTypeModel | None = None, skip_group_on_error: bool = False):
def __init__(
self,
model: BatchSqsTypeModel | None = None,
skip_group_on_error: bool = False,
logger: logging.Logger | None = None,
):
"""
Initialize the SqsFifoProcessor.

Expand All @@ -77,12 +82,14 @@ def __init__(self, model: BatchSqsTypeModel | None = None, skip_group_on_error:
skip_group_on_error: bool
Determines whether to exclusively skip messages from the MessageGroupID that encountered processing failures
Default is False.
logger: logging.Logger | None
Optional Logger instance to output warnings with tracebacks for failed records.

"""
self._skip_group_on_error: bool = skip_group_on_error
self._current_group_id = None
self._failed_group_ids: set[str] = set()
super().__init__(EventType.SQS, model)
super().__init__(EventType.SQS, model, logger=logger)

def _process_record(self, record):
self._current_group_id = record.get("attributes", {}).get("MessageGroupId")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import json
import logging
import uuid
from random import randint
from typing import TYPE_CHECKING, Any
Expand Down Expand Up @@ -861,3 +862,76 @@ async def simple_async_handler(record: SQSRecord):
# THEN record is processed successfully using asyncio.run()
assert result == {"batchItemFailures": []}
assert result == {"batchItemFailures": []}


def test_batch_processor_logs_exception_with_injected_logger(sqs_event_factory, caplog):
fail_record = sqs_event_factory("fail")
success_record = sqs_event_factory("success")

def handler(record):
if "fail" in record["body"]:
raise ValueError("intentional failure")
return record["body"]

test_logger = logging.getLogger("test_logger")
processor = BatchProcessor(event_type=EventType.SQS, logger=test_logger)

with caplog.at_level(logging.WARNING, logger="test_logger"):
process_partial_response(
event={"Records": [fail_record, success_record]},
record_handler=handler,
processor=processor,
)

warning_records = [r for r in caplog.records if r.levelno == logging.WARNING]
assert len(warning_records) == 1, f"Expected 1 WARNING log, got {len(warning_records)}"
assert "intentional failure" in warning_records[0].getMessage() or warning_records[0].exc_info is not None
assert warning_records[0].exc_info is not None, "Expected exc_info (traceback) in log record"
assert warning_records[0].exc_info[0] is ValueError


def test_batch_processor_does_not_log_without_injected_logger(sqs_event_factory, caplog):
fail_record = sqs_event_factory("fail")

def handler(record):
raise ValueError("intentional failure")

processor = BatchProcessor(event_type=EventType.SQS, raise_on_entire_batch_failure=False, logger=None)

with caplog.at_level(logging.WARNING, logger="aws_lambda_powertools.utilities.batch.base"):
process_partial_response(
event={"Records": [fail_record]},
record_handler=handler,
processor=processor,
)

warning_records = [r for r in caplog.records if r.levelno == logging.WARNING]
assert len(warning_records) == 0, "Expected no WARNING logs when logger is None"


def test_sqs_fifo_circuit_breaker_does_not_log(sqs_event_fifo_factory, caplog):
failing_record = sqs_event_fifo_factory("fail", "group-1")
short_circuited_record = sqs_event_fifo_factory("would-succeed", "group-1")

def handler(record):
if "fail" in record["body"]:
raise ValueError("first record failure")
return record["body"]

test_logger = logging.getLogger("test_logger")
processor = SqsFifoPartialProcessor(logger=test_logger)
processor.raise_on_entire_batch_failure = False

with caplog.at_level(logging.WARNING, logger="test_logger"):
process_partial_response(
event={"Records": [failing_record, short_circuited_record]},
record_handler=handler,
processor=processor,
)

warning_records = [r for r in caplog.records if r.levelno == logging.WARNING]
assert len(warning_records) == 1, (
f"Expected exactly 1 WARNING (real exception only), got {len(warning_records)}: "
+ str([r.getMessage() for r in warning_records])
)
assert warning_records[0].exc_info[0] is ValueError