Skip to content

Commit

Permalink
Re-factor stanza filter and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
moparisthebest committed Apr 15, 2021
1 parent 6f95db7 commit 3792d22
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 148 deletions.
223 changes: 76 additions & 147 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,13 @@ use anyhow::{bail, Result};
mod slicesubsequence;
use slicesubsequence::*;

mod stanzafilter;
use stanzafilter::*;

const IN_BUFFER_SIZE: usize = 8192;
const OUT_BUFFER_SIZE: usize = 8192;

const WHITESPACE: &[u8] = b" \t\n\r";
pub const WHITESPACE: &[u8] = b" \t\n\r";

#[cfg(debug_assertions)]
fn c2s(is_c2s: bool) -> &'static str {
Expand Down Expand Up @@ -148,91 +151,83 @@ async fn handle_connection(mut stream: tokio::net::TcpStream, client_addr: Socke

// starttls
if !direct_tls {
let mut stream_open = Vec::new();
let mut proceed_sent = false;

let (in_rd, mut in_wr) = stream.split();
// we naively read 1 byte at a time, which buffering significantly speeds up
let mut in_rd = tokio::io::BufReader::with_capacity(IN_BUFFER_SIZE, in_rd);

while let Ok(n) = in_rd.read(in_filter.current_buf()).await {
if n == 0 {
bail!("stream ended before open");
}
if let Some(buf) = in_filter.process_next_byte()? {
debug!("received pre-tls stanza: {} '{}'", client_addr, to_str(&buf));
let buf = buf.trim_start(WHITESPACE);
if buf.starts_with(b"<?xml ") {
stream_open.extend_from_slice(buf);
continue;
} else if buf.starts_with(b"<stream:stream ") {
debug!("> {} '{}'", client_addr, to_str(&stream_open));
in_wr.write_all(&stream_open).await?;
stream_open.clear();

// gajim seems to REQUIRE an id here...
let buf = if buf.contains_seq(b"id=") {
buf.replace_first(b" id='", b" id='xmpp-proxy")
.replace_first(br#" id=""#, br#" id="xmpp-proxy"#)
.replace_first(b" to=", br#" bla toblala="#)
.replace_first(b" from=", b" to=")
.replace_first(br#" bla toblala="#, br#" from="#)
} else {
buf.replace_first(b" to=", br#" bla toblala="#)
.replace_first(b" from=", b" to=")
.replace_first(br#" bla toblala="#, br#" id='xmpp-proxy' from="#)
};

debug!("> {} '{}'", client_addr, to_str(&buf));
in_wr.write_all(&buf).await?;

// ejabberd never sends <starttls/> with the first, only the second?
//let buf = br###"<features xmlns="http://etherx.jabber.org/streams"><starttls xmlns="urn:ietf:params:xml:ns:xmpp-tls"><required/></starttls></features>"###;
let buf = br###"<stream:features><starttls xmlns="urn:ietf:params:xml:ns:xmpp-tls"><required/></starttls></stream:features>"###;
debug!("> {} '{}'", client_addr, to_str(buf));
in_wr.write_all(buf).await?;
in_wr.flush().await?;
} else if buf.starts_with(b"<starttls ") {
let buf = br###"<proceed xmlns="urn:ietf:params:xml:ns:xmpp-tls" />"###;
debug!("> {} '{}'", client_addr, to_str(buf));
in_wr.write_all(buf).await?;
in_wr.flush().await?;
break;
let in_rd = tokio::io::BufReader::with_capacity(IN_BUFFER_SIZE, in_rd);
let mut in_rd = StanzaReader(in_rd);

while let Ok(Some(buf)) = in_rd.next(&mut in_filter).await {
debug!("received pre-tls stanza: {} '{}'", client_addr, to_str(&buf));
let buf = buf.trim_start(WHITESPACE);
if buf.starts_with(b"<?xml ") {
debug!("> {} '{}'", client_addr, to_str(&buf));
in_wr.write_all(&buf).await?;
in_wr.flush().await?;
} else if buf.starts_with(b"<stream:stream ") {
// gajim seems to REQUIRE an id here...
let buf = if buf.contains_seq(b"id=") {
buf.replace_first(b" id='", b" id='xmpp-proxy")
.replace_first(br#" id=""#, br#" id="xmpp-proxy"#)
.replace_first(b" to=", br#" bla toblala="#)
.replace_first(b" from=", b" to=")
.replace_first(br#" bla toblala="#, br#" from="#)
} else {
bail!("bad pre-tls stanza: {}", to_str(&buf));
}
buf.replace_first(b" to=", br#" bla toblala="#)
.replace_first(b" from=", b" to=")
.replace_first(br#" bla toblala="#, br#" id='xmpp-proxy' from="#)
};

debug!("> {} '{}'", client_addr, to_str(&buf));
in_wr.write_all(&buf).await?;

// ejabberd never sends <starttls/> with the first, only the second?
//let buf = br###"<features xmlns="http://etherx.jabber.org/streams"><starttls xmlns="urn:ietf:params:xml:ns:xmpp-tls"><required/></starttls></features>"###;
let buf = br###"<stream:features><starttls xmlns="urn:ietf:params:xml:ns:xmpp-tls"><required/></starttls></stream:features>"###;
debug!("> {} '{}'", client_addr, to_str(buf));
in_wr.write_all(buf).await?;
in_wr.flush().await?;
} else if buf.starts_with(b"<starttls ") {
let buf = br###"<proceed xmlns="urn:ietf:params:xml:ns:xmpp-tls" />"###;
debug!("> {} '{}'", client_addr, to_str(buf));
in_wr.write_all(buf).await?;
in_wr.flush().await?;
proceed_sent = true;
break;
} else {
bail!("bad pre-tls stanza: {}", to_str(&buf));
}
}
if !proceed_sent {
bail!("stream ended before open");
}
}

let stream = config.acceptor.accept(stream).await?;

let (in_rd, mut in_wr) = tokio::io::split(stream);
// we naively read 1 byte at a time, which buffering significantly speeds up
let mut in_rd = tokio::io::BufReader::with_capacity(IN_BUFFER_SIZE, in_rd);
let in_rd = tokio::io::BufReader::with_capacity(IN_BUFFER_SIZE, in_rd);
let mut in_rd = StanzaReader(in_rd);

// now read to figure out client vs server
let (stream_open, is_c2s) = {
let mut stream_open = Vec::new();
let mut ret = None;

while let Ok(n) = in_rd.read(in_filter.current_buf()).await {
if n == 0 {
bail!("stream ended before open");
}
if let Some(buf) = in_filter.process_next_byte()? {
debug!("received pre-<stream:stream> stanza: {} '{}'", client_addr, to_str(&buf));
let buf = buf.trim_start(WHITESPACE);
if buf.starts_with(b"<?xml ") {
stream_open.extend_from_slice(buf);
continue;
} else if buf.starts_with(b"<stream:stream ") {
stream_open.extend_from_slice(buf);
//return (stream_open, stanza.contains(r#" xmlns="jabber:client""#) || stanza.contains(r#" xmlns='jabber:client'"#));
ret = Some((stream_open, buf.contains_seq(br#" xmlns="jabber:client""#) || buf.contains_seq(br#" xmlns='jabber:client'"#)));
break;
} else {
bail!("bad pre-<stream:stream> stanza: {}", to_str(&buf));
}
while let Ok(Some(buf)) = in_rd.next(&mut in_filter).await {
debug!("received pre-<stream:stream> stanza: {} '{}'", client_addr, to_str(&buf));
let buf = buf.trim_start(WHITESPACE);
if buf.starts_with(b"<?xml ") {
stream_open.extend_from_slice(buf);
} else if buf.starts_with(b"<stream:stream ") {
stream_open.extend_from_slice(buf);
//return (stream_open, stanza.contains(r#" xmlns="jabber:client""#) || stanza.contains(r#" xmlns='jabber:client'"#));
ret = Some((stream_open, buf.contains_seq(br#" xmlns="jabber:client""#) || buf.contains_seq(br#" xmlns='jabber:client'"#)));
break;
} else {
bail!("bad pre-<stream:stream> stanza: {}", to_str(&buf));
}
}
if ret.is_some() {
Expand Down Expand Up @@ -281,14 +276,14 @@ async fn handle_connection(mut stream: tokio::net::TcpStream, client_addr: Socke

loop {
tokio::select! {
Ok(n) = in_rd.read(in_filter.current_buf()) => {
if n == 0 {
break;
}
if let Some(buf) = in_filter.process_next_byte()? {
debug!("< {} {} '{}'", client_addr, c2s(is_c2s), to_str(buf));
out_wr.write_all(buf).await?;
out_wr.flush().await?;
Ok(buf) = in_rd.next(&mut in_filter) => {
match buf {
None => break,
Some(buf) => {
debug!("< {} {} '{}'", client_addr, c2s(is_c2s), to_str(buf));
out_wr.write_all(buf).await?;
out_wr.flush().await?;
}
}
},
// we could filter outgoing from-server stanzas by size here too by doing same as above
Expand All @@ -308,6 +303,11 @@ async fn handle_connection(mut stream: tokio::net::TcpStream, client_addr: Socke
Ok(())
}

/*
async fn handle_connection(mut stream: tokio::net::TcpStream, client_addr: SocketAddr, local_addr: SocketAddr, config: CloneableConfig) -> Result<()> {
Ok(())
}
*/
fn spawn_listener(listener: TcpListener, config: CloneableConfig) -> JoinHandle<Result<()>> {
let local_addr = listener.local_addr().die("could not get local_addr?");
tokio::spawn(async move {
Expand Down Expand Up @@ -339,74 +339,3 @@ async fn main() {
}
futures::future::join_all(handles).await;
}

struct StanzaFilter {
buf_size: usize,
buf: Vec<u8>,
cnt: usize,
tag_cnt: usize,
last_char_was_lt: bool,
last_char_was_backslash: bool,
}

impl StanzaFilter {
pub fn new(buf_size: usize) -> StanzaFilter {
StanzaFilter {
buf_size,
buf: vec![0u8; buf_size],
cnt: 0,
tag_cnt: 0,
last_char_was_lt: false,
last_char_was_backslash: false,
}
}

#[inline(always)]
pub fn current_buf(&mut self) -> &mut [u8] {
&mut self.buf[self.cnt..(self.cnt + 1)]
}

pub fn process_next_byte(&mut self) -> Result<Option<&[u8]>> {
//println!("n: {}", n);
let b = self.buf[self.cnt];
if b == b'<' {
self.tag_cnt += 1;
self.last_char_was_lt = true;
} else {
if b == b'/' {
// if last_char_was_lt but tag_cnt < 2, should only be </stream:stream>
if self.last_char_was_lt && self.tag_cnt >= 2 {
// non-self-closing tag
self.tag_cnt -= 2;
}
self.last_char_was_backslash = true;
} else {
if b == b'>' {
if self.last_char_was_backslash {
// self-closing tag
self.tag_cnt -= 1;
}
// now special case some tags we want to send stand-alone:
if self.tag_cnt == 1 && self.cnt >= 15 && (b"<?xml" == &self.buf[0..5] || b"<stream:stream" == &self.buf[0..14] || b"</stream:stream" == &self.buf[0..15]) {
self.tag_cnt = 0; // to fall through to next logic
}
if self.tag_cnt == 0 {
let ret = Ok(Some(&self.buf[0..(self.cnt + 1)]));
self.cnt = 0;
self.last_char_was_backslash = false;
self.last_char_was_lt = false;
return ret;
}
}
self.last_char_was_backslash = false;
}
self.last_char_was_lt = false;
}
//println!("b: '{}', cnt: {}, tag_cnt: {}, self.buf.len(): {}", b as char, self.cnt, self.tag_cnt, self.buf.len());
self.cnt += 1;
if self.cnt == self.buf_size {
bail!("stanza too big: {}", to_str(&self.buf));
}
Ok(None)
}
}
2 changes: 1 addition & 1 deletion src/slicesubsequence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ impl<T: PartialEq + Clone> SliceSubsequence<T> for Vec<T> {
#[cfg(test)]
mod tests {
use crate::slicesubsequence::*;
const WHITESPACE: &[u8] = b" \t\n\r";
use crate::WHITESPACE;

#[test]
fn trim_start() {
Expand Down
Loading

0 comments on commit 3792d22

Please sign in to comment.