Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

- Create first draft of key blocking code at the broker. #76

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions wgkex/broker/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@

WG_PUBKEY_PATTERN = re.compile(r"^[A-Za-z0-9+/]{42}[AEIMQUYcgkosw480]=$")

_BANNED_KEYS = list()
_BANNED_CLIENTS = list()


@dataclasses.dataclass
class KeyExchange:
Expand Down Expand Up @@ -93,6 +96,19 @@ def wg_api_v1_key_exchange() -> Tuple[Response | Dict, int]:
return {"error": {"message": str(ex)}}, 400

key = data.public_key
if key in _BANNED_KEYS:
logger.info(
f"wg_key_exchange: Got bad key from %s (%s)", request.remote_addr, data
)
return abort(403, jsonify({"error": {"Key is banned."}}))
if request.remote_addr in _BANNED_CLIENTS:
logger.info(
f"wg_key_exchange: Got key from banned client: %s (%s)",
request.remote_addr,
data,
)
return abort(403, jsonify({"error": {"Client is banned."}}))

domain = data.domain
# in case we want to decide here later we want to publish it only to dedicated gateways
gateway = "all"
Expand Down Expand Up @@ -160,6 +176,35 @@ def wg_api_v2_key_exchange() -> Tuple[Response | Dict, int]:
return {"Endpoint": endpoint}, 200


@app.route("/api/v1/wg/key/block", methods=["POST"])
def wg_key_block() -> Tuple[str, int]:
"""Blocks a key from being send onwards to MQTT.

Message format is as follows:
{
'client_literal': '',
'key_literal': '',
}

key_literal is a literal key.
client_literal is a string representing the source IP (v4,v6) of a client banned from sending keys.

Returns:
Status message.
"""
try:
data = request.get_json(force=True)
except TypeError as ex:
return abort(400, jsonify({"error": {"message": str(ex)}}))
key = data.get("key_literal")
client = data.get("client_literal")
if key:
_BANNED_KEYS.append(key)
if client:
_BANNED_CLIENTS.append(client)
jsonify({"Message": "OK"}), 200


@mqtt.on_connect()
def handle_mqtt_connect(
client: mqtt_client.Client, userdata: bytes, flags: Any, rc: Any
Expand Down
Loading