diff --git a/oryx-tui/src/ebpf.rs b/oryx-tui/src/ebpf.rs index 2cecd60..4a55c30 100644 --- a/oryx-tui/src/ebpf.rs +++ b/oryx-tui/src/ebpf.rs @@ -18,7 +18,7 @@ use oryx_common::{protocols::Protocol, RawPacket}; use crate::{ event::Event, notification::{Notification, NotificationLevel}, - section::firewall::FirewallRule, + section::firewall::{BlockedPort, FirewallRule}, }; use mio::{event::Source, unix::SourceFd, Events, Interest, Poll, Registry, Token}; @@ -66,51 +66,70 @@ impl Source for RingBuffer<'_> { fn update_ipv4_blocklist( ipv4_firewall: &mut HashMap, addr: Ipv4Addr, - port: u16, + port: BlockedPort, enabled: bool, ) { + // hashmap entry exists if let Ok(mut blocked_ports) = ipv4_firewall.get(&addr.to_bits(), 0) { - if enabled { - // add port to blocklist - if let Some(first_zero) = blocked_ports.iter().enumerate().find(|&x| *x.1 == 0) { - blocked_ports[first_zero.0] = port; - // dbg!("UPSERTING"); - // dbg!(blocked_ports[0], blocked_ports[1]); - ipv4_firewall - .insert(addr.to_bits(), blocked_ports, 0) - .unwrap(); - } else { - todo!(); // list is full - } - } else { - // remove port from blocklist - // on veut rebuild une blocklist avec les ports restants non nuls - // par example là [8888,0,80,0,..] - // hashmap = key:[0,0,0] - // => [8888,80,0 ....] - let non_null_ports = blocked_ports - .into_iter() - .filter(|p| (*p != 0 && *p != port)) - .collect::>(); - let mut blocked_ports = [0; 32]; - for (idx, p) in non_null_ports.iter().enumerate() { - blocked_ports[idx] = *p; + match port { + // single port update + BlockedPort::Single(port) => { + if enabled { + // add port to blocklist + if let Some(first_zero) = blocked_ports.iter().enumerate().find(|&x| *x.1 == 0) + { + blocked_ports[first_zero.0] = port; + // dbg!("UPSERTING"); + // dbg!(blocked_ports[0], blocked_ports[1]); + ipv4_firewall + .insert(addr.to_bits(), blocked_ports, 0) + .unwrap(); + } else { + todo!(); // list is full + } + } else { + // remove port from blocklist + // eg: remove port 53 [8888,53,80,0,..] => [8888,0,80,0,..] => [8888,80,0 ....] + let non_null_ports = blocked_ports + .into_iter() + .filter(|p| (*p != 0 && *p != port)) + .collect::>(); + let mut blocked_ports = [0; 32]; + for (idx, p) in non_null_ports.iter().enumerate() { + blocked_ports[idx] = *p; + } + if blocked_ports.iter().sum::() == 0 { + //if block_list is now empty, we need to delete key + ipv4_firewall.remove(&addr.to_bits()).unwrap(); + } else { + ipv4_firewall + .insert(addr.to_bits(), blocked_ports, 0) + .unwrap(); + } + } } - if blocked_ports.iter().sum::() == 0 { - //now block_list is empty, we need to delete key - ipv4_firewall.remove(&addr.to_bits()).unwrap(); - } else { - ipv4_firewall - .insert(addr.to_bits(), blocked_ports, 0) - .unwrap(); + BlockedPort::All => { + if enabled { + ipv4_firewall.insert(addr.to_bits(), [0; 32], 0).unwrap(); + } else { + ipv4_firewall.remove(&addr.to_bits()).unwrap(); + } } } - } else { + } + // no hashmap entry, create new blocklist + else { // shouldn't be disabling if blocklist is empty assert!(enabled); - // create new blocklist with port as first element + let mut blocked_ports: [u16; 32] = [0; 32]; - blocked_ports[0] = port; + match port { + BlockedPort::Single(port) => { + blocked_ports[0] = port; + } + BlockedPort::All => {} + } + ipv4_firewall .insert(addr.to_bits(), blocked_ports, 0) .unwrap(); @@ -120,41 +139,75 @@ fn update_ipv4_blocklist( fn update_ipv6_blocklist( ipv6_firewall: &mut HashMap, addr: Ipv6Addr, - port: u16, + port: BlockedPort, enabled: bool, ) { + // hashmap entry exists if let Ok(mut blocked_ports) = ipv6_firewall.get(&addr.to_bits(), 0) { - if enabled { - // add port to blocklist - if let Some(first_zero) = blocked_ports.iter().enumerate().find(|&x| *x.1 == 0) { - blocked_ports[first_zero.0] = port; - ipv6_firewall - .insert(addr.to_bits(), blocked_ports, 0) - .unwrap(); - } else { - todo!(); // list is full + match port { + // single port update + BlockedPort::Single(port) => { + if enabled { + // add port to blocklist + if let Some(first_zero) = blocked_ports.iter().enumerate().find(|&x| *x.1 == 0) + { + blocked_ports[first_zero.0] = port; + // dbg!("UPSERTING"); + // dbg!(blocked_ports[0], blocked_ports[1]); + ipv6_firewall + .insert(addr.to_bits(), blocked_ports, 0) + .unwrap(); + } else { + todo!(); // list is full + } + } else { + // remove port from blocklist + // eg: remove port 53 [8888,53,80,0,..] => [8888,0,80,0,..] => [8888,80,0 ....] + let non_null_ports = blocked_ports + .into_iter() + .filter(|p| (*p != 0 && *p != port)) + .collect::>(); + let mut blocked_ports = [0; 32]; + for (idx, p) in non_null_ports.iter().enumerate() { + blocked_ports[idx] = *p; + } + if blocked_ports.iter().sum::() == 0 { + //if block_list is now empty, we need to delete key + ipv6_firewall.remove(&addr.to_bits()).unwrap(); + } else { + ipv6_firewall + .insert(addr.to_bits(), blocked_ports, 0) + .unwrap(); + } + } } - } else { - // remove port from blocklist - if let Some(matching_port) = blocked_ports.iter().enumerate().find(|&x| *x.1 == port) { - blocked_ports[matching_port.0] = 0; - ipv6_firewall - .insert(addr.to_bits(), blocked_ports, 0) - .unwrap(); + BlockedPort::All => { + if enabled { + ipv6_firewall.insert(addr.to_bits(), [0; 32], 0).unwrap(); + } else { + ipv6_firewall.remove(&addr.to_bits()).unwrap(); + } } } - } else { + } + // no hashmap entry, create new blocklist + else { // shouldn't be disabling if blocklist is empty assert!(enabled); - //create new blocklist with port as first element + let mut blocked_ports: [u16; 32] = [0; 32]; - blocked_ports[0] = port; + match port { + BlockedPort::Single(port) => { + blocked_ports[0] = port; + } + BlockedPort::All => {} + } + ipv6_firewall .insert(addr.to_bits(), blocked_ports, 0) .unwrap(); } } - impl Ebpf { pub fn load_ingress( iface: String, diff --git a/oryx-tui/src/section/firewall.rs b/oryx-tui/src/section/firewall.rs index 2958dc7..35175b0 100644 --- a/oryx-tui/src/section/firewall.rs +++ b/oryx-tui/src/section/firewall.rs @@ -7,7 +7,7 @@ use ratatui::{ widgets::{Block, Borders, Cell, Clear, HighlightSpacing, Padding, Row, Table, TableState}, Frame, }; -use std::{net::IpAddr, str::FromStr}; +use std::{net::IpAddr, num::ParseIntError, str::FromStr}; use tui_input::{backend::crossterm::EventHandler, Input}; use uuid; @@ -19,7 +19,33 @@ pub struct FirewallRule { name: String, pub enabled: bool, pub ip: IpAddr, - pub port: u16, + pub port: BlockedPort, +} + +#[derive(Debug, Clone)] +pub enum BlockedPort { + Single(u16), + All, +} + +impl Display for BlockedPort { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + BlockedPort::Single(p) => write!(f, "{}", p), + BlockedPort::All => write!(f, "*"), + } + } +} + +impl FromStr for BlockedPort { + type Err = ParseIntError; + fn from_str(s: &str) -> Result { + if s == "*" { + return Ok(BlockedPort::All); + } else { + return Ok(BlockedPort::Single(u16::from_str(s)?)); + } + } } impl Display for FirewallRule { @@ -80,7 +106,7 @@ impl UserInput { self.port.error = None; if self.port.field.value().is_empty() { self.port.error = Some("Required field.".to_string()); - } else if u16::from_str(self.port.field.value()).is_err() { + } else if BlockedPort::from_str(self.port.field.value()).is_err() { self.port.error = Some("Invalid Port number.".to_string()); } } @@ -280,13 +306,14 @@ impl Firewall { // update rule with user input rule.name = user_input.name.field.to_string(); rule.ip = IpAddr::from_str(user_input.ip.field.value()).unwrap(); - rule.port = u16::from_str(user_input.port.field.value()).unwrap(); + rule.port = + BlockedPort::from_str(user_input.port.field.value()).unwrap(); } else { let rule = FirewallRule { id: uuid::Uuid::new_v4(), name: user_input.name.field.to_string(), ip: IpAddr::from_str(user_input.ip.field.value()).unwrap(), - port: u16::from_str(user_input.port.field.value()).unwrap(), + port: BlockedPort::from_str(user_input.port.field.value()).unwrap(), enabled: false, }; self.rules.push(rule);