Skip to content

Commit

Permalink
[sdk/logs] Replace mocks with real instances where possible (open-tel…
Browse files Browse the repository at this point in the history
  • Loading branch information
pmcollins authored Jul 23, 2024
1 parent 8749168 commit be02f98
Showing 1 changed file with 73 additions and 70 deletions.
143 changes: 73 additions & 70 deletions opentelemetry-sdk/tests/logs/test_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,50 +19,42 @@
from opentelemetry._logs import get_logger as APIGetLogger
from opentelemetry.attributes import BoundedAttributes
from opentelemetry.sdk import trace
from opentelemetry.sdk._logs import LoggerProvider, LoggingHandler
from opentelemetry.sdk._logs import (
LogData,
LoggerProvider,
LoggingHandler,
LogRecordProcessor,
)
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.trace import INVALID_SPAN_CONTEXT


def get_logger(level=logging.NOTSET, logger_provider=None):
logger = logging.getLogger(__name__)
handler = LoggingHandler(level=level, logger_provider=logger_provider)
logger.addHandler(handler)
return logger


class TestLoggingHandler(unittest.TestCase):
def test_handler_default_log_level(self):
emitter_provider_mock = Mock(spec=LoggerProvider)
emitter_mock = APIGetLogger(
__name__, logger_provider=emitter_provider_mock
)
logger = get_logger(logger_provider=emitter_provider_mock)
processor, logger = set_up_test_logging(logging.NOTSET)

# Make sure debug messages are ignored by default
logger.debug("Debug message")
self.assertEqual(emitter_mock.emit.call_count, 0)
assert processor.emit_count() == 0

# Assert emit gets called for warning message
with self.assertLogs(level=logging.WARNING):
logger.warning("Warning message")
self.assertEqual(emitter_mock.emit.call_count, 1)
self.assertEqual(processor.emit_count(), 1)

def test_handler_custom_log_level(self):
emitter_provider_mock = Mock(spec=LoggerProvider)
emitter_mock = APIGetLogger(
__name__, logger_provider=emitter_provider_mock
)
logger = get_logger(
level=logging.ERROR, logger_provider=emitter_provider_mock
)
processor, logger = set_up_test_logging(logging.ERROR)

with self.assertLogs(level=logging.WARNING):
logger.warning("Warning message test custom log level")
# Make sure any log with level < ERROR is ignored
self.assertEqual(emitter_mock.emit.call_count, 0)
assert processor.emit_count() == 0

with self.assertLogs(level=logging.ERROR):
logger.error("Mumbai, we have a major problem")
with self.assertLogs(level=logging.CRITICAL):
logger.critical("No Time For Caution")
self.assertEqual(emitter_mock.emit.call_count, 2)
self.assertEqual(processor.emit_count(), 2)

# pylint: disable=protected-access
def test_log_record_emit_noop(self):
Expand All @@ -77,14 +69,16 @@ def test_log_record_emit_noop(self):
logger.addHandler(handler_mock)
with self.assertLogs(level=logging.WARNING):
logger.warning("Warning message")
handler_mock._translate.assert_not_called()

def test_log_flush_noop(self):

no_op_logger_provider = NoOpLoggerProvider()
no_op_logger_provider.force_flush = Mock()

logger = get_logger(logger_provider=no_op_logger_provider)
logger = logging.getLogger("foo")
handler = LoggingHandler(
level=logging.NOTSET, logger_provider=no_op_logger_provider
)
logger.addHandler(handler)

with self.assertLogs(level=logging.WARNING):
logger.warning("Warning message")
Expand All @@ -93,16 +87,13 @@ def test_log_flush_noop(self):
no_op_logger_provider.force_flush.assert_not_called()

def test_log_record_no_span_context(self):
emitter_provider_mock = Mock(spec=LoggerProvider)
emitter_mock = APIGetLogger(
__name__, logger_provider=emitter_provider_mock
)
logger = get_logger(logger_provider=emitter_provider_mock)
processor, logger = set_up_test_logging(logging.WARNING)

# Assert emit gets called for warning message
with self.assertLogs(level=logging.WARNING):
logger.warning("Warning message")
args, _ = emitter_mock.emit.call_args_list[0]
log_record = args[0]

log_record = processor.get_log_record(0)

self.assertIsNotNone(log_record)
self.assertEqual(log_record.trace_id, INVALID_SPAN_CONTEXT.trace_id)
Expand All @@ -112,31 +103,23 @@ def test_log_record_no_span_context(self):
)

def test_log_record_observed_timestamp(self):
emitter_provider_mock = Mock(spec=LoggerProvider)
emitter_mock = APIGetLogger(
__name__, logger_provider=emitter_provider_mock
)
logger = get_logger(logger_provider=emitter_provider_mock)
# Assert emit gets called for warning message
processor, logger = set_up_test_logging(logging.WARNING)

