From 43a1e638f9cd39747034e4aa6bdb390b247c791f Mon Sep 17 00:00:00 2001 From: badrogger Date: Fri, 15 Nov 2024 18:12:03 +0000 Subject: [PATCH 01/17] Add NFTablesRuleController --- Dockerfile | 7 +- core/schains/firewall/firewall_manager.py | 9 + core/schains/firewall/nftables.py | 288 ++++++++++++++++++++++ core/schains/firewall/rule_controller.py | 12 +- core/schains/firewall/utils.py | 37 ++- tests.Dockerfile | 6 +- tests/firewall/nftables_test.py | 105 ++++++++ 7 files changed, 458 insertions(+), 6 deletions(-) create mode 100644 core/schains/firewall/nftables.py create mode 100644 tests/firewall/nftables_test.py diff --git a/Dockerfile b/Dockerfile index 0a2c24bb3..540b256f6 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,6 +1,6 @@ FROM python:3.11-bookworm -RUN apt-get update && apt-get install -y wget git libxslt-dev iptables kmod swig +RUN apt-get update && apt-get install -y wget git libxslt-dev iptables kmod swig nftables python3-nftables RUN mkdir /usr/src/admin WORKDIR /usr/src/admin @@ -8,12 +8,13 @@ WORKDIR /usr/src/admin COPY requirements.txt ./ COPY requirements-dev.txt ./ -RUN pip3 install --no-cache-dir -r requirements.txt +RUN pip3 install -r requirements.txt COPY . . RUN update-alternatives --set iptables /usr/sbin/iptables-legacy && \ update-alternatives --set ip6tables /usr/sbin/ip6tables-legacy -ENV PYTHONPATH="/usr/src/admin" +ENV PYTHONPATH="/usr/src/admin":/usr/lib/python3/dist-packages/ + ENV COLUMNS=80 diff --git a/core/schains/firewall/firewall_manager.py b/core/schains/firewall/firewall_manager.py index e2216fc73..43c5cb81d 100644 --- a/core/schains/firewall/firewall_manager.py +++ b/core/schains/firewall/firewall_manager.py @@ -22,6 +22,7 @@ from typing import Iterable, Optional from core.schains.firewall.iptables import IptablesController +from core.schains.firewall.nftables import NftablesController from core.schains.firewall.types import ( IFirewallManager, IHostFirewallController, @@ -88,3 +89,11 @@ def flush(self) -> None: class IptablesSChainFirewallManager(SChainFirewallManager): def create_host_controller(self) -> IptablesController: return IptablesController() + + +class NftSchainFirewallManager(SChainFirewallManager): + def create_host_controller(self) -> NftablesController: + nc_controller = NftablesController(chain=self.name) + nc_controller.create_table() + nc_controller.create_chain() + return nc_controller diff --git a/core/schains/firewall/nftables.py b/core/schains/firewall/nftables.py new file mode 100644 index 000000000..32cc24405 --- /dev/null +++ b/core/schains/firewall/nftables.py @@ -0,0 +1,288 @@ +# -*- coding: utf-8 -*- +# +# This file is part of SKALE Admin +# +# Copyright (C) 2024 SKALE Labs +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + + +import logging +import importlib +import ipaddress +import multiprocessing +from functools import wraps +from typing import Callable, Iterable + +from core.schains.firewall.types import IHostFirewallController, SChainRule + +from typing import TypeVar +import json + +T = TypeVar('T') + + +logger = logging.getLogger(__name__) + +TABLE = 'filter' +CHAIN = 'INPUT' + + +def refreshed(func: Callable) -> Callable: + @wraps(func) + def wrapper(self, *args, **kwargs): + self.refresh() + return func(self, *args, **kwargs) + + return wrapper + + +def is_like_number(value): + if value is None: + return False + try: + int(value) + except ValueError: + return False + return True + + +class NftablesCmdFailedError(Exception): + pass + + +class NftablesController(IHostFirewallController): + plock = multiprocessing.Lock() + FAMILY = 'inet' + + def __init__(self, table: str = TABLE, chain: str = CHAIN) -> None: + self.table = table + self.chain = chain + self._nftables = importlib.import_module('nftables') + self.nft = self._nftables.Nftables() + self.nft.set_json_output(True) + + def _compose_json(self, commands: list[dict]) -> dict: + json_cmd = {'nftables': commands} + self.nft.json_validate(json_cmd) + return json_cmd + + def create_table(self) -> None: + if not self.has_table(self.table): + return self.run_cmd(f'add table inet {self.table}') + + def create_chain(self) -> None: + if not self.has_chain(self.chain): + return self.run_json_cmd( + self._compose_json( + [ + { + 'add': { + 'chain': { + 'family': self.FAMILY, + 'table': self.table, + 'name': self.chain, + 'hook': 'input', + } + } + } + ] + ) + ) + + @property + def chains(self) -> list[dict]: + output = self.run_cmd('list chains') + if output[0] != 0: + raise NftablesCmdFailedError(output) + parsed = json.loads(output[1])['nftables'] + return [record['chain']['name'] for record in parsed if 'chain' in record] + + @property + def tables(self) -> list[dict]: + output = self.run_cmd('list tables') + if output[0] != 0: + raise NftablesCmdFailedError(output) + parsed = json.loads(output[1])['nftables'] + return [record['table']['name'] for record in parsed if 'table' in record] + + def run_json_cmd(self, cmd: dict) -> tuple: + logger.debug('Nftables json cmd %s', cmd) + with self.plock: + return self.nft.json_cmd(cmd) + + def run_cmd(self, cmd: str) -> tuple: + logger.debug('Nftables cmd %s', cmd) + with self.plock: + return self.nft.cmd(cmd) + + def has_chain(self, chain: str) -> bool: + return chain in self.chains + + def has_table(self, table: str) -> bool: + return table in self.tables + + def add_rule(self, rule: SChainRule) -> None: + if self.has_rule(rule): + return + expr = self.rule_to_expr(rule) + + json_cmd = self._compose_json( + [ + { + 'add': { + 'rule': { + 'family': self.FAMILY, + 'table': self.table, + 'chain': self.chain, + 'expr': expr, + } + } + } + ] + ) + + rc, output, error = self.run_json_cmd(json_cmd) + if rc != 0: + raise NftablesCmdFailedError(f'Failed to add allow rule: {error}') + + @classmethod + def rule_to_expr(cls, rule: SChainRule) -> list: + expr = [] + + if rule.first_ip: + if rule.last_ip == rule.first_ip: + expr.append( + { + 'match': { + 'left': {'payload': {'protocol': 'ip', 'field': 'saddr'}}, + 'op': '==', + 'right': f'{rule.first_ip}', + } + } + ) + else: + expr.append( + { + 'match': { + 'left': {'payload': {'protocol': 'ip', 'field': 'saddr'}}, + 'op': '==', + 'right': {'range': [f'{rule.first_ip}', f'{rule.last_ip}']}, + } + } + ) + + if rule.port: + expr.append( + { + 'match': { + 'left': {'payload': {'protocol': 'tcp', 'field': 'dport'}}, + 'op': '==', + 'right': rule.port, + } + } + ) + + expr.append({'accept': None}) + return expr + + @classmethod + def expr_to_rule(self, expr: list) -> None: + port, first_ip, last_ip = None, None, None + for item in expr: + if 'match' in item: + match = item['match'] + + if match.get('left', {}).get('payload', {}).get('field') == 'dport': + port = match.get('right') + + if match.get('left', {}).get('payload', {}).get('field') == 'saddr': + right = match.get('right') + if isinstance(right, str): + first_ip = right + else: + first_ip, last_ip = right['range'] + + if any([port, first_ip, last_ip]): + return SChainRule(port=port, first_ip=first_ip, last_ip=last_ip) + + def remove_rule(self, rule: SChainRule) -> None: + if self.has_rule(rule): + expr = self.rule_to_expr(rule) + + output = None + rc, output, error = self.run_cmd(f'list chain {self.FAMILY} {self.table} {self.chain}') + if rc != 0: + raise Exception(f'Failed to list rules: {error}') + + current_rules = json.loads(output) + + handle = None + for item in current_rules.get('nftables', []): + if 'rule' in item: + rule_data = item['rule'] + if rule_data.get('expr') == expr: + handle = rule_data.get('handle') + break + + if handle is None: + raise Exception('Rule not found') + + json_cmd = self._compose_json( + [ + { + 'delete': { + 'rule': { + 'family': self.FAMILY, + 'table': self.table, + 'chain': self.chain, + 'handle': handle, + } + } + } + ] + ) + + rc, output, error = self.run_json_cmd(json_cmd) + if rc != 0: + raise NftablesCmdFailedError(f'Failed to delete rule: {error}') + + @property # type: ignore + def rules(self) -> Iterable[SChainRule]: + output = None + rc, output, error = self.run_cmd(f'list chain {self.FAMILY} {self.table} {self.chain}') + if output == '': + return [] + + data = json.loads(output) + rules = [] + + for item in data.get('nftables', []): + if 'rule' in item: + plain_rule = item['rule'] + rule = self.expr_to_rule(plain_rule.get('expr', [])) + if rule: + rules.append(rule) + return rules + + def has_rule(self, rule: SChainRule) -> bool: + return rule in self.rules + + @classmethod + def from_ip_network(cls, ip: str) -> str: + return str(ipaddress.ip_network(ip).hosts()[0]) + + @classmethod + def to_ip_network(cls, ip: str) -> str: + return str(ipaddress.ip_network(ip)) diff --git a/core/schains/firewall/rule_controller.py b/core/schains/firewall/rule_controller.py index 51e8920a8..08bfcd48d 100644 --- a/core/schains/firewall/rule_controller.py +++ b/core/schains/firewall/rule_controller.py @@ -23,7 +23,7 @@ from functools import wraps from typing import Any, Callable, cast, Dict, Iterable, List, Optional, TypeVar -from .firewall_manager import IptablesSChainFirewallManager +from .firewall_manager import IptablesSChainFirewallManager, NftSchainFirewallManager from .types import ( IFirewallManager, IpRange, @@ -214,3 +214,13 @@ def create_firewall_manager(self) -> IptablesSChainFirewallManager: self.base_port, # type: ignore self.base_port + self.ports_per_schain - 1 # type: ignore ) + + +class NftSchainRuleController(SChainRuleController): + @configured_only + def create_firewall_manager(self) -> NftSchainFirewallManager: + return NftSchainFirewallManager( + self.name, + self.base_port, # type: ignore + self.base_port + self.ports_per_schain - 1 # type: ignore + ) diff --git a/core/schains/firewall/utils.py b/core/schains/firewall/utils.py index 737361e18..7bb5ec006 100644 --- a/core/schains/firewall/utils.py +++ b/core/schains/firewall/utils.py @@ -25,7 +25,7 @@ from skale import Skale from .types import IpRange -from .rule_controller import IptablesSChainRuleController +from .rule_controller import IptablesSChainRuleController, NftSchainRuleController logger = logging.getLogger(__name__) @@ -37,6 +37,22 @@ def get_default_rule_controller( own_ip: Optional[str] = None, node_ips: List[str] = [], sync_agent_ranges: Optional[List[IpRange]] = [] +) -> IptablesSChainRuleController: + return get_nftables_rule_controller( + name=name, + base_port=base_port, + own_ip=own_ip, + node_ips=node_ips, + sync_agent_ranges=sync_agent_ranges + ) + + +def get_iptables_rule_controller( + name: str, + base_port: Optional[int] = None, + own_ip: Optional[str] = None, + node_ips: List[str] = [], + sync_agent_ranges: Optional[List[IpRange]] = [] ) -> IptablesSChainRuleController: sync_agent_ranges = sync_agent_ranges or [] logger.info('Creating rule controller for %s', name) @@ -50,6 +66,25 @@ def get_default_rule_controller( ) +def get_nftables_rule_controller( + name: str, + base_port: Optional[int] = None, + own_ip: Optional[str] = None, + node_ips: List[str] = [], + sync_agent_ranges: Optional[List[IpRange]] = [] +) -> NftSchainRuleController: + sync_agent_ranges = sync_agent_ranges or [] + logger.info('Creating rule controller for %s', name) + logger.debug('Rule controller ranges for %s: %s', name, sync_agent_ranges) + return NftSchainRuleController( + name=name, + base_port=base_port, + own_ip=own_ip, + node_ips=node_ips, + sync_ip_ranges=sync_agent_ranges + ) + + def get_sync_agent_ranges(skale: Skale) -> List[IpRange]: sync_agent_ranges = [] rnum = skale.sync_manager.get_ip_ranges_number() diff --git a/tests.Dockerfile b/tests.Dockerfile index b31db00ee..75b72f635 100644 --- a/tests.Dockerfile +++ b/tests.Dockerfile @@ -1,3 +1,7 @@ FROM admin:base -RUN pip3 install --no-cache-dir -r requirements-dev.txt +RUN apt update && apt install -y nftables python3-nftables + +RUN pip3 install -r requirements-dev.txt + +ENV PYTHONPATH=${PYTHONPATH}:/usr/lib/python3/dist-packages/ diff --git a/tests/firewall/nftables_test.py b/tests/firewall/nftables_test.py new file mode 100644 index 000000000..038ec6cc8 --- /dev/null +++ b/tests/firewall/nftables_test.py @@ -0,0 +1,105 @@ +import concurrent.futures +import importlib +import subprocess +import time + +import pytest + +from core.schains.firewall.nftables import NftablesController +from core.schains.firewall.types import SChainRule + + +@pytest.fixture +def nf_test_tables(): + nft = importlib.import_module('nftables').Nftables() + nft.cmd('flush ruleset') + return nft + + +@pytest.fixture +def filter_table(nf_test_tables): + print(nf_test_tables.cmd('add table inet filter')) + + +@pytest.fixture +def custom_chain(nf_test_tables, filter_table): + nf_test_tables.cmd('add chain inet filter test-chain') + return 'test-chain' + + +def test_nftables_controller(custom_chain): + nft_controller = NftablesController(chain='test-chain') + rule_a = SChainRule(10000, '1.1.1.1', '2.2.2.2') + rule_b = SChainRule(10001, '3.3.3.3') + nft_controller.add_rule(rule_a) + nft_controller.add_rule(rule_b) + assert nft_controller.has_rule(rule_a) + assert nft_controller.has_rule(rule_b) + rules = list(nft_controller.rules) + assert rules == sorted([rule_b, rule_a]) + nft_controller.remove_rule(rule_a) + assert not nft_controller.has_rule(rule_a) + assert nft_controller.has_rule(rule_b) + nft_controller.remove_rule(rule_b) + assert not nft_controller.has_rule(rule_a) + + +def test_nftables_controller_duplicates(custom_chain): + rule_a = SChainRule(10000, '1.1.1.1', '2.2.2.2') + manager = NftablesController(chain='test-chain') + manager.add_rule(rule_a) + rule_b = SChainRule(10001, '3.3.3.3', '4.4.4.4') + manager.add_rule(rule_b) + assert sorted(list(manager.rules)) == sorted([ + SChainRule(port=10001, first_ip='3.3.3.3', last_ip='4.4.4.4'), + SChainRule(port=10000, first_ip='1.1.1.1', last_ip='2.2.2.2') + ]) + assert manager.has_rule(rule_b) + manager.add_rule(rule_b) + assert manager.has_rule(rule_b) + assert sorted(list(manager.rules)) == sorted([ + SChainRule(port=10001, first_ip='3.3.3.3', last_ip='4.4.4.4'), + SChainRule(port=10000, first_ip='1.1.1.1', last_ip='2.2.2.2') + ]) + manager.remove_rule(rule_b) + assert list(manager.rules) == [ + SChainRule(port=10000, first_ip='1.1.1.1', last_ip='2.2.2.2') + ] + + +def add_remove_rule(srule, refresh): + manager = NftablesController() + manager.add_rule(srule) + time.sleep(1) + if not manager.has_rule(srule): + return False + time.sleep(1) + manager.remove_rule(srule) + return True + + +def generate_srules(number=5): + return [ + SChainRule( + 10000 + 1, + f'{i}.{i}.{i}.{i}', f'{i + 1}.{i + 1}.{i + 1}.{i + 1}' + ) + for i in range(1, number * 2, 2) + ] + + +def test_nftables_manager_parallel(custom_chain): + srules = generate_srules(number=12) + + futures = [] + with concurrent.futures.ProcessPoolExecutor(max_workers=12) as executor: + futures = [ + executor.submit(add_remove_rule, srule) + for srule in srules + ] + + for future in concurrent.futures.as_completed(futures): + assert future.result + manager = NftablesController(custom_chain) + time.sleep(10) + assert len(list(manager.rules)) == 0 From 6980a49edf1d479da677b871e6fe64bf6188443a Mon Sep 17 00:00:00 2001 From: badrogger Date: Fri, 15 Nov 2024 18:38:11 +0000 Subject: [PATCH 02/17] Rename Nftables to NFTables --- core/schains/firewall/firewall_manager.py | 8 ++++---- core/schains/firewall/nftables.py | 18 +++++++++--------- core/schains/firewall/rule_controller.py | 8 ++++---- core/schains/firewall/utils.py | 6 +++--- tests/firewall/nftables_test.py | 13 ++++++------- 5 files changed, 26 insertions(+), 27 deletions(-) diff --git a/core/schains/firewall/firewall_manager.py b/core/schains/firewall/firewall_manager.py index 43c5cb81d..b43f3a223 100644 --- a/core/schains/firewall/firewall_manager.py +++ b/core/schains/firewall/firewall_manager.py @@ -22,7 +22,7 @@ from typing import Iterable, Optional from core.schains.firewall.iptables import IptablesController -from core.schains.firewall.nftables import NftablesController +from core.schains.firewall.nftables import NFTablesController from core.schains.firewall.types import ( IFirewallManager, IHostFirewallController, @@ -91,9 +91,9 @@ def create_host_controller(self) -> IptablesController: return IptablesController() -class NftSchainFirewallManager(SChainFirewallManager): - def create_host_controller(self) -> NftablesController: - nc_controller = NftablesController(chain=self.name) +class NFTSchainFirewallManager(SChainFirewallManager): + def create_host_controller(self) -> NFTablesController: + nc_controller = NFTablesController(chain=self.name) nc_controller.create_table() nc_controller.create_chain() return nc_controller diff --git a/core/schains/firewall/nftables.py b/core/schains/firewall/nftables.py index 32cc24405..95e677826 100644 --- a/core/schains/firewall/nftables.py +++ b/core/schains/firewall/nftables.py @@ -58,11 +58,11 @@ def is_like_number(value): return True -class NftablesCmdFailedError(Exception): +class NFTablesCmdFailedError(Exception): pass -class NftablesController(IHostFirewallController): +class NFTablesController(IHostFirewallController): plock = multiprocessing.Lock() FAMILY = 'inet' @@ -70,7 +70,7 @@ def __init__(self, table: str = TABLE, chain: str = CHAIN) -> None: self.table = table self.chain = chain self._nftables = importlib.import_module('nftables') - self.nft = self._nftables.Nftables() + self.nft = self._nftables.NFTables() self.nft.set_json_output(True) def _compose_json(self, commands: list[dict]) -> dict: @@ -105,7 +105,7 @@ def create_chain(self) -> None: def chains(self) -> list[dict]: output = self.run_cmd('list chains') if output[0] != 0: - raise NftablesCmdFailedError(output) + raise NFTablesCmdFailedError(output) parsed = json.loads(output[1])['nftables'] return [record['chain']['name'] for record in parsed if 'chain' in record] @@ -113,17 +113,17 @@ def chains(self) -> list[dict]: def tables(self) -> list[dict]: output = self.run_cmd('list tables') if output[0] != 0: - raise NftablesCmdFailedError(output) + raise NFTablesCmdFailedError(output) parsed = json.loads(output[1])['nftables'] return [record['table']['name'] for record in parsed if 'table' in record] def run_json_cmd(self, cmd: dict) -> tuple: - logger.debug('Nftables json cmd %s', cmd) + logger.debug('NFTables json cmd %s', cmd) with self.plock: return self.nft.json_cmd(cmd) def run_cmd(self, cmd: str) -> tuple: - logger.debug('Nftables cmd %s', cmd) + logger.debug('NFTables cmd %s', cmd) with self.plock: return self.nft.cmd(cmd) @@ -155,7 +155,7 @@ def add_rule(self, rule: SChainRule) -> None: rc, output, error = self.run_json_cmd(json_cmd) if rc != 0: - raise NftablesCmdFailedError(f'Failed to add allow rule: {error}') + raise NFTablesCmdFailedError(f'Failed to add allow rule: {error}') @classmethod def rule_to_expr(cls, rule: SChainRule) -> list: @@ -256,7 +256,7 @@ def remove_rule(self, rule: SChainRule) -> None: rc, output, error = self.run_json_cmd(json_cmd) if rc != 0: - raise NftablesCmdFailedError(f'Failed to delete rule: {error}') + raise NFTablesCmdFailedError(f'Failed to delete rule: {error}') @property # type: ignore def rules(self) -> Iterable[SChainRule]: diff --git a/core/schains/firewall/rule_controller.py b/core/schains/firewall/rule_controller.py index 08bfcd48d..3b63026bb 100644 --- a/core/schains/firewall/rule_controller.py +++ b/core/schains/firewall/rule_controller.py @@ -23,7 +23,7 @@ from functools import wraps from typing import Any, Callable, cast, Dict, Iterable, List, Optional, TypeVar -from .firewall_manager import IptablesSChainFirewallManager, NftSchainFirewallManager +from .firewall_manager import IptablesSChainFirewallManager, NFTSchainFirewallManager from .types import ( IFirewallManager, IpRange, @@ -216,10 +216,10 @@ def create_firewall_manager(self) -> IptablesSChainFirewallManager: ) -class NftSchainRuleController(SChainRuleController): +class NFTSchainRuleController(SChainRuleController): @configured_only - def create_firewall_manager(self) -> NftSchainFirewallManager: - return NftSchainFirewallManager( + def create_firewall_manager(self) -> NFTSchainFirewallManager: + return NFTSchainFirewallManager( self.name, self.base_port, # type: ignore self.base_port + self.ports_per_schain - 1 # type: ignore diff --git a/core/schains/firewall/utils.py b/core/schains/firewall/utils.py index 7bb5ec006..1f94694fd 100644 --- a/core/schains/firewall/utils.py +++ b/core/schains/firewall/utils.py @@ -25,7 +25,7 @@ from skale import Skale from .types import IpRange -from .rule_controller import IptablesSChainRuleController, NftSchainRuleController +from .rule_controller import IptablesSChainRuleController, NFTSchainRuleController logger = logging.getLogger(__name__) @@ -72,11 +72,11 @@ def get_nftables_rule_controller( own_ip: Optional[str] = None, node_ips: List[str] = [], sync_agent_ranges: Optional[List[IpRange]] = [] -) -> NftSchainRuleController: +) -> NFTSchainRuleController: sync_agent_ranges = sync_agent_ranges or [] logger.info('Creating rule controller for %s', name) logger.debug('Rule controller ranges for %s: %s', name, sync_agent_ranges) - return NftSchainRuleController( + return NFTSchainRuleController( name=name, base_port=base_port, own_ip=own_ip, diff --git a/tests/firewall/nftables_test.py b/tests/firewall/nftables_test.py index 038ec6cc8..2cfa02058 100644 --- a/tests/firewall/nftables_test.py +++ b/tests/firewall/nftables_test.py @@ -1,17 +1,16 @@ import concurrent.futures import importlib -import subprocess import time import pytest -from core.schains.firewall.nftables import NftablesController +from core.schains.firewall.nftables import NFTablesController from core.schains.firewall.types import SChainRule @pytest.fixture def nf_test_tables(): - nft = importlib.import_module('nftables').Nftables() + nft = importlib.import_module('nftables').NFTables() nft.cmd('flush ruleset') return nft @@ -28,7 +27,7 @@ def custom_chain(nf_test_tables, filter_table): def test_nftables_controller(custom_chain): - nft_controller = NftablesController(chain='test-chain') + nft_controller = NFTablesController(chain='test-chain') rule_a = SChainRule(10000, '1.1.1.1', '2.2.2.2') rule_b = SChainRule(10001, '3.3.3.3') nft_controller.add_rule(rule_a) @@ -46,7 +45,7 @@ def test_nftables_controller(custom_chain): def test_nftables_controller_duplicates(custom_chain): rule_a = SChainRule(10000, '1.1.1.1', '2.2.2.2') - manager = NftablesController(chain='test-chain') + manager = NFTablesController(chain='test-chain') manager.add_rule(rule_a) rule_b = SChainRule(10001, '3.3.3.3', '4.4.4.4') manager.add_rule(rule_b) @@ -68,7 +67,7 @@ def test_nftables_controller_duplicates(custom_chain): def add_remove_rule(srule, refresh): - manager = NftablesController() + manager = NFTablesController() manager.add_rule(srule) time.sleep(1) if not manager.has_rule(srule): @@ -100,6 +99,6 @@ def test_nftables_manager_parallel(custom_chain): for future in concurrent.futures.as_completed(futures): assert future.result - manager = NftablesController(custom_chain) + manager = NFTablesController(custom_chain) time.sleep(10) assert len(list(manager.rules)) == 0 From 17a9eeade8719dc62fd0badaebfd8c1c4741f68f Mon Sep 17 00:00:00 2001 From: badrogger Date: Mon, 18 Nov 2024 16:10:57 +0000 Subject: [PATCH 03/17] Fix tests --- tests/firewall/nftables_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/firewall/nftables_test.py b/tests/firewall/nftables_test.py index 2cfa02058..d11a2d6ff 100644 --- a/tests/firewall/nftables_test.py +++ b/tests/firewall/nftables_test.py @@ -10,7 +10,7 @@ @pytest.fixture def nf_test_tables(): - nft = importlib.import_module('nftables').NFTables() + nft = importlib.import_module('nftables').Nftables() nft.cmd('flush ruleset') return nft From 204673a4038f0aebd9b95e23991812d31b35cab7 Mon Sep 17 00:00:00 2001 From: badrogger Date: Mon, 18 Nov 2024 19:09:47 +0000 Subject: [PATCH 04/17] Fix import --- core/schains/firewall/nftables.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/schains/firewall/nftables.py b/core/schains/firewall/nftables.py index 95e677826..5c8808c16 100644 --- a/core/schains/firewall/nftables.py +++ b/core/schains/firewall/nftables.py @@ -70,7 +70,7 @@ def __init__(self, table: str = TABLE, chain: str = CHAIN) -> None: self.table = table self.chain = chain self._nftables = importlib.import_module('nftables') - self.nft = self._nftables.NFTables() + self.nft = self._nftables.Nftables() self.nft.set_json_output(True) def _compose_json(self, commands: list[dict]) -> dict: From e9cb5064f461a8f4c52667424fbf330cd9462a62 Mon Sep 17 00:00:00 2001 From: badrogger Date: Fri, 22 Nov 2024 15:46:27 +0000 Subject: [PATCH 05/17] Fix default_rule_controller_test --- core/schains/firewall/__init__.py | 1 + .../firewall/default_rule_controller_test.py | 67 +++++-------------- 2 files changed, 17 insertions(+), 51 deletions(-) diff --git a/core/schains/firewall/__init__.py b/core/schains/firewall/__init__.py index 8edbd1a7c..1bba60b76 100644 --- a/core/schains/firewall/__init__.py +++ b/core/schains/firewall/__init__.py @@ -19,6 +19,7 @@ from .firewall_manager import SChainFirewallManager # noqa from .iptables import IptablesController # noqa +from .nftables import NFTablesController # noqa from .rule_controller import SChainRuleController # noqa from .types import IRuleController # noqa from .utils import get_default_rule_controller # noqa diff --git a/tests/firewall/default_rule_controller_test.py b/tests/firewall/default_rule_controller_test.py index c2473e16f..2032b918c 100644 --- a/tests/firewall/default_rule_controller_test.py +++ b/tests/firewall/default_rule_controller_test.py @@ -1,23 +1,13 @@ + import mock import pytest import concurrent.futures from skale.schain_config import PORTS_PER_SCHAIN # noqa -from core.schains.firewall import IptablesController +from core.schains.firewall import NFTablesController from core.schains.firewall.utils import get_default_rule_controller from core.schains.firewall.types import IpRange, SkaledPorts -from tests.firewall.iptables_test import get_rules_through_subprocess -from tools.helper import run_cmd - - -@pytest.fixture -def refresh(): - run_cmd(['iptables', '-F']) - try: - yield - finally: - run_cmd(['iptables', '-F']) def test_get_default_rule_controller(): @@ -58,22 +48,6 @@ def sync_rules(*args): return False -def parse_plain_rule(plain_rule): - first_ip, last_ip, port = None, None, None - pr = plain_rule.split() - if '--src-range' in pr: - srange = pr[11] - first_ip, last_ip = srange.split('-') - port = pr[7] - elif '-s' in pr: - first_ip = last_ip = pr[3][:-3] - port = pr[9] - elif '--dport' in pr: - port = pr[7] - - return first_ip, last_ip, int(port) - - def run_concurrent_rc_syncing( node_number, schain_number, @@ -147,55 +121,46 @@ def run_concurrent_rc_syncing( else: assert not r - pr = get_rules_through_subprocess(unique=False)[3:] - rules = [parse_plain_rule(r) for r in pr] + controllers = [NFTablesController(chain=name) for name in schain_names] + rules = [] + for controller in controllers: + rules.extend(controller.rules) + + print([r.port for r in rules]) + print([r.first_ip for r in rules]) - c = IptablesController() # Check that all ip rules are there for ip in node_ips: if ip != own_ip: assert sum( - map(lambda x: x[0] == ip, rules) - ) == 5 * schain_number, ip - assert sum( - map(lambda x: x.first_ip == ip, c.rules) + map(lambda x: x.first_ip == ip, rules) ) == 5 * schain_number, ip # Check that all internal ports rules are there except CATCHUP for p in internal_ports: - assert sum(map(lambda x: x[2] == p, rules)) == node_number - 1, p - assert sum(map(lambda x: x.port == p, c.rules)) == node_number - 1, p + assert sum(map(lambda x: x.port == p, rules)) == node_number - 1, p # Check CATCHUP rules including sync agents rules catchup_e_number = node_number + sync_agent_ranges_number - 1 for p in catchup_ports: - assert sum(map(lambda x: x[2] == p, rules)) == catchup_e_number, p - assert sum(map(lambda x: x.port == p, c.rules)) == catchup_e_number, p + assert sum(map(lambda x: x.port == p, rules)) == catchup_e_number, p # Check ZMQ rules including sync agents rules zmq_e_number = node_number + sync_agent_ranges_number - 1 for p in zmq_ports: - assert sum(map(lambda x: x[2] == p, rules)) == zmq_e_number, p - assert sum(map(lambda x: x.port == p, c.rules)) == zmq_e_number, p + assert sum(map(lambda x: x.port == p, rules)) == zmq_e_number, p # Check sync ip ranges rules for r in sync_agent_ranges: assert sum( - map(lambda x: x[0] == r.start_ip, rules) - ) == schain_number * 2, ip - assert sum( - map(lambda x: x.first_ip == r.start_ip, c.rules) - ) == schain_number * 2, ip - assert sum( - map(lambda x: x[1] == r.end_ip, rules) + map(lambda x: x.first_ip == r.start_ip, rules) ) == schain_number * 2, ip assert sum( - map(lambda x: x.last_ip == r.end_ip, c.rules) + map(lambda x: x.last_ip == r.end_ip, rules) ) == schain_number * 2, ip for port in public_ports: - assert sum(map(lambda x: x[2] == port, rules)) == 1, port - assert sum(map(lambda x: x.port == port, c.rules)) == 1, port + assert sum(map(lambda x: x.port == port, rules)) == 1, port @pytest.mark.parametrize('attempt', range(5)) From 8e40cc4c5018665a4d42f43fffaf91aadd576242 Mon Sep 17 00:00:00 2001 From: badrogger Date: Mon, 25 Nov 2024 13:33:29 +0000 Subject: [PATCH 06/17] Fix default_rule_controller test --- tests/firewall/default_rule_controller_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/firewall/default_rule_controller_test.py b/tests/firewall/default_rule_controller_test.py index 2032b918c..bd29fb40d 100644 --- a/tests/firewall/default_rule_controller_test.py +++ b/tests/firewall/default_rule_controller_test.py @@ -180,7 +180,7 @@ def test_concurrent_rc_behavior_no_refresh(attempt): @pytest.mark.parametrize('attempt', range(5)) -def test_concurrent_rc_behavior_with_refresh(attempt, refresh): +def test_concurrent_rc_behavior_with_refresh(attempt): node_number = 16 schain_number = 8 own_ip = '1.1.1.1' From b8d57817004a85b1909b0fa48ae3479bc88e48a7 Mon Sep 17 00:00:00 2001 From: badrogger Date: Mon, 25 Nov 2024 15:09:24 +0000 Subject: [PATCH 07/17] Fix test_concurrent_rc_behavior_with_refresh --- tests/firewall/default_rule_controller_test.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/tests/firewall/default_rule_controller_test.py b/tests/firewall/default_rule_controller_test.py index bd29fb40d..0eb4668a3 100644 --- a/tests/firewall/default_rule_controller_test.py +++ b/tests/firewall/default_rule_controller_test.py @@ -9,6 +9,17 @@ from core.schains.firewall.utils import get_default_rule_controller from core.schains.firewall.types import IpRange, SkaledPorts +from tools.helper import run_cmd + + +@pytest.fixture +def refresh(): + run_cmd(['nft', 'flush', 'ruleset']) + try: + yield + finally: + run_cmd(['nft', 'flush', 'ruleset']) + def test_get_default_rule_controller(): own_ip = '3.3.3.3' @@ -180,7 +191,7 @@ def test_concurrent_rc_behavior_no_refresh(attempt): @pytest.mark.parametrize('attempt', range(5)) -def test_concurrent_rc_behavior_with_refresh(attempt): +def test_concurrent_rc_behavior_with_refresh(attempt, refresh): node_number = 16 schain_number = 8 own_ip = '1.1.1.1' From 50cd4ce53d961c5f3a04302636600a32d82a74f1 Mon Sep 17 00:00:00 2001 From: badrogger Date: Mon, 25 Nov 2024 18:35:10 +0000 Subject: [PATCH 08/17] Do not raise Exception in nftables --- core/schains/firewall/nftables.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/schains/firewall/nftables.py b/core/schains/firewall/nftables.py index 5c8808c16..753015823 100644 --- a/core/schains/firewall/nftables.py +++ b/core/schains/firewall/nftables.py @@ -224,7 +224,7 @@ def remove_rule(self, rule: SChainRule) -> None: output = None rc, output, error = self.run_cmd(f'list chain {self.FAMILY} {self.table} {self.chain}') if rc != 0: - raise Exception(f'Failed to list rules: {error}') + raise NFTablesCmdFailedError(f'Failed to list rules: {error}') current_rules = json.loads(output) @@ -237,7 +237,7 @@ def remove_rule(self, rule: SChainRule) -> None: break if handle is None: - raise Exception('Rule not found') + raise NFTablesCmdFailedError('Rule not found') json_cmd = self._compose_json( [ From 3ff53f841d4d0f8bca71ee4ce3d3a7a77b201108 Mon Sep 17 00:00:00 2001 From: badrogger Date: Fri, 13 Dec 2024 20:44:54 +0000 Subject: [PATCH 09/17] Fix chain creation --- core/schains/firewall/firewall_manager.py | 2 +- core/schains/firewall/nftables.py | 101 +++++++++++++++------- 2 files changed, 70 insertions(+), 33 deletions(-) diff --git a/core/schains/firewall/firewall_manager.py b/core/schains/firewall/firewall_manager.py index b43f3a223..d16f71574 100644 --- a/core/schains/firewall/firewall_manager.py +++ b/core/schains/firewall/firewall_manager.py @@ -95,5 +95,5 @@ class NFTSchainFirewallManager(SChainFirewallManager): def create_host_controller(self) -> NFTablesController: nc_controller = NFTablesController(chain=self.name) nc_controller.create_table() - nc_controller.create_chain() + nc_controller.create_chain(self.first_port, self.last_port) return nc_controller diff --git a/core/schains/firewall/nftables.py b/core/schains/firewall/nftables.py index 753015823..36d98543b 100644 --- a/core/schains/firewall/nftables.py +++ b/core/schains/firewall/nftables.py @@ -22,8 +22,7 @@ import importlib import ipaddress import multiprocessing -from functools import wraps -from typing import Callable, Iterable +from typing import Iterable from core.schains.firewall.types import IHostFirewallController, SChainRule @@ -35,29 +34,10 @@ logger = logging.getLogger(__name__) -TABLE = 'filter' +TABLE = 'firewall' CHAIN = 'INPUT' -def refreshed(func: Callable) -> Callable: - @wraps(func) - def wrapper(self, *args, **kwargs): - self.refresh() - return func(self, *args, **kwargs) - - return wrapper - - -def is_like_number(value): - if value is None: - return False - try: - int(value) - except ValueError: - return False - return True - - class NFTablesCmdFailedError(Exception): pass @@ -82,7 +62,43 @@ def create_table(self) -> None: if not self.has_table(self.table): return self.run_cmd(f'add table inet {self.table}') - def create_chain(self) -> None: + def add_schain_drop_rule(self, first_port: int, last_port: int) -> None: + expr = [ + { + "match": { + "left": { + "payload": { + "protocol": "tcp", + "field": "dport" + } + }, + "op": "==", + "right": {'range': [first_port, last_port]} + } + }, + {'counter': None}, + {"drop": None} + ] + + if self.expr_to_rule(expr) not in self.get_rules_by_policy(policy='drop'): + cmd = { + 'nftables': [ + { + 'add': { + 'rule': { + 'family': self.FAMILY, + 'table': self.table, + 'chain': self.chain, + 'expr': expr, + } + } + } + ] + } + self.run_json_cmd(cmd) + logger.info('Added drop rule for chain %s', self.chain) + + def create_chain(self, first_port: int, last_port: int) -> None: if not self.has_chain(self.chain): return self.run_json_cmd( self._compose_json( @@ -94,12 +110,16 @@ def create_chain(self) -> None: 'table': self.table, 'name': self.chain, 'hook': 'input', + 'type': 'filter', + 'prio': 0, + 'policy': 'accept', } } } ] ) ) + self.add_schain_drop_rule(first_port, last_port) @property def chains(self) -> list[dict]: @@ -141,7 +161,7 @@ def add_rule(self, rule: SChainRule) -> None: json_cmd = self._compose_json( [ { - 'add': { + 'insert': { 'rule': { 'family': self.FAMILY, 'table': self.table, @@ -194,7 +214,7 @@ def rule_to_expr(cls, rule: SChainRule) -> list: } ) - expr.append({'accept': None}) + expr.extend([{'counter': None}, {'accept': None}]) return expr @classmethod @@ -217,6 +237,13 @@ def expr_to_rule(self, expr: list) -> None: if any([port, first_ip, last_ip]): return SChainRule(port=port, first_ip=first_ip, last_ip=last_ip) + @classmethod + def expr_equals(cls, expr_a: list[dict], expr_b: list[dict]) -> bool: + for item_a, item_b in zip(sorted(expr_a), sorted(expr_b)): + if 'counter' not in item_a and item_a != item_b: + return False + return True + def remove_rule(self, rule: SChainRule) -> None: if self.has_rule(rule): expr = self.rule_to_expr(rule) @@ -228,11 +255,15 @@ def remove_rule(self, rule: SChainRule) -> None: current_rules = json.loads(output) + logger.info('HERE HERE %s', expr) + logger.info('HERE current rules %s', current_rules) handle = None for item in current_rules.get('nftables', []): if 'rule' in item: rule_data = item['rule'] - if rule_data.get('expr') == expr: + logger.info('HERE HERE 2 %s', rule_data['expr']) + logger.info('HERE HERE 3 %s', expr) + if self.expr_equals(rule_data.get('expr'), expr): handle = rule_data.get('handle') break @@ -260,6 +291,12 @@ def remove_rule(self, rule: SChainRule) -> None: @property # type: ignore def rules(self) -> Iterable[SChainRule]: + return self.get_rules_by_policy(policy='accept') + + def has_rule(self, rule: SChainRule) -> bool: + return rule in self.rules + + def get_rules_by_policy(self, policy: str) -> list[SChainRule]: output = None rc, output, error = self.run_cmd(f'list chain {self.FAMILY} {self.table} {self.chain}') if output == '': @@ -271,14 +308,14 @@ def rules(self) -> Iterable[SChainRule]: for item in data.get('nftables', []): if 'rule' in item: plain_rule = item['rule'] - rule = self.expr_to_rule(plain_rule.get('expr', [])) - if rule: - rules.append(rule) + expr = plain_rule.get('expr', []) + if {policy: None} in expr: + rule = self.expr_to_rule(expr) + if rule: + rules.append(rule) + logger.debug('Rules for policy %s: %s', policy, rules) return rules - def has_rule(self, rule: SChainRule) -> bool: - return rule in self.rules - @classmethod def from_ip_network(cls, ip: str) -> str: return str(ipaddress.ip_network(ip).hosts()[0]) From 1b8e97050f065c5abe427487c1d9b27c57ab81a6 Mon Sep 17 00:00:00 2001 From: badrogger Date: Mon, 16 Dec 2024 13:36:08 +0000 Subject: [PATCH 10/17] Fix tests --- core/schains/firewall/nftables.py | 52 +++++++++++++------------------ tests/firewall/nftables_test.py | 6 ++-- 2 files changed, 24 insertions(+), 34 deletions(-) diff --git a/core/schains/firewall/nftables.py b/core/schains/firewall/nftables.py index 36d98543b..0404a0f21 100644 --- a/core/schains/firewall/nftables.py +++ b/core/schains/firewall/nftables.py @@ -64,20 +64,15 @@ def create_table(self) -> None: def add_schain_drop_rule(self, first_port: int, last_port: int) -> None: expr = [ - { - "match": { - "left": { - "payload": { - "protocol": "tcp", - "field": "dport" + { + 'match': { + 'op': '==', + 'left': {'payload': {'protocol': 'tcp', 'field': 'dport'}}, + 'right': {'range': [first_port, last_port]}, } - }, - "op": "==", - "right": {'range': [first_port, last_port]} - } - }, - {'counter': None}, - {"drop": None} + }, + {'counter': None}, + {'drop': None}, ] if self.expr_to_rule(expr) not in self.get_rules_by_policy(policy='drop'): @@ -178,7 +173,7 @@ def add_rule(self, rule: SChainRule) -> None: raise NFTablesCmdFailedError(f'Failed to add allow rule: {error}') @classmethod - def rule_to_expr(cls, rule: SChainRule) -> list: + def rule_to_expr(cls, rule: SChainRule, counter: bool = True) -> list: expr = [] if rule.first_ip: @@ -186,8 +181,8 @@ def rule_to_expr(cls, rule: SChainRule) -> list: expr.append( { 'match': { - 'left': {'payload': {'protocol': 'ip', 'field': 'saddr'}}, 'op': '==', + 'left': {'payload': {'protocol': 'ip', 'field': 'saddr'}}, 'right': f'{rule.first_ip}', } } @@ -196,8 +191,8 @@ def rule_to_expr(cls, rule: SChainRule) -> list: expr.append( { 'match': { - 'left': {'payload': {'protocol': 'ip', 'field': 'saddr'}}, 'op': '==', + 'left': {'payload': {'protocol': 'ip', 'field': 'saddr'}}, 'right': {'range': [f'{rule.first_ip}', f'{rule.last_ip}']}, } } @@ -207,14 +202,17 @@ def rule_to_expr(cls, rule: SChainRule) -> list: expr.append( { 'match': { - 'left': {'payload': {'protocol': 'tcp', 'field': 'dport'}}, 'op': '==', + 'left': {'payload': {'protocol': 'tcp', 'field': 'dport'}}, 'right': rule.port, } } ) - expr.extend([{'counter': None}, {'accept': None}]) + if counter: + expr.append({'counter': None}) + + expr.append({'accept': None}) return expr @classmethod @@ -237,16 +235,9 @@ def expr_to_rule(self, expr: list) -> None: if any([port, first_ip, last_ip]): return SChainRule(port=port, first_ip=first_ip, last_ip=last_ip) - @classmethod - def expr_equals(cls, expr_a: list[dict], expr_b: list[dict]) -> bool: - for item_a, item_b in zip(sorted(expr_a), sorted(expr_b)): - if 'counter' not in item_a and item_a != item_b: - return False - return True - def remove_rule(self, rule: SChainRule) -> None: if self.has_rule(rule): - expr = self.rule_to_expr(rule) + expr = self.rule_to_expr(rule, counter=False) output = None rc, output, error = self.run_cmd(f'list chain {self.FAMILY} {self.table} {self.chain}') @@ -255,15 +246,14 @@ def remove_rule(self, rule: SChainRule) -> None: current_rules = json.loads(output) - logger.info('HERE HERE %s', expr) - logger.info('HERE current rules %s', current_rules) handle = None for item in current_rules.get('nftables', []): if 'rule' in item: rule_data = item['rule'] - logger.info('HERE HERE 2 %s', rule_data['expr']) - logger.info('HERE HERE 3 %s', expr) - if self.expr_equals(rule_data.get('expr'), expr): + rule_expr = list( + filter(lambda statement: 'counter' not in statement, rule_data['expr']) + ) + if expr == rule_expr: handle = rule_data.get('handle') break diff --git a/tests/firewall/nftables_test.py b/tests/firewall/nftables_test.py index d11a2d6ff..06dc1d52f 100644 --- a/tests/firewall/nftables_test.py +++ b/tests/firewall/nftables_test.py @@ -17,12 +17,12 @@ def nf_test_tables(): @pytest.fixture def filter_table(nf_test_tables): - print(nf_test_tables.cmd('add table inet filter')) + print(nf_test_tables.cmd('add table inet firewall')) @pytest.fixture def custom_chain(nf_test_tables, filter_table): - nf_test_tables.cmd('add chain inet filter test-chain') + nf_test_tables.cmd('add chain inet firewall test-chain') return 'test-chain' @@ -35,7 +35,7 @@ def test_nftables_controller(custom_chain): assert nft_controller.has_rule(rule_a) assert nft_controller.has_rule(rule_b) rules = list(nft_controller.rules) - assert rules == sorted([rule_b, rule_a]) + assert sorted(rules) == sorted([rule_b, rule_a]), (rules, sorted([rule_b, rule_a])) nft_controller.remove_rule(rule_a) assert not nft_controller.has_rule(rule_a) assert nft_controller.has_rule(rule_b) From 5c16c2d8d9b2f0f450ac90e5402c0333849c256c Mon Sep 17 00:00:00 2001 From: badrogger Date: Mon, 16 Dec 2024 19:56:32 +0000 Subject: [PATCH 11/17] Fix chain creation --- core/schains/firewall/nftables.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/core/schains/firewall/nftables.py b/core/schains/firewall/nftables.py index 0404a0f21..669bd0273 100644 --- a/core/schains/firewall/nftables.py +++ b/core/schains/firewall/nftables.py @@ -34,8 +34,9 @@ logger = logging.getLogger(__name__) + TABLE = 'firewall' -CHAIN = 'INPUT' +CHAIN = 'skale' class NFTablesCmdFailedError(Exception): @@ -48,7 +49,7 @@ class NFTablesController(IHostFirewallController): def __init__(self, table: str = TABLE, chain: str = CHAIN) -> None: self.table = table - self.chain = chain + self.chain = f'skale-{chain}' self._nftables = importlib.import_module('nftables') self.nft = self._nftables.Nftables() self.nft.set_json_output(True) @@ -95,7 +96,8 @@ def add_schain_drop_rule(self, first_port: int, last_port: int) -> None: def create_chain(self, first_port: int, last_port: int) -> None: if not self.has_chain(self.chain): - return self.run_json_cmd( + logger.info('Creating chain %s', self.chain) + self.run_json_cmd( self._compose_json( [ { From 319b806e7bb70e9e23c579707231b50750423ca1 Mon Sep 17 00:00:00 2001 From: badrogger Date: Tue, 17 Dec 2024 12:10:24 +0000 Subject: [PATCH 12/17] Fix nftables tests --- tests/firewall/nftables_test.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/firewall/nftables_test.py b/tests/firewall/nftables_test.py index 06dc1d52f..0ee564ec0 100644 --- a/tests/firewall/nftables_test.py +++ b/tests/firewall/nftables_test.py @@ -22,8 +22,9 @@ def filter_table(nf_test_tables): @pytest.fixture def custom_chain(nf_test_tables, filter_table): - nf_test_tables.cmd('add chain inet firewall test-chain') - return 'test-chain' + name = 'test-chain' + nf_test_tables.cmd('add chain inet firewall skale-{name}') + return name def test_nftables_controller(custom_chain): From d7d09c10d00421fe7af62ef1427488abbd0e78b6 Mon Sep 17 00:00:00 2001 From: badrogger Date: Tue, 17 Dec 2024 12:31:16 +0000 Subject: [PATCH 13/17] Fix nftables test --- tests/firewall/nftables_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/firewall/nftables_test.py b/tests/firewall/nftables_test.py index 0ee564ec0..4481265fd 100644 --- a/tests/firewall/nftables_test.py +++ b/tests/firewall/nftables_test.py @@ -23,7 +23,7 @@ def filter_table(nf_test_tables): @pytest.fixture def custom_chain(nf_test_tables, filter_table): name = 'test-chain' - nf_test_tables.cmd('add chain inet firewall skale-{name}') + nf_test_tables.cmd(f'add chain inet firewall skale-{name}') return name From 275c7d54be2fe9acd963752a3857023dd44b0b9c Mon Sep 17 00:00:00 2001 From: badrogger Date: Tue, 17 Dec 2024 19:54:59 +0000 Subject: [PATCH 14/17] Save rules backup after sync --- core/schains/firewall/firewall_manager.py | 5 +++++ core/schains/firewall/iptables.py | 3 +++ core/schains/firewall/nftables.py | 27 +++++++++++++++++++---- core/schains/firewall/types.py | 4 ++++ tests/utils.py | 3 +++ tools/configs/__init__.py | 2 ++ 6 files changed, 40 insertions(+), 4 deletions(-) diff --git a/core/schains/firewall/firewall_manager.py b/core/schains/firewall/firewall_manager.py index d16f71574..1393864bc 100644 --- a/core/schains/firewall/firewall_manager.py +++ b/core/schains/firewall/firewall_manager.py @@ -71,6 +71,11 @@ def update_rules(self, rules: Iterable[SChainRule]) -> None: rules_to_remove = actual_rules - expected_rules self.add_rules(rules_to_add) self.remove_rules(rules_to_remove) + self.save_rules() + + def save_rules(self) -> None: + """ Saves rules into persistent storage """ + self.host_controller.save_rules() def add_rules(self, rules: Iterable[SChainRule]) -> None: logger.debug('Adding rules %s', rules) diff --git a/core/schains/firewall/iptables.py b/core/schains/firewall/iptables.py index 1d28c4037..fbe1b55f4 100644 --- a/core/schains/firewall/iptables.py +++ b/core/schains/firewall/iptables.py @@ -139,3 +139,6 @@ def from_ip_network(cls, ip: str) -> str: @classmethod def to_ip_network(cls, ip: str) -> str: return str(ipaddress.ip_network(ip)) + + def save_rules(self): + raise NotImplementedError('save_rules is not implemented for iptables host controller') diff --git a/core/schains/firewall/nftables.py b/core/schains/firewall/nftables.py index 669bd0273..57551ce5b 100644 --- a/core/schains/firewall/nftables.py +++ b/core/schains/firewall/nftables.py @@ -18,16 +18,17 @@ # along with this program. If not, see . -import logging import importlib import ipaddress +import json +import logging import multiprocessing -from typing import Iterable +import os +from typing import Iterable, TypeVar from core.schains.firewall.types import IHostFirewallController, SChainRule -from typing import TypeVar -import json +from tools.configs import NFT_CHAIN_BASE_PATH T = TypeVar('T') @@ -315,3 +316,21 @@ def from_ip_network(cls, ip: str) -> str: @classmethod def to_ip_network(cls, ip: str) -> str: return str(ipaddress.ip_network(ip)) + + def get_plain_chain_rules(self) -> str: + self.nft.set_json_output(False) + output = '' + try: + rc, output, error = self.run_cmd(f'list chain {self.FAMILY} {self.table} {self.chain}') + if rc != 0: + raise NFTablesCmdFailedError(f"Failed to get table content: {error}") + finally: + self.nft.set_json_output(True) + + return output + + def save_rules(self) -> None: + chain_rules = self.get_plain_chain_rules() + nft_chain_path = os.path.join(NFT_CHAIN_BASE_PATH, f'{self.chain}.conf') + with open(nft_chain_path, 'w') as nft_chain_file: + nft_chain_file.write(chain_rules) diff --git a/core/schains/firewall/types.py b/core/schains/firewall/types.py index 65ba8885d..0d25e6d7f 100644 --- a/core/schains/firewall/types.py +++ b/core/schains/firewall/types.py @@ -88,6 +88,10 @@ def rules(self) -> Iterable[SChainRule]: # pragma: no cover def has_rule(self, rule: SChainRule) -> bool: # pragma: no cover pass + @abstractmethod + def save_rules(self) -> None: # pragma: no cover + pass + class IFirewallManager(ABC): @property diff --git a/tests/utils.py b/tests/utils.py index dc33bf91b..e7fa56881 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -222,6 +222,9 @@ def rules(self): def has_rule(self, srule): return srule in self._rules + def save_rules(self): + pass + class SChainTestFirewallManager(SChainFirewallManager): def create_host_controller(self): diff --git a/tools/configs/__init__.py b/tools/configs/__init__.py index 4794de043..e1b0e053f 100644 --- a/tools/configs/__init__.py +++ b/tools/configs/__init__.py @@ -106,3 +106,5 @@ SYNC_NODE = os.getenv('SYNC_NODE') == 'True' DOCKER_NODE_CONFIG_FILEPATH = os.path.join(NODE_DATA_PATH, 'docker.json') + +NFT_CHAIN_BASE_PATH = '/etc/nft.conf.d/chains' From 517a3a858bb8c20fd4242252018d431df89b9214 Mon Sep 17 00:00:00 2001 From: badrogger Date: Wed, 18 Dec 2024 12:03:19 +0000 Subject: [PATCH 15/17] Fix tests --- .../firewall/default_rule_controller_test.py | 21 +++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/tests/firewall/default_rule_controller_test.py b/tests/firewall/default_rule_controller_test.py index 0eb4668a3..c21489a8f 100644 --- a/tests/firewall/default_rule_controller_test.py +++ b/tests/firewall/default_rule_controller_test.py @@ -1,7 +1,10 @@ +import concurrent.futures import mock +import os +import shutil + import pytest -import concurrent.futures from skale.schain_config import PORTS_PER_SCHAIN # noqa @@ -21,7 +24,17 @@ def refresh(): run_cmd(['nft', 'flush', 'ruleset']) -def test_get_default_rule_controller(): +@pytest.fixture() +def nft_chain_folder(): + path = '/etc/nft.conf.d/chains' + try: + os.makedirs(path) + yield path + finally: + shutil.rmtree(path) + + +def test_get_default_rule_controller(nft_chain_folder): own_ip = '3.3.3.3' node_ips = ['1.1.1.1', '2.2.2.2', '3.3.3.3', '4.4.4.4'] base_port = 10064 @@ -175,7 +188,7 @@ def run_concurrent_rc_syncing( @pytest.mark.parametrize('attempt', range(5)) -def test_concurrent_rc_behavior_no_refresh(attempt): +def test_concurrent_rc_behavior_no_refresh(attempt, nft_chain_folder): node_number = 16 schain_number = 8 own_ip = '1.1.1.1' @@ -191,7 +204,7 @@ def test_concurrent_rc_behavior_no_refresh(attempt): @pytest.mark.parametrize('attempt', range(5)) -def test_concurrent_rc_behavior_with_refresh(attempt, refresh): +def test_concurrent_rc_behavior_with_refresh(attempt, refresh, nft_chain_folder): node_number = 16 schain_number = 8 own_ip = '1.1.1.1' From 0aacca37e28b5ece4e282c04b517adede9f727ce Mon Sep 17 00:00:00 2001 From: badrogger Date: Thu, 26 Dec 2024 12:59:41 +0000 Subject: [PATCH 16/17] Cleanup nftables chain after schain removal --- core/schains/firewall/firewall_manager.py | 1 + core/schains/firewall/iptables.py | 5 ++++- core/schains/firewall/nftables.py | 22 ++++++++++++++++++++++ core/schains/firewall/types.py | 4 ++++ 4 files changed, 31 insertions(+), 1 deletion(-) diff --git a/core/schains/firewall/firewall_manager.py b/core/schains/firewall/firewall_manager.py index 1393864bc..5b2f76407 100644 --- a/core/schains/firewall/firewall_manager.py +++ b/core/schains/firewall/firewall_manager.py @@ -89,6 +89,7 @@ def remove_rules(self, rules: Iterable[SChainRule]) -> None: def flush(self) -> None: self.remove_rules(self.rules) + self.host_controller.cleanup() class IptablesSChainFirewallManager(SChainFirewallManager): diff --git a/core/schains/firewall/iptables.py b/core/schains/firewall/iptables.py index fbe1b55f4..589250d68 100644 --- a/core/schains/firewall/iptables.py +++ b/core/schains/firewall/iptables.py @@ -140,5 +140,8 @@ def from_ip_network(cls, ip: str) -> str: def to_ip_network(cls, ip: str) -> str: return str(ipaddress.ip_network(ip)) - def save_rules(self): + def save_rules(self) -> None: raise NotImplementedError('save_rules is not implemented for iptables host controller') + + def cleanup(self) -> None: + raise NotImplementedError('cleanup is not implemented for iptables host controller') diff --git a/core/schains/firewall/nftables.py b/core/schains/firewall/nftables.py index 57551ce5b..d0ff2dd2a 100644 --- a/core/schains/firewall/nftables.py +++ b/core/schains/firewall/nftables.py @@ -119,6 +119,25 @@ def create_chain(self, first_port: int, last_port: int) -> None: ) self.add_schain_drop_rule(first_port, last_port) + def delete_chain(self) -> None: + if self.has_chain(self.chain): + logger.info('Removing chain %s', self.chain) + self.run_json_cmd( + self._compose_json( + [ + { + 'delete': { + 'chain': { + 'family': self.FAMILY, + 'table': self.table, + 'name': self.chain + } + } + } + ] + ) + ) + @property def chains(self) -> list[dict]: output = self.run_cmd('list chains') @@ -334,3 +353,6 @@ def save_rules(self) -> None: nft_chain_path = os.path.join(NFT_CHAIN_BASE_PATH, f'{self.chain}.conf') with open(nft_chain_path, 'w') as nft_chain_file: nft_chain_file.write(chain_rules) + + def cleanup(self) -> None: + self.delete_chain() diff --git a/core/schains/firewall/types.py b/core/schains/firewall/types.py index 0d25e6d7f..0062cccec 100644 --- a/core/schains/firewall/types.py +++ b/core/schains/firewall/types.py @@ -92,6 +92,10 @@ def has_rule(self, rule: SChainRule) -> bool: # pragma: no cover def save_rules(self) -> None: # pragma: no cover pass + @abstractmethod + def cleanup(self) -> None: # pragma: no cover + pass + class IFirewallManager(ABC): @property From 80f773b0eb3b1948aba59aa6aa8f4017c6ac1b19 Mon Sep 17 00:00:00 2001 From: badrogger Date: Thu, 26 Dec 2024 15:28:26 +0000 Subject: [PATCH 17/17] Fix tests --- tests/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/utils.py b/tests/utils.py index e7fa56881..014c054f8 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -225,6 +225,9 @@ def has_rule(self, srule): def save_rules(self): pass + def cleanup(self): + pass + class SChainTestFirewallManager(SChainFirewallManager): def create_host_controller(self):