Skip to content

Commit

Permalink
Merge pull request #819 from skalenetwork/develop
Browse files Browse the repository at this point in the history
Fix migration
  • Loading branch information
DmytroNazarenko authored Jan 9, 2025
2 parents 413fa65 + 440b6af commit 34df742
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 11 deletions.
110 changes: 104 additions & 6 deletions node_cli/core/nftables.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def rule_exists(self, chain: str, new_rule_expr: list[dict]) -> bool:
return True
return False

def add_drop_rule_if_node_exists(self, protocol: str) -> None:
def add_drop_rule(self, protocol: str) -> None:
expr = [
{
"match": {
Expand Down Expand Up @@ -197,7 +197,46 @@ def add_drop_rule_if_node_exists(self, protocol: str) -> None:
self.execute_cmd(cmd)
logger.info('Added drop rule for %s', protocol)

def add_rule_if_not_exists(self, rule: Rule) -> None:
def remove_drop_rule(self, protocol: str) -> None:
expr = [
{
"match": {
"op": "==",
"left": {
"payload": {
"protocol": "ip",
"field": "protocol"
}
},
"right": protocol
}
},
{'counter': None},
{"drop": None}
]

# Check if the drop rule exists before attempting to remove it
if self.rule_exists(self.chain, expr):
cmd = {
'nftables': [
{
'delete': {
'rule': {
'family': self.family,
'table': self.table,
'chain': self.chain,
'expr': expr,
}
}
}
]
}
self.execute_cmd(cmd)
logger.info('Removed drop rule for %s', protocol)
else:
logger.info('Drop rule does not exist for %s', protocol)

def add_rule(self, rule: Rule) -> None:
expr = []

if rule.protocol in ['tcp', 'udp']:
Expand Down Expand Up @@ -249,6 +288,56 @@ def add_rule_if_not_exists(self, rule: Rule) -> None:
'Rule already exists in chain %s: %s port %s', rule.chain, rule.protocol, rule.port
)

def remove_rule(self, rule: Rule) -> None:
expr = []

if rule.protocol in ['tcp', 'udp']:
if rule.port:
expr.append(
{
'match': {
'left': {'payload': {'protocol': rule.protocol, 'field': 'dport'}},
'op': '==',
'right': rule.port,
}
}
)
elif rule.protocol == 'icmp' and rule.icmp_type:
expr.append(
{
'match': {
'left': {'payload': {'protocol': 'icmp', 'field': 'type'}},
'op': '==',
'right': rule.icmp_type,
}
}
)

# Check if the rule exists before attempting to remove it
if self.rule_exists(rule.chain, expr):
cmd = {
'nftables': [
{
'delete': {
'rule': {
'family': self.family,
'table': self.table,
'chain': rule.chain,
'expr': expr,
}
}
}
]
}
self.execute_cmd(cmd)
logger.info(
'Removed rule from chain %s: %s port %s', rule.chain, rule.protocol, rule.port
)
else:
logger.info(
'Rule does not exist in chain %s: %s port %s', rule.chain, rule.protocol, rule.port
)

def add_connection_tracking_rule(self, chain: str) -> None:
expr = [
{
Expand Down Expand Up @@ -339,28 +428,37 @@ def setup_firewall(self, enable_monitoring: bool = False) -> None:
if enable_monitoring:
tcp_ports.extend([8080, 9100])
for port in tcp_ports:
self.add_rule_if_not_exists(Rule(chain=self.chain, protocol='tcp', port=port))
self.add_rule(Rule(chain=self.chain, protocol='tcp', port=port))

self.add_rule_if_not_exists(Rule(chain=self.chain, protocol='udp', port=53))
self.add_rule(Rule(chain=self.chain, protocol='udp', port=53))
self.add_loopback_rule(chain=self.chain)

icmp_types = ['destination-unreachable', 'source-quench', 'time-exceeded']
for icmp_type in icmp_types:
self.add_rule_if_not_exists(
self.add_rule(
Rule(
chain=self.chain,
protocol='icmp',
icmp_type=icmp_type
)
)

self.add_drop_rule_if_node_exists(protocol='udp')
self.add_drop_rule(protocol='udp')

except Exception as e:
logger.error('Failed to setup firewall: %s', e)
raise NFTablesError(e)
logger.info('Firewall rules are configured')

def cleanup_rules(self):
""" Cleanups all node-cli generated rules """
self.remove_drop_rule('tcp')
self.remove_drop_rule('udp')
tcp_ports = [get_ssh_port(), 53, 443, 3009, 8080, 9100]
for port in tcp_ports:
self.remove_rule(Rule(chain=self.chain, protocol='tcp', port=port))
self.remove_rule(Rule(chain=self.chain, protocol='udp', port=53))

def flush_chain(self, chain: str) -> None:
"""Remove all rules from a specific chain"""
json_cmd = {
Expand Down
9 changes: 8 additions & 1 deletion node_cli/migrations/focal_to_jammy.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,14 @@ def migrate() -> None:
ssh_port = get_ssh_port()
logger.info('Running migration from focal to jammy')
remove_old_iptables_rules(ssh_port)

logger.info('Flushing nftables rules generated by release upgrade')
nft = NFTablesManager(family='ip', table='filter')
nft.flush_chain(IPTABLES_CHAIN)
nft.cleanup_rules()

# Logging rules after migration
res = run_cmd(['nft', 'list', 'ruleset'])
plain_rules = res.stdout.decode('utf-8')
logger.debug(plain_rules)

logger.info('Migration from focal to jammy completed')
4 changes: 3 additions & 1 deletion tests/core/migration_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from node_cli.migrations.focal_to_jammy import migrate
from node_cli.migrations.focal_to_jammy import migrate, NFTablesManager

from node_cli.utils.helper import run_cmd

Expand Down Expand Up @@ -41,3 +41,5 @@ def test_migration(base_rules):
res = run_cmd(['iptables', '-S'])
output = res.stdout.decode('utf-8')
assert output == f'-P INPUT ACCEPT\n-P FORWARD ACCEPT\n-P OUTPUT ACCEPT\n-N {CUSTOM_CHAIN_NAME}\n-A {CUSTOM_CHAIN_NAME} -p tcp -m tcp --dport 2222 -j ACCEPT\n' # noqa
nft = NFTablesManager(family='ip', table='filter')
assert nft.get_rules(chain='INPUT') == []
6 changes: 3 additions & 3 deletions tests/core/nftables_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,12 @@ def test_create_chain_if_not_exists(mock_exists, mock_execute, nft_manager):
)
@patch.object(NFTablesManager, 'execute_cmd')
@patch.object(NFTablesManager, 'rule_exists')
def test_add_rule_if_not_exists(mock_exists, mock_execute, nft_manager, rule_data):
def test_add_rule(mock_exists, mock_execute, nft_manager, rule_data):
"""Test rule addition with different types"""
mock_exists.return_value = False

rule = Rule(**rule_data)
nft_manager.add_rule_if_not_exists(rule)
nft_manager.add_rule(rule)
mock_execute.assert_called_once()


Expand All @@ -138,4 +138,4 @@ def test_invalid_protocol(nft_manager):
"""Test adding rule with invalid protocol"""
rule = Rule(chain='INPUT', protocol='invalid', port=80)
with pytest.raises(Exception):
nft_manager.add_rule_if_not_exists(rule)
nft_manager.add_rule(rule)

0 comments on commit 34df742

Please sign in to comment.