diff --git a/Cargo.lock b/Cargo.lock index 57b837b35..161a6ae09 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1222,7 +1222,10 @@ dependencies = [ "sha2", "shadowsocks", "smoltcp 0.12.0", + "snafu", "socket2 0.5.8", + "stack-error", + "stack-error-macro", "tempfile", "thiserror 2.0.11", "time", @@ -1728,7 +1731,7 @@ dependencies = [ "proc-macro2", "quote", "sha3", - "strum", + "strum 0.26.3", "syn 2.0.96", "void", ] @@ -5646,6 +5649,27 @@ dependencies = [ "managed", ] +[[package]] +name = "snafu" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "223891c85e2a29c3fe8fb900c1fae5e69c2e42415e3177752e8718475efa5019" +dependencies = [ + "snafu-derive", +] + +[[package]] +name = "snafu-derive" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03c3c6b7927ffe7ecaa769ee0e3994da3b8cafc8f444578982c83ecb161af917" +dependencies = [ + "heck 0.5.0", + "proc-macro2", + "quote", + "syn 2.0.96", +] + [[package]] name = "socket2" version = "0.4.10" @@ -5746,6 +5770,24 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" +[[package]] +name = "stack-error" +version = "0.1.0" +dependencies = [ + "snafu", + "strum 0.25.0", +] + +[[package]] +name = "stack-error-macro" +version = "0.1.0" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", + "syn 2.0.96", +] + [[package]] name = "static_assertions" version = "1.1.0" @@ -5764,13 +5806,35 @@ version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" +[[package]] +name = "strum" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "290d54ea6f91c969195bdbcd7442c8c2a2ba87da8bf60a7ee86a235d4bc1e125" +dependencies = [ + "strum_macros 0.25.3", +] + [[package]] name = "strum" version = "0.26.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06" dependencies = [ - "strum_macros", + "strum_macros 0.26.4", +] + +[[package]] +name = "strum_macros" +version = "0.25.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23dc1fa9ac9c169a78ba62f0b841814b7abae11bdd047b9c58f893439e309ea0" +dependencies = [ + "heck 0.4.1", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.96", ] [[package]] @@ -6385,7 +6449,7 @@ dependencies = [ "serde", "serde-value", "serde_ignored", - "strum", + "strum 0.26.3", "thiserror 2.0.11", "toml", "tor-basic-utils", @@ -6481,7 +6545,7 @@ dependencies = [ "scopeguard", "serde", "signature", - "strum", + "strum 0.26.3", "thiserror 2.0.11", "time", "tor-async-utils", @@ -6514,7 +6578,7 @@ dependencies = [ "paste", "retry-error", "static_assertions", - "strum", + "strum 0.26.3", "thiserror 2.0.11", "tracing", "void", @@ -6555,7 +6619,7 @@ dependencies = [ "rand 0.8.5", "safelog", "serde", - "strum", + "strum 0.26.3", "thiserror 2.0.11", "tor-async-utils", "tor-basic-utils", @@ -6592,7 +6656,7 @@ dependencies = [ "retry-error", "safelog", "slotmap-careful", - "strum", + "strum 0.26.3", "thiserror 2.0.11", "tor-async-utils", "tor-basic-utils", @@ -6719,7 +6783,7 @@ dependencies = [ "safelog", "serde", "serde_with", - "strum", + "strum 0.26.3", "thiserror 2.0.11", "tor-basic-utils", "tor-bytes", @@ -6825,7 +6889,7 @@ dependencies = [ "rand 0.8.5", "serde", "static_assertions", - "strum", + "strum 0.26.3", "thiserror 2.0.11", "time", "tor-basic-utils", @@ -7036,7 +7100,7 @@ dependencies = [ "pin-project", "priority-queue", "slotmap-careful", - "strum", + "strum 0.26.3", "thiserror 2.0.11", "tor-error", "tor-general-addr", diff --git a/Cargo.toml b/Cargo.toml index 3be675cc9..142e12fc3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,6 +5,8 @@ members = [ "clash", "clash_lib", "clash_doc", + "crates/stack-error-macro", + "crates/stack-error", "clash_ffi", ] diff --git a/clash_lib/Cargo.toml b/clash_lib/Cargo.toml index 0cd7964b5..5aa55d001 100644 --- a/clash_lib/Cargo.toml +++ b/clash_lib/Cargo.toml @@ -17,6 +17,11 @@ zero_copy = [] tokio-console = ["tokio/tracing"] [dependencies] + +stack-error-macro = { path = "../crates/stack-error-macro" } +stack-error = { path = "../crates/stack-error" } +snafu = "0.8" + # Async tokio = { version = "1", features = ["full"] } tokio-util = { version = "0.7", features = ["net", "codec", "io", "compat"] } diff --git a/clash_lib/src/app/dns/config.rs b/clash_lib/src/app/dns/config.rs index c1642b64b..23a2e7cdf 100644 --- a/clash_lib/src/app/dns/config.rs +++ b/clash_lib/src/app/dns/config.rs @@ -149,7 +149,7 @@ impl Config { pub fn parse_fallback_ip_cidr( ipcidr: &[String], - ) -> anyhow::Result> { + ) -> std::result::Result, crate::Error> { let mut output = vec![]; for ip in ipcidr.iter() { @@ -164,7 +164,7 @@ impl Config { pub fn parse_hosts( hosts_mapping: &HashMap, - ) -> anyhow::Result> { + ) -> std::result::Result, crate::Error> { let mut tree = trie::StringTrie::new(); tree.insert( "localhost", @@ -172,7 +172,7 @@ impl Config { ); for (host, ip_str) in hosts_mapping.iter() { - let ip = ip_str.parse::()?; + let ip = ip_str.parse::().map_err(Error::StdNet)?; tree.insert(host.as_str(), Arc::new(ip)); } diff --git a/clash_lib/src/app/dns/dhcp.rs b/clash_lib/src/app/dns/dhcp.rs index 6dcea690c..533749771 100644 --- a/clash_lib/src/app/dns/dhcp.rs +++ b/clash_lib/src/app/dns/dhcp.rs @@ -3,12 +3,14 @@ use crate::{ dns_client::DNSNetMode, helper::make_clients, Client, EnhancedResolver, ThreadSafeDNSClient, }, + error::dns::{ClientTimeoutSnafu, IoSnafu}, proxy::utils::{new_udp_socket, Interface}, }; use async_trait::async_trait; use dhcproto::{Decodable, Encodable}; use futures::FutureExt; use network_interface::{Addr, NetworkInterfaceConfig}; +use snafu::ResultExt; use std::{ env, fmt::{Debug, Formatter}, @@ -56,8 +58,8 @@ impl Client for DhcpClient { format!("dhcp#{}", self.iface) } - async fn exchange(&self, msg: &Message) -> anyhow::Result { - let clients = self.resolve().await?; + async fn exchange(&self, msg: &Message) -> crate::error::DnsResult { + let clients = self.resolve().await.context(IoSnafu)?; let mut dbg_str = vec![]; for c in &clients { dbg_str.push(format!("{:?}", c)); @@ -67,7 +69,8 @@ impl Client for DhcpClient { DHCP_TIMEOUT, EnhancedResolver::batch_exchange(&clients, msg), ) - .await? + .await + .context(ClientTimeoutSnafu)? } } diff --git a/clash_lib/src/app/dns/dns_client.rs b/clash_lib/src/app/dns/dns_client.rs index 4ec18c507..6953430df 100644 --- a/clash_lib/src/app/dns/dns_client.rs +++ b/clash_lib/src/app/dns/dns_client.rs @@ -16,12 +16,14 @@ use hickory_proto::{ udp::UdpClientStream, ProtoError, }; use rustls::ClientConfig; +use snafu::ResultExt; use tokio::{sync::RwLock, task::JoinHandle}; use tracing::{info, warn}; use crate::{ common::tls::{self, GLOBAL_ROOT_STORE}, dns::{dhcp::DhcpClient, ThreadSafeDNSClient}, + error::dns::ProtoSnafu, proxy::utils::new_tcp_stream, }; use hickory_proto::{ @@ -277,7 +279,7 @@ impl Client for DnsClient { format!("{}#{}:{}", &self.net, &self.host, &self.port) } - async fn exchange(&self, msg: &Message) -> anyhow::Result { + async fn exchange(&self, msg: &Message) -> crate::error::DnsResult { let mut inner = self.inner.write().await; if let Some(bg) = &inner.bg_handle { @@ -309,14 +311,14 @@ impl Client for DnsClient { .send(req) .first_answer() .await - .map_err(|x| Error::DNSError(x.to_string()).into()) + .context(ProtoSnafu) .map(|x| x.into()) } } async fn dns_stream_builder( cfg: &DnsConfig, -) -> Result<(client::Client, JoinHandle>), Error> { +) -> crate::error::DnsResult<(client::Client, JoinHandle>)> { match cfg { DnsConfig::Udp(addr, iface) => { let stream = UdpClientStream::builder( @@ -329,7 +331,7 @@ async fn dns_stream_builder( client::Client::connect(stream) .await .map(|(x, y)| (x, tokio::spawn(y))) - .map_err(|x| Error::DNSError(x.to_string())) + .context(ProtoSnafu) } DnsConfig::Tcp(addr, iface) => { let (stream, sender) = TcpClientStream::new( @@ -342,7 +344,7 @@ async fn dns_stream_builder( client::Client::new(stream, sender, None) .await .map(|(x, y)| (x, tokio::spawn(y))) - .map_err(|x| Error::DNSError(x.to_string())) + .context(ProtoSnafu) } DnsConfig::Tls(addr, host, iface) => { let mut tls_config = ClientConfig::builder() @@ -379,7 +381,7 @@ async fn dns_stream_builder( ) .await .map(|(x, y)| (x, tokio::spawn(y))) - .map_err(|x| Error::DNSError(x.to_string())) + .context(ProtoSnafu) } DnsConfig::Https(addr, host, iface) => { let mut tls_config = ClientConfig::builder() @@ -402,7 +404,7 @@ async fn dns_stream_builder( client::Client::connect(stream) .await .map(|(x, y)| (x, tokio::spawn(y))) - .map_err(|x| Error::DNSError(x.to_string())) + .context(ProtoSnafu) } } } diff --git a/clash_lib/src/app/dns/mod.rs b/clash_lib/src/app/dns/mod.rs index 00fd112b9..6e4afc232 100644 --- a/clash_lib/src/app/dns/mod.rs +++ b/clash_lib/src/app/dns/mod.rs @@ -28,7 +28,10 @@ pub use server::get_dns_listener; pub trait Client: Sync + Send + Debug { /// used to identify the client for logging fn id(&self) -> String; - async fn exchange(&self, msg: &op::Message) -> anyhow::Result; + async fn exchange( + &self, + msg: &op::Message, + ) -> crate::error::DnsResult; } type ThreadSafeDNSClient = Arc; @@ -51,22 +54,25 @@ pub trait ClashResolver: Sync + Send { &self, host: &str, enhanced: bool, - ) -> anyhow::Result>; + ) -> std::result::Result, crate::error::DnsError>; async fn resolve_v4( &self, host: &str, enhanced: bool, - ) -> anyhow::Result>; + ) -> std::result::Result, crate::error::DnsError>; async fn resolve_v6( &self, host: &str, enhanced: bool, - ) -> anyhow::Result>; + ) -> std::result::Result, crate::error::DnsError>; async fn cached_for(&self, ip: std::net::IpAddr) -> Option; /// Used for DNS Server - async fn exchange(&self, message: &op::Message) -> anyhow::Result; + async fn exchange( + &self, + message: &op::Message, + ) -> std::result::Result; /// Only used for look up fake IP async fn reverse_lookup(&self, ip: std::net::IpAddr) -> Option; diff --git a/clash_lib/src/app/dns/resolver/enhanced.rs b/clash_lib/src/app/dns/resolver/enhanced.rs index 626e9751e..89affd39a 100644 --- a/clash_lib/src/app/dns/resolver/enhanced.rs +++ b/clash_lib/src/app/dns/resolver/enhanced.rs @@ -1,6 +1,7 @@ use async_trait::async_trait; use futures::{FutureExt, TryFutureExt}; use rand::seq::IndexedRandom; +use snafu::ResultExt; use std::{ net, sync::{ @@ -19,7 +20,10 @@ use crate::{ common::{mmdb::Mmdb, trie}, config::def::DNSMode, dns::{helper::make_clients, ThreadSafeDNSClient}, - Error, + error::dns::{ + ProtoSnafu, IPV6DisabledSnafu, InvalidQuerySnafu, NoRecordSnafu, + TimeoutSnafu, + }, }; use crate::dns::{ @@ -223,7 +227,7 @@ impl EnhancedResolver { pub async fn batch_exchange( clients: &Vec, message: &op::Message, - ) -> anyhow::Result { + ) -> crate::error::DnsResult { let mut queries = Vec::new(); for c in clients { queries.push( @@ -249,7 +253,7 @@ impl EnhancedResolver { Ok(r) => Ok(r.0), Err(e) => Err(e), }, - _ = timeout => Err(Error::DNSError("DNS query timeout".into()).into()) + _ = timeout => TimeoutSnafu.fail(), } } @@ -258,12 +262,13 @@ impl EnhancedResolver { &self, host: &str, record_type: rr::record_type::RecordType, - ) -> anyhow::Result> { + ) -> crate::error::DnsResult> { let mut m = op::Message::new(); let mut q = op::Query::new(); let name = rr::Name::from_str_relaxed(host) - .map_err(|_x| anyhow!("invalid domain: {}", host))? - .append_domain(&rr::Name::root())?; // makes it FQDN + .context(ProtoSnafu)? + .append_domain(&rr::Name::root()) + .context(ProtoSnafu)?; // makes it FQDN q.set_name(name); q.set_query_type(record_type); m.add_query(q); @@ -275,14 +280,20 @@ impl EnhancedResolver { if !ip_list.is_empty() { Ok(ip_list) } else { - Err(anyhow!("no record for hostname: {}", host)) + NoRecordSnafu { + host: host.to_owned(), + } + .fail() } } Err(e) => Err(e), } } - async fn exchange(&self, message: &op::Message) -> anyhow::Result { + async fn exchange( + &self, + message: &op::Message, + ) -> crate::error::DnsResult { if let Some(q) = message.query() { if let Some(lru) = &self.lru_cache { if let Some(cached) = lru.read().await.peek(q.to_string().as_str()) { @@ -294,14 +305,17 @@ impl EnhancedResolver { } self.exchange_no_cache(message).await } else { - Err(anyhow!("invalid query")) + InvalidQuerySnafu { + queries: message.queries().to_vec(), + } + .fail() } } async fn exchange_no_cache( &self, message: &op::Message, - ) -> anyhow::Result { + ) -> crate::error::DnsResult { let q = message.query().unwrap(); let query = async move { @@ -367,7 +381,7 @@ impl EnhancedResolver { async fn ip_exchange( &self, message: &op::Message, - ) -> anyhow::Result { + ) -> crate::error::DnsResult { if let Some(matched) = self.match_policy(message) { return EnhancedResolver::batch_exchange(matched, message).await; } @@ -473,7 +487,7 @@ impl ClashResolver for EnhancedResolver { &self, host: &str, enhanced: bool, - ) -> anyhow::Result> { + ) -> crate::error::DnsResult> { match self.ipv6.load(Relaxed) { true => { let fut1 = self @@ -502,7 +516,7 @@ impl ClashResolver for EnhancedResolver { &self, host: &str, enhanced: bool, - ) -> anyhow::Result> { + ) -> crate::error::DnsResult> { if enhanced { if let Some(hosts) = &self.hosts { if let Some(v) = hosts.search(host) { @@ -543,9 +557,9 @@ impl ClashResolver for EnhancedResolver { &self, host: &str, enhanced: bool, - ) -> anyhow::Result> { + ) -> crate::error::DnsResult> { if !self.ipv6.load(Relaxed) { - return Err(Error::DNSError("ipv6 disabled".into()).into()); + return IPV6DisabledSnafu.fail(); } if enhanced { @@ -584,7 +598,10 @@ impl ClashResolver for EnhancedResolver { None } - async fn exchange(&self, message: &op::Message) -> anyhow::Result { + async fn exchange( + &self, + message: &op::Message, + ) -> crate::error::DnsResult { let rv = self.exchange(message).await?; let hostname = message .query() diff --git a/clash_lib/src/app/dns/resolver/system_linux.rs b/clash_lib/src/app/dns/resolver/system_linux.rs index 3ccc8d14a..84917370a 100644 --- a/clash_lib/src/app/dns/resolver/system_linux.rs +++ b/clash_lib/src/app/dns/resolver/system_linux.rs @@ -4,8 +4,12 @@ use async_trait::async_trait; use hickory_resolver::TokioResolver; use rand::seq::IteratorRandom; +use snafu::ResultExt; -use crate::app::dns::{ClashResolver, ResolverKind}; +use crate::{ + app::dns::{ClashResolver, ResolverKind}, + error::dns::{ResolveSnafu, UnsupportedSnafu}, +}; pub struct SystemResolver { inner: TokioResolver, @@ -14,9 +18,9 @@ pub struct SystemResolver { /// Bug in libc, use tokio impl instead: https://sourceware.org/bugzilla/show_bug.cgi?id=10652 impl SystemResolver { - pub fn new(ipv6: bool) -> anyhow::Result { + pub fn new(ipv6: bool) -> crate::error::DnsResult { Ok(Self { - inner: TokioResolver::tokio_from_system_conf()?, + inner: TokioResolver::tokio_from_system_conf().context(ResolveSnafu)?, ipv6: AtomicBool::new(ipv6), }) } @@ -28,8 +32,8 @@ impl ClashResolver for SystemResolver { &self, host: &str, _: bool, - ) -> anyhow::Result> { - let response = self.inner.lookup_ip(host).await?; + ) -> crate::error::DnsResult> { + let response = self.inner.lookup_ip(host).await.context(ResolveSnafu)?; Ok(response .iter() .filter(|x| self.ipv6() || x.is_ipv4()) @@ -40,8 +44,8 @@ impl ClashResolver for SystemResolver { &self, host: &str, _: bool, - ) -> anyhow::Result> { - let response = self.inner.ipv4_lookup(host).await?; + ) -> crate::error::DnsResult> { + let response = self.inner.ipv4_lookup(host).await.context(ResolveSnafu)?; Ok(response.iter().map(|x| x.0).choose(&mut rand::rng())) } @@ -49,8 +53,8 @@ impl ClashResolver for SystemResolver { &self, host: &str, _: bool, - ) -> anyhow::Result> { - let response = self.inner.ipv6_lookup(host).await?; + ) -> crate::error::DnsResult> { + let response = self.inner.ipv6_lookup(host).await.context(ResolveSnafu)?; Ok(response.iter().map(|x| x.0).choose(&mut rand::rng())) } @@ -61,8 +65,8 @@ impl ClashResolver for SystemResolver { async fn exchange( &self, _: &hickory_proto::op::Message, - ) -> anyhow::Result { - Err(anyhow::anyhow!("unsupported")) + ) -> crate::error::DnsResult { + UnsupportedSnafu.fail() } fn ipv6(&self) -> bool { diff --git a/clash_lib/src/app/router/mod.rs b/clash_lib/src/app/router/mod.rs index 05bd4a0ab..fd5daba59 100644 --- a/clash_lib/src/app/router/mod.rs +++ b/clash_lib/src/app/router/mod.rs @@ -343,8 +343,6 @@ pub fn map_rule_type( mod tests { use std::sync::Arc; - use anyhow::Ok; - use crate::{ app::dns::{MockClashResolver, SystemResolver}, common::{geodata::GeoData, http::new_http_client, mmdb::Mmdb}, diff --git a/clash_lib/src/error/dns.rs b/clash_lib/src/error/dns.rs new file mode 100644 index 000000000..4f61b2b47 --- /dev/null +++ b/clash_lib/src/error/dns.rs @@ -0,0 +1,144 @@ +use snafu::{Location, Snafu}; +use stack_error_macro::stack_trace_debug; + +#[derive(Snafu)] +#[snafu(visibility(pub))] +#[stack_trace_debug] +pub enum DnsError { + #[snafu(display("invalid domain: {domain}"))] + InvaldDomain { + domain: String, + #[snafu(implicit)] + location: Location, + }, + #[snafu(display("no record: {host}"))] + NoRecord { + host: String, + #[snafu(implicit)] + location: Location, + }, + #[snafu(display("invalid query: {queries:?}"))] + InvalidQuery { + queries: Vec, + #[snafu(implicit)] + location: Location, + }, + #[snafu(display("dns timeout"))] + Timeout { + #[snafu(implicit)] + location: Location, + }, + #[snafu(display("hickory proto"))] + Proto { + #[snafu(implicit)] + location: Location, + #[snafu(source)] + error: hickory_proto::ProtoError, + }, + #[snafu(display("hickory resolve"))] + ResolveError { + #[snafu(implicit)] + location: Location, + #[snafu(source)] + error: hickory_resolver::ResolveError, + }, + #[snafu(display("unsupported operation"))] + Unsupported { + #[snafu(implicit)] + location: Location, + }, + #[snafu(display("ipv6 disabled"))] + IPV6Disabled { + #[snafu(implicit)] + location: Location, + }, + #[snafu(display("empty dns"))] + EmptyDns { + #[snafu(implicit)] + location: Location, + }, + #[snafu(display("dns timeout"))] + ClientTimeout { + #[snafu(implicit)] + location: Location, + #[snafu(source)] + error: tokio::time::error::Elapsed, + }, + #[snafu(display("io error"))] + Io { + #[snafu(implicit)] + location: Location, + #[snafu(source)] + error: std::io::Error, + }, +} + +pub type DnsResult = std::result::Result; + +#[cfg(test)] +mod tests { + + use snafu::{Location, ResultExt, Snafu}; + use stack_error_macro::stack_trace_debug; + + // layer 1 + #[derive(Snafu)] + #[snafu(visibility(pub))] + #[stack_trace_debug] + pub enum RustError { + #[snafu(display("Failed to call Java"))] + Rust { + #[snafu(implicit)] + location: Location, + #[snafu(source)] + source: JavaError, + }, + } + + // layer 2 + #[derive(Snafu)] + #[snafu(visibility(pub))] + #[stack_trace_debug] + pub enum JavaError { + #[snafu(display("Failed to call Python"))] + Python { + #[snafu(implicit)] + location: Location, + #[snafu(source)] + source: PythonError, + }, + } + + // layer 3 + #[derive(Snafu)] + #[snafu(visibility(pub))] + #[stack_trace_debug] + pub enum PythonError { + #[snafu(display("IO Error"))] + IO { + #[snafu(implicit)] + location: Location, + #[snafu(source)] + source: std::io::Error, + }, + } + + fn fn1() -> Result<(), RustError> { + fn2().context(RustSnafu) + } + + fn fn2() -> Result<(), JavaError> { + fn3().context(PythonSnafu) + } + + fn fn3() -> Result<(), PythonError> { + let res = Err(std::io::Error::new(std::io::ErrorKind::Other, "error")); + res.context(IOSnafu) + } + + #[test] + fn test_snafu_error() { + let res = fn1(); + println!("{:?}", res); + } +} diff --git a/clash_lib/src/error/mod.rs b/clash_lib/src/error/mod.rs new file mode 100644 index 000000000..c8fcad847 --- /dev/null +++ b/clash_lib/src/error/mod.rs @@ -0,0 +1,25 @@ +pub mod dns; + +pub use dns::{DnsError, DnsResult}; + +#[derive(thiserror::Error, Debug)] +pub enum Error { + #[error(transparent)] + StdNet(#[from] std::net::AddrParseError), + #[error(transparent)] + IpNet(#[from] ipnet::AddrParseError), + #[error(transparent)] + Io(#[from] std::io::Error), + #[error("invalid config: {0}")] + InvalidConfig(String), + #[error("profile error: {0}")] + ProfileError(String), + #[error("dns error: {0}")] + DNSError(String), + #[error(transparent)] + DNSServerError(#[from] watfaq_dns::DNSError), + #[error("crypto error: {0}")] + Crypto(String), + #[error("operation error: {0}")] + Operation(String), +} \ No newline at end of file diff --git a/clash_lib/src/lib.rs b/clash_lib/src/lib.rs index a115dcc16..d0d4848f3 100644 --- a/clash_lib/src/lib.rs +++ b/clash_lib/src/lib.rs @@ -26,7 +26,7 @@ use config::def::LogLevel; use once_cell::sync::OnceCell; use proxy::tun::get_tun_runner; -use std::{io, path::PathBuf, sync::Arc}; +use std::{path::PathBuf, sync::Arc}; use thiserror::Error; use tokio::{ sync::{broadcast, mpsc, oneshot, Mutex}, @@ -37,6 +37,7 @@ use tracing::{debug, error, info}; mod app; mod common; mod config; +pub mod error; mod proxy; mod session; @@ -46,24 +47,15 @@ pub use config::{ DNSListen as ClashDNSListen, RuntimeConfig as ClashRuntimeConfig, }; +pub use error::Error; + #[derive(Error, Debug)] -pub enum Error { - #[error(transparent)] - IpNet(#[from] ipnet::AddrParseError), - #[error(transparent)] - Io(#[from] io::Error), - #[error("invalid config: {0}")] - InvalidConfig(String), - #[error("profile error: {0}")] - ProfileError(String), - #[error("dns error: {0}")] - DNSError(String), - #[error(transparent)] - DNSServerError(#[from] watfaq_dns::DNSError), - #[error("crypto error: {0}")] - Crypto(String), - #[error("operation error: {0}")] - Operation(String), +pub struct DnsError(pub String); + +impl std::fmt::Display for DnsError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "dns error: {}", self.0) + } } pub type Runner = futures::future::BoxFuture<'static, Result<(), Error>>; diff --git a/clash_lib/src/proxy/hysteria2/error.rs b/clash_lib/src/proxy/hysteria2/error.rs new file mode 100644 index 000000000..91a3ecda3 --- /dev/null +++ b/clash_lib/src/proxy/hysteria2/error.rs @@ -0,0 +1,72 @@ +use snafu::{Location, Snafu}; +use stack_error_macro::stack_trace_debug; + +#[derive(Snafu)] +#[snafu(visibility(pub))] +#[stack_trace_debug] +pub enum Error { + #[snafu(display("Failed to call Dns"))] + Dns { + #[snafu(implicit)] + location: Location, + source: crate::error::DnsError, + }, + #[snafu(display("empty dns"))] + DsnEmpty { + #[snafu(implicit)] + location: Location, + }, + #[snafu(display("Failed to call Io"))] + Io { + #[snafu(implicit)] + location: Location, + #[snafu(source)] + error: std::io::Error, + }, + #[snafu(display("Failed to call Quinn"))] + QuinnConnect { + #[snafu(implicit)] + location: Location, + #[snafu(source)] + error: quinn::ConnectError, + }, + #[snafu(display("Failed to call Quinn"))] + QuinnConnection { + #[snafu(implicit)] + location: Location, + #[snafu(source)] + error: quinn::ConnectionError, + }, + #[snafu(display("Failed to call H3"))] + H3 { + #[snafu(implicit)] + location: Location, + #[snafu(source)] + error: h3::Error, + }, + #[snafu(display("Failed to call Auth, status code: {status}"))] + Auth { status: u16 }, + #[snafu(display("Failed to call Auth, msg: {msg}"))] + AuthOther { msg: String }, + #[snafu(display("Failed to call to_str"))] + ToStr { + #[snafu(implicit)] + location: Location, + #[snafu(source)] + error: http::header::ToStrError, + }, + #[snafu(display("Failed to call ParseInt"))] + ParseInt { + #[snafu(implicit)] + location: Location, + #[snafu(source)] + error: std::num::ParseIntError, + }, + #[snafu(display("Failed to call ParseInt"))] + ParseBool { + #[snafu(implicit)] + location: Location, + #[snafu(source)] + error: core::str::ParseBoolError, + }, +} diff --git a/clash_lib/src/proxy/hysteria2/mod.rs b/clash_lib/src/proxy/hysteria2/mod.rs index 7288b6bcd..1004772ed 100644 --- a/clash_lib/src/proxy/hysteria2/mod.rs +++ b/clash_lib/src/proxy/hysteria2/mod.rs @@ -10,10 +10,18 @@ use std::{ }; mod codec; mod congestion; +mod error; mod salamander; mod udp_hop; +type Result = std::result::Result; + use bytes::Bytes; +use error::{ + AuthOtherSnafu, AuthSnafu, DnsSnafu, DsnEmptySnafu, H3Snafu, IoSnafu, + ParseBoolSnafu, ParseIntSnafu, QuinnConnectSnafu, QuinnConnectionSnafu, + ToStrSnafu, +}; use futures::{SinkExt, StreamExt}; use h3::client::SendRequest; use h3_quinn::OpenStreams; @@ -29,6 +37,7 @@ use rustls::{ }, ClientConfig as RustlsClientConfig, }; +use snafu::{OptionExt, ResultExt}; use tokio::{ io::{AsyncRead, AsyncWrite, ReadBuf}, sync::Mutex, @@ -122,7 +131,7 @@ impl ServerCertVerifier for CertVerifyOption { server_name: &rustls::pki_types::ServerName<'_>, ocsp_response: &[u8], now: rustls::pki_types::UnixTime, - ) -> Result { + ) -> std::result::Result { if let Some(ref fingerprint) = self.fingerprint { let cert_hex = encode_hex(&sha256(end_entity.as_ref())); if &cert_hex != fingerprint { @@ -155,7 +164,10 @@ impl ServerCertVerifier for CertVerifyOption { message: &[u8], cert: &rustls::pki_types::CertificateDer<'_>, dss: &rustls::DigitallySignedStruct, - ) -> Result { + ) -> std::result::Result< + rustls::client::danger::HandshakeSignatureValid, + rustls::Error, + > { self.pki.verify_tls12_signature(message, cert, dss) } @@ -164,7 +176,10 @@ impl ServerCertVerifier for CertVerifyOption { message: &[u8], cert: &rustls::pki_types::CertificateDer<'_>, dss: &rustls::DigitallySignedStruct, - ) -> Result { + ) -> std::result::Result< + rustls::client::danger::HandshakeSignatureValid, + rustls::Error, + > { self.pki.verify_tls13_signature(message, cert, dss) } } @@ -177,7 +192,7 @@ enum CcRx { impl FromStr for CcRx { type Err = ParseIntError; - fn from_str(s: &str) -> Result { + fn from_str(s: &str) -> std::result::Result { if s.eq_ignore_ascii_case("auto") { Ok(Self::Auto) } else { @@ -208,7 +223,7 @@ impl Handler { const DEFAULT_MAX_IDLE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300); - pub fn new(opts: HystOption) -> anyhow::Result { + pub fn new(opts: HystOption) -> Result { let verify = CertVerifyOption::new( opts.fingerprint.clone(), opts.ca.clone(), @@ -253,7 +268,7 @@ impl Handler { &self, sess: &Session, resolver: ThreadSafeDNSResolver, - ) -> anyhow::Result<(Connection, SendRequest)> { + ) -> Result<(Connection, SendRequest)> { // Everytime we enstablish a new session, we should lookup the server // address. maybe it changed since it use ddns let server_socket_addr = match self.opts.addr.clone() { @@ -261,8 +276,9 @@ impl Handler { SocksAddr::Domain(d, port) => { let ip = resolver .resolve(d.as_str(), true) - .await? - .ok_or_else(|| anyhow!("resolve domain {} failed", d))?; + .await + .context(DnsSnafu)? + .context(DsnEmptySnafu)?; SocketAddr::new(ip, port) } }; @@ -292,18 +308,20 @@ impl Handler { let mut ep = if let Some(obfs) = self.opts.obfs.as_ref() { match obfs { Obfs::Salamander(salamander_obfs) => { - let socket = create_socket().await?; + let socket = create_socket().await.context(IoSnafu)?; let obfs = salamander::Salamander::new( - socket.into_std()?, + socket.into_std().context(IoSnafu)?, salamander_obfs.key.to_vec(), - )?; + ) + .context(IoSnafu)?; quinn::Endpoint::new_with_abstract_socket( self.ep_config.clone(), None, Arc::new(obfs), Arc::new(TokioRuntime), - )? + ) + .context(IoSnafu)? } } } else if let Some(port_gen) = self.opts.ports.as_ref() { @@ -311,13 +329,15 @@ impl Handler { server_socket_addr.port(), port_gen.clone(), None, - )?; + ) + .context(IoSnafu)?; quinn::Endpoint::new_with_abstract_socket( self.ep_config.clone(), None, Arc::new(udp_hop), Arc::new(TokioRuntime), - )? + ) + .context(IoSnafu)? } else { let socket = { if resolver.ipv6() { @@ -327,7 +347,8 @@ impl Handler { #[cfg(any(target_os = "linux", target_os = "android"))] sess.so_mark, ) - .await? + .await + .context(IoSnafu)? } else { new_udp_socket( Some((Ipv4Addr::UNSPECIFIED, 0).into()), @@ -335,23 +356,27 @@ impl Handler { #[cfg(any(target_os = "linux", target_os = "android"))] sess.so_mark, ) - .await? + .await + .context(IoSnafu)? } }; quinn::Endpoint::new( self.ep_config.clone(), None, - socket.into_std()?, + socket.into_std().context(IoSnafu)?, Arc::new(TokioRuntime), - )? + ) + .context(IoSnafu)? }; ep.set_default_client_config(self.client_config.clone()); let session = ep - .connect(server_socket_addr, self.opts.sni.as_deref().unwrap_or(""))? - .await?; + .connect(server_socket_addr, self.opts.sni.as_deref().unwrap_or("")) + .context(QuinnConnectSnafu)? + .await + .context(QuinnConnectionSnafu)?; let (guard, _rx, udp) = Self::auth(&session, &self.opts.passwd).await?; *self.support_udp.write().unwrap() = udp; // todo set congestion controller according to cc_rx @@ -375,11 +400,13 @@ impl Handler { async fn auth( conn: &quinn::Connection, passwd: &str, - ) -> anyhow::Result<(SendRequest, CcRx, bool)> { + ) -> Result<(SendRequest, CcRx, bool)> { let h3_conn = h3_quinn::Connection::new(conn.clone()); - let (_, mut sender) = - h3::client::builder().build::<_, _, Bytes>(h3_conn).await?; + let (_, mut sender) = h3::client::builder() + .build::<_, _, Bytes>(h3_conn) + .await + .context(H3Snafu)?; let req = http::Request::post("https://hysteria/auth") .header("Hysteria-Auth", passwd) @@ -387,14 +414,14 @@ impl Handler { .header("Hysteria-Padding", codec::padding(64..=512)) .body(()) .unwrap(); - let mut r = sender.send_request(req).await?; - r.finish().await?; + let mut r = sender.send_request(req).await.context(H3Snafu)?; + r.finish().await.context(H3Snafu)?; - let r = r.recv_response().await?; + let r = r.recv_response().await.context(H3Snafu)?; const HYSTERIA_STATUS_OK: u16 = 233; if r.status() != HYSTERIA_STATUS_OK { - return Err(anyhow!("auth failed: response status code {}", r.status())); + return AuthSnafu { status: r.status() }.fail(); } // MUST have Hysteria-CC-RX and Hysteria-UDP headers according to hysteria2 @@ -402,16 +429,24 @@ impl Handler { let cc_rx = r .headers() .get("Hysteria-CC-RX") - .ok_or_else(|| anyhow!("auth failed: missing Hysteria-CC-RX header"))? - .to_str()? - .parse()?; + .context(AuthOtherSnafu { + msg: "missing Hysteria-CC-RX header".to_owned(), + })? + .to_str() + .context(ToStrSnafu)? + .parse() + .context(ParseIntSnafu)?; let support_udp = r .headers() .get("Hysteria-UDP") - .ok_or_else(|| anyhow!("auth failed: missing Hysteria-UDP header"))? - .to_str()? - .parse()?; + .context(AuthOtherSnafu { + msg: "missing Hysteria-UDP header".to_owned(), + })? + .to_str() + .context(ToStrSnafu)? + .parse() + .context(ParseBoolSnafu)?; Ok((sender, cc_rx, support_udp)) } diff --git a/clash_lib/src/proxy/utils/test_utils/noop.rs b/clash_lib/src/proxy/utils/test_utils/noop.rs index fe1807e23..17b858f97 100644 --- a/clash_lib/src/proxy/utils/test_utils/noop.rs +++ b/clash_lib/src/proxy/utils/test_utils/noop.rs @@ -8,6 +8,7 @@ use crate::{ dispatcher::{BoxedChainedDatagram, BoxedChainedStream}, dns::{ClashResolver, ResolverKind, ThreadSafeDNSResolver}, }, + error::dns::UnsupportedSnafu, proxy::{ConnectorType, DialWithConnector, OutboundHandler, OutboundType}, session::Session, }; @@ -20,7 +21,7 @@ impl ClashResolver for NoopResolver { &self, _host: &str, _enhanced: bool, - ) -> anyhow::Result> { + ) -> crate::error::DnsResult> { Ok(None) } @@ -28,7 +29,7 @@ impl ClashResolver for NoopResolver { &self, _host: &str, _enhanced: bool, - ) -> anyhow::Result> { + ) -> crate::error::DnsResult> { Ok(None) } @@ -36,7 +37,7 @@ impl ClashResolver for NoopResolver { &self, _host: &str, _enhanced: bool, - ) -> anyhow::Result> { + ) -> crate::error::DnsResult> { Ok(None) } @@ -45,8 +46,11 @@ impl ClashResolver for NoopResolver { } /// Used for DNS Server - async fn exchange(&self, _message: &op::Message) -> anyhow::Result { - Err(anyhow::anyhow!("unsupported")) + async fn exchange( + &self, + _message: &op::Message, + ) -> crate::error::DnsResult { + UnsupportedSnafu {}.fail() } /// Only used for look up fake IP diff --git a/crates/stack-error-macro/Cargo.toml b/crates/stack-error-macro/Cargo.toml new file mode 100644 index 000000000..6d519e2db --- /dev/null +++ b/crates/stack-error-macro/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "stack-error-macro" +version = "0.1.0" +edition = "2021" + +[lib] +proc-macro = true + +[dependencies] +proc-macro2 = "1.0.66" +quote = "1.0" +syn = "1.0" +syn2 = { version = "2.0", package = "syn", features = [ + "derive", + "parsing", + "printing", + "clone-impls", + "proc-macro", + "extra-traits", + "full", +] } \ No newline at end of file diff --git a/crates/stack-error-macro/src/lib.rs b/crates/stack-error-macro/src/lib.rs new file mode 100644 index 000000000..0c947fd0f --- /dev/null +++ b/crates/stack-error-macro/src/lib.rs @@ -0,0 +1,286 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! implement `::stack_error::ext::StackError` + +use proc_macro2::{Span, TokenStream as TokenStream2}; +use quote::{quote, quote_spanned}; +use syn2::spanned::Spanned; +use syn2::{parenthesized, Attribute, Ident, ItemEnum, Variant}; + +#[proc_macro_attribute] +pub fn stack_trace_debug( + args: proc_macro::TokenStream, + input: proc_macro::TokenStream, +) -> proc_macro::TokenStream { + stack_trace_style_impl(args.into(), input.into()).into() +} + +fn stack_trace_style_impl(args: TokenStream2, input: TokenStream2) -> TokenStream2 { + let input_cloned: TokenStream2 = input.clone(); + + let error_enum_definition: ItemEnum = syn2::parse2(input_cloned).unwrap(); + let enum_name = error_enum_definition.ident; + + let mut variants = vec![]; + + for error_variant in error_enum_definition.variants { + let variant = ErrorVariant::from_enum_variant(error_variant); + variants.push(variant); + } + + let debug_fmt_fn = build_debug_fmt_impl(enum_name.clone(), variants.clone()); + let next_fn = build_next_impl(enum_name.clone(), variants); + let debug_impl = build_debug_impl(enum_name.clone()); + + quote! { + #args + #input + + impl ::stack_error::ext::StackError for #enum_name { + #debug_fmt_fn + #next_fn + } + + #debug_impl + } +} + +/// Generate `debug_fmt` fn. +/// +/// The generated fn will be like: +/// ```rust, ignore +/// fn debug_fmt(&self, layer: usize, buf: &mut Vec); +/// ``` +fn build_debug_fmt_impl(enum_name: Ident, variants: Vec) -> TokenStream2 { + let match_arms = variants + .iter() + .map(|v| v.to_debug_match_arm()) + .collect::>(); + + quote! { + fn debug_fmt(&self, layer: usize, buf: &mut Vec) { + use #enum_name::*; + match self { + #(#match_arms)* + } + } + } +} + +/// Generate `next` fn. +/// +/// The generated fn will be like: +/// ```rust, ignore +/// fn next(&self) -> Option<&dyn ::stack_error::ext::StackError>; +/// ``` +fn build_next_impl(enum_name: Ident, variants: Vec) -> TokenStream2 { + let match_arms = variants + .iter() + .map(|v| v.to_next_match_arm()) + .collect::>(); + + quote! { + fn next(&self) -> Option<&dyn ::stack_error::ext::StackError> { + use #enum_name::*; + match self { + #(#match_arms)* + } + } + } +} + +/// Implement [std::fmt::Debug] via `debug_fmt` +fn build_debug_impl(enum_name: Ident) -> TokenStream2 { + quote! { + impl std::fmt::Debug for #enum_name { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + use ::stack_error::ext::StackError; + let mut buf = vec![]; + self.debug_fmt(0, &mut buf); + write!(f, "{}", buf.join("\n")) + } + } + } +} + +#[derive(Clone, Debug)] +struct ErrorVariant { + name: Ident, + fields: Vec, + has_location: bool, + has_source: bool, + has_external_cause: bool, + display: TokenStream2, + span: Span, + cfg_attr: Option, +} + +impl ErrorVariant { + /// Construct self from [Variant] + fn from_enum_variant(variant: Variant) -> Self { + let span = variant.span(); + let mut has_location = false; + let mut has_source = false; + let mut has_external_cause = false; + + for field in &variant.fields { + if let Some(ident) = &field.ident { + if ident == "location" { + has_location = true; + } else if ident == "source" { + has_source = true; + } else if ident == "error" { + has_external_cause = true; + } + } + } + + let mut display = None; + let mut cfg_attr = None; + for attr in variant.attrs { + if attr.path().is_ident("snafu") { + attr.parse_nested_meta(|meta| { + if meta.path.is_ident("display") { + let content; + parenthesized!(content in meta.input); + let display_ts: TokenStream2 = content.parse()?; + display = Some(display_ts); + Ok(()) + } else { + Err(meta.error("unrecognized repr")) + } + }) + .expect("Each error should contains a display attribute"); + } + + if attr.path().is_ident("cfg") { + cfg_attr = Some(attr); + } + } + + let field_ident = variant + .fields + .iter() + .map(|f| f.ident.clone().unwrap_or_else(|| Ident::new("_", f.span()))) + .collect(); + + Self { + name: variant.ident, + fields: field_ident, + has_location, + has_source, + has_external_cause, + display: display.unwrap(), + span, + cfg_attr, + } + } + + /// Convert self into an match arm that will be used in [build_debug_impl]. + /// + /// The generated match arm will be like: + /// ```rust, ignore + /// ErrorKindWithSource { source, .. } => { + /// debug_fmt(source, layer + 1, buf); + /// }, + /// ErrorKindWithoutSource { .. } => { + /// buf.push(format!("{layer}: {}, at {}", format!(#display), location))); + /// } + /// ``` + /// + /// The generated code assumes fn `debug_fmt`, var `layer`, var `buf` are in scope. + fn to_debug_match_arm(&self) -> TokenStream2 { + let name = &self.name; + let fields = &self.fields; + let display = &self.display; + let cfg = if let Some(cfg) = &self.cfg_attr { + quote_spanned!(cfg.span() => #cfg) + } else { + quote! {} + }; + + match (self.has_location, self.has_source, self.has_external_cause) { + (true, true, _) => quote_spanned! { + self.span => #cfg #[allow(unused_variables)] #name { #(#fields),*, } => { + buf.push(format!("{layer}: {}, at {}", format!(#display), location)); + source.debug_fmt(layer + 1, buf); + }, + }, + (true, false, true) => quote_spanned! { + self.span => #cfg #[allow(unused_variables)] #name { #(#fields),* } => { + buf.push(format!("{layer}: {}, at {}", format!(#display), location)); + buf.push(format!("{}: {:?}", layer + 1, error)); + }, + }, + (true, false, false) => quote_spanned! { + self.span => #cfg #[allow(unused_variables)] #name { #(#fields),* } => { + buf.push(format!("{layer}: {}, at {}", format!(#display), location)); + }, + }, + (false, true, _) => quote_spanned! { + self.span => #cfg #[allow(unused_variables)] #name { #(#fields),* } => { + buf.push(format!("{layer}: {}", format!(#display))); + source.debug_fmt(layer + 1, buf); + }, + }, + (false, false, true) => quote_spanned! { + self.span => #cfg #[allow(unused_variables)] #name { #(#fields),* } => { + buf.push(format!("{layer}: {}", format!(#display))); + buf.push(format!("{}: {:?}", layer + 1, error)); + }, + }, + (false, false, false) => quote_spanned! { + self.span => #cfg #[allow(unused_variables)] #name { #(#fields),* } => { + buf.push(format!("{layer}: {}", format!(#display))); + }, + }, + } + } + + /// Convert self into an match arm that will be used in [build_next_impl]. + /// + /// The generated match arm will be like: + /// ```rust, ignore + /// ErrorKindWithSource { source, .. } => { + /// Some(source) + /// }, + /// ErrorKindWithoutSource { .. } => { + /// None + /// } + /// ``` + fn to_next_match_arm(&self) -> TokenStream2 { + let name = &self.name; + let fields = &self.fields; + let cfg = if let Some(cfg) = &self.cfg_attr { + quote_spanned!(cfg.span() => #cfg) + } else { + quote! {} + }; + + if self.has_source { + quote_spanned! { + self.span => #cfg #[allow(unused_variables)] #name { #(#fields),* } => { + Some(source) + }, + } + } else { + quote_spanned! { + self.span => #cfg #[allow(unused_variables)] #name { #(#fields),* } =>{ + None + } + } + } + } +} diff --git a/crates/stack-error/Cargo.toml b/crates/stack-error/Cargo.toml new file mode 100644 index 000000000..cc0861a7f --- /dev/null +++ b/crates/stack-error/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "stack-error" +version = "0.1.0" +edition = "2021" + +[dependencies] +snafu = "0.8" +strum = { version = "0.25", features = ["derive"] } \ No newline at end of file diff --git a/crates/stack-error/src/ext.rs b/crates/stack-error/src/ext.rs new file mode 100644 index 000000000..a777b1e45 --- /dev/null +++ b/crates/stack-error/src/ext.rs @@ -0,0 +1,211 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::any::Any; +use std::sync::Arc; + +use crate::status_code::StatusCode; + +/// Extension to [`Error`](std::error::Error) in std. +pub trait ErrorExt: StackError { + /// Map this error to [StatusCode]. + fn status_code(&self) -> StatusCode { + StatusCode::Unknown + } + + /// Returns the error as [Any](std::any::Any) so that it can be + /// downcast to a specific implementation. + fn as_any(&self) -> &dyn Any; + + fn output_msg(&self) -> String + where + Self: Sized, + { + match self.status_code() { + StatusCode::Unknown | StatusCode::Internal => { + // masks internal error from end user + format!("Internal error: {}", self.status_code() as u32) + } + _ => { + let error = self.last(); + if let Some(external_error) = error.source() { + let external_root = external_error.sources().last().unwrap(); + + if error.to_string().is_empty() { + format!("{external_root}") + } else { + format!("{error}: {external_root}") + } + } else { + format!("{error}") + } + } + } + } +} + +pub trait StackError: std::error::Error { + fn debug_fmt(&self, layer: usize, buf: &mut Vec); + + fn next(&self) -> Option<&dyn StackError>; + + fn last(&self) -> &dyn StackError + where + Self: Sized, + { + let Some(mut result) = self.next() else { + return self; + }; + while let Some(err) = result.next() { + result = err; + } + result + } +} + +impl StackError for Arc { + fn debug_fmt(&self, layer: usize, buf: &mut Vec) { + self.as_ref().debug_fmt(layer, buf) + } + + fn next(&self) -> Option<&dyn StackError> { + self.as_ref().next() + } +} + +impl StackError for Box { + fn debug_fmt(&self, layer: usize, buf: &mut Vec) { + self.as_ref().debug_fmt(layer, buf) + } + + fn next(&self) -> Option<&dyn StackError> { + self.as_ref().next() + } +} + +/// An opaque boxed error based on errors that implement [ErrorExt] trait. +pub struct BoxedError { + inner: Box, +} + +impl BoxedError { + pub fn new(err: E) -> Self { + Self { + inner: Box::new(err), + } + } +} + +impl std::fmt::Debug for BoxedError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut buf = vec![]; + self.debug_fmt(0, &mut buf); + write!(f, "{}", buf.join("\n")) + } +} + +impl std::fmt::Display for BoxedError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.inner) + } +} + +impl std::error::Error for BoxedError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + self.inner.source() + } +} + +impl crate::ext::ErrorExt for BoxedError { + fn status_code(&self) -> crate::status_code::StatusCode { + self.inner.status_code() + } + + fn as_any(&self) -> &dyn std::any::Any { + self.inner.as_any() + } +} + +// Implement ErrorCompat for this opaque error so the backtrace is also available +// via `ErrorCompat::backtrace()`. +impl crate::snafu::ErrorCompat for BoxedError { + fn backtrace(&self) -> Option<&crate::snafu::Backtrace> { + None + } +} + +impl StackError for BoxedError { + fn debug_fmt(&self, layer: usize, buf: &mut Vec) { + self.inner.debug_fmt(layer, buf) + } + + fn next(&self) -> Option<&dyn StackError> { + self.inner.next() + } +} + +/// Error type with plain error message +#[derive(Debug)] +pub struct PlainError { + msg: String, + status_code: StatusCode, +} + +impl PlainError { + pub fn new(msg: String, status_code: StatusCode) -> Self { + Self { msg, status_code } + } +} + +impl std::fmt::Display for PlainError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.msg) + } +} + +impl std::error::Error for PlainError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + None + } +} + +impl crate::ext::ErrorExt for PlainError { + fn status_code(&self) -> crate::status_code::StatusCode { + self.status_code + } + + fn as_any(&self) -> &dyn std::any::Any { + self as _ + } +} + +impl StackError for PlainError { + fn debug_fmt(&self, layer: usize, buf: &mut Vec) { + buf.push(format!("{}: {}", layer, self.msg)) + } + + fn next(&self) -> Option<&dyn StackError> { + None + } +} + +impl StackError for std::io::Error { + fn debug_fmt(&self, layer: usize, buf: &mut Vec) { + buf.push(format!("{}: {}", layer, self)) + } + + fn next(&self) -> Option<&dyn StackError> { + None + } +} diff --git a/crates/stack-error/src/lib.rs b/crates/stack-error/src/lib.rs new file mode 100644 index 000000000..c5c0e6efe --- /dev/null +++ b/crates/stack-error/src/lib.rs @@ -0,0 +1,26 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#![feature(error_iter)] + +pub mod ext; +pub mod mock; +pub mod status_code; + +pub use snafu; + +// HACK - these headers are here for shared in gRPC services. For common HTTP headers, +// please define in `src/servers/src/http/header.rs`. +pub const GREPTIME_DB_HEADER_ERROR_CODE: &str = "x-greptime-err-code"; +pub const GREPTIME_DB_HEADER_ERROR_MSG: &str = "x-greptime-err-msg"; diff --git a/crates/stack-error/src/mock.rs b/crates/stack-error/src/mock.rs new file mode 100644 index 000000000..572e11dea --- /dev/null +++ b/crates/stack-error/src/mock.rs @@ -0,0 +1,73 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Utils for mock. + +use std::any::Any; +use std::fmt; + +use crate::ext::{ErrorExt, StackError}; +use crate::status_code::StatusCode; + +/// A mock error mainly for test. +#[derive(Debug)] +pub struct MockError { + pub code: StatusCode, + source: Option>, +} + +impl MockError { + /// Create a new [MockError] without backtrace. + pub fn new(code: StatusCode) -> MockError { + MockError { code, source: None } + } + + /// Create a new [MockError] with source. + pub fn with_source(source: MockError) -> MockError { + MockError { + code: source.code, + source: Some(Box::new(source)), + } + } +} + +impl fmt::Display for MockError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.code) + } +} + +impl std::error::Error for MockError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + self.source.as_ref().map(|e| e as _) + } +} + +impl ErrorExt for MockError { + fn status_code(&self) -> StatusCode { + self.code + } + + fn as_any(&self) -> &dyn Any { + self + } +} + +impl StackError for MockError { + fn debug_fmt(&self, _: usize, _: &mut Vec) {} + + fn next(&self) -> Option<&dyn StackError> { + None + } +} diff --git a/crates/stack-error/src/status_code.rs b/crates/stack-error/src/status_code.rs new file mode 100644 index 000000000..a9d61eed5 --- /dev/null +++ b/crates/stack-error/src/status_code.rs @@ -0,0 +1,239 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::fmt; + +use strum::{AsRefStr, EnumIter, EnumString, FromRepr}; + +/// Common status code for public API. +#[derive(Debug, Clone, Copy, PartialEq, Eq, EnumString, AsRefStr, EnumIter, FromRepr)] +pub enum StatusCode { + // ====== Begin of common status code ============== + /// Success. + Success = 0, + + /// Unknown error. + Unknown = 1000, + /// Unsupported operation. + Unsupported = 1001, + /// Unexpected error, maybe there is a BUG. + Unexpected = 1002, + /// Internal server error. + Internal = 1003, + /// Invalid arguments. + InvalidArguments = 1004, + /// The task is cancelled. + Cancelled = 1005, + // ====== End of common status code ================ + + // ====== Begin of SQL related status code ========= + /// SQL Syntax error. + InvalidSyntax = 2000, + // ====== End of SQL related status code =========== + + // ====== Begin of query related status code ======= + /// Fail to create a plan for the query. + PlanQuery = 3000, + /// The query engine fail to execute query. + EngineExecuteQuery = 3001, + // ====== End of query related status code ========= + + // ====== Begin of catalog related status code ===== + /// Table already exists. + TableAlreadyExists = 4000, + TableNotFound = 4001, + TableColumnNotFound = 4002, + TableColumnExists = 4003, + DatabaseNotFound = 4004, + RegionNotFound = 4005, + RegionAlreadyExists = 4006, + RegionReadonly = 4007, + /// Region is not in a proper state to handle specific request. + RegionNotReady = 4008, + // If mutually exclusive operations are reached at the same time, + // only one can be executed, another one will get region busy. + RegionBusy = 4009, + // ====== End of catalog related status code ======= + + // ====== Begin of storage related status code ===== + /// Storage is temporarily unable to handle the request + StorageUnavailable = 5000, + /// Request is outdated, e.g., version mismatch + RequestOutdated = 5001, + // ====== End of storage related status code ======= + + // ====== Begin of server related status code ===== + /// Runtime resources exhausted, like creating threads failed. + RuntimeResourcesExhausted = 6000, + + /// Rate limit exceeded + RateLimited = 6001, + // ====== End of server related status code ======= + + // ====== Begin of auth related status code ===== + /// User not exist + UserNotFound = 7000, + /// Unsupported password type + UnsupportedPasswordType = 7001, + /// Username and password does not match + UserPasswordMismatch = 7002, + /// Not found http authorization header + AuthHeaderNotFound = 7003, + /// Invalid http authorization header + InvalidAuthHeader = 7004, + /// Illegal request to connect catalog-schema + AccessDenied = 7005, + /// User is not authorized to perform the operation + PermissionDenied = 7006, + // ====== End of auth related status code ===== + + // ====== Begin of flow related status code ===== + FlowAlreadyExists = 8000, + FlowNotFound = 8001, + // ====== End of flow related status code ===== +} + +impl StatusCode { + /// Returns `true` if `code` is success. + pub fn is_success(code: u32) -> bool { + Self::Success as u32 == code + } + + /// Returns `true` if the error with this code is retryable. + pub fn is_retryable(&self) -> bool { + match self { + StatusCode::StorageUnavailable + | StatusCode::RuntimeResourcesExhausted + | StatusCode::Internal + | StatusCode::RegionNotReady + | StatusCode::RegionBusy => true, + + StatusCode::Success + | StatusCode::Unknown + | StatusCode::Unsupported + | StatusCode::Unexpected + | StatusCode::InvalidArguments + | StatusCode::Cancelled + | StatusCode::InvalidSyntax + | StatusCode::PlanQuery + | StatusCode::EngineExecuteQuery + | StatusCode::TableAlreadyExists + | StatusCode::TableNotFound + | StatusCode::RegionAlreadyExists + | StatusCode::RegionNotFound + | StatusCode::FlowAlreadyExists + | StatusCode::FlowNotFound + | StatusCode::RegionReadonly + | StatusCode::TableColumnNotFound + | StatusCode::TableColumnExists + | StatusCode::DatabaseNotFound + | StatusCode::RateLimited + | StatusCode::UserNotFound + | StatusCode::UnsupportedPasswordType + | StatusCode::UserPasswordMismatch + | StatusCode::AuthHeaderNotFound + | StatusCode::InvalidAuthHeader + | StatusCode::AccessDenied + | StatusCode::PermissionDenied + | StatusCode::RequestOutdated => false, + } + } + + /// Returns `true` if we should print an error log for an error with + /// this status code. + pub fn should_log_error(&self) -> bool { + match self { + StatusCode::Unknown + | StatusCode::Unexpected + | StatusCode::Internal + | StatusCode::Cancelled + | StatusCode::PlanQuery + | StatusCode::EngineExecuteQuery + | StatusCode::StorageUnavailable + | StatusCode::RuntimeResourcesExhausted => true, + StatusCode::Success + | StatusCode::Unsupported + | StatusCode::InvalidArguments + | StatusCode::InvalidSyntax + | StatusCode::TableAlreadyExists + | StatusCode::TableNotFound + | StatusCode::RegionAlreadyExists + | StatusCode::RegionNotFound + | StatusCode::FlowAlreadyExists + | StatusCode::FlowNotFound + | StatusCode::RegionNotReady + | StatusCode::RegionBusy + | StatusCode::RegionReadonly + | StatusCode::TableColumnNotFound + | StatusCode::TableColumnExists + | StatusCode::DatabaseNotFound + | StatusCode::RateLimited + | StatusCode::UserNotFound + | StatusCode::UnsupportedPasswordType + | StatusCode::UserPasswordMismatch + | StatusCode::AuthHeaderNotFound + | StatusCode::InvalidAuthHeader + | StatusCode::AccessDenied + | StatusCode::PermissionDenied + | StatusCode::RequestOutdated => false, + } + } + + pub fn from_u32(value: u32) -> Option { + StatusCode::from_repr(value as usize) + } +} + +impl fmt::Display for StatusCode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + // The current debug format is suitable to display. + write!(f, "{self:?}") + } +} + +#[cfg(test)] +mod tests { + use strum::IntoEnumIterator; + + use super::*; + + fn assert_status_code_display(code: StatusCode, msg: &str) { + let code_msg = format!("{code}"); + assert_eq!(msg, code_msg); + } + + #[test] + fn test_display_status_code() { + assert_status_code_display(StatusCode::Unknown, "Unknown"); + assert_status_code_display(StatusCode::TableAlreadyExists, "TableAlreadyExists"); + } + + #[test] + fn test_from_u32() { + for code in StatusCode::iter() { + let num = code as u32; + assert_eq!(StatusCode::from_u32(num), Some(code)); + } + + assert_eq!(StatusCode::from_u32(10000), None); + } + + #[test] + fn test_is_success() { + assert!(StatusCode::is_success(0)); + assert!(!StatusCode::is_success(1)); + assert!(!StatusCode::is_success(2)); + assert!(!StatusCode::is_success(3)); + } +}