diff --git a/Dockerfile b/Dockerfile index 0a2c24bb..540b256f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,6 +1,6 @@ FROM python:3.11-bookworm -RUN apt-get update && apt-get install -y wget git libxslt-dev iptables kmod swig +RUN apt-get update && apt-get install -y wget git libxslt-dev iptables kmod swig nftables python3-nftables RUN mkdir /usr/src/admin WORKDIR /usr/src/admin @@ -8,12 +8,13 @@ WORKDIR /usr/src/admin COPY requirements.txt ./ COPY requirements-dev.txt ./ -RUN pip3 install --no-cache-dir -r requirements.txt +RUN pip3 install -r requirements.txt COPY . . RUN update-alternatives --set iptables /usr/sbin/iptables-legacy && \ update-alternatives --set ip6tables /usr/sbin/ip6tables-legacy -ENV PYTHONPATH="/usr/src/admin" +ENV PYTHONPATH="/usr/src/admin":/usr/lib/python3/dist-packages/ + ENV COLUMNS=80 diff --git a/core/schains/firewall/__init__.py b/core/schains/firewall/__init__.py index 8edbd1a7..1bba60b7 100644 --- a/core/schains/firewall/__init__.py +++ b/core/schains/firewall/__init__.py @@ -19,6 +19,7 @@ from .firewall_manager import SChainFirewallManager # noqa from .iptables import IptablesController # noqa +from .nftables import NFTablesController # noqa from .rule_controller import SChainRuleController # noqa from .types import IRuleController # noqa from .utils import get_default_rule_controller # noqa diff --git a/core/schains/firewall/firewall_manager.py b/core/schains/firewall/firewall_manager.py index e2216fc7..5b2f7640 100644 --- a/core/schains/firewall/firewall_manager.py +++ b/core/schains/firewall/firewall_manager.py @@ -22,6 +22,7 @@ from typing import Iterable, Optional from core.schains.firewall.iptables import IptablesController +from core.schains.firewall.nftables import NFTablesController from core.schains.firewall.types import ( IFirewallManager, IHostFirewallController, @@ -70,6 +71,11 @@ def update_rules(self, rules: Iterable[SChainRule]) -> None: rules_to_remove = actual_rules - expected_rules self.add_rules(rules_to_add) self.remove_rules(rules_to_remove) + self.save_rules() + + def save_rules(self) -> None: + """ Saves rules into persistent storage """ + self.host_controller.save_rules() def add_rules(self, rules: Iterable[SChainRule]) -> None: logger.debug('Adding rules %s', rules) @@ -83,8 +89,17 @@ def remove_rules(self, rules: Iterable[SChainRule]) -> None: def flush(self) -> None: self.remove_rules(self.rules) + self.host_controller.cleanup() class IptablesSChainFirewallManager(SChainFirewallManager): def create_host_controller(self) -> IptablesController: return IptablesController() + + +class NFTSchainFirewallManager(SChainFirewallManager): + def create_host_controller(self) -> NFTablesController: + nc_controller = NFTablesController(chain=self.name) + nc_controller.create_table() + nc_controller.create_chain(self.first_port, self.last_port) + return nc_controller diff --git a/core/schains/firewall/iptables.py b/core/schains/firewall/iptables.py index 1d28c403..589250d6 100644 --- a/core/schains/firewall/iptables.py +++ b/core/schains/firewall/iptables.py @@ -139,3 +139,9 @@ def from_ip_network(cls, ip: str) -> str: @classmethod def to_ip_network(cls, ip: str) -> str: return str(ipaddress.ip_network(ip)) + + def save_rules(self) -> None: + raise NotImplementedError('save_rules is not implemented for iptables host controller') + + def cleanup(self) -> None: + raise NotImplementedError('cleanup is not implemented for iptables host controller') diff --git a/core/schains/firewall/nftables.py b/core/schains/firewall/nftables.py new file mode 100644 index 00000000..d0ff2dd2 --- /dev/null +++ b/core/schains/firewall/nftables.py @@ -0,0 +1,358 @@ +# -*- coding: utf-8 -*- +# +# This file is part of SKALE Admin +# +# Copyright (C) 2024 SKALE Labs +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + + +import importlib +import ipaddress +import json +import logging +import multiprocessing +import os +from typing import Iterable, TypeVar + +from core.schains.firewall.types import IHostFirewallController, SChainRule + +from tools.configs import NFT_CHAIN_BASE_PATH + +T = TypeVar('T') + + +logger = logging.getLogger(__name__) + + +TABLE = 'firewall' +CHAIN = 'skale' + + +class NFTablesCmdFailedError(Exception): + pass + + +class NFTablesController(IHostFirewallController): + plock = multiprocessing.Lock() + FAMILY = 'inet' + + def __init__(self, table: str = TABLE, chain: str = CHAIN) -> None: + self.table = table + self.chain = f'skale-{chain}' + self._nftables = importlib.import_module('nftables') + self.nft = self._nftables.Nftables() + self.nft.set_json_output(True) + + def _compose_json(self, commands: list[dict]) -> dict: + json_cmd = {'nftables': commands} + self.nft.json_validate(json_cmd) + return json_cmd + + def create_table(self) -> None: + if not self.has_table(self.table): + return self.run_cmd(f'add table inet {self.table}') + + def add_schain_drop_rule(self, first_port: int, last_port: int) -> None: + expr = [ + { + 'match': { + 'op': '==', + 'left': {'payload': {'protocol': 'tcp', 'field': 'dport'}}, + 'right': {'range': [first_port, last_port]}, + } + }, + {'counter': None}, + {'drop': None}, + ] + + if self.expr_to_rule(expr) not in self.get_rules_by_policy(policy='drop'): + cmd = { + 'nftables': [ + { + 'add': { + 'rule': { + 'family': self.FAMILY, + 'table': self.table, + 'chain': self.chain, + 'expr': expr, + } + } + } + ] + } + self.run_json_cmd(cmd) + logger.info('Added drop rule for chain %s', self.chain) + + def create_chain(self, first_port: int, last_port: int) -> None: + if not self.has_chain(self.chain): + logger.info('Creating chain %s', self.chain) + self.run_json_cmd( + self._compose_json( + [ + { + 'add': { + 'chain': { + 'family': self.FAMILY, + 'table': self.table, + 'name': self.chain, + 'hook': 'input', + 'type': 'filter', + 'prio': 0, + 'policy': 'accept', + } + } + } + ] + ) + ) + self.add_schain_drop_rule(first_port, last_port) + + def delete_chain(self) -> None: + if self.has_chain(self.chain): + logger.info('Removing chain %s', self.chain) + self.run_json_cmd( + self._compose_json( + [ + { + 'delete': { + 'chain': { + 'family': self.FAMILY, + 'table': self.table, + 'name': self.chain + } + } + } + ] + ) + ) + + @property + def chains(self) -> list[dict]: + output = self.run_cmd('list chains') + if output[0] != 0: + raise NFTablesCmdFailedError(output) + parsed = json.loads(output[1])['nftables'] + return [record['chain']['name'] for record in parsed if 'chain' in record] + + @property + def tables(self) -> list[dict]: + output = self.run_cmd('list tables') + if output[0] != 0: + raise NFTablesCmdFailedError(output) + parsed = json.loads(output[1])['nftables'] + return [record['table']['name'] for record in parsed if 'table' in record] + + def run_json_cmd(self, cmd: dict) -> tuple: + logger.debug('NFTables json cmd %s', cmd) + with self.plock: + return self.nft.json_cmd(cmd) + + def run_cmd(self, cmd: str) -> tuple: + logger.debug('NFTables cmd %s', cmd) + with self.plock: + return self.nft.cmd(cmd) + + def has_chain(self, chain: str) -> bool: + return chain in self.chains + + def has_table(self, table: str) -> bool: + return table in self.tables + + def add_rule(self, rule: SChainRule) -> None: + if self.has_rule(rule): + return + expr = self.rule_to_expr(rule) + + json_cmd = self._compose_json( + [ + { + 'insert': { + 'rule': { + 'family': self.FAMILY, + 'table': self.table, + 'chain': self.chain, + 'expr': expr, + } + } + } + ] + ) + + rc, output, error = self.run_json_cmd(json_cmd) + if rc != 0: + raise NFTablesCmdFailedError(f'Failed to add allow rule: {error}') + + @classmethod + def rule_to_expr(cls, rule: SChainRule, counter: bool = True) -> list: + expr = [] + + if rule.first_ip: + if rule.last_ip == rule.first_ip: + expr.append( + { + 'match': { + 'op': '==', + 'left': {'payload': {'protocol': 'ip', 'field': 'saddr'}}, + 'right': f'{rule.first_ip}', + } + } + ) + else: + expr.append( + { + 'match': { + 'op': '==', + 'left': {'payload': {'protocol': 'ip', 'field': 'saddr'}}, + 'right': {'range': [f'{rule.first_ip}', f'{rule.last_ip}']}, + } + } + ) + + if rule.port: + expr.append( + { + 'match': { + 'op': '==', + 'left': {'payload': {'protocol': 'tcp', 'field': 'dport'}}, + 'right': rule.port, + } + } + ) + + if counter: + expr.append({'counter': None}) + + expr.append({'accept': None}) + return expr + + @classmethod + def expr_to_rule(self, expr: list) -> None: + port, first_ip, last_ip = None, None, None + for item in expr: + if 'match' in item: + match = item['match'] + + if match.get('left', {}).get('payload', {}).get('field') == 'dport': + port = match.get('right') + + if match.get('left', {}).get('payload', {}).get('field') == 'saddr': + right = match.get('right') + if isinstance(right, str): + first_ip = right + else: + first_ip, last_ip = right['range'] + + if any([port, first_ip, last_ip]): + return SChainRule(port=port, first_ip=first_ip, last_ip=last_ip) + + def remove_rule(self, rule: SChainRule) -> None: + if self.has_rule(rule): + expr = self.rule_to_expr(rule, counter=False) + + output = None + rc, output, error = self.run_cmd(f'list chain {self.FAMILY} {self.table} {self.chain}') + if rc != 0: + raise NFTablesCmdFailedError(f'Failed to list rules: {error}') + + current_rules = json.loads(output) + + handle = None + for item in current_rules.get('nftables', []): + if 'rule' in item: + rule_data = item['rule'] + rule_expr = list( + filter(lambda statement: 'counter' not in statement, rule_data['expr']) + ) + if expr == rule_expr: + handle = rule_data.get('handle') + break + + if handle is None: + raise NFTablesCmdFailedError('Rule not found') + + json_cmd = self._compose_json( + [ + { + 'delete': { + 'rule': { + 'family': self.FAMILY, + 'table': self.table, + 'chain': self.chain, + 'handle': handle, + } + } + } + ] + ) + + rc, output, error = self.run_json_cmd(json_cmd) + if rc != 0: + raise NFTablesCmdFailedError(f'Failed to delete rule: {error}') + + @property # type: ignore + def rules(self) -> Iterable[SChainRule]: + return self.get_rules_by_policy(policy='accept') + + def has_rule(self, rule: SChainRule) -> bool: + return rule in self.rules + + def get_rules_by_policy(self, policy: str) -> list[SChainRule]: + output = None + rc, output, error = self.run_cmd(f'list chain {self.FAMILY} {self.table} {self.chain}') + if output == '': + return [] + + data = json.loads(output) + rules = [] + + for item in data.get('nftables', []): + if 'rule' in item: + plain_rule = item['rule'] + expr = plain_rule.get('expr', []) + if {policy: None} in expr: + rule = self.expr_to_rule(expr) + if rule: + rules.append(rule) + logger.debug('Rules for policy %s: %s', policy, rules) + return rules + + @classmethod + def from_ip_network(cls, ip: str) -> str: + return str(ipaddress.ip_network(ip).hosts()[0]) + + @classmethod + def to_ip_network(cls, ip: str) -> str: + return str(ipaddress.ip_network(ip)) + + def get_plain_chain_rules(self) -> str: + self.nft.set_json_output(False) + output = '' + try: + rc, output, error = self.run_cmd(f'list chain {self.FAMILY} {self.table} {self.chain}') + if rc != 0: + raise NFTablesCmdFailedError(f"Failed to get table content: {error}") + finally: + self.nft.set_json_output(True) + + return output + + def save_rules(self) -> None: + chain_rules = self.get_plain_chain_rules() + nft_chain_path = os.path.join(NFT_CHAIN_BASE_PATH, f'{self.chain}.conf') + with open(nft_chain_path, 'w') as nft_chain_file: + nft_chain_file.write(chain_rules) + + def cleanup(self) -> None: + self.delete_chain() diff --git a/core/schains/firewall/rule_controller.py b/core/schains/firewall/rule_controller.py index 51e8920a..3b63026b 100644 --- a/core/schains/firewall/rule_controller.py +++ b/core/schains/firewall/rule_controller.py @@ -23,7 +23,7 @@ from functools import wraps from typing import Any, Callable, cast, Dict, Iterable, List, Optional, TypeVar -from .firewall_manager import IptablesSChainFirewallManager +from .firewall_manager import IptablesSChainFirewallManager, NFTSchainFirewallManager from .types import ( IFirewallManager, IpRange, @@ -214,3 +214,13 @@ def create_firewall_manager(self) -> IptablesSChainFirewallManager: self.base_port, # type: ignore self.base_port + self.ports_per_schain - 1 # type: ignore ) + + +class NFTSchainRuleController(SChainRuleController): + @configured_only + def create_firewall_manager(self) -> NFTSchainFirewallManager: + return NFTSchainFirewallManager( + self.name, + self.base_port, # type: ignore + self.base_port + self.ports_per_schain - 1 # type: ignore + ) diff --git a/core/schains/firewall/types.py b/core/schains/firewall/types.py index 65ba8885..0062ccce 100644 --- a/core/schains/firewall/types.py +++ b/core/schains/firewall/types.py @@ -88,6 +88,14 @@ def rules(self) -> Iterable[SChainRule]: # pragma: no cover def has_rule(self, rule: SChainRule) -> bool: # pragma: no cover pass + @abstractmethod + def save_rules(self) -> None: # pragma: no cover + pass + + @abstractmethod + def cleanup(self) -> None: # pragma: no cover + pass + class IFirewallManager(ABC): @property diff --git a/core/schains/firewall/utils.py b/core/schains/firewall/utils.py index 737361e1..1f94694f 100644 --- a/core/schains/firewall/utils.py +++ b/core/schains/firewall/utils.py @@ -25,7 +25,7 @@ from skale import Skale from .types import IpRange -from .rule_controller import IptablesSChainRuleController +from .rule_controller import IptablesSChainRuleController, NFTSchainRuleController logger = logging.getLogger(__name__) @@ -37,6 +37,22 @@ def get_default_rule_controller( own_ip: Optional[str] = None, node_ips: List[str] = [], sync_agent_ranges: Optional[List[IpRange]] = [] +) -> IptablesSChainRuleController: + return get_nftables_rule_controller( + name=name, + base_port=base_port, + own_ip=own_ip, + node_ips=node_ips, + sync_agent_ranges=sync_agent_ranges + ) + + +def get_iptables_rule_controller( + name: str, + base_port: Optional[int] = None, + own_ip: Optional[str] = None, + node_ips: List[str] = [], + sync_agent_ranges: Optional[List[IpRange]] = [] ) -> IptablesSChainRuleController: sync_agent_ranges = sync_agent_ranges or [] logger.info('Creating rule controller for %s', name) @@ -50,6 +66,25 @@ def get_default_rule_controller( ) +def get_nftables_rule_controller( + name: str, + base_port: Optional[int] = None, + own_ip: Optional[str] = None, + node_ips: List[str] = [], + sync_agent_ranges: Optional[List[IpRange]] = [] +) -> NFTSchainRuleController: + sync_agent_ranges = sync_agent_ranges or [] + logger.info('Creating rule controller for %s', name) + logger.debug('Rule controller ranges for %s: %s', name, sync_agent_ranges) + return NFTSchainRuleController( + name=name, + base_port=base_port, + own_ip=own_ip, + node_ips=node_ips, + sync_ip_ranges=sync_agent_ranges + ) + + def get_sync_agent_ranges(skale: Skale) -> List[IpRange]: sync_agent_ranges = [] rnum = skale.sync_manager.get_ip_ranges_number() diff --git a/tests.Dockerfile b/tests.Dockerfile index b31db00e..75b72f63 100644 --- a/tests.Dockerfile +++ b/tests.Dockerfile @@ -1,3 +1,7 @@ FROM admin:base -RUN pip3 install --no-cache-dir -r requirements-dev.txt +RUN apt update && apt install -y nftables python3-nftables + +RUN pip3 install -r requirements-dev.txt + +ENV PYTHONPATH=${PYTHONPATH}:/usr/lib/python3/dist-packages/ diff --git a/tests/firewall/default_rule_controller_test.py b/tests/firewall/default_rule_controller_test.py index c2473e16..c21489a8 100644 --- a/tests/firewall/default_rule_controller_test.py +++ b/tests/firewall/default_rule_controller_test.py @@ -1,26 +1,40 @@ + +import concurrent.futures import mock +import os +import shutil + import pytest -import concurrent.futures from skale.schain_config import PORTS_PER_SCHAIN # noqa -from core.schains.firewall import IptablesController +from core.schains.firewall import NFTablesController from core.schains.firewall.utils import get_default_rule_controller from core.schains.firewall.types import IpRange, SkaledPorts -from tests.firewall.iptables_test import get_rules_through_subprocess + from tools.helper import run_cmd @pytest.fixture def refresh(): - run_cmd(['iptables', '-F']) + run_cmd(['nft', 'flush', 'ruleset']) try: yield finally: - run_cmd(['iptables', '-F']) + run_cmd(['nft', 'flush', 'ruleset']) -def test_get_default_rule_controller(): +@pytest.fixture() +def nft_chain_folder(): + path = '/etc/nft.conf.d/chains' + try: + os.makedirs(path) + yield path + finally: + shutil.rmtree(path) + + +def test_get_default_rule_controller(nft_chain_folder): own_ip = '3.3.3.3' node_ips = ['1.1.1.1', '2.2.2.2', '3.3.3.3', '4.4.4.4'] base_port = 10064 @@ -58,22 +72,6 @@ def sync_rules(*args): return False -def parse_plain_rule(plain_rule): - first_ip, last_ip, port = None, None, None - pr = plain_rule.split() - if '--src-range' in pr: - srange = pr[11] - first_ip, last_ip = srange.split('-') - port = pr[7] - elif '-s' in pr: - first_ip = last_ip = pr[3][:-3] - port = pr[9] - elif '--dport' in pr: - port = pr[7] - - return first_ip, last_ip, int(port) - - def run_concurrent_rc_syncing( node_number, schain_number, @@ -147,59 +145,50 @@ def run_concurrent_rc_syncing( else: assert not r - pr = get_rules_through_subprocess(unique=False)[3:] - rules = [parse_plain_rule(r) for r in pr] + controllers = [NFTablesController(chain=name) for name in schain_names] + rules = [] + for controller in controllers: + rules.extend(controller.rules) + + print([r.port for r in rules]) + print([r.first_ip for r in rules]) - c = IptablesController() # Check that all ip rules are there for ip in node_ips: if ip != own_ip: assert sum( - map(lambda x: x[0] == ip, rules) - ) == 5 * schain_number, ip - assert sum( - map(lambda x: x.first_ip == ip, c.rules) + map(lambda x: x.first_ip == ip, rules) ) == 5 * schain_number, ip # Check that all internal ports rules are there except CATCHUP for p in internal_ports: - assert sum(map(lambda x: x[2] == p, rules)) == node_number - 1, p - assert sum(map(lambda x: x.port == p, c.rules)) == node_number - 1, p + assert sum(map(lambda x: x.port == p, rules)) == node_number - 1, p # Check CATCHUP rules including sync agents rules catchup_e_number = node_number + sync_agent_ranges_number - 1 for p in catchup_ports: - assert sum(map(lambda x: x[2] == p, rules)) == catchup_e_number, p - assert sum(map(lambda x: x.port == p, c.rules)) == catchup_e_number, p + assert sum(map(lambda x: x.port == p, rules)) == catchup_e_number, p # Check ZMQ rules including sync agents rules zmq_e_number = node_number + sync_agent_ranges_number - 1 for p in zmq_ports: - assert sum(map(lambda x: x[2] == p, rules)) == zmq_e_number, p - assert sum(map(lambda x: x.port == p, c.rules)) == zmq_e_number, p + assert sum(map(lambda x: x.port == p, rules)) == zmq_e_number, p # Check sync ip ranges rules for r in sync_agent_ranges: assert sum( - map(lambda x: x[0] == r.start_ip, rules) - ) == schain_number * 2, ip - assert sum( - map(lambda x: x.first_ip == r.start_ip, c.rules) - ) == schain_number * 2, ip - assert sum( - map(lambda x: x[1] == r.end_ip, rules) + map(lambda x: x.first_ip == r.start_ip, rules) ) == schain_number * 2, ip assert sum( - map(lambda x: x.last_ip == r.end_ip, c.rules) + map(lambda x: x.last_ip == r.end_ip, rules) ) == schain_number * 2, ip for port in public_ports: - assert sum(map(lambda x: x[2] == port, rules)) == 1, port - assert sum(map(lambda x: x.port == port, c.rules)) == 1, port + assert sum(map(lambda x: x.port == port, rules)) == 1, port @pytest.mark.parametrize('attempt', range(5)) -def test_concurrent_rc_behavior_no_refresh(attempt): +def test_concurrent_rc_behavior_no_refresh(attempt, nft_chain_folder): node_number = 16 schain_number = 8 own_ip = '1.1.1.1' @@ -215,7 +204,7 @@ def test_concurrent_rc_behavior_no_refresh(attempt): @pytest.mark.parametrize('attempt', range(5)) -def test_concurrent_rc_behavior_with_refresh(attempt, refresh): +def test_concurrent_rc_behavior_with_refresh(attempt, refresh, nft_chain_folder): node_number = 16 schain_number = 8 own_ip = '1.1.1.1' diff --git a/tests/firewall/nftables_test.py b/tests/firewall/nftables_test.py new file mode 100644 index 00000000..4481265f --- /dev/null +++ b/tests/firewall/nftables_test.py @@ -0,0 +1,105 @@ +import concurrent.futures +import importlib +import time + +import pytest + +from core.schains.firewall.nftables import NFTablesController +from core.schains.firewall.types import SChainRule + + +@pytest.fixture +def nf_test_tables(): + nft = importlib.import_module('nftables').Nftables() + nft.cmd('flush ruleset') + return nft + + +@pytest.fixture +def filter_table(nf_test_tables): + print(nf_test_tables.cmd('add table inet firewall')) + + +@pytest.fixture +def custom_chain(nf_test_tables, filter_table): + name = 'test-chain' + nf_test_tables.cmd(f'add chain inet firewall skale-{name}') + return name + + +def test_nftables_controller(custom_chain): + nft_controller = NFTablesController(chain='test-chain') + rule_a = SChainRule(10000, '1.1.1.1', '2.2.2.2') + rule_b = SChainRule(10001, '3.3.3.3') + nft_controller.add_rule(rule_a) + nft_controller.add_rule(rule_b) + assert nft_controller.has_rule(rule_a) + assert nft_controller.has_rule(rule_b) + rules = list(nft_controller.rules) + assert sorted(rules) == sorted([rule_b, rule_a]), (rules, sorted([rule_b, rule_a])) + nft_controller.remove_rule(rule_a) + assert not nft_controller.has_rule(rule_a) + assert nft_controller.has_rule(rule_b) + nft_controller.remove_rule(rule_b) + assert not nft_controller.has_rule(rule_a) + + +def test_nftables_controller_duplicates(custom_chain): + rule_a = SChainRule(10000, '1.1.1.1', '2.2.2.2') + manager = NFTablesController(chain='test-chain') + manager.add_rule(rule_a) + rule_b = SChainRule(10001, '3.3.3.3', '4.4.4.4') + manager.add_rule(rule_b) + assert sorted(list(manager.rules)) == sorted([ + SChainRule(port=10001, first_ip='3.3.3.3', last_ip='4.4.4.4'), + SChainRule(port=10000, first_ip='1.1.1.1', last_ip='2.2.2.2') + ]) + assert manager.has_rule(rule_b) + manager.add_rule(rule_b) + assert manager.has_rule(rule_b) + assert sorted(list(manager.rules)) == sorted([ + SChainRule(port=10001, first_ip='3.3.3.3', last_ip='4.4.4.4'), + SChainRule(port=10000, first_ip='1.1.1.1', last_ip='2.2.2.2') + ]) + manager.remove_rule(rule_b) + assert list(manager.rules) == [ + SChainRule(port=10000, first_ip='1.1.1.1', last_ip='2.2.2.2') + ] + + +def add_remove_rule(srule, refresh): + manager = NFTablesController() + manager.add_rule(srule) + time.sleep(1) + if not manager.has_rule(srule): + return False + time.sleep(1) + manager.remove_rule(srule) + return True + + +def generate_srules(number=5): + return [ + SChainRule( + 10000 + 1, + f'{i}.{i}.{i}.{i}', f'{i + 1}.{i + 1}.{i + 1}.{i + 1}' + ) + for i in range(1, number * 2, 2) + ] + + +def test_nftables_manager_parallel(custom_chain): + srules = generate_srules(number=12) + + futures = [] + with concurrent.futures.ProcessPoolExecutor(max_workers=12) as executor: + futures = [ + executor.submit(add_remove_rule, srule) + for srule in srules + ] + + for future in concurrent.futures.as_completed(futures): + assert future.result + manager = NFTablesController(custom_chain) + time.sleep(10) + assert len(list(manager.rules)) == 0 diff --git a/tests/utils.py b/tests/utils.py index dc33bf91..e7fa5688 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -222,6 +222,9 @@ def rules(self): def has_rule(self, srule): return srule in self._rules + def save_rules(self): + pass + class SChainTestFirewallManager(SChainFirewallManager): def create_host_controller(self): diff --git a/tools/configs/__init__.py b/tools/configs/__init__.py index 4794de04..e1b0e053 100644 --- a/tools/configs/__init__.py +++ b/tools/configs/__init__.py @@ -106,3 +106,5 @@ SYNC_NODE = os.getenv('SYNC_NODE') == 'True' DOCKER_NODE_CONFIG_FILEPATH = os.path.join(NODE_DATA_PATH, 'docker.json') + +NFT_CHAIN_BASE_PATH = '/etc/nft.conf.d/chains'