diff --git a/src/client.rs b/src/client.rs index 22c18b5..aa6ba19 100644 --- a/src/client.rs +++ b/src/client.rs @@ -245,7 +245,8 @@ pub(crate) async fn create_ws_stream( mut stream: S, ) -> Result> { let client = config.client.as_ref().ok_or("client not exist")?; - let tunnel_path = config.tunnel_path.trim_matches('/'); + let tunnel_path = config.tunnel_path.extract().first().ok_or("tunnel path not exist")?.clone(); + let tunnel_path = tunnel_path.as_str().trim_matches('/'); let b64_dst = dst_addr.as_ref().map(|dst_addr| addess_to_b64str(dst_addr, false)); diff --git a/src/config.rs b/src/config.rs index 9d979a0..bed1d09 100644 --- a/src/config.rs +++ b/src/config.rs @@ -14,13 +14,73 @@ pub struct Config { pub remarks: Option, pub method: Option, pub password: Option, - pub tunnel_path: String, + pub tunnel_path: TunnelPath, #[serde(skip)] pub test_timeout_secs: u64, #[serde(skip)] pub is_server: bool, } +#[derive(Clone, Serialize, Deserialize, Debug)] +#[serde(untagged)] +pub enum TunnelPath { + Single(String), + Multiple(Vec), +} + +impl std::fmt::Display for TunnelPath { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TunnelPath::Single(s) => write!(f, "{}", s), + TunnelPath::Multiple(v) => { + let mut s = String::new(); + for (i, item) in v.iter().enumerate() { + if i > 0 { + s.push(','); + } + s.push_str(item); + } + write!(f, "{}", s) + } + } + } +} + +impl Default for TunnelPath { + fn default() -> Self { + TunnelPath::Single("/tunnel/".to_string()) + } +} + +impl TunnelPath { + pub fn is_empty(&self) -> bool { + match self { + TunnelPath::Single(s) => s.is_empty(), + TunnelPath::Multiple(v) => v.is_empty(), + } + } + + pub fn standardize(&mut self) { + match self { + TunnelPath::Single(s) => { + *s = format!("/{}/", s.trim().trim_matches('/')); + } + TunnelPath::Multiple(v) => { + for s in v.iter_mut() { + *s = format!("/{}/", s.trim().trim_matches('/')); + } + } + } + } + + pub fn extract(&self) -> Vec { + match self { + TunnelPath::Single(s) => vec![s.clone()], + TunnelPath::Multiple(v) => v.clone(), + } + } +} + #[derive(Clone, Serialize, Deserialize, Debug, Default)] pub struct Server { pub disable_tls: Option, @@ -70,7 +130,7 @@ impl Config { remarks: None, method: None, password: None, - tunnel_path: "/tunnel/".to_string(), + tunnel_path: TunnelPath::default(), server: None, client: None, test_timeout_secs: 5, @@ -177,9 +237,9 @@ impl Config { self.test_timeout_secs = 5; } if self.tunnel_path.is_empty() { - self.tunnel_path = "/tunnel/".to_string(); + self.tunnel_path = TunnelPath::default(); } else { - self.tunnel_path = format!("/{}/", self.tunnel_path.trim().trim_matches('/')); + self.tunnel_path.standardize(); } if let Some(server) = &mut self.server { @@ -238,7 +298,8 @@ impl Config { let remarks = crate::base64_encode(remarks.as_bytes(), engine); let domain = client.server_domain.as_ref().map_or("".to_string(), |d| d.clone()); let domain = crate::base64_encode(domain.as_bytes(), engine); - let tunnel_path = crate::base64_encode(self.tunnel_path.as_bytes(), engine); + let err = "tunnel_path is not set"; + let tunnel_path = crate::base64_encode(self.tunnel_path.extract().first().ok_or(err)?.as_bytes(), engine); let host = &client.server_host; let port = client.server_port; diff --git a/src/server.rs b/src/server.rs index 903c231..18df18a 100644 --- a/src/server.rs +++ b/src/server.rs @@ -128,7 +128,7 @@ async fn handle_incoming( return Err(Error::from("empty request")); } - if !check_uri_path(&buf, &config.tunnel_path)? { + if !check_uri_path(&buf, &config.tunnel_path.extract())? { return forward_traffic_wrapper(stream, &buf, &config).await; } @@ -158,14 +158,16 @@ where Ok(()) } -fn check_uri_path(buf: &[u8], path: &str) -> Result { +fn check_uri_path(buf: &[u8], path: &[String]) -> Result { let mut headers = [httparse::EMPTY_HEADER; 512]; let mut req = httparse::Request::new(&mut headers); req.parse(buf)?; if let Some(p) = req.path { - if p == path { - return Ok(true); + for path in path { + if p == *path { + return Ok(true); + } } } Ok(false)