Skip to content
Permalink

Comparing changes

This is a direct comparison between two commits made in this repository or its related repositories. View the default comparison for this range or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: CDCgov/RecordLinker
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: f55fc4a5b6ffcc490892f80f776c673eea08e5aa
Choose a base ref
..
head repository: CDCgov/RecordLinker
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: e4a8c7b52c71c3623c62e8f85f00ddd10d52384b
Choose a head ref
Showing with 297 additions and 6 deletions.
  1. +12 −2 assets/production_log_config.json
  2. +2 −0 pyproject.toml
  3. +11 −4 src/recordlinker/config.py
  4. +88 −0 src/recordlinker/log.py
  5. +79 −0 src/recordlinker/splunk.py
  6. +49 −0 tests/unit/test_log.py
  7. +56 −0 tests/unit/test_splunk.py
14 changes: 12 additions & 2 deletions assets/production_log_config.json
Original file line number Diff line number Diff line change
@@ -34,6 +34,16 @@
"class": "logging.StreamHandler",
"filters": ["dict_values"],
"stream": "ext://sys.stdout"
},
"splunk_console": {
"formatter": "default",
"class": "recordlinker.log.SplunkHecHandler",
"filters": ["correlation_id"]
},
"splunk_access": {
"formatter": "access",
"class": "recordlinker.log.SplunkHecHandler",
"filters": ["dict_values"]
}
},
"loggers": {
@@ -57,12 +67,12 @@
"propagate": false
},
"recordlinker": {
"handlers": ["console"],
"handlers": ["console", "splunk_console"],
"level": "INFO",
"propagate": false
},
"recordlinker.access": {
"handlers": ["access"],
"handlers": ["access", "splunk_access"],
"level": "INFO",
"propagate": false
}
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -26,6 +26,7 @@ dependencies = [
"python-dateutil==2.9.0.post0",
"sqlalchemy",
"rapidfuzz",
# Observability
"python-json-logger",
"asgi-correlation-id",
# Database drivers
@@ -45,6 +46,7 @@ dev = [
"ruff",
"mypy",
"types-python-dateutil",
# Observability
"opentelemetry-api",
"opentelemetry-sdk",
]
15 changes: 11 additions & 4 deletions src/recordlinker/config.py
Original file line number Diff line number Diff line change
@@ -43,6 +43,10 @@ class Settings(pydantic_settings.BaseSettings):
description="The path to the logging configuration file",
default="",
)
splunk_uri: typing.Optional[str] = pydantic.Field(
description="The URI for the Splunk HEC server",
default="",
)
initial_algorithms: str = pydantic.Field(
description=(
"The path to the initial algorithms file that is loaded on startup if the "
@@ -78,7 +82,11 @@ def default_log_config(self) -> dict:
"loggers": {
"": {"handlers": ["console"], "level": "WARNING"},
"recordlinker": {"handlers": ["console"], "level": "INFO", "propagate": False},
"recordlinker.access": {"handlers": ["console"], "level": "CRITICAL", "propagate": False},
"recordlinker.access": {
"handlers": ["console"],
"level": "CRITICAL",
"propagate": False,
},
},
}

@@ -94,9 +102,8 @@ def configure_logging(self) -> None:
with open(self.log_config, "r") as fobj:
config = json.loads(fobj.read())
except Exception as exc:
raise ConfigurationError(
f"Error loading log configuration: {self.log_config}"
) from exc
msg = f"Error loading log configuration: {self.log_config}"
raise ConfigurationError(msg) from exc
logging.config.dictConfig(config or self.default_log_config())


88 changes: 88 additions & 0 deletions src/recordlinker/log.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import concurrent.futures
import json
import logging
import typing

import pythonjsonlogger.jsonlogger

from recordlinker import config
from recordlinker import splunk

RESERVED_ATTRS = pythonjsonlogger.jsonlogger.RESERVED_ATTRS + ("taskName",)


@@ -42,3 +47,86 @@ def __init__(
**kwargs: typing.Any,
):
super().__init__(*args, reserved_attrs=reserved_attrs, **kwargs)


class SplunkHecHandler(logging.Handler):
"""
A custom logging handler that sends log records to a Splunk HTTP Event Collector (HEC)
server. This handler is only enabled if the `splunk_uri` setting is configured,
otherwise each log record is ignored.
WARNING: This handler does not guarantee delivery of log records to the Splunk HEC
server. Events are sent asynchronously to reduce blocking IO calls, and the client
does not wait for a response from the server. Thus its possible that some log records
will be dropped. Other logging handlers should be used in conjunction with this handler
in production environments to ensure log records are not lost.
"""

MAX_WORKERS = 10

class SplunkHecClientSingleton:
"""
A singleton class for the Splunk HEC client.
"""

_instance: splunk.SplunkHECClient | None = None

@classmethod
def get_instance(cls, uri: str) -> splunk.SplunkHECClient:
"""
Get the singleton instance of the Splunk HEC client.
"""
if cls._instance is None:
cls._instance = splunk.SplunkHECClient(uri)
return cls._instance

def __init__(self, uri: str | None = None, **kwargs: typing.Any) -> None:
"""
Initialize the Splunk HEC logging handler. If the `splunk_uri` setting is
configured, create a new Splunk HEC client instance or use the existing
singleton instance. Its optimal to use a singleton instance to avoid
re-testing the connection to the Splunk HEC server.
"""
logging.Handler.__init__(self)
self.client: splunk.SplunkHECClient | None = None
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=self.MAX_WORKERS)
self.last_future: concurrent.futures.Future | None = None
uri = uri or config.settings.splunk_uri
if uri:
self.client = self.SplunkHecClientSingleton.get_instance(uri)

