diff --git a/src/aio/search.rs b/src/aio/search.rs index 03a4950a..97fececd 100644 --- a/src/aio/search.rs +++ b/src/aio/search.rs @@ -11,7 +11,7 @@ use tokio::time::timeout; use crate::aio::Gateway; use crate::common::{messages, parsing, SearchOptions}; use crate::errors::SearchError; -use crate::search::validate_url; +use crate::search::{check_is_ip_spoofed, validate_url}; const MAX_RESPONSE_SIZE: usize = 1500; @@ -26,37 +26,7 @@ pub async fn search_gateway(options: SearchOptions) -> Result { - if src_ip.ip() != url_ip.ip() { - return Err(SearchError::SpoofedIp { - src_ip: (*src_ip.ip()).into(), - url_ip: (*url_ip.ip()).into(), - }); - } - } - (SocketAddr::V6(src_ip), SocketAddr::V6(url_ip)) => { - if src_ip.ip() != url_ip.ip() { - return Err(SearchError::SpoofedIp { - src_ip: (*src_ip.ip()).into(), - url_ip: (*url_ip.ip()).into(), - }); - } - } - (SocketAddr::V6(src_ip), SocketAddr::V4(url_ip)) => { - return Err(SearchError::SpoofedIp { - src_ip: (*src_ip.ip()).into(), - url_ip: (*url_ip.ip()).into(), - }) - } - (SocketAddr::V4(src_ip), SocketAddr::V6(url_ip)) => { - return Err(SearchError::SpoofedIp { - src_ip: (*src_ip.ip()).into(), - url_ip: (*url_ip.ip()).into(), - }) - } - } + check_is_ip_spoofed(&from, &addr)?; let (control_schema_url, control_url) = run_with_timeout(options.http_timeout, get_control_urls(&addr, &root_url)).await??; diff --git a/src/search.rs b/src/search.rs index 142f3e1b..ccdcb8de 100644 --- a/src/search.rs +++ b/src/search.rs @@ -1,7 +1,7 @@ use std::collections::HashMap; -use std::net::IpAddr; +use std::net::{IpAddr, SocketAddr}; use std::net::{SocketAddrV4, UdpSocket}; -use std::str; +use std::str::FromStr; use crate::common::{messages, parsing, SearchOptions}; use crate::errors::SearchError; @@ -31,11 +31,13 @@ pub fn search_gateway(options: SearchOptions) -> Result { loop { let mut buf = [0u8; 1500]; - let (read, _) = socket.recv_from(&mut buf)?; - let text = str::from_utf8(&buf[..read])?; + let (read, from) = socket.recv_from(&mut buf)?; + let text = std::str::from_utf8(&buf[..read])?; let (addr, root_url) = parsing::parse_search_result(text)?; + check_is_ip_spoofed(&from, &addr.into())?; + let (control_schema_url, control_url) = match get_control_urls(&addr, &root_url) { Ok(o) => o, Err(e) => { @@ -58,13 +60,19 @@ pub fn search_gateway(options: SearchOptions) -> Result { } }; - return Ok(Gateway { + let gateway = Gateway { addr, root_url, control_url, control_schema_url, control_schema, - }); + }; + + let gateway_url = reqwest::Url::from_str(&format!("{gateway}"))?; + + validate_url((*addr.ip()).into(), &gateway_url)?; + + return Ok(gateway); } } @@ -98,6 +106,40 @@ fn get_control_schemas( parsing::parse_schemas(body.as_ref()) } +pub fn check_is_ip_spoofed(from: &SocketAddr, addr: &SocketAddr) -> Result<(), SearchError> { + match (from, addr) { + (SocketAddr::V4(src_ip), SocketAddr::V4(url_ip)) => { + if src_ip.ip() != url_ip.ip() { + return Err(SearchError::SpoofedIp { + src_ip: (*src_ip.ip()).into(), + url_ip: (*url_ip.ip()).into(), + }); + } + } + (SocketAddr::V6(src_ip), SocketAddr::V6(url_ip)) => { + if src_ip.ip() != url_ip.ip() { + return Err(SearchError::SpoofedIp { + src_ip: (*src_ip.ip()).into(), + url_ip: (*url_ip.ip()).into(), + }); + } + } + (SocketAddr::V6(src_ip), SocketAddr::V4(url_ip)) => { + return Err(SearchError::SpoofedIp { + src_ip: (*src_ip.ip()).into(), + url_ip: (*url_ip.ip()).into(), + }) + } + (SocketAddr::V4(src_ip), SocketAddr::V6(url_ip)) => { + return Err(SearchError::SpoofedIp { + src_ip: (*src_ip.ip()).into(), + url_ip: (*url_ip.ip()).into(), + }) + } + } + Ok(()) +} + pub fn validate_url(src_ip: IpAddr, url: &reqwest::Url) -> Result<(), SearchError> { match url.host_str() { Some(url_host) if url_host != src_ip.to_string() => Err(SearchError::SpoofedUrl {