diff --git a/core/schains/firewall/nftables.py b/core/schains/firewall/nftables.py index 36d98543b..0404a0f21 100644 --- a/core/schains/firewall/nftables.py +++ b/core/schains/firewall/nftables.py @@ -64,20 +64,15 @@ def create_table(self) -> None: def add_schain_drop_rule(self, first_port: int, last_port: int) -> None: expr = [ - { - "match": { - "left": { - "payload": { - "protocol": "tcp", - "field": "dport" + { + 'match': { + 'op': '==', + 'left': {'payload': {'protocol': 'tcp', 'field': 'dport'}}, + 'right': {'range': [first_port, last_port]}, } - }, - "op": "==", - "right": {'range': [first_port, last_port]} - } - }, - {'counter': None}, - {"drop": None} + }, + {'counter': None}, + {'drop': None}, ] if self.expr_to_rule(expr) not in self.get_rules_by_policy(policy='drop'): @@ -178,7 +173,7 @@ def add_rule(self, rule: SChainRule) -> None: raise NFTablesCmdFailedError(f'Failed to add allow rule: {error}') @classmethod - def rule_to_expr(cls, rule: SChainRule) -> list: + def rule_to_expr(cls, rule: SChainRule, counter: bool = True) -> list: expr = [] if rule.first_ip: @@ -186,8 +181,8 @@ def rule_to_expr(cls, rule: SChainRule) -> list: expr.append( { 'match': { - 'left': {'payload': {'protocol': 'ip', 'field': 'saddr'}}, 'op': '==', + 'left': {'payload': {'protocol': 'ip', 'field': 'saddr'}}, 'right': f'{rule.first_ip}', } } @@ -196,8 +191,8 @@ def rule_to_expr(cls, rule: SChainRule) -> list: expr.append( { 'match': { - 'left': {'payload': {'protocol': 'ip', 'field': 'saddr'}}, 'op': '==', + 'left': {'payload': {'protocol': 'ip', 'field': 'saddr'}}, 'right': {'range': [f'{rule.first_ip}', f'{rule.last_ip}']}, } } @@ -207,14 +202,17 @@ def rule_to_expr(cls, rule: SChainRule) -> list: expr.append( { 'match': { - 'left': {'payload': {'protocol': 'tcp', 'field': 'dport'}}, 'op': '==', + 'left': {'payload': {'protocol': 'tcp', 'field': 'dport'}}, 'right': rule.port, } } ) - expr.extend([{'counter': None}, {'accept': None}]) + if counter: + expr.append({'counter': None}) + + expr.append({'accept': None}) return expr @classmethod @@ -237,16 +235,9 @@ def expr_to_rule(self, expr: list) -> None: if any([port, first_ip, last_ip]): return SChainRule(port=port, first_ip=first_ip, last_ip=last_ip) - @classmethod - def expr_equals(cls, expr_a: list[dict], expr_b: list[dict]) -> bool: - for item_a, item_b in zip(sorted(expr_a), sorted(expr_b)): - if 'counter' not in item_a and item_a != item_b: - return False - return True - def remove_rule(self, rule: SChainRule) -> None: if self.has_rule(rule): - expr = self.rule_to_expr(rule) + expr = self.rule_to_expr(rule, counter=False) output = None rc, output, error = self.run_cmd(f'list chain {self.FAMILY} {self.table} {self.chain}') @@ -255,15 +246,14 @@ def remove_rule(self, rule: SChainRule) -> None: current_rules = json.loads(output) - logger.info('HERE HERE %s', expr) - logger.info('HERE current rules %s', current_rules) handle = None for item in current_rules.get('nftables', []): if 'rule' in item: rule_data = item['rule'] - logger.info('HERE HERE 2 %s', rule_data['expr']) - logger.info('HERE HERE 3 %s', expr) - if self.expr_equals(rule_data.get('expr'), expr): + rule_expr = list( + filter(lambda statement: 'counter' not in statement, rule_data['expr']) + ) + if expr == rule_expr: handle = rule_data.get('handle') break diff --git a/tests/firewall/nftables_test.py b/tests/firewall/nftables_test.py index d11a2d6ff..06dc1d52f 100644 --- a/tests/firewall/nftables_test.py +++ b/tests/firewall/nftables_test.py @@ -17,12 +17,12 @@ def nf_test_tables(): @pytest.fixture def filter_table(nf_test_tables): - print(nf_test_tables.cmd('add table inet filter')) + print(nf_test_tables.cmd('add table inet firewall')) @pytest.fixture def custom_chain(nf_test_tables, filter_table): - nf_test_tables.cmd('add chain inet filter test-chain') + nf_test_tables.cmd('add chain inet firewall test-chain') return 'test-chain' @@ -35,7 +35,7 @@ def test_nftables_controller(custom_chain): assert nft_controller.has_rule(rule_a) assert nft_controller.has_rule(rule_b) rules = list(nft_controller.rules) - assert rules == sorted([rule_b, rule_a]) + assert sorted(rules) == sorted([rule_b, rule_a]), (rules, sorted([rule_b, rule_a])) nft_controller.remove_rule(rule_a) assert not nft_controller.has_rule(rule_a) assert nft_controller.has_rule(rule_b)