diff --git a/Cargo.lock b/Cargo.lock index cd67c2e..08a4ae9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -195,6 +195,10 @@ name = "cc" version = "1.0.94" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "17f6e324229dc011159fcc089755d1e2e216a90d43a7dea6853ca740b84f35e7" +dependencies = [ + "jobserver", + "libc", +] [[package]] name = "cesu8" @@ -244,7 +248,7 @@ dependencies = [ [[package]] name = "common" -version = "1.2.9" +version = "1.2.10" [[package]] name = "console" @@ -606,6 +610,15 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" +[[package]] +name = "jobserver" +version = "0.1.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2b099aaa34a9751c5bf0878add70444e1ed2dd73f347be99003d4577277de6e" +dependencies = [ + "libc", +] + [[package]] name = "js-sys" version = "0.3.69" @@ -720,6 +733,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "lz4_flex" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75761162ae2b0e580d7e7c390558127e5f01b4194debd6221fd8c207fc80e3f5" + [[package]] name = "memchr" version = "2.7.2" @@ -1586,7 +1605,7 @@ checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" [[package]] name = "vnt" -version = "1.2.9" +version = "1.2.10" dependencies = [ "aes", "aes-gcm", @@ -1602,6 +1621,7 @@ dependencies = [ "libloading", "libsm", "log", + "lz4_flex", "mio", "openssl-sys", "packet", @@ -1619,11 +1639,12 @@ dependencies = [ "thiserror", "tokio", "tun", + "zstd", ] [[package]] name = "vnt-cli" -version = "1.2.9" +version = "1.2.10" dependencies = [ "anyhow", "chrono", @@ -1645,7 +1666,7 @@ dependencies = [ [[package]] name = "vnt-jni" -version = "1.2.9" +version = "1.2.10" dependencies = [ "android_logger", "common", @@ -2003,3 +2024,31 @@ name = "zeroize" version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d" + +[[package]] +name = "zstd" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d789b1514203a1120ad2429eae43a7bd32b90976a7bb8a05f7ec02fa88cc23a" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "7.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1cd99b45c6bc03a018c8b8a86025678c87e55526064e38f9df301989dce7ec0a" +dependencies = [ + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.10+zstd.1.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c253a4914af5bafc8fa8c86ee400827e83cf6ec01195ec1f1ed8441bf00d65aa" +dependencies = [ + "cc", + "pkg-config", +] diff --git a/README.md b/README.md index 7017090..8a3e80a 100644 --- a/README.md +++ b/README.md @@ -88,6 +88,8 @@ features说明 | log | 日志 | 是 | | command | list、route等命令 | 是 | | file_config | yaml配置文件 | 是 | +| lz4 | lz4压缩 | 是 | +| zstd | zstd压缩 | 否 | ### ip转发/代理 diff --git a/common/Cargo.toml b/common/Cargo.toml index 72c4ae8..16e39e1 100644 --- a/common/Cargo.toml +++ b/common/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "common" -version = "1.2.9" +version = "1.2.10" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/vnt-cli/Cargo.toml b/vnt-cli/Cargo.toml index e591254..a351838 100644 --- a/vnt-cli/Cargo.toml +++ b/vnt-cli/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "vnt-cli" -version = "1.2.9" +version = "1.2.10" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -29,7 +29,7 @@ sudo = "0.6.0" winapi = { version = "0.3.9", features = ["handleapi", "processthreadsapi", "winnt", "securitybaseapi", "impl-default"] } [features] -default = ["server_encrypt", "aes_gcm", "aes_cbc", "aes_ecb", "sm4_cbc", "ip_proxy", "port_mapping", "log", "command", "file_config"] +default = ["server_encrypt", "aes_gcm", "aes_cbc", "aes_ecb", "sm4_cbc", "ip_proxy", "port_mapping", "log", "command", "file_config", "lz4"] openssl = ["vnt/openssl"] openssl-vendored = ["vnt/openssl-vendored"] ring-cipher = ["vnt/ring-cipher"] @@ -40,6 +40,8 @@ aes_gcm = ["vnt/aes_gcm"] server_encrypt = ["vnt/server_encrypt"] ip_proxy = ["vnt/ip_proxy"] port_mapping = ["vnt/port_mapping"] +lz4 = ["vnt/lz4_compress"] +zstd = ["vnt/zstd_compress"] log = ["log4rs"] command = [] file_config = [] diff --git a/vnt-cli/README.md b/vnt-cli/README.md index 97c5ea2..e9ca1be 100644 --- a/vnt-cli/README.md +++ b/vnt-cli/README.md @@ -95,6 +95,14 @@ ### --mapping `10.26.0.10:80>` 端口映射,可以设置多个映射地址,例如 '--mapping udp:0.0.0.0:80->10.26.0.10:80 --mapping tcp:0.0.0.0:80->10.26.0.11:81' 表示将本地udp 80端口的数据转发到10.26.0.10:80,将本地tcp 80端口的数据转发到10.26.0.11:81,转发的目的地址可以使用域名+端口 + +### --compressor `` +启用压缩,默认仅支持lz4压缩,开启压缩后,如果数据包长度大于等于128,则会使用压缩,否则还是会按原数据发送 + +也支持开启zstd压缩,但是需要自行编译,编译时加入参数--features zstd + +如果宽度速度比较慢,可以考虑使用高级别的压缩 + ### -f `` 指定配置文件 配置文件采用yaml格式,可参考: diff --git a/vnt-cli/src/config/file_config.rs b/vnt-cli/src/config/file_config.rs new file mode 100644 index 0000000..3e7c9e3 --- /dev/null +++ b/vnt-cli/src/config/file_config.rs @@ -0,0 +1,181 @@ +use anyhow::anyhow; +use std::net::Ipv4Addr; +use std::str::FromStr; + +use serde::{Deserialize, Serialize}; + +use crate::config::get_device_id; +use vnt::channel::punch::PunchModel; +use vnt::channel::UseChannelType; +use vnt::cipher::CipherModel; +use vnt::compression::Compressor; +use vnt::core::Config; + +#[derive(Serialize, Deserialize, Debug)] +#[serde(default)] +pub struct FileConfig { + #[cfg(target_os = "windows")] + pub tap: bool, + pub token: String, + pub device_id: String, + pub name: String, + pub server_address: String, + pub stun_server: Vec, + pub dns: Vec, + pub in_ips: Vec, + pub out_ips: Vec, + pub password: Option, + pub mtu: Option, + pub tcp: bool, + pub ip: Option, + pub use_channel: String, + #[cfg(feature = "ip_proxy")] + pub no_proxy: bool, + pub server_encrypt: bool, + pub parallel: usize, + pub cipher_model: Option, + pub finger: bool, + pub punch_model: String, + pub ports: Option>, + pub cmd: bool, + pub first_latency: bool, + pub device_name: Option, + pub packet_loss: Option, + pub packet_delay: u32, + #[cfg(feature = "port_mapping")] + pub mapping: Vec, + pub compressor: Option, +} + +impl Default for FileConfig { + fn default() -> Self { + Self { + #[cfg(target_os = "windows")] + tap: false, + token: "".to_string(), + device_id: get_device_id(), + name: os_info::get().to_string(), + server_address: "nat1.wherewego.top:29872".to_string(), + stun_server: vec![ + "stun1.l.google.com:19302".to_string(), + "stun2.l.google.com:19302".to_string(), + "stun.miwifi.com:3478".to_string(), + ], + dns: vec![], + in_ips: vec![], + out_ips: vec![], + password: None, + mtu: None, + tcp: false, + ip: None, + use_channel: "all".to_string(), + #[cfg(feature = "ip_proxy")] + no_proxy: false, + server_encrypt: false, + parallel: 1, + cipher_model: None, + finger: false, + punch_model: "all".to_string(), + ports: None, + cmd: false, + first_latency: false, + device_name: None, + packet_loss: None, + packet_delay: 0, + #[cfg(feature = "port_mapping")] + mapping: vec![], + compressor: None, + } + } +} + +pub fn read_config(file_path: &str) -> anyhow::Result<(Config, bool)> { + let conf = std::fs::read_to_string(file_path)?; + let file_conf = match serde_yaml::from_str::(&conf) { + Ok(val) => val, + Err(e) => { + log::error!("{:?}", e); + return Err(anyhow!("{}", e)); + } + }; + if file_conf.token.is_empty() { + return Err(anyhow!("token is_empty")); + } + + let in_ips = match common::args_parse::ips_parse(&file_conf.in_ips) { + Ok(in_ips) => in_ips, + Err(e) => { + return Err(anyhow!("in_ips {:?} error:{}", &file_conf.in_ips, e)); + } + }; + let out_ips = match common::args_parse::out_ips_parse(&file_conf.out_ips) { + Ok(out_ips) => out_ips, + Err(e) => { + return Err(anyhow!("out_ips {:?} error:{}", &file_conf.out_ips, e)); + } + }; + let virtual_ip = match file_conf.ip.clone().map(|v| Ipv4Addr::from_str(&v)) { + None => None, + Some(r) => Some(r.map_err(|e| anyhow!("ip {:?} error:{}", &file_conf.ip, e))?), + }; + let cipher_model = { + #[cfg(not(any(feature = "aes_gcm", feature = "server_encrypt")))] + if file_conf.password.is_some() && file_conf.cipher_model.is_none() { + Err(anyhow!("cipher_model undefined"))? + } + #[cfg(not(any( + feature = "aes_gcm", + feature = "server_encrypt", + feature = "aes_cbc", + feature = "aes_ecb", + feature = "sm4_cbc" + )))] + { + CipherModel::None + } + #[cfg(any(feature = "aes_gcm", feature = "server_encrypt"))] + CipherModel::AesGcm + }; + + let punch_model = PunchModel::from_str(&file_conf.punch_model).map_err(|e| anyhow!("{}", e))?; + let use_channel_type = + UseChannelType::from_str(&file_conf.use_channel).map_err(|e| anyhow!("{}", e))?; + let compressor = if let Some(compressor) = file_conf.compressor.as_ref() { + Compressor::from_str(compressor).map_err(|e| anyhow!("{}", e))? + } else { + Compressor::None + }; + let config = Config::new( + #[cfg(target_os = "windows")] + file_conf.tap, + file_conf.token, + file_conf.device_id, + file_conf.name, + file_conf.server_address, + file_conf.dns, + file_conf.stun_server, + in_ips, + out_ips, + file_conf.password, + file_conf.mtu, + file_conf.tcp, + virtual_ip, + #[cfg(feature = "ip_proxy")] + file_conf.no_proxy, + file_conf.server_encrypt, + file_conf.parallel, + cipher_model, + file_conf.finger, + punch_model, + file_conf.ports, + file_conf.first_latency, + file_conf.device_name, + use_channel_type, + file_conf.packet_loss, + file_conf.packet_delay, + #[cfg(feature = "port_mapping")] + file_conf.mapping, + compressor, + )?; + Ok((config, file_conf.cmd)) +} diff --git a/vnt-cli/src/main.rs b/vnt-cli/src/main.rs index 9ad1521..044cd68 100644 --- a/vnt-cli/src/main.rs +++ b/vnt-cli/src/main.rs @@ -1,3 +1,4 @@ +use anyhow::anyhow; use std::io; use std::net::Ipv4Addr; use std::path::PathBuf; @@ -10,6 +11,7 @@ use common::args_parse::{ips_parse, out_ips_parse}; use vnt::channel::punch::PunchModel; use vnt::channel::UseChannelType; use vnt::cipher::CipherModel; +use vnt::compression::Compressor; use vnt::core::{Config, Vnt}; #[cfg(feature = "command")] @@ -78,6 +80,7 @@ fn main() { opts.optmulti("", "dns", "dns", ""); opts.optmulti("", "mapping", "mapping", ""); opts.optopt("f", "", "配置文件", ""); + opts.optopt("", "compressor", "压缩算法", ""); //"后台运行时,查看其他设备列表" opts.optflag("", "list", "后台运行时,查看其他设备列表"); opts.optflag("", "all", "后台运行时,查看其他设备完整信息"); @@ -294,6 +297,13 @@ fn main() { .unwrap_or(0); #[cfg(feature = "port_mapping")] let port_mapping_list = matches.opt_strs("mapping"); + let compressor = if let Some(compressor) = matches.opt_str("compressor").as_ref() { + Compressor::from_str(compressor) + .map_err(|e| anyhow!("{}", e)) + .unwrap() + } else { + Compressor::None + }; let config = match Config::new( #[cfg(target_os = "windows")] tap, @@ -324,10 +334,11 @@ fn main() { packet_delay, #[cfg(feature = "port_mapping")] port_mapping_list, + compressor, ) { Ok(config) => config, Err(e) => { - println!("config error: {}", e); + println!("config.toml error: {}", e); return; } }; @@ -502,7 +513,14 @@ fn print_usage(program: &str, _opts: Options) { println!(" --dns DNS服务器地址,可使用多个dns,不指定时使用系统解析"); #[cfg(feature = "port_mapping")] println!(" --mapping 端口映射,例如 --mapping udp:0.0.0.0:80->10.26.0.10:80 --mapping tcp:0.0.0.0:80->10.26.0.10:80"); - + #[cfg(all(feature = "lz4", feature = "zstd"))] + println!(" --compressor 启用压缩,可选值lz4/zstd<,level>,level为压缩级别,例如 --compressor lz4 或--compressor zstd,10"); + #[cfg(feature = "lz4")] + #[cfg(not(feature = "zstd"))] + println!(" --compressor 启用压缩,可选值lz4,例如 --compressor lz4"); + #[cfg(feature = "zstd")] + #[cfg(not(feature = "lz4"))] + println!(" --compressor 启用压缩,可选值zstd<,level>,level为压缩级别,例如 --compressor zstd,10"); println!(); #[cfg(feature = "command")] { diff --git a/vnt-jni/Cargo.toml b/vnt-jni/Cargo.toml index 339c3f5..7ac63b5 100644 --- a/vnt-jni/Cargo.toml +++ b/vnt-jni/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "vnt-jni" -version = "1.2.9" +version = "1.2.10" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/vnt-jni/src/config.rs b/vnt-jni/src/config.rs index 2a94cab..20ccf4b 100644 --- a/vnt-jni/src/config.rs +++ b/vnt-jni/src/config.rs @@ -7,6 +7,7 @@ use jni::JNIEnv; use vnt::channel::punch::PunchModel; use vnt::channel::UseChannelType; use vnt::cipher::CipherModel; +use vnt::compression::Compressor; use vnt::core::Config; use crate::utils::*; @@ -118,6 +119,7 @@ pub fn new_config(env: &mut JNIEnv, config: JObject) -> Result { packet_loss_rate, packet_delay, port_mapping, + Compressor::None, ) { Ok(config) => config, Err(e) => { diff --git a/vnt/Cargo.toml b/vnt/Cargo.toml index a58f0a4..fd8954e 100644 --- a/vnt/Cargo.toml +++ b/vnt/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "vnt" -version = "1.2.9" +version = "1.2.10" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -37,7 +37,8 @@ dns-parser = "0.8.0" tokio = { version = "1.37.0", features = ["full"], optional = true } - +lz4_flex = { version = "0.11", default-features = false, optional = true } +zstd = { version = "0.13.1", optional = true } [target.'cfg(target_os = "windows")'.dependencies] libloading = "0.8.0" @@ -47,7 +48,7 @@ protobuf-codegen = "3.2.0" protoc-bin-vendored = "3.0.0" [features] -default = ["server_encrypt", "aes_gcm", "aes_cbc", "aes_ecb", "sm4_cbc", "ip_proxy", "port_mapping"] +default = ["server_encrypt", "aes_gcm", "aes_cbc", "aes_ecb", "sm4_cbc", "ip_proxy", "port_mapping", "lz4_compress","zstd_compress"] openssl = ["openssl-sys"] # 从源码编译 openssl-vendored = ["openssl-sys/vendored"] @@ -59,3 +60,5 @@ aes_gcm = ["aes-gcm"] server_encrypt = ["aes-gcm", "rsa", "spki"] ip_proxy = ["tokio"] port_mapping = ["tokio"] +lz4_compress = ["lz4_flex"] +zstd_compress = ["zstd"] diff --git a/vnt/src/channel/handler.rs b/vnt/src/channel/handler.rs index d9e8c0e..4715fb9 100644 --- a/vnt/src/channel/handler.rs +++ b/vnt/src/channel/handler.rs @@ -2,5 +2,11 @@ use crate::channel::context::ChannelContext; use crate::channel::RouteKey; pub trait RecvChannelHandler: Clone + Send + 'static { - fn handle(&mut self, buf: &mut [u8], route_key: RouteKey, context: &ChannelContext); + fn handle( + &mut self, + buf: &mut [u8], + extend: &mut [u8], + route_key: RouteKey, + context: &ChannelContext, + ); } diff --git a/vnt/src/channel/mod.rs b/vnt/src/channel/mod.rs index 5373f48..7536edb 100644 --- a/vnt/src/channel/mod.rs +++ b/vnt/src/channel/mod.rs @@ -18,7 +18,7 @@ pub mod sender; pub mod tcp_channel; pub mod udp_channel; -const BUFFER_SIZE: usize = 1024 * 16; +pub const BUFFER_SIZE: usize = 1024 * 16; #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub enum UseChannelType { Relay, diff --git a/vnt/src/channel/tcp_channel.rs b/vnt/src/channel/tcp_channel.rs index 5628937..07dbde4 100644 --- a/vnt/src/channel/tcp_channel.rs +++ b/vnt/src/channel/tcp_channel.rs @@ -87,6 +87,7 @@ where let mut read_map: HashMap, usize)> = HashMap::with_capacity(32); + let mut extend = [0; BUFFER_SIZE]; loop { poll.poll(&mut events, None)?; for event in events.iter() { @@ -132,9 +133,13 @@ where } token => { if event.is_readable() { - if let Err(e) = - readable_handle(&token, &mut read_map, &mut recv_handler, &context) - { + if let Err(e) = readable_handle( + &token, + &mut read_map, + &mut recv_handler, + &context, + &mut extend, + ) { closed_handle_r(&token, &mut read_map); log::warn!("{:?}", e); if let Err(e) = write_waker.notify(token, false) { @@ -339,6 +344,7 @@ fn readable_handle( map: &mut HashMap, usize)>, recv_handler: &mut H, context: &ChannelContext, + extend: &mut [u8], ) -> io::Result<()> where H: RecvChannelHandler, @@ -360,7 +366,7 @@ where } *begin += len; if end > 4 && *begin == end { - recv_handler.handle(&mut buf[4..end], *route_key, context); + recv_handler.handle(&mut buf[4..end], extend, *route_key, context); *begin = 0; } } diff --git a/vnt/src/channel/udp_channel.rs b/vnt/src/channel/udp_channel.rs index 3f3b1ec..b5adbd7 100644 --- a/vnt/src/channel/udp_channel.rs +++ b/vnt/src/channel/udp_channel.rs @@ -70,6 +70,7 @@ where { let mut events = Events::with_capacity(1024); let mut buf = [0; BUFFER_SIZE]; + let mut extend = [0; BUFFER_SIZE]; let mut read_map: HashMap = HashMap::with_capacity(32); loop { poll.poll(&mut events, None)?; @@ -115,6 +116,7 @@ where Ok((len, addr)) => { recv_handler.handle( &mut buf[..len], + &mut extend, RouteKey::new(false, token.0, addr), &context, ); @@ -252,6 +254,7 @@ where } let mut events = Events::with_capacity(udps.len()); + let mut extend = [0; BUFFER_SIZE]; loop { poll.poll(&mut events, None)?; for x in events.iter() { @@ -270,6 +273,7 @@ where Ok((len, addr)) => { recv_handler.handle( &mut buf[..len], + &mut extend, RouteKey::new(false, index, addr), &context, ); diff --git a/vnt/src/compression/lz4_compress.rs b/vnt/src/compression/lz4_compress.rs new file mode 100644 index 0000000..6c28abc --- /dev/null +++ b/vnt/src/compression/lz4_compress.rs @@ -0,0 +1,33 @@ +use anyhow::anyhow; + +use crate::protocol::NetPacket; + +#[derive(Clone)] +pub struct Lz4Compressor; + +impl Lz4Compressor { + pub fn compress, O: AsRef<[u8]> + AsMut<[u8]>>( + in_net_packet: &NetPacket, + out: &mut NetPacket, + ) -> anyhow::Result<()> { + out.set_data_len_max(); + let len = match lz4_flex::compress_into(in_net_packet.payload(), out.payload_mut()) { + Ok(len) => len, + Err(e) => Err(anyhow!("Lz4 compress {}", e))?, + }; + out.set_payload_len(len)?; + Ok(()) + } + pub fn decompress, O: AsRef<[u8]> + AsMut<[u8]>>( + in_net_packet: &NetPacket, + out: &mut NetPacket, + ) -> anyhow::Result<()> { + out.set_data_len_max(); + let len = match lz4_flex::decompress_into(in_net_packet.payload(), out.payload_mut()) { + Ok(len) => len, + Err(e) => Err(anyhow!("Lz4 decompress {}", e))?, + }; + out.set_payload_len(len)?; + Ok(()) + } +} diff --git a/vnt/src/compression/mod.rs b/vnt/src/compression/mod.rs new file mode 100644 index 0000000..797ed42 --- /dev/null +++ b/vnt/src/compression/mod.rs @@ -0,0 +1,218 @@ +use std::str::FromStr; + +use anyhow::anyhow; + +#[cfg(feature = "lz4_compress")] +use crate::compression::lz4_compress::Lz4Compressor; +#[cfg(feature = "zstd_compress")] +use crate::compression::zstd_compress::ZstdCompressor; +use crate::protocol::extension::CompressionAlgorithm; +#[cfg(feature = "zstd_compress")] +use zstd::zstd_safe::CompressionLevel; + +use crate::protocol::NetPacket; + +#[cfg(feature = "lz4_compress")] +mod lz4_compress; +#[cfg(feature = "zstd_compress")] +mod zstd_compress; + +#[derive(Clone, Copy, Debug)] +pub enum Compressor { + #[cfg(feature = "lz4_compress")] + Lz4, + #[cfg(feature = "zstd_compress")] + Zstd(CompressionLevel), + None, +} + +impl FromStr for Compressor { + type Err = String; + #[cfg(not(any(feature = "lz4_compress", feature = "zstd_compress")))] + fn from_str(s: &str) -> Result { + Err(format!("not match '{}', Compression not supported", s)) + } + #[cfg(any(feature = "lz4_compress", feature = "zstd_compress"))] + fn from_str(s: &str) -> Result { + let str = s.trim().to_lowercase(); + match str.as_str() { + #[cfg(feature = "lz4_compress")] + "lz4" => Ok(Compressor::Lz4), + #[cfg(feature = "zstd_compress")] + "zstd" => Ok(Compressor::Zstd(9)), + "none" => Ok(Compressor::None), + _ => { + #[cfg(feature = "zstd_compress")] + { + let string_array: Vec = str.split(',').map(|s| s.to_string()).collect(); + if string_array.len() != 2 || string_array[0] != "zstd" { + return Err(format!("not match '{}', exp: zstd,10", s)); + } + return match CompressionLevel::from_str(&string_array[1]) { + Ok(level) => Ok(Compressor::Zstd(level)), + Err(_) => Err(format!("not match '{}', exp: zstd,10", s)), + }; + } + #[cfg(not(feature = "zstd_compress"))] + #[cfg(feature = "lz4_compress")] + return Err(format!("not match '{}', exp: lz4", s)); + } + } + } +} + +#[cfg(not(any(feature = "lz4_compress", feature = "zstd_compress")))] +impl Compressor { + pub fn compress, O: AsRef<[u8]> + AsMut<[u8]>>( + &self, + _in_net_packet: &NetPacket, + _out: &mut NetPacket, + ) -> anyhow::Result { + Ok(false) + } + pub fn decompress, O: AsRef<[u8]> + AsMut<[u8]>>( + _algorithm: CompressionAlgorithm, + _in_net_packet: &NetPacket, + _out: &mut NetPacket, + ) -> anyhow::Result<()> { + Err(anyhow!("Unsupported decompress")) + } +} + +#[cfg(any(feature = "lz4_compress", feature = "zstd_compress"))] +impl Compressor { + pub fn compress, O: AsRef<[u8]> + AsMut<[u8]>>( + &self, + in_net_packet: &NetPacket, + out: &mut NetPacket, + ) -> anyhow::Result { + match self { + #[cfg(feature = "lz4_compress")] + Compressor::Lz4 => { + if in_net_packet.data_len() < 128 { + return Ok(false); + } + Lz4Compressor::compress(in_net_packet, out)?; + let mut compression_extension_tail = out.append_compression_extension_tail()?; + compression_extension_tail.set_algorithm(CompressionAlgorithm::Lz4); + //压缩没效果,则放弃压缩 + if out.data_len() >= in_net_packet.data_len() - 16 { + return Ok(false); + } + return Ok(true); + } + #[cfg(feature = "zstd_compress")] + Compressor::Zstd(level) => { + if in_net_packet.data_len() < 128 { + return Ok(false); + } + ZstdCompressor::compress(*level, in_net_packet, out)?; + let mut compression_extension_tail = out.append_compression_extension_tail()?; + compression_extension_tail.set_algorithm(CompressionAlgorithm::Zstd); + //压缩没效果,则放弃压缩 + if out.data_len() >= in_net_packet.data_len() - 16 { + return Ok(false); + } + return Ok(true); + } + Compressor::None => {} + } + Ok(false) + } + pub fn decompress, O: AsRef<[u8]> + AsMut<[u8]>>( + algorithm: CompressionAlgorithm, + in_net_packet: &NetPacket, + out: &mut NetPacket, + ) -> anyhow::Result<()> { + match algorithm { + #[cfg(feature = "lz4_compress")] + CompressionAlgorithm::Lz4 => Lz4Compressor::decompress(in_net_packet, out), + #[cfg(feature = "zstd_compress")] + CompressionAlgorithm::Zstd => ZstdCompressor::decompress(in_net_packet, out), + _ => Err(anyhow!("Unknown decompress {:?}", algorithm)), + } + } +} + +#[test] +fn test_lz4() { + use crate::protocol::extension::{CompressionAlgorithm, ExtensionTailPacket}; + let lz4 = Compressor::Lz4; + let in_packet = NetPacket::new([ + 65, 108, 105, 99, 101, 32, 119, 97, 116, 32, 98, 101, 103, 105, 110, 110, 105, 110, 103, + 32, 116, 111, 32, 103, 101, 116, 32, 118, 101, 114, 121, 32, 116, 105, 114, 101, 100, 32, + 111, 102, 32, 115, 105, 116, 116, 105, 110, 103, 32, 98, 121, 32, 104, 101, 114, 32, 115, + 105, 115, 116, 101, 114, 32, 111, 110, 32, 116, 104, 101, 32, 98, 97, 110, 107, 44, 32, 97, + 110, 100, 32, 111, 102, 32, 104, 97, 118, 105, 110, 103, 32, 110, 111, 116, 104, 105, 110, + 103, 32, 116, 111, 32, 100, 111, 58, 32, 111, 110, 99, 101, 32, 111, 114, 32, 116, 119, + 105, 99, 101, 32, 115, 104, 101, 32, 104, 97, 100, 32, 112, 101, 101, 112, 101, 100, 32, + 105, 110, 116, 111, 32, 116, 104, 101, 32, 98, 111, 111, 107, 32, 104, 101, 114, 32, 115, + 105, 115, 116, 101, 114, 32, 119, 97, 115, 32, 114, 101, 97, 100, 105, 110, 103, 44, 32, + 98, 117, 116, 32, 105, 116, 32, 104, 97, 100, 32, 110, 111, 32, 112, 105, 99, 116, 117, + 114, 101, 115, 32, 111, 114, 32, 99, 111, 110, 118, 101, 114, 115, 97, 116, 105, + ]) + .unwrap(); + let mut out_packet = NetPacket::new([0; 1000]).unwrap(); + let mut src_out_packet = NetPacket::new([0; 1000]).unwrap(); + lz4.compress(&in_packet, &mut out_packet).unwrap(); + let tail = out_packet.split_tail_packet().unwrap(); + match tail { + ExtensionTailPacket::Compression(c) => match c.algorithm() { + CompressionAlgorithm::Lz4 => { + Compressor::decompress(CompressionAlgorithm::Lz4, &out_packet, &mut src_out_packet) + .unwrap(); + } + _ => { + unimplemented!() + } + }, + ExtensionTailPacket::Unknown => { + unimplemented!() + } + } + assert!(!out_packet.is_extension()); + assert_eq!(in_packet.payload(), src_out_packet.payload()) +} +#[test] +fn test_zstd() { + use crate::protocol::extension::{CompressionAlgorithm, ExtensionTailPacket}; + let zstd = Compressor::Zstd(22); + let in_packet = NetPacket::new([ + 65, 108, 105, 99, 101, 32, 119, 97, 115, 32, 98, 101, 103, 105, 110, 110, 105, 110, 103, + 32, 116, 111, 32, 103, 101, 116, 32, 118, 101, 114, 121, 32, 116, 105, 114, 101, 100, 32, + 111, 102, 32, 115, 105, 116, 116, 105, 110, 103, 32, 98, 121, 32, 104, 101, 114, 32, 115, + 105, 115, 116, 101, 114, 32, 111, 110, 32, 116, 104, 101, 32, 98, 97, 110, 107, 44, 32, 97, + 110, 100, 32, 111, 102, 32, 104, 97, 118, 105, 110, 103, 32, 110, 111, 116, 104, 105, 110, + 103, 32, 116, 111, 32, 100, 111, 58, 32, 111, 110, 99, 101, 32, 111, 114, 32, 116, 119, + 105, 99, 101, 32, 115, 104, 101, 32, 104, 97, 100, 32, 112, 101, 101, 112, 101, 100, 32, + 105, 110, 116, 111, 32, 116, 104, 101, 32, 98, 111, 111, 107, 32, 104, 101, 114, 32, 115, + 105, 115, 116, 101, 114, 32, 119, 97, 115, 32, 114, 101, 97, 100, 105, 110, 103, 44, 32, + 98, 117, 116, 32, 105, 116, 32, 104, 97, 100, 32, 110, 111, 32, 112, 105, 99, 116, 117, + 114, 101, 115, 32, 111, 114, 32, 99, 111, 110, 118, 101, 114, 115, 97, 116, 105, + ]) + .unwrap(); + let mut out_packet = NetPacket::new([0; 1000]).unwrap(); + let mut src_out_packet = NetPacket::new([0; 1000]).unwrap(); + zstd.compress(&in_packet, &mut out_packet).unwrap(); + let tail = out_packet.split_tail_packet().unwrap(); + match tail { + ExtensionTailPacket::Compression(c) => match c.algorithm() { + CompressionAlgorithm::Zstd => { + Compressor::decompress( + CompressionAlgorithm::Zstd, + &out_packet, + &mut src_out_packet, + ) + .unwrap(); + } + _ => { + unimplemented!() + } + }, + ExtensionTailPacket::Unknown => { + unimplemented!() + } + } + assert!(!out_packet.is_extension()); + assert_eq!(in_packet.payload(), src_out_packet.payload()) +} diff --git a/vnt/src/compression/zstd_compress.rs b/vnt/src/compression/zstd_compress.rs new file mode 100644 index 0000000..d647ecb --- /dev/null +++ b/vnt/src/compression/zstd_compress.rs @@ -0,0 +1,38 @@ +use crate::protocol::NetPacket; +use anyhow::anyhow; +use zstd::zstd_safe::CompressionLevel; + +#[derive(Clone)] +pub struct ZstdCompressor; + +impl ZstdCompressor { + pub fn compress, O: AsRef<[u8]> + AsMut<[u8]>>( + compression_level: CompressionLevel, + in_net_packet: &NetPacket, + out: &mut NetPacket, + ) -> anyhow::Result<()> { + out.set_data_len_max(); + let len = match zstd::zstd_safe::compress( + out.payload_mut(), + in_net_packet.payload(), + compression_level, + ) { + Ok(len) => len, + Err(e) => Err(anyhow!("zstd compress {}", e))?, + }; + out.set_payload_len(len)?; + Ok(()) + } + pub fn decompress, O: AsRef<[u8]> + AsMut<[u8]>>( + in_net_packet: &NetPacket, + out: &mut NetPacket, + ) -> anyhow::Result<()> { + out.set_data_len_max(); + let len = match zstd::zstd_safe::decompress(out.payload_mut(), in_net_packet.payload()) { + Ok(len) => len, + Err(e) => Err(anyhow!("zstd decompress {}", e))?, + }; + out.set_payload_len(len)?; + Ok(()) + } +} diff --git a/vnt/src/core/conn.rs b/vnt/src/core/conn.rs index a380c50..bd80078 100644 --- a/vnt/src/core/conn.rs +++ b/vnt/src/core/conn.rs @@ -47,7 +47,7 @@ pub struct Vnt { impl Vnt { pub fn new(config: Config, callback: Call) -> anyhow::Result { - log::info!("config:{:?}", config); + log::info!("config.toml:{:?}", config); //服务端非对称加密 #[cfg(feature = "server_encrypt")] let rsa_cipher: Arc>> = Arc::new(Mutex::new(None)); @@ -136,7 +136,7 @@ impl Vnt { let device = tun_tap_device::create_device(&config)?; log::info!("创建tun成功"); let tun_info = DeviceInfo::new(device.name()?, device.version()?); - log::info!("tun信息{:?}",tun_info); + log::info!("tun信息{:?}", tun_info); callback.create_tun(tun_info); device }; @@ -180,6 +180,7 @@ impl Vnt { config.parallel, up_counter, device_list.clone(), + config.compressor, ); #[cfg(any(target_os = "windows", target_os = "linux", target_os = "macos"))] let device_adapter = DeviceAdapter::new(device.clone()); diff --git a/vnt/src/core/mod.rs b/vnt/src/core/mod.rs index 9ade95b..d64248c 100644 --- a/vnt/src/core/mod.rs +++ b/vnt/src/core/mod.rs @@ -7,6 +7,7 @@ pub use conn::Vnt; use crate::channel::punch::PunchModel; use crate::channel::UseChannelType; use crate::cipher::CipherModel; +use crate::compression::Compressor; use crate::util::{address_choose, dns_query_all}; mod conn; @@ -46,6 +47,7 @@ pub struct Config { // 端口映射 #[cfg(feature = "port_mapping")] pub port_mapping_list: Vec<(bool, SocketAddr, String)>, + pub compressor: Compressor, } impl Config { @@ -77,6 +79,7 @@ impl Config { packet_delay: u32, // 例如 [udp:127.0.0.1:80->10.26.0.10:8080,tcp:127.0.0.1:80->10.26.0.10:8080] #[cfg(feature = "port_mapping")] port_mapping_list: Vec, + compressor: Compressor, ) -> anyhow::Result { for x in stun_server.iter_mut() { if !x.contains(":") { @@ -140,6 +143,7 @@ impl Config { packet_delay, #[cfg(feature = "port_mapping")] port_mapping_list, + compressor, }) } } diff --git a/vnt/src/handle/extension/mod.rs b/vnt/src/handle/extension/mod.rs new file mode 100644 index 0000000..45e13cb --- /dev/null +++ b/vnt/src/handle/extension/mod.rs @@ -0,0 +1,24 @@ +use crate::compression::Compressor; +use crate::protocol::extension::ExtensionTailPacket; +use crate::protocol::NetPacket; +use anyhow::anyhow; + +pub fn handle_extension_tail + AsMut<[u8]>, O: AsRef<[u8]> + AsMut<[u8]>>( + in_net_packet: &mut NetPacket, + out: &mut NetPacket, +) -> anyhow::Result { + if in_net_packet.is_extension() { + let tail_packet = in_net_packet.split_tail_packet()?; + match tail_packet { + ExtensionTailPacket::Compression(extension) => { + let compression_algorithm = extension.algorithm(); + Compressor::decompress(compression_algorithm, &in_net_packet, out)?; + out.head_mut().copy_from_slice(in_net_packet.head()); + Ok(true) + } + ExtensionTailPacket::Unknown => Err(anyhow!("Unknown decompress")), + } + } else { + Ok(false) + } +} diff --git a/vnt/src/handle/mod.rs b/vnt/src/handle/mod.rs index fc0c10d..e59631d 100644 --- a/vnt/src/handle/mod.rs +++ b/vnt/src/handle/mod.rs @@ -2,6 +2,7 @@ use crossbeam_utils::atomic::AtomicCell; use std::net::{Ipv4Addr, SocketAddr}; pub mod callback; +mod extension; pub mod handshaker; pub mod maintain; pub mod recv_data; diff --git a/vnt/src/handle/recv_data/client.rs b/vnt/src/handle/recv_data/client.rs index ee748c6..70e8f99 100644 --- a/vnt/src/handle/recv_data/client.rs +++ b/vnt/src/handle/recv_data/client.rs @@ -14,6 +14,7 @@ use crate::channel::punch::NatInfo; use crate::channel::{Route, RouteKey}; use crate::cipher::Cipher; use crate::external_route::AllowExternalRoute; +use crate::handle::extension::handle_extension_tail; use crate::handle::maintain::PunchSender; use crate::handle::recv_data::PacketHandler; use crate::handle::CurrentDeviceInfo; @@ -70,14 +71,26 @@ impl PacketHandler for ClientPacketHandler { fn handle( &self, mut net_packet: NetPacket<&mut [u8]>, + mut extend: NetPacket<&mut [u8]>, route_key: RouteKey, context: &ChannelContext, current_device: &CurrentDeviceInfo, - ) -> io::Result<()> { + ) -> anyhow::Result<()> { self.client_cipher.decrypt_ipv4(&mut net_packet)?; context .route_table .update_read_time(&net_packet.source(), &route_key); + //处理扩展 + let net_packet = if net_packet.is_extension() { + //这样重用数组,减少一次数据拷贝 + if handle_extension_tail(&mut net_packet, &mut extend)? { + extend + } else { + net_packet + } + } else { + net_packet + }; match net_packet.protocol() { Protocol::Service => {} Protocol::Error => {} diff --git a/vnt/src/handle/recv_data/mod.rs b/vnt/src/handle/recv_data/mod.rs index b644eda..8c2bdea 100644 --- a/vnt/src/handle/recv_data/mod.rs +++ b/vnt/src/handle/recv_data/mod.rs @@ -1,7 +1,7 @@ use std::collections::HashMap; use std::net::Ipv4Addr; use std::sync::Arc; -use std::{io, thread}; +use std::thread; use crossbeam_utils::atomic::AtomicCell; use parking_lot::{Mutex, RwLock}; @@ -43,7 +43,13 @@ pub struct RecvDataHandler { } impl RecvChannelHandler for RecvDataHandler { - fn handle(&mut self, buf: &mut [u8], route_key: RouteKey, context: &ChannelContext) { + fn handle( + &mut self, + buf: &mut [u8], + extend: &mut [u8], + route_key: RouteKey, + context: &ChannelContext, + ) { //判断stun响应包 if !route_key.is_tcp() { if let Ok(rs) = self @@ -55,7 +61,7 @@ impl RecvChannelHandler for RecvDataHandler { } } } - if let Err(e) = self.handle0(buf, route_key, context) { + if let Err(e) = self.handle0(buf, extend, route_key, context) { log::error!("[{}]-{:?}", thread::current().name().unwrap_or(""), e); } } @@ -116,12 +122,14 @@ impl RecvDataHandler { fn handle0( &mut self, buf: &mut [u8], + extend: &mut [u8], route_key: RouteKey, context: &ChannelContext, - ) -> io::Result<()> { + ) -> anyhow::Result<()> { // 统计流量 self.counter.add(buf.len() as _); let net_packet = NetPacket::new(buf)?; + let extend = NetPacket::unchecked(extend); if net_packet.ttl() == 0 || net_packet.source_ttl() < net_packet.ttl() { log::warn!("丢弃过时包:{:?}", net_packet.head()); return Ok(()); @@ -139,16 +147,16 @@ impl RecvDataHandler { if net_packet.is_gateway() { //服务端-客户端包 self.server - .handle(net_packet, route_key, context, ¤t_device) + .handle(net_packet, extend, route_key, context, ¤t_device) } else { //客户端-客户端包 self.client - .handle(net_packet, route_key, context, ¤t_device) + .handle(net_packet, extend, route_key, context, ¤t_device) } } else { //转发包 self.turn - .handle(net_packet, route_key, context, ¤t_device) + .handle(net_packet, extend, route_key, context, ¤t_device) } } } @@ -157,8 +165,9 @@ pub trait PacketHandler { fn handle( &self, net_packet: NetPacket<&mut [u8]>, + extend: NetPacket<&mut [u8]>, route_key: RouteKey, context: &ChannelContext, current_device: &CurrentDeviceInfo, - ) -> io::Result<()>; + ) -> anyhow::Result<()>; } diff --git a/vnt/src/handle/recv_data/server.rs b/vnt/src/handle/recv_data/server.rs index 877ecd9..c42ec48 100644 --- a/vnt/src/handle/recv_data/server.rs +++ b/vnt/src/handle/recv_data/server.rs @@ -94,10 +94,11 @@ impl PacketHandler for ServerPacketHandler { fn handle( &self, mut net_packet: NetPacket<&mut [u8]>, + _extend: NetPacket<&mut [u8]>, route_key: RouteKey, context: &ChannelContext, current_device: &CurrentDeviceInfo, - ) -> io::Result<()> { + ) -> anyhow::Result<()> { context .route_table .update_read_time(&net_packet.source(), &route_key); diff --git a/vnt/src/handle/recv_data/turn.rs b/vnt/src/handle/recv_data/turn.rs index fa7e814..f4d70ff 100644 --- a/vnt/src/handle/recv_data/turn.rs +++ b/vnt/src/handle/recv_data/turn.rs @@ -3,6 +3,7 @@ use crate::channel::RouteKey; use crate::handle::recv_data::PacketHandler; use crate::handle::CurrentDeviceInfo; use crate::protocol::NetPacket; +use anyhow::Context; /// 处理客户端中转包 #[derive(Clone)] @@ -18,10 +19,11 @@ impl PacketHandler for TurnPacketHandler { fn handle( &self, mut net_packet: NetPacket<&mut [u8]>, + _extend: NetPacket<&mut [u8]>, route_key: RouteKey, context: &ChannelContext, _current_device: &CurrentDeviceInfo, - ) -> std::io::Result<()> { + ) -> anyhow::Result<()> { // ttl减一 let ttl = net_packet.incr_ttl(); if ttl > 0 { @@ -33,7 +35,9 @@ impl PacketHandler for TurnPacketHandler { return Ok(()); } if route.metric <= ttl { - return context.send_by_key(net_packet.buffer(), route.route_key()); + return context + .send_by_key(net_packet.buffer(), route.route_key()) + .context("转发失败"); } } //其他没有路由的不转发 diff --git a/vnt/src/handle/tun_tap/tun_handler.rs b/vnt/src/handle/tun_tap/tun_handler.rs index 94e1bd7..f5217be 100644 --- a/vnt/src/handle/tun_tap/tun_handler.rs +++ b/vnt/src/handle/tun_tap/tun_handler.rs @@ -5,6 +5,7 @@ use std::{io, thread}; use crossbeam_utils::atomic::AtomicCell; use parking_lot::Mutex; +use crate::channel::BUFFER_SIZE; use packet::icmp::icmp::IcmpPacket; use packet::icmp::Kind; use packet::ip::ipv4::packet::IpV4Packet; @@ -14,6 +15,7 @@ use tun::Device; use crate::channel::context::ChannelContext; use crate::cipher::Cipher; +use crate::compression::Compressor; use crate::external_route::ExternalRoute; use crate::handle::tun_tap::channel_group::channel_group; use crate::handle::{check_dest, CurrentDeviceInfo, PeerDeviceInfo}; @@ -27,7 +29,7 @@ use crate::protocol::ip_turn_packet::BroadcastPacket; use crate::protocol::{ip_turn_packet, NetPacket, MAX_TTL}; use crate::util::{SingleU64Adder, StopManager}; -fn icmp(device_writer: &Device, mut ipv4_packet: IpV4Packet<&mut [u8]>) -> io::Result<()> { +fn icmp(device_writer: &Device, mut ipv4_packet: IpV4Packet<&mut [u8]>) -> anyhow::Result<()> { if ipv4_packet.protocol() == Protocol::Icmp { let mut icmp = IcmpPacket::new(ipv4_packet.payload_mut())?; if icmp.kind() == Kind::EchoRequest { @@ -48,6 +50,7 @@ pub(crate) fn handle( context: &ChannelContext, data: &mut [u8], len: usize, + extend: &mut [u8], device_writer: &Device, current_device: CurrentDeviceInfo, ip_route: &ExternalRoute, @@ -55,7 +58,8 @@ pub(crate) fn handle( client_cipher: &Cipher, server_cipher: &Cipher, device_list: &Mutex<(u16, Vec)>, -) -> io::Result<()> { + compressor: &Compressor, +) -> anyhow::Result<()> { //忽略掉结构不对的情况(ipv6数据、win tap会读到空数据),不然日志打印太多了 let ipv4_packet = match IpV4Packet::new(&mut data[12..len]) { Ok(packet) => packet, @@ -70,6 +74,7 @@ pub(crate) fn handle( context, data, len, + extend, current_device, ip_route, #[cfg(feature = "ip_proxy")] @@ -77,6 +82,7 @@ pub(crate) fn handle( client_cipher, server_cipher, device_list, + compressor, ); } @@ -92,6 +98,7 @@ pub fn start( parallel: usize, mut up_counter: SingleU64Adder, device_list: Arc)>>, + compressor: Compressor, ) -> io::Result<()> { if parallel > 1 { let (sender, receivers) = channel_group::<(Vec, usize)>(parallel, 16); @@ -108,6 +115,7 @@ pub fn start( thread::Builder::new() .name(format!("tunHandler-{}", index)) .spawn(move || { + let mut extend = [0; BUFFER_SIZE]; while let Ok((mut buf, len)) = receiver.recv() { #[cfg(not(target_os = "macos"))] let start = 0; @@ -117,6 +125,7 @@ pub fn start( &context, &mut buf[start..], len, + &mut extend, &device, current_device.load(), &ip_route, @@ -125,6 +134,7 @@ pub fn start( &client_cipher, &server_cipher, &device_list, + &compressor, ) { Ok(_) => {} Err(e) => { @@ -162,6 +172,7 @@ pub fn start( server_cipher, &mut up_counter, device_list, + compressor, ) { log::warn!("stop:{}", e); } @@ -261,18 +272,21 @@ fn base_handle( context: &ChannelContext, buf: &mut [u8], data_len: usize, //数据总长度=12+ip包长度 + extend: &mut [u8], current_device: CurrentDeviceInfo, ip_route: &ExternalRoute, #[cfg(feature = "ip_proxy")] proxy_map: &Option, client_cipher: &Cipher, server_cipher: &Cipher, device_list: &Mutex<(u16, Vec)>, -) -> io::Result<()> { + compressor: &Compressor, +) -> anyhow::Result<()> { let ipv4_packet = IpV4Packet::new(&buf[12..data_len])?; let protocol = ipv4_packet.protocol(); let src_ip = ipv4_packet.source_ip(); let mut dest_ip = ipv4_packet.destination_ip(); let mut net_packet = NetPacket::new0(data_len, buf)?; + let mut out = NetPacket::unchecked(extend); net_packet.set_default_version(); net_packet.set_protocol(protocol::Protocol::IpTurn); net_packet.set_transport_protocol(ip_turn_packet::Protocol::Ipv4.into()); @@ -288,6 +302,17 @@ fn base_handle( } return Ok(()); } + let mut net_packet = if compressor.compress(&net_packet, &mut out)? { + out.set_default_version(); + out.set_protocol(protocol::Protocol::IpTurn); + out.set_transport_protocol(ip_turn_packet::Protocol::Ipv4.into()); + out.first_set_ttl(6); + out.set_source(src_ip); + out.set_destination(dest_ip); + out + } else { + net_packet + }; if dest_ip.is_multicast() { //当作广播处理 dest_ip = Ipv4Addr::BROADCAST; @@ -333,5 +358,6 @@ fn base_handle( &dest_ip, current_device.connect_server, current_device.status.online(), - ) + )?; + Ok(()) } diff --git a/vnt/src/handle/tun_tap/unix.rs b/vnt/src/handle/tun_tap/unix.rs index a397623..09b0d4e 100644 --- a/vnt/src/handle/tun_tap/unix.rs +++ b/vnt/src/handle/tun_tap/unix.rs @@ -1,5 +1,7 @@ use crate::channel::context::ChannelContext; +use crate::channel::BUFFER_SIZE; use crate::cipher::Cipher; +use crate::compression::Compressor; use crate::external_route::ExternalRoute; use crate::handle::tun_tap::channel_group::GroupSyncSender; use crate::handle::{CurrentDeviceInfo, PeerDeviceInfo}; @@ -30,6 +32,7 @@ pub(crate) fn start_simple( server_cipher: Cipher, up_counter: &mut SingleU64Adder, device_list: Arc)>>, + compressor: Compressor, ) -> io::Result<()> { let poll = Poll::new()?; let waker = Arc::new(Waker::new(poll.registry(), STOP)?); @@ -49,6 +52,7 @@ pub(crate) fn start_simple( server_cipher, up_counter, device_list, + compressor, ) { log::error!("{:?}", e); }; @@ -68,8 +72,10 @@ fn start_simple0( server_cipher: Cipher, up_counter: &mut SingleU64Adder, device_list: Arc)>>, + compressor: Compressor, ) -> io::Result<()> { - let mut buf = [0; 1024 * 16]; + let mut buf = [0; BUFFER_SIZE]; + let mut extend = [0; BUFFER_SIZE]; let fd = device.as_tun_fd(); fd.set_nonblock()?; SourceFd(&fd.as_raw_fd()).register(poll.registry(), FD, Interest::READABLE)?; @@ -102,6 +108,7 @@ fn start_simple0( context, &mut buf, len, + &mut extend, &device, current_device.load(), &ip_route, @@ -110,6 +117,7 @@ fn start_simple0( &client_cipher, &server_cipher, &device_list, + &compressor, ) { Ok(_) => {} Err(e) => { diff --git a/vnt/src/handle/tun_tap/windows.rs b/vnt/src/handle/tun_tap/windows.rs index 01d317a..f1c86c7 100644 --- a/vnt/src/handle/tun_tap/windows.rs +++ b/vnt/src/handle/tun_tap/windows.rs @@ -1,5 +1,7 @@ use crate::channel::context::ChannelContext; +use crate::channel::BUFFER_SIZE; use crate::cipher::Cipher; +use crate::compression::Compressor; use crate::external_route::ExternalRoute; use crate::handle::tun_tap::channel_group::GroupSyncSender; use crate::handle::{CurrentDeviceInfo, PeerDeviceInfo}; @@ -24,6 +26,7 @@ pub(crate) fn start_simple( server_cipher: Cipher, up_counter: &mut SingleU64Adder, device_list: Arc)>>, + compressor: Compressor, ) -> io::Result<()> { let worker = { let device = device.clone(); @@ -44,6 +47,7 @@ pub(crate) fn start_simple( server_cipher, up_counter, device_list, + compressor, ) { log::error!("{:?}", e); } @@ -60,8 +64,10 @@ fn start_simple0( server_cipher: Cipher, up_counter: &mut SingleU64Adder, device_list: Arc)>>, + compressor: Compressor, ) -> io::Result<()> { - let mut buf = [0; 1024 * 16]; + let mut buf = [0; BUFFER_SIZE]; + let mut extend = [0; BUFFER_SIZE]; loop { let len = device.read(&mut buf[12..])? + 12; //单线程的 @@ -72,6 +78,7 @@ fn start_simple0( context, &mut buf, len, + &mut extend, &device, current_device.load(), &ip_route, @@ -80,6 +87,7 @@ fn start_simple0( &client_cipher, &server_cipher, &device_list, + &compressor, ) { Ok(_) => {} Err(e) => { diff --git a/vnt/src/lib.rs b/vnt/src/lib.rs index a159e14..0a07456 100644 --- a/vnt/src/lib.rs +++ b/vnt/src/lib.rs @@ -16,3 +16,4 @@ pub mod tun_tap_device; pub mod util; pub use handle::callback::*; +pub mod compression; diff --git a/vnt/src/nat/stun.rs b/vnt/src/nat/stun.rs index ef4a636..05b9a58 100644 --- a/vnt/src/nat/stun.rs +++ b/vnt/src/nat/stun.rs @@ -71,11 +71,15 @@ pub fn stun_test_nat0(stun_servers: Vec) -> io::Result<(NatType, Vec io::Result> { diff --git a/vnt/src/protocol/extension.rs b/vnt/src/protocol/extension.rs new file mode 100644 index 0000000..1e43e06 --- /dev/null +++ b/vnt/src/protocol/extension.rs @@ -0,0 +1,141 @@ +/* 扩展协议 + 0 15 31 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | 扩展数据(n) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | 扩展数据(n) | type(8) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + 注:扩展数据的长度由type决定 +*/ + +use anyhow::anyhow; +use std::io; + +use crate::protocol::NetPacket; + +#[derive(Eq, PartialEq, Copy, Clone, Debug)] +pub enum ExtensionTailType { + Compression, + Unknown(u8), +} + +impl From for ExtensionTailType { + fn from(value: u8) -> Self { + if value == 0 { + ExtensionTailType::Compression + } else { + ExtensionTailType::Unknown(value) + } + } +} + +pub enum ExtensionTailPacket { + Compression(CompressionExtensionTail), + Unknown, +} + +impl + AsMut<[u8]>> NetPacket { + /// 分离尾部数据 + pub fn split_tail_packet(&mut self) -> anyhow::Result> { + if self.is_extension() { + let payload = self.payload(); + if let Some(v) = payload.last() { + return match ExtensionTailType::from(*v) { + ExtensionTailType::Compression => { + let data_len = self.data_len - 4; + self.set_data_len(data_len)?; + self.set_extension_flag(false); + Ok(ExtensionTailPacket::Compression( + CompressionExtensionTail::new( + &self.raw_buffer()[data_len..data_len + 4], + ), + )) + } + ExtensionTailType::Unknown(e) => Err(anyhow!("unknown extension {}", e)), + }; + } + } + Err(anyhow!("not extension")) + } + /// 追加压缩扩展 + pub fn append_compression_extension_tail( + &mut self, + ) -> io::Result> { + let len = self.data_len; + //增加数据长度 + self.set_data_len(self.data_len + 4)?; + self.set_extension_flag(true); + let mut tail = CompressionExtensionTail::new(&mut self.buffer_mut()[len..]); + tail.init(); + return Ok(tail); + } +} + +/* 扩展协议 + 0 15 31 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | algorithm(8) | | type(8) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + 注:扩展数据的长度由type决定 +*/ +/// 压缩扩展 +pub struct CompressionExtensionTail { + buffer: B, +} + +impl> CompressionExtensionTail { + pub fn new(buffer: B) -> CompressionExtensionTail { + assert_eq!(buffer.as_ref().len(), 4); + CompressionExtensionTail { buffer } + } +} + +impl> CompressionExtensionTail { + pub fn algorithm(&self) -> CompressionAlgorithm { + self.buffer.as_ref()[0].into() + } +} + +impl + AsMut<[u8]>> CompressionExtensionTail { + pub fn init(&mut self) { + self.buffer.as_mut().fill(0); + } + pub fn set_algorithm(&mut self, algorithm: CompressionAlgorithm) { + self.buffer.as_mut()[0] = algorithm.into() + } +} + +#[derive(Eq, PartialEq, Copy, Clone, Debug)] +pub enum CompressionAlgorithm { + #[cfg(feature = "lz4_compress")] + Lz4, + #[cfg(feature = "zstd_compress")] + Zstd, + Unknown(u8), +} + +impl From for CompressionAlgorithm { + fn from(value: u8) -> Self { + match value { + #[cfg(feature = "lz4_compress")] + 1 => CompressionAlgorithm::Lz4, + #[cfg(feature = "zstd_compress")] + 2 => CompressionAlgorithm::Zstd, + v => CompressionAlgorithm::Unknown(v), + } + } +} + +impl From for u8 { + fn from(value: CompressionAlgorithm) -> Self { + match value { + #[cfg(feature = "lz4_compress")] + CompressionAlgorithm::Lz4 => 1, + #[cfg(feature = "zstd_compress")] + CompressionAlgorithm::Zstd => 2, + CompressionAlgorithm::Unknown(val) => val, + } + } +} diff --git a/vnt/src/protocol/mod.rs b/vnt/src/protocol/mod.rs index 19bedca..0a96cc3 100644 --- a/vnt/src/protocol/mod.rs +++ b/vnt/src/protocol/mod.rs @@ -21,6 +21,7 @@ pub const HEAD_LEN: usize = 12; pub mod body; pub mod control_packet; pub mod error_packet; +pub mod extension; pub mod ip_turn_packet; pub mod other_turn_packet; pub mod service_packet; @@ -101,6 +102,10 @@ pub struct NetPacket { } impl> NetPacket { + pub fn unchecked(buffer: B) -> Self { + let data_len = buffer.as_ref().len(); + Self { data_len, buffer } + } pub fn new(buffer: B) -> io::Result> { let data_len = buffer.as_ref().len(); Self::new0(data_len, buffer) @@ -158,6 +163,10 @@ impl> NetPacket { pub fn is_gateway(&self) -> bool { self.buffer.as_ref()[0] & 0x40 == 0x40 } + /// 扩展协议 + pub fn is_extension(&self) -> bool { + self.buffer.as_ref()[0] & 0x20 == 0x20 + } pub fn version(&self) -> Version { Version::from(self.buffer.as_ref()[0] & 0x0F) } @@ -190,6 +199,9 @@ impl> NetPacket { } impl + AsMut<[u8]>> NetPacket { + pub fn head_mut(&mut self) -> &mut [u8] { + &mut self.buffer.as_mut()[..12] + } pub fn buffer_mut(&mut self) -> &mut [u8] { &mut self.buffer.as_mut()[..self.data_len] } @@ -208,6 +220,13 @@ impl + AsMut<[u8]>> NetPacket { self.buffer.as_mut()[0] = self.buffer.as_ref()[0] & 0xBF }; } + pub fn set_extension_flag(&mut self, is_extension: bool) { + if is_extension { + self.buffer.as_mut()[0] = self.buffer.as_ref()[0] | 0x20 + } else { + self.buffer.as_mut()[0] = self.buffer.as_ref()[0] & 0xDF + }; + } pub fn set_default_version(&mut self) { let v: u8 = Version::V2.into(); self.buffer.as_mut()[0] = (self.buffer.as_ref()[0] & 0xF0) | (0x0F & v); @@ -264,6 +283,10 @@ impl + AsMut<[u8]>> NetPacket { self.data_len = data_len; Ok(()) } + pub fn set_payload_len(&mut self, payload_len: usize) -> io::Result<()> { + let data_len = HEAD_LEN + payload_len; + self.set_data_len(data_len) + } pub fn set_data_len_max(&mut self) { self.data_len = self.buffer.as_ref().len(); } diff --git a/vnt/src/tun_tap_device/tun_create_helper.rs b/vnt/src/tun_tap_device/tun_create_helper.rs index 2f83423..67b7665 100644 --- a/vnt/src/tun_tap_device/tun_create_helper.rs +++ b/vnt/src/tun_tap_device/tun_create_helper.rs @@ -8,6 +8,7 @@ use tun::Device; use crate::channel::context::ChannelContext; use crate::cipher::Cipher; +use crate::compression::Compressor; use crate::external_route::ExternalRoute; use crate::handle::{CurrentDeviceInfo, PeerDeviceInfo}; #[cfg(feature = "ip_proxy")] @@ -78,6 +79,7 @@ struct TunDeviceHelperInner { parallel: usize, up_counter: SingleU64Adder, device_list: Arc)>>, + compressor: Compressor, } impl TunDeviceHelper { @@ -92,6 +94,7 @@ impl TunDeviceHelper { parallel: usize, up_counter: SingleU64Adder, device_list: Arc)>>, + compressor: Compressor, ) -> Self { Self { inner: Arc::new(AtomicCell::new(Some(TunDeviceHelperInner { @@ -106,6 +109,7 @@ impl TunDeviceHelper { parallel, up_counter, device_list, + compressor, }))), } } @@ -124,6 +128,7 @@ impl TunDeviceHelper { inner.parallel, inner.up_counter, inner.device_list, + inner.compressor, )?; Ok(()) } else {