generated from CDCgov/template
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into documentation/installation-guide
- Loading branch information
Showing
6 changed files
with
295 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import json | ||
import time | ||
import typing | ||
import urllib.parse | ||
import urllib.request | ||
|
||
TIMEOUT = 5 | ||
|
||
|
||
class SplunkError(Exception): | ||
pass | ||
|
||
|
||
class SplunkHECClient: | ||
PATH = "/services/collector/event" | ||
|
||
def __init__(self, splunk_uri: str) -> None: | ||
""" | ||
Create a new Splunk HEC client and test its connection. | ||
The URI uses a custom scheme to specify the Splunk HEC server and parameters. | ||
The URI format is: | ||
splunkhec://<token>@<host>:<port>?index=<index>&proto=<protocol>&ssl_verify=<verify>&source=<source> | ||
""" | ||
try: | ||
uri: urllib.parse.ParseResult = urllib.parse.urlparse(splunk_uri) | ||
# flatten the query string values from lists to single values | ||
qs: dict[str, str] = {k: v[0] for k, v in urllib.parse.parse_qs(uri.query).items()} | ||
|
||
if uri.scheme != "splunkhec": | ||
raise SplunkError(f"invalid scheme: {uri.scheme}") | ||
|
||
scheme = qs.get("proto", "https").lower() | ||
host = f"{uri.hostname}:{uri.port}" if uri.port else uri.hostname | ||
self.url = f"{scheme}://{host}{self.PATH}" | ||
self.headers = { | ||
"Authorization": f"Splunk {uri.username}", | ||
"Content-Type": "application/json", | ||
} | ||
# initialize the default payload parameters | ||
self.params: dict[str, str] = {"host": uri.hostname or "", "sourcetype": "_json"} | ||
if qs.get("index"): | ||
self.params["index"] = qs["index"] | ||
if qs.get("source"): | ||
self.params["source"] = qs["source"] | ||
self._test_connection() | ||
except Exception as exc: | ||
raise SplunkError(f"invalid connection: {splunk_uri}") from exc | ||
|
||
def _send_request(self, body: bytes | None = None): | ||
request = urllib.request.Request(self.url, data=body, method="POST", headers=self.headers) | ||
try: | ||
with urllib.request.urlopen(request, timeout=TIMEOUT) as response: | ||
# return the response status code | ||
return response.getcode() | ||
except urllib.error.HTTPError as exc: | ||
return exc.code | ||
|
||
def _test_connection(self) -> None: | ||
status = self._send_request() | ||
# check for a 400 bad request, which indicates a successful connection | ||
# 400 is expected because the payload is empty | ||
if status != 400: | ||
raise urllib.error.HTTPError(self.url, status, "could not connect", None, None) # type: ignore | ||
|
||
def send(self, data: dict, epoch: float = 0) -> int: | ||
""" | ||
Send data to the Splunk HEC endpoint. | ||
:param data: The data to send. | ||
:param epoch: The timestamp to use for the event. If not provided, the current time is used. | ||
:return: The HTTP status code of the response. | ||
""" | ||
epoch = epoch or int(time.time()) | ||
payload: dict[str, typing.Any] = {"time": epoch, "event": data} | self.params | ||
body: bytes = json.dumps(payload).encode("utf-8") | ||
try: | ||
return self._send_request(body=body) | ||
except Exception as exc: | ||
raise SplunkError(f"could not send data: {data}") from exc |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import unittest.mock | ||
|
||
import pytest | ||
|
||
from recordlinker import splunk | ||
|
||
|
||
class TestSplunkHECClient: | ||
def test_invalid_uri(self): | ||
with pytest.raises(splunk.SplunkError): | ||
splunk.SplunkHECClient("http://localhost") | ||
|
||
def test_valid_uri(self): | ||
with unittest.mock.patch("urllib.request.urlopen") as mock_urlopen: | ||
mock_response = unittest.mock.MagicMock() | ||
mock_response.read.return_value = b"{}" | ||
mock_response.getcode.return_value = 400 # Set getcode() to return 400 | ||
mock_urlopen.return_value.__enter__.return_value = mock_response | ||
client = splunk.SplunkHECClient("splunkhec://token@localhost:8088?index=idx&source=src") | ||
assert client.url == "https://localhost:8088/services/collector/event" | ||
assert client.headers == { | ||
"Authorization": "Splunk token", | ||
"Content-Type": "application/json", | ||
} | ||
assert client.params == {"host": "localhost", "sourcetype": "_json", "index": "idx", "source": "src"} | ||
|
||
def test_valid_uri_no_port(self): | ||
with unittest.mock.patch("urllib.request.urlopen") as mock_urlopen: | ||
mock_response = unittest.mock.MagicMock() | ||
mock_response.read.return_value = b"{}" | ||
mock_response.getcode.return_value = 400 # Set getcode() to return 400 | ||
mock_urlopen.return_value.__enter__.return_value = mock_response | ||
client = splunk.SplunkHECClient("splunkhec://token@localhost?index=idx&source=src") | ||
assert client.url == "https://localhost/services/collector/event" | ||
assert client.headers == { | ||
"Authorization": "Splunk token", | ||
"Content-Type": "application/json", | ||
} | ||
assert client.params == {"host": "localhost", "sourcetype": "_json", "index": "idx", "source": "src"} | ||
|
||
def test_send(self): | ||
with unittest.mock.patch("urllib.request.urlopen") as mock_urlopen: | ||
mock_response = unittest.mock.MagicMock() | ||
mock_response.read.return_value = b"{}" | ||
mock_response.getcode.side_effect = [400, 200] # Set getcode() to return 400 | ||
mock_urlopen.return_value.__enter__.return_value = mock_response | ||
client = splunk.SplunkHECClient("splunkhec://token@localhost?index=idx&source=src") | ||
assert client.send({"key": "value"}, epoch=10.5) == 200 | ||
req = mock_urlopen.call_args[0][0] | ||
assert req.method == "POST" | ||
assert req.get_full_url() == "https://localhost/services/collector/event" | ||
assert req.headers == { | ||
"Authorization": "Splunk token", | ||
"Content-type": "application/json", | ||
} | ||
assert req.data == b'{"time": 10.5, "event": {"key": "value"}, "host": "localhost", "sourcetype": "_json", "index": "idx", "source": "src"}' |