diff --git a/src/block_ads.rs b/src/block_ads.rs index 51a825a..7fa8130 100644 --- a/src/block_ads.rs +++ b/src/block_ads.rs @@ -1,6 +1,8 @@ use std::net::IpAddr; use std::collections::HashSet; use std::fs::read_to_string; +use std::fs::{File, OpenOptions}; +use std::io::{self, Write}; use pnet::datalink::{self, Channel}; use pnet::packet::Packet; @@ -10,8 +12,6 @@ use pnet::packet::tcp::TcpPacket; use pnet::packet::ip::IpNextHeaderProtocols; use dns_lookup::lookup_addr; -use trust_dns_resolver::config; -use trust_dns_resolver::proto::rr::domain; pub fn parse_adsfile(arguments: &Vec, domains_list: &mut HashSet) -> Result<(), ()> { @@ -27,12 +27,42 @@ pub fn parse_adsfile(arguments: &Vec, domains_list: &mut HashSet domains_list.insert(content.to_string()); } } + block_ads(domains_list); Ok(()) } -fn config_hosts(domain_name: &str) -> Result<(), ()> { +fn block_ads(domains_list: &mut HashSet) -> io::Result<()> { let hosts_path = "/etc/hosts"; - let domains_to_block: HashSet<&str> = vec!["ads.example.com", "adserver.net"].into_iter().collect(); + let mut file = OpenOptions::new() + .write(true) + .append(true) + .open(hosts_path)?; + + let iterator = domains_list.iter(); + + for domain in iterator { + let entry = format!("127.0.0.1 {}\n", domain); + + file.write_all(entry.as_bytes())?; + } + Ok(()) +} + +fn write_content(hosts_path: &str, new_content: String) -> Result<(), ()> { + let mut file = File::create(&hosts_path).map_err(|error| { + eprintln!("Error while opening hosts file for writing: {}", error); + })?; + + if let Err(error) = write!(file, "{}", new_content) { + eprintln!("Error while writing to hosts file: {}", error); + return Err(()); + } + + Ok(()) +} + +fn config_hosts(domain_name: &str) -> Result<(), ()> { + let hosts_path = "hosts.txt"; let block_ip = IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1)); let current_content = read_to_string(&hosts_path).map_err(|error| { @@ -48,7 +78,7 @@ fn config_hosts(domain_name: &str) -> Result<(), ()> { let parts: Vec<&str> = line.split_whitespace().collect(); if parts.len() >= 2 { let domain = parts[1]; - + if domain_name == domain { let new_line = format!("{} {}\n", block_ip, domain); new_content.push_str(&new_line); @@ -61,6 +91,7 @@ fn config_hosts(domain_name: &str) -> Result<(), ()> { new_content.push_str("\n"); } } + write_content(hosts_path, new_content)?; Ok(()) } @@ -99,7 +130,7 @@ pub fn catch_packets(interface_name: &str, blacklist: HashSet) continue; } if blacklist.contains(&domain_name) { - config_hosts(domain_name); + config_hosts(&domain_name); continue; } println!("Source IP: {}", src_ipv4);