def __del__(self) -> None:
"""
Clean up the executor when the handler is deleted.
"""
self.executor.shutdown(wait=True)

def flush(self) -> None:
"""
Wait for the last future to complete before flushing the handler.
"""
if self.last_future is not None:
self.last_future.result()
self.last_future = None

def emit(self, record: logging.LogRecord) -> None:
"""
Emit the log record to the Splunk HEC server, if a client is configured.
"""
if self.client is None:
# No Splunk HEC client configured, do nothing
return
msg = self.format(record)
data: dict[str, typing.Any] = {}
try:
# Attempt to parse the message as a JSON object
data = json.loads(msg)
except json.JSONDecodeError:
# If the message is not JSON, create a new dictionary with the message
data = {"message": msg}
# Run this in a separate thread to avoid blocking the main thread.
# Logging to Splunk is a bonus feature and should not block the main thread,
# using a ThreadPoolExecutor to send the request asynchronously allows us
# to initiate the request and continue processing without waiting for the IO
# operation to complete.
self.last_future = self.executor.submit(self.client.send, data, epoch=record.created)
79 changes: 79 additions & 0 deletions src/recordlinker/splunk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import json
import time
import typing
import urllib.parse
import urllib.request

TIMEOUT = 5


class SplunkError(Exception):
pass


class SplunkHECClient:
PATH = "/services/collector/event"

def __init__(self, splunk_uri: str) -> None:
"""
Create a new Splunk HEC client and test its connection.
The URI uses a custom scheme to specify the Splunk HEC server and parameters.
The URI format is:
splunkhec://<token>@<host>:<port>?index=<index>&proto=<protocol>&ssl_verify=<verify>&source=<source>
"""
try:
uri: urllib.parse.ParseResult = urllib.parse.urlparse(splunk_uri)
# flatten the query string values from lists to single values
qs: dict[str, str] = {k: v[0] for k, v in urllib.parse.parse_qs(uri.query).items()}

if uri.scheme != "splunkhec":
raise SplunkError(f"invalid scheme: {uri.scheme}")

scheme = qs.get("proto", "https").lower()
host = f"{uri.hostname}:{uri.port}" if uri.port else uri.hostname
self.url = f"{scheme}://{host}{self.PATH}"
self.headers = {
"Authorization": f"Splunk {uri.username}",
"Content-Type": "application/json",
}
# initialize the default payload parameters
self.params: dict[str, str] = {"host": uri.hostname or "", "sourcetype": "_json"}
if qs.get("index"):
self.params["index"] = qs["index"]
if qs.get("source"):
self.params["source"] = qs["source"]
self._test_connection()
except Exception as exc:
raise SplunkError(f"invalid connection: {splunk_uri}") from exc

def _send_request(self, body: bytes | None = None):
request = urllib.request.Request(self.url, data=body, method="POST", headers=self.headers)
try:
with urllib.request.urlopen(request, timeout=TIMEOUT) as response:
# return the response status code
return response.getcode()
except urllib.error.HTTPError as exc:
return exc.code

def _test_connection(self) -> None:
status = self._send_request()
# check for a 400 bad request, which indicates a successful connection
# 400 is expected because the payload is empty
if status != 400:
raise urllib.error.HTTPError(self.url, status, "could not connect", None, None) # type: ignore