with self.assertLogs(level=logging.WARNING):
logger.warning("Warning message")
args, _ = emitter_mock.emit.call_args_list[0]
log_record = args[0]

log_record = processor.get_log_record(0)
self.assertIsNotNone(log_record.observed_timestamp)

def test_log_record_user_attributes(self):
"""Attributes can be injected into logs by adding them to the LogRecord"""
emitter_provider_mock = Mock(spec=LoggerProvider)
emitter_mock = APIGetLogger(
__name__, logger_provider=emitter_provider_mock
)
logger = get_logger(logger_provider=emitter_provider_mock)
processor, logger = set_up_test_logging(logging.WARNING)

# Assert emit gets called for warning message
with self.assertLogs(level=logging.WARNING):
logger.warning("Warning message", extra={"http.status_code": 200})
args, _ = emitter_mock.emit.call_args_list[0]
log_record = args[0]

log_record = processor.get_log_record(0)

self.assertIsNotNone(log_record)
self.assertEqual(len(log_record.attributes), 4)
Expand All @@ -157,18 +140,15 @@ def test_log_record_user_attributes(self):

def test_log_record_exception(self):
"""Exception information will be included in attributes"""
emitter_provider_mock = Mock(spec=LoggerProvider)
emitter_mock = APIGetLogger(
__name__, logger_provider=emitter_provider_mock
)
logger = get_logger(logger_provider=emitter_provider_mock)
processor, logger = set_up_test_logging(logging.ERROR)

try:
raise ZeroDivisionError("division by zero")
except ZeroDivisionError:
with self.assertLogs(level=logging.ERROR):
logger.exception("Zero Division Error")
args, _ = emitter_mock.emit.call_args_list[0]
log_record = args[0]

log_record = processor.get_log_record(0)

self.assertIsNotNone(log_record)
self.assertEqual(log_record.body, "Zero Division Error")
Expand All @@ -191,18 +171,15 @@ def test_log_record_exception(self):

def test_log_exc_info_false(self):
"""Exception information will be included in attributes"""
emitter_provider_mock = Mock(spec=LoggerProvider)
emitter_mock = APIGetLogger(
__name__, logger_provider=emitter_provider_mock
)
logger = get_logger(logger_provider=emitter_provider_mock)
processor, logger = set_up_test_logging(logging.NOTSET)

try:
raise ZeroDivisionError("division by zero")
except ZeroDivisionError:
with self.assertLogs(level=logging.ERROR):
logger.error("Zero Division Error", exc_info=False)
args, _ = emitter_mock.emit.call_args_list[0]
log_record = args[0]

log_record = processor.get_log_record(0)

self.assertIsNotNone(log_record)
self.assertEqual(log_record.body, "Zero Division Error")
Expand All @@ -215,23 +192,49 @@ def test_log_exc_info_false(self):
)

def test_log_record_trace_correlation(self):
emitter_provider_mock = Mock(spec=LoggerProvider)
emitter_mock = APIGetLogger(
__name__, logger_provider=emitter_provider_mock
)
logger = get_logger(logger_provider=emitter_provider_mock)
processor, logger = set_up_test_logging(logging.WARNING)

tracer = trace.TracerProvider().get_tracer(__name__)
with tracer.start_as_current_span("test") as span:
with self.assertLogs(level=logging.CRITICAL):
logger.critical("Critical message within span")

args, _ = emitter_mock.emit.call_args_list[0]
log_record = args[0]
log_record = processor.get_log_record(0)

self.assertEqual(log_record.body, "Critical message within span")
self.assertEqual(log_record.severity_text, "CRITICAL")
self.assertEqual(log_record.severity_number, SeverityNumber.FATAL)
span_context = span.get_span_context()
self.assertEqual(log_record.trace_id, span_context.trace_id)
self.assertEqual(log_record.span_id, span_context.span_id)
self.assertEqual(log_record.trace_flags, span_context.trace_flags)


def set_up_test_logging(level):
logger_provider = LoggerProvider()
processor = FakeProcessor()
logger_provider.add_log_record_processor(processor)
logger = logging.getLogger("foo")
handler = LoggingHandler(level=level, logger_provider=logger_provider)
logger.addHandler(handler)
return processor, logger


class FakeProcessor(LogRecordProcessor):
def __init__(self):
self.log_data_emitted = []

def emit(self, log_data: LogData):
self.log_data_emitted.append(log_data)

def shutdown(self):
pass

def force_flush(self, timeout_millis: int = 30000):
pass

def emit_count(self):
return len(self.log_data_emitted)

def get_log_record(self, i):
return self.log_data_emitted[i].log_record

0 comments on commit be02f98

Please sign in to comment.