diff --git a/oryx-ebpf/src/main.rs b/oryx-ebpf/src/main.rs index 6afd83c..2950b70 100644 --- a/oryx-ebpf/src/main.rs +++ b/oryx-ebpf/src/main.rs @@ -34,16 +34,11 @@ static TRANSPORT_FILTERS: Array = Array::with_max_entries(8, 0); static LINK_FILTERS: Array = Array::with_max_entries(8, 0); #[map] -static BLOCKLIST_IPV6_INGRESS: HashMap = +static BLOCKLIST_IPV6: HashMap = HashMap::::with_max_entries(128, 0); + #[map] -static BLOCKLIST_IPV6_EGRESS: HashMap = - HashMap::::with_max_entries(128, 0); -#[map] -static BLOCKLIST_IPV4_INGRESS: HashMap = - HashMap::::with_max_entries(128, 0); -#[map] -static BLOCKLIST_IPV4_EGRESS: HashMap = +static BLOCKLIST_IPV4: HashMap = HashMap::::with_max_entries(128, 0); #[classifier] @@ -96,6 +91,28 @@ fn filter_for_ipv4_address( false } +#[inline] +fn filter_for_ipv6_address( + addr: u128, + port: u16, + blocked_ports_map: &HashMap, +) -> bool { + if let Some(blocked_ports) = unsafe { blocked_ports_map.get(&addr) } { + for (idx, blocked_port) in blocked_ports.iter().enumerate() { + if *blocked_port == 0 { + if idx == 0 { + return true; + } else { + break; + } + } else if *blocked_port == port { + return true; + } + } + } + false +} + #[inline] fn filter_packet(protocol: Protocol) -> bool { match protocol { @@ -134,8 +151,8 @@ fn process(ctx: TcContext) -> Result { let src_port = u16::from_be(unsafe { (*tcphdr).source }); let dst_port = u16::from_be(unsafe { (*tcphdr).dest }); - if filter_for_ipv4_address(src_addr, src_port, &BLOCKLIST_IPV4_INGRESS) - || filter_for_ipv4_address(dst_addr, dst_port, &BLOCKLIST_IPV4_EGRESS) + if filter_for_ipv4_address(src_addr, src_port, &BLOCKLIST_IPV4) + || filter_for_ipv4_address(dst_addr, dst_port, &BLOCKLIST_IPV4) { return Ok(TC_ACT_SHOT); } @@ -155,8 +172,8 @@ fn process(ctx: TcContext) -> Result { let src_port = u16::from_be(unsafe { (*udphdr).source }); let dst_port = u16::from_be(unsafe { (*udphdr).dest }); - if filter_for_ipv4_address(src_addr, src_port, &BLOCKLIST_IPV4_INGRESS) - || filter_for_ipv4_address(dst_addr, dst_port, &BLOCKLIST_IPV4_EGRESS) + if filter_for_ipv4_address(src_addr, src_port, &BLOCKLIST_IPV4) + || filter_for_ipv4_address(dst_addr, dst_port, &BLOCKLIST_IPV4) { return Ok(TC_ACT_SHOT); } @@ -187,11 +204,20 @@ fn process(ctx: TcContext) -> Result { } EtherType::Ipv6 => { let header: Ipv6Hdr = ctx.load(EthHdr::LEN).map_err(|_| ())?; + let src_addr = header.src_addr().to_bits(); + let dst_addr = header.dst_addr().to_bits(); match header.next_hdr { IpProto::Tcp => { let tcphdr: *const TcpHdr = ptr_at(&ctx, EthHdr::LEN + Ipv6Hdr::LEN)?; + let src_port = u16::from_be(unsafe { (*tcphdr).source }); + let dst_port = u16::from_be(unsafe { (*tcphdr).dest }); + if filter_for_ipv6_address(src_addr, src_port, &BLOCKLIST_IPV6) + || filter_for_ipv6_address(dst_addr, dst_port, &BLOCKLIST_IPV6) + { + return Ok(TC_ACT_SHOT); + } if filter_packet(Protocol::Network(NetworkProtocol::Ipv6)) || filter_packet(Protocol::Transport(TransportProtocol::TCP)) { @@ -204,7 +230,14 @@ fn process(ctx: TcContext) -> Result { } IpProto::Udp => { let udphdr: *const UdpHdr = ptr_at(&ctx, EthHdr::LEN + Ipv6Hdr::LEN)?; + let src_port = u16::from_be(unsafe { (*udphdr).source }); + let dst_port = u16::from_be(unsafe { (*udphdr).dest }); + if filter_for_ipv6_address(src_addr, src_port, &BLOCKLIST_IPV6) + || filter_for_ipv6_address(dst_addr, dst_port, &BLOCKLIST_IPV6) + { + return Ok(TC_ACT_SHOT); + } if filter_packet(Protocol::Network(NetworkProtocol::Ipv6)) || filter_packet(Protocol::Transport(TransportProtocol::UDP)) { diff --git a/oryx-tui/src/ebpf.rs b/oryx-tui/src/ebpf.rs index d8d3145..11497cd 100644 --- a/oryx-tui/src/ebpf.rs +++ b/oryx-tui/src/ebpf.rs @@ -295,9 +295,9 @@ impl Ebpf { Array::try_from(bpf.take_map("LINK_FILTERS").unwrap()).unwrap(); // firewall-ebpf interface let mut ipv4_firewall: HashMap<_, u32, [u16; 32]> = - HashMap::try_from(bpf.take_map("BLOCKLIST_IPV4_INGRESS").unwrap()).unwrap(); + HashMap::try_from(bpf.take_map("BLOCKLIST_IPV4").unwrap()).unwrap(); let mut ipv6_firewall: HashMap<_, u128, [u16; 32]> = - HashMap::try_from(bpf.take_map("BLOCKLIST_IPV6_INGRESS").unwrap()).unwrap(); + HashMap::try_from(bpf.take_map("BLOCKLIST_IPV6").unwrap()).unwrap(); thread::spawn(move || loop { if let Ok(signal) = firewall_ingress_receiver.recv() { @@ -482,9 +482,9 @@ impl Ebpf { // firewall-ebpf interface let mut ipv4_firewall: HashMap<_, u32, [u16; 32]> = - HashMap::try_from(bpf.take_map("BLOCKLIST_IPV4_EGRESS").unwrap()).unwrap(); + HashMap::try_from(bpf.take_map("BLOCKLIST_IPV4").unwrap()).unwrap(); let mut ipv6_firewall: HashMap<_, u128, [u16; 32]> = - HashMap::try_from(bpf.take_map("BLOCKLIST_IPV6_EGRESS").unwrap()).unwrap(); + HashMap::try_from(bpf.take_map("BLOCKLIST_IPV6").unwrap()).unwrap(); thread::spawn(move || loop { if let Ok(signal) = firewall_egress_receiver.recv() { match signal {