def send(self, data: dict, epoch: float = 0) -> int:
"""
Send data to the Splunk HEC endpoint.
:param data: The data to send.
:param epoch: The timestamp to use for the event. If not provided, the current time is used.
:return: The HTTP status code of the response.
"""
epoch = epoch or int(time.time())
payload: dict[str, typing.Any] = {"time": epoch, "event": data} | self.params
body: bytes = json.dumps(payload).encode("utf-8")
try:
return self._send_request(body=body)
except Exception as exc:
raise SplunkError(f"could not send data: {data}") from exc
49 changes: 49 additions & 0 deletions tests/unit/test_log.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import unittest.mock

from recordlinker import log

@@ -88,3 +89,51 @@ def test_format_reserved_attrs(self):
)
record.taskName = "task"
assert formatter.format(record) == '{"message": "test"}'


class TestSplunkHecHandler:
def test_json_record(self):
with unittest.mock.patch("recordlinker.splunk.SplunkHECClient") as mock_client:
mock_instance = mock_client.return_value
mock_instance.send.return_value = 200
uri = "splunkhec://token@localhost:8088?index=index&source=source"
handler = log.SplunkHecHandler(uri=uri)
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="test_log.py",
lineno=10,
exc_info=None,
msg='{"key": "value"}',
args=[],
)
assert handler.emit(record) is None
handler.flush()
send_args = mock_instance.send.call_args.args
assert send_args == ({"key": "value"},)
send_kwargs = mock_instance.send.call_args.kwargs
assert send_kwargs == {"epoch": record.created}
log.SplunkHecHandler.SplunkHecClientSingleton._instance = None

def test_non_json_record(self):
with unittest.mock.patch("recordlinker.splunk.SplunkHECClient") as mock_client:
mock_instance = mock_client.return_value
mock_instance.send.return_value = 200
uri = "splunkhec://token@localhost:8088?index=index&source=source"
handler = log.SplunkHecHandler(uri=uri)
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="test_log.py",
lineno=10,
exc_info=None,
msg="test",
args=[],
)
assert handler.emit(record) is None
handler.flush()
send_args = mock_instance.send.call_args.args
assert send_args == ({"message": "test"},)
send_kwargs = mock_instance.send.call_args.kwargs
assert send_kwargs == {"epoch": record.created}
log.SplunkHecHandler.SplunkHecClientSingleton._instance = None
56 changes: 56 additions & 0 deletions tests/unit/test_splunk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import unittest.mock

import pytest

from recordlinker import splunk


class TestSplunkHECClient:
def test_invalid_uri(self):
with pytest.raises(splunk.SplunkError):
splunk.SplunkHECClient("http://localhost")

def test_valid_uri(self):
with unittest.mock.patch("urllib.request.urlopen") as mock_urlopen:
mock_response = unittest.mock.MagicMock()
mock_response.read.return_value = b"{}"
mock_response.getcode.return_value = 400 # Set getcode() to return 400
mock_urlopen.return_value.__enter__.return_value = mock_response
client = splunk.SplunkHECClient("splunkhec://token@localhost:8088?index=idx&source=src")
assert client.url == "https://localhost:8088/services/collector/event"
assert client.headers == {
"Authorization": "Splunk token",
"Content-Type": "application/json",
}
assert client.params == {"host": "localhost", "sourcetype": "_json", "index": "idx", "source": "src"}

def test_valid_uri_no_port(self):
with unittest.mock.patch("urllib.request.urlopen") as mock_urlopen:
mock_response = unittest.mock.MagicMock()
mock_response.read.return_value = b"{}"
mock_response.getcode.return_value = 400 # Set getcode() to return 400
mock_urlopen.return_value.__enter__.return_value = mock_response
client = splunk.SplunkHECClient("splunkhec://token@localhost?index=idx&source=src")
assert client.url == "https://localhost/services/collector/event"
assert client.headers == {
"Authorization": "Splunk token",
"Content-Type": "application/json",
}
assert client.params == {"host": "localhost", "sourcetype": "_json", "index": "idx", "source": "src"}

def test_send(self):
with unittest.mock.patch("urllib.request.urlopen") as mock_urlopen:
mock_response = unittest.mock.MagicMock()
mock_response.read.return_value = b"{}"
mock_response.getcode.side_effect = [400, 200] # Set getcode() to return 400
mock_urlopen.return_value.__enter__.return_value = mock_response
client = splunk.SplunkHECClient("splunkhec://token@localhost?index=idx&source=src")
assert client.send({"key": "value"}, epoch=10.5) == 200
req = mock_urlopen.call_args[0][0]
assert req.method == "POST"
assert req.get_full_url() == "https://localhost/services/collector/event"
assert req.headers == {
"Authorization": "Splunk token",
"Content-type": "application/json",
}
assert req.data == b'{"time": 10.5, "event": {"key": "value"}, "host": "localhost", "sourcetype": "_json", "index": "idx", "source": "src"}'