Skip to content

Commit

Permalink
Decrypt packets in place
Browse files Browse the repository at this point in the history
  • Loading branch information
Hasan6979 committed Dec 10, 2024
1 parent 8f50b1e commit e394cac
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 72 deletions.
14 changes: 5 additions & 9 deletions neptun/src/device/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -775,10 +775,10 @@ impl Device {
let src_buf =
unsafe { &mut *(&mut t.src_buf[..] as *mut [u8] as *mut [MaybeUninit<u8>]) };
while let Ok((packet_len, addr)) = udp.recv_from(src_buf) {
let packet = &t.src_buf[..packet_len];
// let packet = &t.src_buf[..packet_len];
// The rate limiter initially checks mac1 and mac2, and optionally asks to send a cookie
let parsed_packet =
match rate_limiter.verify_packet(Some(addr.as_socket().unwrap().ip()), packet, &mut t.dst_buf) {
match rate_limiter.verify_packet(Some(addr.as_socket().unwrap().ip()), &mut t.src_buf, packet_len, &mut t.dst_buf) {
Ok(packet) => packet,
Err(TunnResult::WriteToNetwork(cookie)) => {
if let Err(err) = udp.send_to(cookie, &addr) {
Expand Down Expand Up @@ -866,7 +866,7 @@ impl Device {
loop {
let res = {
let mut tun = peer.tunnel.lock();
tun.decapsulate(None, &[], &mut t.dst_buf[..])
tun.decapsulate(None, &mut [], 0, &mut t.dst_buf[..])
};

let TunnResult::WriteToNetwork(packet) = res else {
Expand Down Expand Up @@ -926,11 +926,7 @@ impl Device {

let res = {
let mut tun = peer.tunnel.lock();
tun.decapsulate(
Some(peer_addr),
&t.src_buf[..read_bytes],
&mut t.dst_buf[..],
)
tun.decapsulate(Some(peer_addr), &mut t.src_buf, read_bytes, &mut t.dst_buf)
};

match res {
Expand Down Expand Up @@ -987,7 +983,7 @@ impl Device {
loop {
let res = {
let mut tun = peer.tunnel.lock();
tun.decapsulate(None, &[], &mut t.dst_buf[..])
tun.decapsulate(None, &mut [], 0, &mut t.dst_buf[..])
};
let TunnResult::WriteToNetwork(packet) = res else {
break;
Expand Down
11 changes: 7 additions & 4 deletions neptun/src/noise/integration_tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ mod tests {
let mut receiving_buffer = vec![0u8; MAX_PACKET];

// Initiate handshake from a side
match a.tunnel.lock().encapsulate(&[], &mut sending_buffer) {
match a.tunnel.lock().encapsulate(&mut sending_buffer, 0) {
TunnResult::WriteToNetwork(msg) => {
a.client_socket
.send_to(msg, b.client_address)
Expand All @@ -95,7 +95,8 @@ mod tests {

match b.tunnel.lock().decapsulate(
None,
&receiving_buffer[..bytes_read],
&mut receiving_buffer,
bytes_read,
&mut sending_buffer,
) {
TunnResult::WriteToNetwork(msg) => {
Expand All @@ -118,7 +119,8 @@ mod tests {

match a.tunnel.lock().decapsulate(
None,
&receiving_buffer[..bytes_read],
&mut receiving_buffer,
bytes_read,
&mut sending_buffer,
) {
TunnResult::WriteToNetwork(msg) => {
Expand All @@ -142,7 +144,8 @@ mod tests {

match b.tunnel.lock().decapsulate(
None,
&receiving_buffer[..bytes_read],
&mut receiving_buffer,
bytes_read,
&mut sending_buffer,
) {
TunnResult::Done => (),
Expand Down
109 changes: 61 additions & 48 deletions neptun/src/noise/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ const DATA_OVERHEAD_SZ: usize = 32;

#[derive(Debug)]
pub struct HandshakeInit<'a> {
msg_buffer: &'a [u8],
sender_idx: u32,
unencrypted_ephemeral: &'a [u8; 32],
encrypted_static: &'a [u8],
Expand All @@ -104,6 +105,7 @@ pub struct HandshakeInit<'a> {

#[derive(Debug)]
pub struct HandshakeResponse<'a> {
msg_buffer: &'a [u8],
sender_idx: u32,
pub receiver_idx: u32,
unencrypted_ephemeral: &'a [u8; 32],
Expand All @@ -121,7 +123,8 @@ pub struct PacketCookieReply<'a> {
pub struct PacketData<'a> {
pub receiver_idx: u32,
counter: u64,
encrypted_encapsulated_packet: &'a [u8],
encrypted_encapsulated_packet: &'a mut [u8],
data_len: usize,
}

/// Describes a packet from network
Expand All @@ -139,23 +142,27 @@ impl Tunn {
}

#[inline(always)]
pub fn parse_incoming_packet(src: &[u8]) -> Result<Packet, WireGuardError> {
pub fn parse_incoming_packet<'a>(
src: &'a mut [u8],
data_len: usize,
) -> Result<Packet<'a>, WireGuardError> {
if src.len() < 4 {
return Err(WireGuardError::InvalidPacket);
}

// Checks the type, as well as the reserved zero fields
let packet_type = u32::from_le_bytes(src[0..4].try_into().unwrap());

Ok(match (packet_type, src.len()) {
Ok(match (packet_type, data_len) {
(HANDSHAKE_INIT, HANDSHAKE_INIT_SZ) => Packet::HandshakeInit(HandshakeInit {
msg_buffer: &src[..data_len],
sender_idx: u32::from_le_bytes(src[4..8].try_into().unwrap()),
unencrypted_ephemeral: <&[u8; 32] as TryFrom<&[u8]>>::try_from(&src[8..40])
.expect("length already checked above"),
encrypted_static: &src[40..88],
encrypted_timestamp: &src[88..116],
}),
(HANDSHAKE_RESP, HANDSHAKE_RESP_SZ) => Packet::HandshakeResponse(HandshakeResponse {
msg_buffer: &src[..data_len],
sender_idx: u32::from_le_bytes(src[4..8].try_into().unwrap()),
receiver_idx: u32::from_le_bytes(src[8..12].try_into().unwrap()),
unencrypted_ephemeral: <&[u8; 32] as TryFrom<&[u8]>>::try_from(&src[12..44])
Expand All @@ -170,7 +177,8 @@ impl Tunn {
(DATA, DATA_OVERHEAD_SZ..=std::usize::MAX) => Packet::PacketData(PacketData {
receiver_idx: u32::from_le_bytes(src[4..8].try_into().unwrap()),
counter: u64::from_le_bytes(src[8..16].try_into().unwrap()),
encrypted_encapsulated_packet: &src[16..],
encrypted_encapsulated_packet: &mut src[16..],
data_len: data_len - 16,
}),
_ => return Err(WireGuardError::InvalidPacket),
})
Expand Down Expand Up @@ -308,42 +316,44 @@ impl Tunn {
pub fn decapsulate<'a>(
&mut self,
src_addr: Option<IpAddr>,
datagram: &[u8],
data_buffer: &'a mut [u8],
data_len: usize,
dst: &'a mut [u8],
) -> TunnResult<'a> {
if datagram.is_empty() {
if data_len == 0 {
// Indicates a repeated call
return self.send_queued_packet(dst);
}

let mut cookie = [0u8; COOKIE_REPLY_SZ];
let packet = match self
.rate_limiter
.verify_packet(src_addr, datagram, &mut cookie)
{
Ok(packet) => packet,
Err(TunnResult::WriteToNetwork(cookie)) => {
dst[..cookie.len()].copy_from_slice(cookie);
self.tx_bytes += cookie.len();
return TunnResult::WriteToNetwork(&mut dst[..cookie.len()]);
}
Err(TunnResult::Err(e)) => return TunnResult::Err(e),
_ => unreachable!(),
};
let packet =
match self
.rate_limiter
.verify_packet(src_addr, data_buffer, data_len, &mut cookie)
{
Ok(packet) => packet,
Err(TunnResult::WriteToNetwork(cookie)) => {
dst[..cookie.len()].copy_from_slice(cookie);
self.tx_bytes += cookie.len();
return TunnResult::WriteToNetwork(&mut dst[..cookie.len()]);
}
Err(TunnResult::Err(e)) => return TunnResult::Err(e),
_ => unreachable!(),
};

self.handle_verified_packet(packet, dst)
}

pub(crate) fn handle_verified_packet<'a>(
&mut self,
packet: Packet,
packet: Packet<'a>,
dst: &'a mut [u8],
) -> TunnResult<'a> {
match packet {
Packet::HandshakeInit(p) => self.handle_handshake_init(p, dst),
Packet::HandshakeResponse(p) => self.handle_handshake_response(p, dst),
Packet::PacketCookieReply(p) => self.handle_cookie_reply(p),
Packet::PacketData(p) => self.handle_data(p, dst),
Packet::PacketData(p) => self.handle_data(p),
}
.unwrap_or_else(TunnResult::from)
}
Expand Down Expand Up @@ -454,8 +464,7 @@ impl Tunn {
/// Decrypts a data packet, and stores the decapsulated packet in dst.
fn handle_data<'a>(
&mut self,
packet: PacketData,
dst: &'a mut [u8],
packet: PacketData<'a>,
) -> Result<TunnResult<'a>, WireGuardError> {
let r_idx = packet.receiver_idx as usize;
let idx = r_idx % N_SESSIONS;
Expand All @@ -467,7 +476,7 @@ impl Tunn {
tracing::trace!(message = "No current session available", remote_idx = r_idx);
WireGuardError::NoCurrentSession
})?;
session.receive_packet_data(packet, dst)?
session.receive_packet_data(packet)?
};

self.set_current_session(r_idx);
Expand Down Expand Up @@ -679,9 +688,9 @@ mod tests {
handshake_init.into()
}

fn create_handshake_response(tun: &mut Tunn, handshake_init: &[u8]) -> Vec<u8> {
fn create_handshake_response(tun: &mut Tunn, handshake_init: &mut [u8]) -> Vec<u8> {
let mut dst = vec![0u8; 2048];
let handshake_resp = tun.decapsulate(None, handshake_init, &mut dst);
let handshake_resp = tun.decapsulate(None, handshake_init, handshake_init.len(), &mut dst);
assert!(matches!(handshake_resp, TunnResult::WriteToNetwork(_)));

let handshake_resp = if let TunnResult::WriteToNetwork(sent) = handshake_resp {
Expand All @@ -693,9 +702,9 @@ mod tests {
handshake_resp.into()
}

fn parse_handshake_resp(tun: &mut Tunn, handshake_resp: &[u8]) -> Vec<u8> {
fn parse_handshake_resp(tun: &mut Tunn, handshake_resp: &mut [u8]) -> Vec<u8> {
let mut dst = vec![0u8; 2048];
let keepalive = tun.decapsulate(None, handshake_resp, &mut dst);
let keepalive = tun.decapsulate(None, handshake_resp, handshake_resp.len(), &mut dst);
assert!(matches!(keepalive, TunnResult::WriteToNetwork(_)));

let keepalive = if let TunnResult::WriteToNetwork(sent) = keepalive {
Expand All @@ -707,18 +716,18 @@ mod tests {
keepalive.into()
}

fn parse_keepalive(tun: &mut Tunn, keepalive: &[u8]) {
fn parse_keepalive(tun: &mut Tunn, keepalive: &mut [u8]) {
let mut dst = vec![0u8; 2048];
let keepalive = tun.decapsulate(None, keepalive, &mut dst);
let keepalive = tun.decapsulate(None, keepalive, 0, &mut dst);
assert!(matches!(keepalive, TunnResult::Done));
}

fn create_two_tuns_and_handshake() -> (Tunn, Tunn) {
let (mut my_tun, mut their_tun) = create_two_tuns();
let init = create_handshake_init(&mut my_tun);
let resp = create_handshake_response(&mut their_tun, &init);
let keepalive = parse_handshake_resp(&mut my_tun, &resp);
parse_keepalive(&mut their_tun, &keepalive);
let mut init = create_handshake_init(&mut my_tun);
let mut resp = create_handshake_response(&mut their_tun, &mut init);
let mut keepalive = parse_handshake_resp(&mut my_tun, &mut resp);
parse_keepalive(&mut their_tun, &mut keepalive);

(my_tun, their_tun)
}
Expand Down Expand Up @@ -754,27 +763,30 @@ mod tests {
#[test]
fn handshake_init() {
let (mut my_tun, _their_tun) = create_two_tuns();
let init = create_handshake_init(&mut my_tun);
let packet = Tunn::parse_incoming_packet(&init).unwrap();
let mut init = create_handshake_init(&mut my_tun);
let init_len = init.len();
let packet = Tunn::parse_incoming_packet(&mut init, init_len).unwrap();
assert!(matches!(packet, Packet::HandshakeInit(_)));
}

#[test]
fn handshake_init_and_response() {
let (mut my_tun, mut their_tun) = create_two_tuns();
let init = create_handshake_init(&mut my_tun);
let resp = create_handshake_response(&mut their_tun, &init);
let packet = Tunn::parse_incoming_packet(&resp).unwrap();
let mut init = create_handshake_init(&mut my_tun);
let mut resp = create_handshake_response(&mut their_tun, &mut init);
let resp_len = resp.len();
let packet = Tunn::parse_incoming_packet(&mut resp, resp_len).unwrap();
assert!(matches!(packet, Packet::HandshakeResponse(_)));
}

#[test]
fn full_handshake() {
let (mut my_tun, mut their_tun) = create_two_tuns();
let init = create_handshake_init(&mut my_tun);
let resp = create_handshake_response(&mut their_tun, &init);
let keepalive = parse_handshake_resp(&mut my_tun, &resp);
let packet = Tunn::parse_incoming_packet(&keepalive).unwrap();
let mut init = create_handshake_init(&mut my_tun);
let mut resp = create_handshake_response(&mut their_tun, &mut init);
let mut keepalive = parse_handshake_resp(&mut my_tun, &mut resp);
let kepalive_len = keepalive.len();
let packet = Tunn::parse_incoming_packet(&mut keepalive, kepalive_len).unwrap();
assert!(matches!(packet, Packet::PacketData(_)));
}

Expand Down Expand Up @@ -829,17 +841,18 @@ mod tests {
let mut my_dst = [0u8; 1024];
let mut their_dst = [0u8; 1024];

let mut sent_packet_buf = create_ipv4_udp_packet();

let data = my_tun.encapsulate(&mut sent_packet_buf, sent_packet_buf.len());
let sent_packet_buf = create_ipv4_udp_packet();
let packet_buf_len = sent_packet_buf.len();
my_dst[16..16 + packet_buf_len].copy_from_slice(&sent_packet_buf[..packet_buf_len]);
let data = my_tun.encapsulate(&mut my_dst, packet_buf_len);
assert!(matches!(data, TunnResult::WriteToNetwork(_)));
let data = if let TunnResult::WriteToNetwork(sent) = data {
sent
} else {
unreachable!();
};

let data = their_tun.decapsulate(None, data, &mut their_dst);
let data = their_tun.decapsulate(None, data, data.len(), &mut their_dst);
assert!(matches!(data, TunnResult::WriteToTunnelV4(..)));
let recv_packet_buf = if let TunnResult::WriteToTunnelV4(recv, _addr) = data {
recv
Expand Down
19 changes: 14 additions & 5 deletions neptun/src/noise/rate_limiter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,16 +156,25 @@ impl RateLimiter {
pub fn verify_packet<'a, 'b>(
&self,
src_addr: Option<IpAddr>,
src: &'a [u8],
src: &'a mut [u8],
message_len: usize,
dst: &'b mut [u8],
) -> Result<Packet<'a>, TunnResult<'b>> {
let packet = Tunn::parse_incoming_packet(src)?;
let packet = Tunn::parse_incoming_packet(src, message_len)?;

// Verify and rate limit handshake messages only
if let Packet::HandshakeInit(HandshakeInit { sender_idx, .. })
| Packet::HandshakeResponse(HandshakeResponse { sender_idx, .. }) = packet
if let Packet::HandshakeInit(HandshakeInit {
msg_buffer,
sender_idx,
..
})
| Packet::HandshakeResponse(HandshakeResponse {
msg_buffer,
sender_idx,
..
}) = packet
{
let (msg, macs) = src.split_at(src.len() - 32);
let (msg, macs) = msg_buffer.split_at(message_len - 32);
let (mac1, mac2) = macs.split_at(16);

let computed_mac1 = b2s_keyed_mac_16(&self.mac1_key, msg);
Expand Down
Loading

0 comments on commit e394cac

Please sign in to comment.