diff --git a/assets/production_log_config.json b/assets/production_log_config.json index 6f4b6afa..c0b9f551 100644 --- a/assets/production_log_config.json +++ b/assets/production_log_config.json @@ -34,6 +34,16 @@ "class": "logging.StreamHandler", "filters": ["dict_values"], "stream": "ext://sys.stdout" + }, + "splunk_console": { + "formatter": "default", + "class": "recordlinker.log.SplunkHecHandler", + "filters": ["correlation_id"] + }, + "splunk_access": { + "formatter": "access", + "class": "recordlinker.log.SplunkHecHandler", + "filters": ["dict_values"] } }, "loggers": { @@ -57,12 +67,12 @@ "propagate": false }, "recordlinker": { - "handlers": ["console"], + "handlers": ["console", "splunk_console"], "level": "INFO", "propagate": false }, "recordlinker.access": { - "handlers": ["access"], + "handlers": ["access", "splunk_access"], "level": "INFO", "propagate": false } diff --git a/src/recordlinker/config.py b/src/recordlinker/config.py index 69356773..ac311a72 100644 --- a/src/recordlinker/config.py +++ b/src/recordlinker/config.py @@ -43,6 +43,10 @@ class Settings(pydantic_settings.BaseSettings): description="The path to the logging configuration file", default="", ) + splunk_uri: typing.Optional[str] = pydantic.Field( + description="The URI for the Splunk HEC server", + default="", + ) initial_algorithms: str = pydantic.Field( description=( "The path to the initial algorithms file that is loaded on startup if the " @@ -78,7 +82,11 @@ def default_log_config(self) -> dict: "loggers": { "": {"handlers": ["console"], "level": "WARNING"}, "recordlinker": {"handlers": ["console"], "level": "INFO", "propagate": False}, - "recordlinker.access": {"handlers": ["console"], "level": "CRITICAL", "propagate": False}, + "recordlinker.access": { + "handlers": ["console"], + "level": "CRITICAL", + "propagate": False, + }, }, } @@ -94,9 +102,8 @@ def configure_logging(self) -> None: with open(self.log_config, "r") as fobj: config = json.loads(fobj.read()) except Exception as exc: - raise ConfigurationError( - f"Error loading log configuration: {self.log_config}" - ) from exc + msg = f"Error loading log configuration: {self.log_config}" + raise ConfigurationError(msg) from exc logging.config.dictConfig(config or self.default_log_config()) diff --git a/src/recordlinker/log.py b/src/recordlinker/log.py index 5bdf9753..72a39ca2 100644 --- a/src/recordlinker/log.py +++ b/src/recordlinker/log.py @@ -1,8 +1,13 @@ +import concurrent.futures +import json import logging import typing import pythonjsonlogger.jsonlogger +from recordlinker import config +from recordlinker import splunk + RESERVED_ATTRS = pythonjsonlogger.jsonlogger.RESERVED_ATTRS + ("taskName",) @@ -42,3 +47,86 @@ def __init__( **kwargs: typing.Any, ): super().__init__(*args, reserved_attrs=reserved_attrs, **kwargs) + + +class SplunkHecHandler(logging.Handler): + """ + A custom logging handler that sends log records to a Splunk HTTP Event Collector (HEC) + server. This handler is only enabled if the `splunk_uri` setting is configured, + otherwise each log record is ignored. + + WARNING: This handler does not guarantee delivery of log records to the Splunk HEC + server. Events are sent asynchronously to reduce blocking IO calls, and the client + does not wait for a response from the server. Thus its possible that some log records + will be dropped. Other logging handlers should be used in conjunction with this handler + in production environments to ensure log records are not lost. + """ + + MAX_WORKERS = 10 + + class SplunkHecClientSingleton: + """ + A singleton class for the Splunk HEC client. + """ + + _instance: splunk.SplunkHECClient | None = None + + @classmethod + def get_instance(cls, uri: str) -> splunk.SplunkHECClient: + """ + Get the singleton instance of the Splunk HEC client. + """ + if cls._instance is None: + cls._instance = splunk.SplunkHECClient(uri) + return cls._instance + + def __init__(self, uri: str | None = None, **kwargs: typing.Any) -> None: + """ + Initialize the Splunk HEC logging handler. If the `splunk_uri` setting is + configured, create a new Splunk HEC client instance or use the existing + singleton instance. Its optimal to use a singleton instance to avoid + re-testing the connection to the Splunk HEC server. + """ + logging.Handler.__init__(self) + self.client: splunk.SplunkHECClient | None = None + self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=self.MAX_WORKERS) + self.last_future: concurrent.futures.Future | None = None + uri = uri or config.settings.splunk_uri + if uri: + self.client = self.SplunkHecClientSingleton.get_instance(uri) + + def __del__(self) -> None: + """ + Clean up the executor when the handler is deleted. + """ + self.executor.shutdown(wait=True) + + def flush(self) -> None: + """ + Wait for the last future to complete before flushing the handler. + """ + if self.last_future is not None: + self.last_future.result() + self.last_future = None + + def emit(self, record: logging.LogRecord) -> None: + """ + Emit the log record to the Splunk HEC server, if a client is configured. + """ + if self.client is None: + # No Splunk HEC client configured, do nothing + return + msg = self.format(record) + data: dict[str, typing.Any] = {} + try: + # Attempt to parse the message as a JSON object + data = json.loads(msg) + except json.JSONDecodeError: + # If the message is not JSON, create a new dictionary with the message + data = {"message": msg} + # Run this in a separate thread to avoid blocking the main thread. + # Logging to Splunk is a bonus feature and should not block the main thread, + # using a ThreadPoolExecutor to send the request asynchronously allows us + # to initiate the request and continue processing without waiting for the IO + # operation to complete. + self.last_future = self.executor.submit(self.client.send, data, epoch=record.created) diff --git a/src/recordlinker/splunk.py b/src/recordlinker/splunk.py new file mode 100644 index 00000000..7e0d2980 --- /dev/null +++ b/src/recordlinker/splunk.py @@ -0,0 +1,79 @@ +import json +import time +import typing +import urllib.parse +import urllib.request + +TIMEOUT = 5 + + +class SplunkError(Exception): + pass + + +class SplunkHECClient: + PATH = "/services/collector/event" + + def __init__(self, splunk_uri: str) -> None: + """ + Create a new Splunk HEC client and test its connection. + The URI uses a custom scheme to specify the Splunk HEC server and parameters. + The URI format is: + splunkhec://@:?index=&proto=&ssl_verify=&source= + """ + try: + uri: urllib.parse.ParseResult = urllib.parse.urlparse(splunk_uri) + # flatten the query string values from lists to single values + qs: dict[str, str] = {k: v[0] for k, v in urllib.parse.parse_qs(uri.query).items()} + + if uri.scheme != "splunkhec": + raise SplunkError(f"invalid scheme: {uri.scheme}") + + scheme = qs.get("proto", "https").lower() + host = f"{uri.hostname}:{uri.port}" if uri.port else uri.hostname + self.url = f"{scheme}://{host}{self.PATH}" + self.headers = { + "Authorization": f"Splunk {uri.username}", + "Content-Type": "application/json", + } + # initialize the default payload parameters + self.params: dict[str, str] = {"host": uri.hostname or "", "sourcetype": "_json"} + if qs.get("index"): + self.params["index"] = qs["index"] + if qs.get("source"): + self.params["source"] = qs["source"] + self._test_connection() + except Exception as exc: + raise SplunkError(f"invalid connection: {splunk_uri}") from exc + + def _send_request(self, body: bytes | None = None): + request = urllib.request.Request(self.url, data=body, method="POST", headers=self.headers) + try: + with urllib.request.urlopen(request, timeout=TIMEOUT) as response: + # return the response status code + return response.getcode() + except urllib.error.HTTPError as exc: + return exc.code + + def _test_connection(self) -> None: + status = self._send_request() + # check for a 400 bad request, which indicates a successful connection + # 400 is expected because the payload is empty + if status != 400: + raise urllib.error.HTTPError(self.url, status, "could not connect", None, None) # type: ignore + + def send(self, data: dict, epoch: float = 0) -> int: + """ + Send data to the Splunk HEC endpoint. + + :param data: The data to send. + :param epoch: The timestamp to use for the event. If not provided, the current time is used. + :return: The HTTP status code of the response. + """ + epoch = epoch or int(time.time()) + payload: dict[str, typing.Any] = {"time": epoch, "event": data} | self.params + body: bytes = json.dumps(payload).encode("utf-8") + try: + return self._send_request(body=body) + except Exception as exc: + raise SplunkError(f"could not send data: {data}") from exc diff --git a/tests/unit/test_log.py b/tests/unit/test_log.py index ae17948a..b372b36a 100644 --- a/tests/unit/test_log.py +++ b/tests/unit/test_log.py @@ -1,4 +1,5 @@ import logging +import unittest.mock from recordlinker import log @@ -88,3 +89,51 @@ def test_format_reserved_attrs(self): ) record.taskName = "task" assert formatter.format(record) == '{"message": "test"}' + + +class TestSplunkHecHandler: + def test_json_record(self): + with unittest.mock.patch("recordlinker.splunk.SplunkHECClient") as mock_client: + mock_instance = mock_client.return_value + mock_instance.send.return_value = 200 + uri = "splunkhec://token@localhost:8088?index=index&source=source" + handler = log.SplunkHecHandler(uri=uri) + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="test_log.py", + lineno=10, + exc_info=None, + msg='{"key": "value"}', + args=[], + ) + assert handler.emit(record) is None + handler.flush() + send_args = mock_instance.send.call_args.args + assert send_args == ({"key": "value"},) + send_kwargs = mock_instance.send.call_args.kwargs + assert send_kwargs == {"epoch": record.created} + log.SplunkHecHandler.SplunkHecClientSingleton._instance = None + + def test_non_json_record(self): + with unittest.mock.patch("recordlinker.splunk.SplunkHECClient") as mock_client: + mock_instance = mock_client.return_value + mock_instance.send.return_value = 200 + uri = "splunkhec://token@localhost:8088?index=index&source=source" + handler = log.SplunkHecHandler(uri=uri) + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="test_log.py", + lineno=10, + exc_info=None, + msg="test", + args=[], + ) + assert handler.emit(record) is None + handler.flush() + send_args = mock_instance.send.call_args.args + assert send_args == ({"message": "test"},) + send_kwargs = mock_instance.send.call_args.kwargs + assert send_kwargs == {"epoch": record.created} + log.SplunkHecHandler.SplunkHecClientSingleton._instance = None diff --git a/tests/unit/test_splunk.py b/tests/unit/test_splunk.py new file mode 100644 index 00000000..66d6ba9f --- /dev/null +++ b/tests/unit/test_splunk.py @@ -0,0 +1,56 @@ +import unittest.mock + +import pytest + +from recordlinker import splunk + + +class TestSplunkHECClient: + def test_invalid_uri(self): + with pytest.raises(splunk.SplunkError): + splunk.SplunkHECClient("http://localhost") + + def test_valid_uri(self): + with unittest.mock.patch("urllib.request.urlopen") as mock_urlopen: + mock_response = unittest.mock.MagicMock() + mock_response.read.return_value = b"{}" + mock_response.getcode.return_value = 400 # Set getcode() to return 400 + mock_urlopen.return_value.__enter__.return_value = mock_response + client = splunk.SplunkHECClient("splunkhec://token@localhost:8088?index=idx&source=src") + assert client.url == "https://localhost:8088/services/collector/event" + assert client.headers == { + "Authorization": "Splunk token", + "Content-Type": "application/json", + } + assert client.params == {"host": "localhost", "sourcetype": "_json", "index": "idx", "source": "src"} + + def test_valid_uri_no_port(self): + with unittest.mock.patch("urllib.request.urlopen") as mock_urlopen: + mock_response = unittest.mock.MagicMock() + mock_response.read.return_value = b"{}" + mock_response.getcode.return_value = 400 # Set getcode() to return 400 + mock_urlopen.return_value.__enter__.return_value = mock_response + client = splunk.SplunkHECClient("splunkhec://token@localhost?index=idx&source=src") + assert client.url == "https://localhost/services/collector/event" + assert client.headers == { + "Authorization": "Splunk token", + "Content-Type": "application/json", + } + assert client.params == {"host": "localhost", "sourcetype": "_json", "index": "idx", "source": "src"} + + def test_send(self): + with unittest.mock.patch("urllib.request.urlopen") as mock_urlopen: + mock_response = unittest.mock.MagicMock() + mock_response.read.return_value = b"{}" + mock_response.getcode.side_effect = [400, 200] # Set getcode() to return 400 + mock_urlopen.return_value.__enter__.return_value = mock_response + client = splunk.SplunkHECClient("splunkhec://token@localhost?index=idx&source=src") + assert client.send({"key": "value"}, epoch=10.5) == 200 + req = mock_urlopen.call_args[0][0] + assert req.method == "POST" + assert req.get_full_url() == "https://localhost/services/collector/event" + assert req.headers == { + "Authorization": "Splunk token", + "Content-type": "application/json", + } + assert req.data == b'{"time": 10.5, "event": {"key": "value"}, "host": "localhost", "sourcetype": "_json", "index": "idx", "source": "src"}'