diff --git a/Cargo.lock b/Cargo.lock index 811d8a9..1d39fb6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -81,13 +81,16 @@ version = "0.1.0" dependencies = [ "airmail", "anyhow", + "bincode", "clap", "crossbeam", "deunicode", "env_logger", "futures-util", "geo", + "geo-types", "geojson", + "geozero", "lazy_static", "lingua", "log", @@ -100,6 +103,7 @@ dependencies = [ "redb", "regex", "reqwest", + "rstar 0.12.0", "rustyline", "s2", "serde", @@ -115,6 +119,7 @@ name = "airmail_service" version = "0.1.0" dependencies = [ "airmail", + "anyhow", "axum", "clap", "deunicode", @@ -124,6 +129,7 @@ dependencies = [ "log", "serde", "serde_json", + "thiserror", "tokio", "tower-http", ] @@ -363,6 +369,15 @@ dependencies = [ "serde", ] +[[package]] +name = "bincode" +version = "1.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" +dependencies = [ + "serde", +] + [[package]] name = "bindgen" version = "0.68.1" @@ -651,15 +666,6 @@ dependencies = [ "static_assertions", ] -[[package]] -name = "concurrent-queue" -version = "2.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ca0197aee26d1ae37445ee532fefce43251d24cc7c166799f4d46817f1d3973" -dependencies = [ - "crossbeam-utils", -] - [[package]] name = "const-oid" version = "0.9.6" @@ -1016,14 +1022,9 @@ dependencies = [ [[package]] name = "event-listener" -version = "5.3.1" +version = "2.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6032be9bd27023a771701cc49f9f053c751055f71efb2e0ae5c15809093675ba" -dependencies = [ - "concurrent-queue", - "parking", - "pin-project-lite", -] +checksum = "0206175f82b8d6bf6652ff7d71a1e27fd2e4efde587fd368662814d6ec1d9ce0" [[package]] name = "fastdivide" @@ -1281,7 +1282,7 @@ dependencies = [ "log", "num-traits", "robust", - "rstar", + "rstar 0.11.0", "spade", ] @@ -1293,7 +1294,8 @@ checksum = "9ff16065e5720f376fbced200a5ae0f47ace85fd70b7e54269790281353b6d61" dependencies = [ "approx 0.5.1", "num-traits", - "rstar", + "rstar 0.11.0", + "rstar 0.12.0", "serde", ] @@ -1319,6 +1321,22 @@ dependencies = [ "thiserror", ] +[[package]] +name = "geozero" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0cd8fb67347739a057fd607b6d8b43ba4ed93619ed84b8f429fa3296f8ae504c" +dependencies = [ + "geo-types", + "geojson", + "log", + "scroll", + "serde_json", + "sqlx", + "thiserror", + "wkt", +] + [[package]] name = "getrandom" version = "0.2.15" @@ -1370,6 +1388,15 @@ dependencies = [ "byteorder", ] +[[package]] +name = "hash32" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47d60b12902ba28e2730cd37e95b8c9223af2808df9e902d4df49588d1470606" +dependencies = [ + "byteorder", +] + [[package]] name = "hashbrown" version = "0.14.5" @@ -1382,9 +1409,9 @@ dependencies = [ [[package]] name = "hashlink" -version = "0.9.1" +version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ba4ff7128dee98c7dc9794b6a411377e1404dba1c97deb8d1a55297bd25d8af" +checksum = "e8094feaf31ff591f651a2664fb9cfd92bba7a60ce3197265e9482ebe753c8f7" dependencies = [ "hashbrown", ] @@ -1396,17 +1423,30 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cdc6457c0eb62c71aac4bc17216026d8410337c4126773b9c5daba343f17964f" dependencies = [ "atomic-polyfill", - "hash32", + "hash32 0.2.1", "rustc_version", "spin 0.9.8", "stable_deref_trait", ] +[[package]] +name = "heapless" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bfb9eb618601c89945a70e254898da93b13be0388091d42117462b265bb3fad" +dependencies = [ + "hash32 0.3.1", + "stable_deref_trait", +] + [[package]] name = "heck" version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" +dependencies = [ + "unicode-segmentation", +] [[package]] name = "heck" @@ -1775,9 +1815,9 @@ checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" [[package]] name = "libsqlite3-sys" -version = "0.28.0" +version = "0.27.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c10584274047cb335c23d3e61bcef8e323adae7c5c8c760540f73610177fc3f" +checksum = "cf4e226dcd58b4be396f7bd3c20da8fdee2911400705297ba7d2d7cc2c30f716" dependencies = [ "cc", "pkg-config", @@ -2404,12 +2444,6 @@ dependencies = [ "stable_deref_trait", ] -[[package]] -name = "parking" -version = "2.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb813b8af86854136c6922af0598d719255ecb2179515e6e7730d468f05c9cae" - [[package]] name = "parking_lot" version = "0.12.3" @@ -2833,11 +2867,23 @@ version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73111312eb7a2287d229f06c00ff35b51ddee180f017ab6dec1f69d62ac098d6" dependencies = [ - "heapless", + "heapless 0.7.17", "num-traits", "smallvec", ] +[[package]] +name = "rstar" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "133315eb94c7b1e8d0cb097e5a710d850263372fd028fff18969de708afc7008" +dependencies = [ + "heapless 0.8.0", + "num-traits", + "serde", + "smallvec", +] + [[package]] name = "rust-stemmers" version = "1.2.0" @@ -2963,6 +3009,12 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "scroll" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04c565b551bafbef4157586fa379538366e4385d42082f255bfd96e4fe8519da" + [[package]] name = "security-framework" version = "2.11.0" @@ -3132,9 +3184,6 @@ name = "smallvec" version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" -dependencies = [ - "serde", -] [[package]] name = "socket2" @@ -3195,9 +3244,9 @@ dependencies = [ [[package]] name = "sqlx" -version = "0.8.0" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "27144619c6e5802f1380337a209d2ac1c431002dd74c6e60aebff3c506dc4f0c" +checksum = "c9a2ccff1a000a5a59cd33da541d9f2fdcd9e6e8229cc200565942bff36d0aaa" dependencies = [ "sqlx-core", "sqlx-macros", @@ -3208,10 +3257,11 @@ dependencies = [ [[package]] name = "sqlx-core" -version = "0.8.0" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a999083c1af5b5d6c071d34a708a19ba3e02106ad82ef7bbd69f5e48266b613b" +checksum = "24ba59a9342a3d9bab6c56c118be528b27c9b60e490080e9711a04dccac83ef6" dependencies = [ + "ahash", "atoi", "byteorder", "bytes", @@ -3224,7 +3274,6 @@ dependencies = [ "futures-intrusive", "futures-io", "futures-util", - "hashbrown", "hashlink", "hex", "indexmap", @@ -3247,26 +3296,26 @@ dependencies = [ [[package]] name = "sqlx-macros" -version = "0.8.0" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a23217eb7d86c584b8cbe0337b9eacf12ab76fe7673c513141ec42565698bb88" +checksum = "4ea40e2345eb2faa9e1e5e326db8c34711317d2b5e08d0d5741619048a803127" dependencies = [ "proc-macro2", "quote", "sqlx-core", "sqlx-macros-core", - "syn 2.0.66", + "syn 1.0.109", ] [[package]] name = "sqlx-macros-core" -version = "0.8.0" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a099220ae541c5db479c6424bdf1b200987934033c2584f79a0e1693601e776" +checksum = "5833ef53aaa16d860e92123292f1f6a3d53c34ba8b1969f152ef1a7bb803f3c8" dependencies = [ "dotenvy", "either", - "heck 0.5.0", + "heck 0.4.1", "hex", "once_cell", "proc-macro2", @@ -3276,9 +3325,8 @@ dependencies = [ "sha2", "sqlx-core", "sqlx-mysql", - "sqlx-postgres", "sqlx-sqlite", - "syn 2.0.66", + "syn 1.0.109", "tempfile", "tokio", "url", @@ -3286,12 +3334,12 @@ dependencies = [ [[package]] name = "sqlx-mysql" -version = "0.8.0" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5afe4c38a9b417b6a9a5eeffe7235d0a106716495536e7727d1c7f4b1ff3eba6" +checksum = "1ed31390216d20e538e447a7a9b959e06ed9fc51c37b514b46eb758016ecd418" dependencies = [ "atoi", - "base64 0.22.1", + "base64 0.21.7", "bitflags 2.5.0", "byteorder", "bytes", @@ -3328,12 +3376,12 @@ dependencies = [ [[package]] name = "sqlx-postgres" -version = "0.8.0" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1dbb157e65f10dbe01f729339c06d239120221c9ad9fa0ba8408c4cc18ecf21" +checksum = "7c824eb80b894f926f89a0b9da0c7f435d27cdd35b8c655b114e58223918577e" dependencies = [ "atoi", - "base64 0.22.1", + "base64 0.21.7", "bitflags 2.5.0", "byteorder", "crc", @@ -3366,9 +3414,9 @@ dependencies = [ [[package]] name = "sqlx-sqlite" -version = "0.8.0" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b2cdd83c008a622d94499c0006d8ee5f821f36c89b7d625c900e5dc30b5c5ee" +checksum = "b244ef0a8414da0bed4bb1910426e890b19e5e9bccc27ada6b797d05c55ae0aa" dependencies = [ "atoi", "flume", @@ -3381,10 +3429,10 @@ dependencies = [ "log", "percent-encoding", "serde", - "serde_urlencoded", "sqlx-core", "tracing", "url", + "urlencoding", ] [[package]] @@ -4009,6 +4057,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "urlencoding" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" + [[package]] name = "userfaultfd" version = "0.8.1" @@ -4365,6 +4419,18 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "wkt" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3c2252781f8927974e8ba6a67c965a759a2b88ea2b1825f6862426bbb1c8f41" +dependencies = [ + "geo-types", + "log", + "num-traits", + "thiserror", +] + [[package]] name = "zerocopy" version = "0.7.34" diff --git a/airmail/Cargo.toml b/airmail/Cargo.toml index 2c7875a..a2404b1 100644 --- a/airmail/Cargo.toml +++ b/airmail/Cargo.toml @@ -37,5 +37,4 @@ anyhow = "1.0.86" thiserror = "1.0.63" [features] -invasive_logging = [] remote_index = ["tantivy/quickwit"] diff --git a/airmail/src/index.rs b/airmail/src/index.rs index a3765be..df2abe8 100644 --- a/airmail/src/index.rs +++ b/airmail/src/index.rs @@ -1,11 +1,10 @@ -use std::path::PathBuf; +use std::path::Path; use std::sync::Arc; use anyhow::Result; -use futures_util::future::join_all; use geo::Rect; use itertools::Itertools; -use log::debug; +use log::{trace, warn}; use s2::region::RegionCoverer; use std::collections::BTreeMap; use tantivy::schema::Value; @@ -112,7 +111,7 @@ impl AirmailIndex { self.tantivy_index.schema().get_field(FIELD_TAGS).unwrap() } - pub fn create(index_dir: &PathBuf) -> Result { + pub fn create(index_dir: &Path) -> Result { let schema = Self::schema(); let tantivy_index = tantivy::Index::open_or_create(MmapDirectory::open(index_dir)?, schema)?; @@ -346,70 +345,66 @@ impl AirmailIndex { let query_string = query.trim().replace("'s", "s"); let start = std::time::Instant::now(); - let (top_docs, searcher) = { - let query = self - .construct_query( - &searcher, - &query_string, - tags, - bbox, - boost_regions, - request_leniency, - ) - .await; - - #[cfg(feature = "invasive_logging")] - { - dbg!(&query); + + let query = self + .construct_query( + &searcher, + &query_string, + tags, + bbox, + boost_regions, + request_leniency, + ) + .await; + + trace!("Search query: {:?}", &query); + + // Perform the search and then resolve the returned documents + let top_docs: Result> = spawn_blocking(move || { + let doc_addresses = searcher.search(&query, &TopDocs::with_limit(10))?; + let mut docs = vec![]; + for (score, doc_address) in doc_addresses { + if let Ok(doc) = searcher.doc::(doc_address) { + docs.push((score, doc)); + } } - let (top_docs, searcher) = spawn_blocking(move || { - (searcher.search(&query, &TopDocs::with_limit(10)), searcher) - }) - .await?; - let top_docs = top_docs?; - debug!( - "Search took {:?} and yielded {} results", - start.elapsed(), - top_docs.len() - ); - (top_docs, searcher) - }; + Ok(docs) + }) + .await?; - let mut scores = Vec::new(); - let mut futures = Vec::new(); - for (score, doc_id) in top_docs { - let searcher = searcher.clone(); - let doc = spawn_blocking(move || searcher.doc::(doc_id)); - scores.push(score); - futures.push(doc); - } - let mut results = Vec::new(); - let top_docs = join_all(futures).await; - for (score, doc_future) in scores.iter().zip(top_docs) { - let doc = doc_future??; - let source = doc - .get_first(self.field_source()) - .map(|value| value.as_str().unwrap().to_string()) - .unwrap_or_default(); - let s2cell = doc - .get_first(self.field_s2cell()) - .unwrap() - .as_u64() - .unwrap(); - let cellid = s2::cellid::CellID(s2cell); - let latlng = s2::latlng::LatLng::from(cellid); - let tags: Vec<(String, String)> = doc - .get_first(self.field_tags()) - .unwrap() - .as_object() - .unwrap() - .map(|(k, v)| (k.to_string(), v.as_str().unwrap().to_string())) - .collect(); - - let poi = AirmailPoi::new(source, latlng.lat.deg(), latlng.lng.deg(), tags)?; - results.push((poi, *score)); - } + let top_docs = top_docs.map_err(|e| { + warn!("Search failed: {:?}", e); + e + })?; + + trace!( + "Search took {:?} and yielded {} results", + start.elapsed(), + top_docs.len() + ); + + let results = top_docs + .into_iter() + .flat_map(|(score, doc)| { + let source = doc + .get_first(self.field_source()) + .map(|value| value.as_str().unwrap_or_default().to_string()) + .unwrap_or_default(); + let s2cell = doc.get_first(self.field_s2cell())?.as_u64()?; + let cellid = s2::cellid::CellID(s2cell); + let latlng = s2::latlng::LatLng::from(cellid); + let tags: Vec<(String, String)> = doc + .get_first(self.field_tags())? + .as_object()? + .map(|(k, v)| (k.to_string(), v.as_str().unwrap_or_default().to_string())) + .collect(); + + AirmailPoi::new(source, latlng.lat.deg(), latlng.lng.deg(), tags) + .ok() + .map(|poi| (poi, score)) + }) + .collect::>(); Ok(results) } @@ -430,7 +425,7 @@ impl AirmailIndexWriter { for content in poi.content { self.process_field(&mut doc, &content); } - doc.add_text(self.schema.get_field(FIELD_SOURCE).unwrap(), source); + doc.add_text(self.schema.get_field(FIELD_SOURCE)?, source); let indexed_keys = [ "natural", "amenity", "shop", "leisure", "tourism", "historic", "cuisine", @@ -443,22 +438,22 @@ impl AirmailIndexWriter { .any(|prefix| key.starts_with(prefix)) { doc.add_text( - self.schema.get_field(FIELD_INDEXED_TAG).unwrap(), + self.schema.get_field(FIELD_INDEXED_TAG)?, format!("{}={}", key, value).as_str(), ); } } doc.add_object( - self.schema.get_field(FIELD_TAGS).unwrap(), + self.schema.get_field(FIELD_TAGS)?, poi.tags .iter() .map(|(k, v)| (k.to_string(), OwnedValue::Str(v.to_string()))) .collect::>(), ); - doc.add_u64(self.schema.get_field(FIELD_S2CELL).unwrap(), poi.s2cell); + doc.add_u64(self.schema.get_field(FIELD_S2CELL)?, poi.s2cell); for parent in poi.s2cell_parents { - doc.add_u64(self.schema.get_field(FIELD_S2CELL_PARENTS).unwrap(), parent); + doc.add_u64(self.schema.get_field(FIELD_S2CELL_PARENTS)?, parent); } self.tantivy_writer.add_document(doc)?; diff --git a/airmail/src/lib.rs b/airmail/src/lib.rs index 3970723..75fbf74 100644 --- a/airmail/src/lib.rs +++ b/airmail/src/lib.rs @@ -1,3 +1,6 @@ +#![forbid(unsafe_code)] +#![warn(clippy::missing_panics_doc)] + #[macro_use] extern crate lazy_static; diff --git a/airmail_indexer/Cargo.toml b/airmail_indexer/Cargo.toml index f432ef7..9928442 100644 --- a/airmail_indexer/Cargo.toml +++ b/airmail_indexer/Cargo.toml @@ -69,10 +69,14 @@ lingua = { version = "1.6.2", default-features = false, features = [ ] } redb = "1.5.0" anyhow = "1.0.86" -sqlx = { version = "0.8", features = ["runtime-tokio", "sqlite"] } -osmx = { version = "0.1.0", optional = true } +sqlx = { version = "0.7", features = ["runtime-tokio", "sqlite"] } +osmx = { version = "0.1.0" } thiserror = "1.0.63" +rstar = { version = "0.12.0", features = ["serde"] } +geo-types = { version = "0.7.11", features = ["use-rstar_0_12"] } +bincode = { version = "1.3.3" } +geozero = { version = "0.13.0", features = ["with-geo", "with-gpkg"] } [features] -default = ["remote_index", "dep:osmx"] +default = ["remote_index"] remote_index = ["airmail/remote_index"] diff --git a/airmail_indexer/src/importer.rs b/airmail_indexer/src/importer.rs index 9cef51b..7b16930 100644 --- a/airmail_indexer/src/importer.rs +++ b/airmail_indexer/src/importer.rs @@ -14,29 +14,44 @@ use std::{ use tokio::{spawn, task::spawn_blocking}; use crate::{ - populate_admin_areas, wof::WhosOnFirst, WofCacheItem, TABLE_AREAS, TABLE_LANGS, TABLE_NAMES, + pip_tree::PipTree, + populate_admin_areas, + wof::{ConcisePipResponse, WhosOnFirst}, + WofCacheItem, TABLE_AREAS, TABLE_LANGS, TABLE_NAMES, }; pub struct ImporterBuilder { - admin_cache: Option, + index: AirmailIndex, + admin_cache_path: Option, wof_db_path: PathBuf, + pip_tree_path: Option, } impl ImporterBuilder { - pub fn new(whosonfirst_spatialite_path: &Path) -> Self { - Self { - admin_cache: None, - wof_db_path: whosonfirst_spatialite_path.to_path_buf(), - } + pub fn new(airmail_index_path: &Path, wof_db_path: &Path) -> Result { + // Create the index + let index = AirmailIndex::create(airmail_index_path)?; + + Ok(Self { + index, + admin_cache_path: None, + wof_db_path: wof_db_path.to_path_buf(), + pip_tree_path: None, + }) } pub fn admin_cache(mut self, admin_cache: &Path) -> Self { - self.admin_cache = Some(admin_cache.to_path_buf()); + self.admin_cache_path = Some(admin_cache.to_path_buf()); + self + } + + pub fn pip_tree_cache(mut self, pip_tree_cache: &Path) -> Self { + self.pip_tree_path = Some(pip_tree_cache.to_path_buf()); self } pub async fn build(self) -> Result { - let admin_cache = if let Some(admin_cache) = self.admin_cache { + let admin_cache = if let Some(admin_cache) = self.admin_cache_path { admin_cache } else { std::env::temp_dir().join("admin_cache.db") @@ -54,27 +69,40 @@ impl ImporterBuilder { let wof_db = WhosOnFirst::new(&self.wof_db_path).await?; - Ok(Importer { - admin_cache: Arc::new(db), - wof_db, - }) + let pip_tree = if let Some(pip_tree_cache) = self.pip_tree_path { + Some(PipTree::new_or_load(&wof_db, &pip_tree_cache).await?) + } else { + None + }; + + Importer::new(self.index, db, wof_db, pip_tree).await } } pub struct Importer { + index: AirmailIndex, admin_cache: Arc, wof_db: WhosOnFirst, + pip_tree: Option>, } impl Importer { - pub async fn run_import( - &self, - mut index: AirmailIndex, - source: &str, - receiver: Receiver, - ) -> Result<()> { + pub async fn new( + index: AirmailIndex, + admin_cache: Database, + wof_db: WhosOnFirst, + pip_tree: Option>, + ) -> Result { + Ok(Self { + index, + admin_cache: Arc::new(admin_cache), + wof_db, + pip_tree, + }) + } + + pub async fn run_import(mut self, source: &str, receiver: Receiver) -> Result<()> { let source = source.to_string(); - // let mut nonblocking_join_handles = Vec::new(); let (to_cache_sender, to_cache_receiver): (Sender, Receiver) = crossbeam::channel::bounded(1024); let (to_index_sender, to_index_receiver): (Sender, Receiver) = @@ -116,10 +144,11 @@ impl Importer { } })); + let mut writer = self.index.writer()?; + handles.push(spawn_blocking(move || { let start = std::time::Instant::now(); - let mut writer = index.writer().unwrap(); let mut count = 0; loop { { @@ -152,6 +181,7 @@ impl Importer { let to_cache_sender = to_cache_sender.clone(); let admin_cache = self.admin_cache.clone(); let wof_db = self.wof_db.clone(); + let pip_tree = self.pip_tree.clone(); handles.push(spawn(async move { let mut read = admin_cache.begin_read().unwrap(); @@ -168,8 +198,14 @@ impl Importer { ); } - match populate_admin_areas(&read, to_cache_sender.clone(), &mut poi, &wof_db) - .await + match populate_admin_areas( + &read, + to_cache_sender.clone(), + &mut poi, + &wof_db, + &pip_tree, + ) + .await { Ok(()) => { let poi = SchemafiedPoi::from(poi); diff --git a/airmail_indexer/src/lib.rs b/airmail_indexer/src/lib.rs index 606e490..ca2299c 100644 --- a/airmail_indexer/src/lib.rs +++ b/airmail_indexer/src/lib.rs @@ -1,5 +1,6 @@ pub mod error; mod importer; +mod pip_tree; mod query_pip; mod wof; @@ -12,9 +13,10 @@ use airmail::poi::ToIndexPoi; use anyhow::Result; use crossbeam::channel::Sender; use lingua::{IsoCode639_3, Language}; +use pip_tree::PipTree; use redb::{ReadTransaction, TableDefinition}; use std::str::FromStr; -use wof::WhosOnFirst; +use wof::{ConcisePipResponse, WhosOnFirst}; pub(crate) const TABLE_AREAS: TableDefinition = TableDefinition::new("admin_areas"); pub(crate) const TABLE_NAMES: TableDefinition = TableDefinition::new("admin_names"); @@ -58,8 +60,10 @@ pub(crate) async fn populate_admin_areas( to_cache_sender: Sender, poi: &mut ToIndexPoi, wof_db: &WhosOnFirst, + pip_tree: &Option>, ) -> Result<()> { - let pip_response = query_pip::query_pip(read, to_cache_sender, poi.s2cell, wof_db).await?; + let pip_response = + query_pip::query_pip(read, to_cache_sender, poi.s2cell, wof_db, pip_tree).await?; for admin in pip_response.admin_names { poi.admins.push(admin); } diff --git a/airmail_indexer/src/main.rs b/airmail_indexer/src/main.rs index 50a8890..6e18aaa 100644 --- a/airmail_indexer/src/main.rs +++ b/airmail_indexer/src/main.rs @@ -1,12 +1,13 @@ -use std::path::PathBuf; +#![forbid(unsafe_code)] +#![warn(clippy::pedantic)] -use airmail::index::AirmailIndex; use airmail_indexer::ImporterBuilder; use anyhow::Result; use clap::Parser; use env_logger::Env; use futures_util::future::join_all; use log::warn; +use std::path::PathBuf; use tokio::{select, spawn, task::spawn_blocking}; mod openstreetmap; @@ -32,8 +33,16 @@ struct Args { #[clap(long, short)] admin_cache: Option, + /// Path to `WhosOnFirst` spatial index for point-in-polygon lookups. If this is specified + /// we'll use the spatial index instead of sqlite geospatial lookups. This will speed up imports, + /// after the index is built. It'll be faster for planet scale imports, or frequent imports + /// but will use 10GB of memory and takes a few minutes to build. `mod_spatialite` is not required + /// if this is specified. + #[clap(long, short)] + pip_tree: Option, + // ============================ OSM-specific options =================================== - /// Path to an OSMExpress file to import. + /// Path to an `OSMExpress` file to import. #[clap(long, short)] osmx: PathBuf, } @@ -44,27 +53,27 @@ async fn main() -> Result<()> { let args = Args::parse(); let mut handles = vec![]; - // Create the index - let index = AirmailIndex::create(&args.index)?; - // Setup the import pipeline - let mut import_builder = ImporterBuilder::new(&args.wof_db); + let mut import_builder = ImporterBuilder::new(&args.index, &args.wof_db)?; if let Some(admin_cache) = args.admin_cache { import_builder = import_builder.admin_cache(&admin_cache); } + if let Some(pip_tree) = args.pip_tree { + import_builder = import_builder.pip_tree_cache(&pip_tree); + } let importer = import_builder.build().await?; // Send POIs from the OSM parser to the importer - let (poi_sender, poi_receiver) = crossbeam::channel::bounded(4096); + let (poi_sender, poi_receiver) = crossbeam::channel::bounded(16384); // Spawn the OSM parser handles.push(spawn_blocking(move || { - openstreetmap::parse_osm(&args.osmx, poi_sender) + openstreetmap::parse_osm(&args.osmx, &poi_sender) })); // Spawn the importer handles.push(spawn(async move { - importer.run_import(index, "osm", poi_receiver).await + importer.run_import("osm", poi_receiver).await })); // Wait for the first thing to finish diff --git a/airmail_indexer/src/openstreetmap.rs b/airmail_indexer/src/openstreetmap.rs index b33d9df..678e249 100644 --- a/airmail_indexer/src/openstreetmap.rs +++ b/airmail_indexer/src/openstreetmap.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, error::Error, path::Path}; +use std::{collections::HashMap, path::Path}; use airmail::poi::ToIndexPoi; use airmail_indexer::error::IndexerError; @@ -9,20 +9,9 @@ use log::{debug, info, warn}; use osmx::{Database, Locations, Transaction, Way}; fn tags_to_poi(tags: &HashMap, lat: f64, lng: f64) -> Option { - if tags.is_empty() { - return None; - } - if tags.contains_key("highway") - || tags.contains_key("natural") - || tags.contains_key("boundary") - || tags.contains_key("admin_level") - { - return None; - } - - let house_number = tags.get("addr:housenumber").map(|s| s.to_string()); - let road = tags.get("addr:street").map(|s| s.to_string()); - let unit = tags.get("addr:unit").map(|s| s.to_string()); + let house_number = tags.get("addr:housenumber").map(ToString::to_string); + let road = tags.get("addr:street").map(ToString::to_string); + let unit = tags.get("addr:unit").map(ToString::to_string); let names = { let mut names = Vec::new(); @@ -88,26 +77,40 @@ fn index_way( tags_to_poi(tags, lat, lng) } -fn tags<'a, I: Iterator>( - tag_iterator: I, -) -> Result, Box> { +fn valid_tags(tags: &HashMap) -> bool { + if tags.is_empty() { + return false; + } + if tags.contains_key("highway") + || tags.contains_key("natural") + || tags.contains_key("boundary") + || tags.contains_key("admin_level") + { + return false; + } + + true +} + +fn tags<'a, I: Iterator>(tag_iterator: I) -> HashMap { let mut tags = HashMap::new(); for (key, value) in tag_iterator { tags.insert(key.to_string(), value.to_string()); } - Ok(tags) + + tags } -/// Parse an OSMExpress file and send POIs for indexing. -pub(crate) fn parse_osm(osmx_path: &Path, sender: Sender) -> Result<()> { +/// Parse an `OSMExpress` file and send POIs for indexing. +pub(crate) fn parse_osm(osmx_path: &Path, sender: &Sender) -> Result<()> { info!("Loading osmx from path: {:?}", osmx_path); - let db = Database::open(osmx_path).unwrap(); + let db = Database::open(osmx_path).map_err(IndexerError::from)?; + let osm = Transaction::begin(&db).map_err(IndexerError::from)?; + let locations = osm.locations().map_err(IndexerError::from)?; let mut interesting = 0; let mut total = 0; info!("Processing nodes"); { - let osm = Transaction::begin(&db).map_err(IndexerError::from)?; - let locations = osm.locations().map_err(IndexerError::from)?; for (node_id, node) in osm.nodes().map_err(IndexerError::from)?.iter() { total += 1; if interesting % 10000 == 0 { @@ -120,7 +123,7 @@ pub(crate) fn parse_osm(osmx_path: &Path, sender: Sender) -> Result< } let tags = tags(node.tags()); - if let Ok(tags) = tags { + if valid_tags(&tags) { let location = locations.get(node_id).expect("Nodes must have locations"); if let Some(poi) = tags_to_poi(&tags, location.lat(), location.lon()) { sender.send(poi).map_err(|e| { @@ -134,8 +137,6 @@ pub(crate) fn parse_osm(osmx_path: &Path, sender: Sender) -> Result< } info!("Processing ways"); { - let osm = Transaction::begin(&db).map_err(IndexerError::from)?; - let locations = osm.locations().map_err(IndexerError::from)?; for (_way_id, way) in osm.ways().map_err(IndexerError::from)?.iter() { if interesting % 10000 == 0 { debug!( @@ -145,9 +146,8 @@ pub(crate) fn parse_osm(osmx_path: &Path, sender: Sender) -> Result< sender.len() ); } - let tags = tags(way.tags()); - if let Ok(tags) = tags { + if valid_tags(&tags) { if let Some(poi) = index_way(&tags, &way, &locations) { sender.send(poi).map_err(|e| { warn!("Error from sender: {}", e); diff --git a/airmail_indexer/src/pip_tree.rs b/airmail_indexer/src/pip_tree.rs new file mode 100644 index 0000000..864497c --- /dev/null +++ b/airmail_indexer/src/pip_tree.rs @@ -0,0 +1,158 @@ +use std::{path::Path, sync::Arc}; + +use anyhow::Result; +use geo::Polygon; +use geo_types::{Geometry, Point}; +use log::{debug, info}; +use rstar::{primitives::GeomWithData, RTree, AABB}; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use tokio::task::spawn_blocking; + +use crate::wof::{ConcisePipResponse, PipWithGeometry, WhosOnFirst}; + +/// A spatial index to hold hold and efficiently query polygons +#[derive(Serialize, Deserialize, Clone)] +pub struct PipTree { + tree: Arc>>, +} + +impl PipTree { + /// Either load a `PipTree` from disk, or create a new one + /// from a `WhosOnFirst` database, and write it to disk. + /// + /// Either load a constructed `WhoIsOnFirst` `PipTree` from disk, + /// or assemble a new one from a `WhosOnFirst` database. + pub async fn new_or_load(wof_db: &WhosOnFirst, path: &Path) -> Result { + if path.exists() { + Self::new_from_disk(path).await + } else { + let pip_tree = Self::new_from_wof_db(wof_db).await?; + pip_tree.write_to_disk(path).await?; + Ok(pip_tree) + } + } + + /// Create a new `PipTree` from a `WhosOnFirst` database. + pub async fn new_from_wof_db(wof_db: &WhosOnFirst) -> Result { + info!("Creating PipTree from WhosOnFirst database"); + let features: Vec = wof_db.all_polygons().await?; + Ok(Self::new(features)) + } + + /// Load a `PipTree` from disk. + pub async fn new_from_disk(path: &Path) -> Result { + let path = path.to_path_buf(); + + let handle = spawn_blocking(move || { + info!("Loading PipTree from disk: {:?}", path); + + let file = std::fs::File::open(path)?; + let reader = std::io::BufReader::new(file); + let tree = bincode::deserialize_from(reader)?; + + Ok(tree) + }); + + handle.await? + } +} + +/// A semi-generic spatial index to hold and efficiently query polygons +impl PipTree +where + T: Clone + DeserializeOwned + Serialize + Send + Sync + 'static, +{ + /// Create a new `PipTree` from a list of features. + /// The features ordinarily contain both geometry and properties, + /// so they need to be split into their component parts for storage. + /// E.g. `impl From for (Option>, T)` + #[must_use] + pub fn new(features: Vec) -> Self + where + S: Into<(Option>, T)>, + { + let features: Vec> = features + .into_iter() + .filter_map(|feature| { + let (geom, t) = feature.into(); + if let Some(Geometry::Polygon(polygon)) = geom { + Some(GeomWithData::new(polygon, t)) + } else { + None + } + }) + .collect(); + + info!("Creating PipTree with {} polygons", features.len()); + let tree = RTree::bulk_load(features); + debug!("PipTree created"); + + Self { + tree: Arc::new(tree), + } + } + + /// Write the `PipTree` to disk. + pub async fn write_to_disk(&self, destination: &Path) -> Result<()> { + let destination = destination.to_path_buf(); + let tree = self.clone(); + + let handle = spawn_blocking(move || { + let size = tree.tree.size(); + debug!( + "Writing PipTree to disk: {:?}, tree size: {}", + destination, size + ); + + let file = std::fs::File::create(destination)?; + let writer = std::io::BufWriter::new(file); + bincode::serialize_into(writer, &tree)?; + + Ok(()) + }); + + handle.await? + } + + /// Find all polygons containing a given point. + pub async fn point_in_polygon(&self, lng: f64, lat: f64) -> Result> { + let self_c = self.clone(); + let handle = spawn_blocking(move || { + let polygons = self_c + .geo_point_in_polygon(Point::new(lng, lat)) + .unwrap_or_default(); + + Ok(polygons) + }); + + handle.await? + } + + /// Find all polygons within a given bounding box. + fn geo_point_in_polygon(&self, point: Point) -> Option> { + let point = AABB::from_point(point); + let found_ids = self + .tree + .locate_in_envelope_intersecting(&point) + .map(|f| f.data.clone()) + .collect::>(); + + if found_ids.is_empty() { + None + } else { + Some(found_ids) + } + } + + /// Size of the `PipTree`. + #[allow(dead_code)] + pub fn len(&self) -> usize { + self.tree.size() + } + + /// Is the `PipTree` empty? + #[allow(dead_code)] + pub fn is_empty(&self) -> bool { + self.tree.size() == 0 + } +} diff --git a/airmail_indexer/src/query_pip.rs b/airmail_indexer/src/query_pip.rs index 5ef9f87..342067a 100644 --- a/airmail_indexer/src/query_pip.rs +++ b/airmail_indexer/src/query_pip.rs @@ -8,7 +8,8 @@ use serde::Deserialize; use crate::{ error::IndexerError, - wof::{PipLangsResponse, WhosOnFirst}, + pip_tree::PipTree, + wof::{ConcisePipResponse, PipLangsResponse, WhosOnFirst}, WofCacheItem, COUNTRIES, TABLE_AREAS, TABLE_LANGS, TABLE_NAMES, }; @@ -32,6 +33,7 @@ async fn query_pip_inner( read: &'_ ReadTransaction<'_>, to_cache_sender: Sender, wof_db: &WhosOnFirst, + pip_tree: &Option>, ) -> Result { let desired_level = 15; let cell = s2::cellid::CellID(s2cell); @@ -72,19 +74,26 @@ async fn query_pip_inner( let lat_lng = s2::latlng::LatLng::from(cell); let lat = lat_lng.lat.deg(); let lng = lat_lng.lng.deg(); - let response = wof_db.point_in_polygon(lng, lat).await?; + + // Prefer the pip_tree, if in use + let response = if let Some(pip_tree) = pip_tree { + pip_tree.point_in_polygon(lng, lat).await? + } else { + wof_db.point_in_polygon(lng, lat).await? + }; + let mut response_ids = Vec::new(); for concise_response in response { let admin_id: u64 = concise_response.id.parse()?; - // These filters are also applied in SQL - if concise_response.r#type == "planet" - || concise_response.r#type == "marketarea" - || concise_response.r#type == "county" - || concise_response.r#type == "timezone" - { - continue; - } + // These filters are applied in SQL, to reduce allocations + // if concise_response.r#type == "planet" + // || concise_response.r#type == "marketarea" + // || concise_response.r#type == "county" + // || concise_response.r#type == "timezone" + // { + // continue; + // } response_ids.push(admin_id); } @@ -184,8 +193,9 @@ pub(crate) async fn query_pip( to_cache_sender: Sender, s2cell: u64, wof_db: &WhosOnFirst, + pip_tree: &Option>, ) -> Result { - let wof_ids = query_pip_inner(s2cell, read, to_cache_sender.clone(), wof_db).await?; + let wof_ids = query_pip_inner(s2cell, read, to_cache_sender.clone(), wof_db, pip_tree).await?; let mut response = PipResponse::default(); let mut admin_name_futures = vec![]; let mut lang_futures = vec![]; diff --git a/airmail_indexer/src/wof.rs b/airmail_indexer/src/wof.rs index df38826..b5c4bdf 100644 --- a/airmail_indexer/src/wof.rs +++ b/airmail_indexer/src/wof.rs @@ -1,6 +1,6 @@ use anyhow::Result; use log::debug; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use sqlx::{ sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePoolOptions}, Pool, Sqlite, @@ -122,22 +122,57 @@ impl WhosOnFirst { Ok(rows) } + + /// Retrieve a flat representation of all polygons in the database. + /// This call can be 10GB+ of data. + pub async fn all_polygons(&self) -> Result> { + // Geometry is stored as spatialite blob, so decode to WKB (geopackage compatible). + let rows = sqlx::query_as::<_, PipWithGeometry>( + r" + SELECT + place.source, + place.id, + place.class, + place.type, + AsGPB(shard.geom) as geom + FROM shard + LEFT JOIN place USING (source, id) + WHERE place.source IS NOT NULL + AND ( + place.type != 'planet' + AND place.type != 'marketarea' + AND place.type != 'county' + AND place.type != 'timezone' + ) + ", + ) + .fetch_all(&self.pool) + .await?; + + Ok(rows) + } } +/// A key-value pair from the WhosOnFirst database. #[derive(Debug, Clone, Deserialize, sqlx::FromRow)] pub struct WofKV { key: String, value: String, } -#[derive(Debug, Clone, Deserialize, sqlx::FromRow)] +/// A concise representation of a place in the WhosOnFirst database. +#[derive(Debug, Clone, sqlx::FromRow, Serialize, Deserialize)] pub struct ConcisePipResponse { - // #[allow(dead_code)] - // pub source: String, + /// WOF data source, usually wof + pub source: String, + + /// WOF ID of the place pub id: String, - // #[allow(dead_code)] - // pub class: String, - #[serde(rename = "type")] + + /// High level bucket of human activity - https://whosonfirst.org/docs/categories/ + /// POINT-OF-VIEW > CLASS > CATEGORY + pub class: String, + pub r#type: String, } @@ -156,6 +191,25 @@ pub struct PipLangsResponse { pub langs: Option, } +/// Represents a place in the WhosOnFirst database with a geometry. +#[derive(sqlx::FromRow)] +pub struct PipWithGeometry { + /// WOF data source, usually wof + pub source: String, + + /// WOF ID of the place + pub id: String, + + /// High level bucket of human activity - https://whosonfirst.org/docs/categories/ + /// POINT-OF-VIEW > CLASS > CATEGORY + pub class: String, + + pub r#type: String, + + pub geom: geozero::wkb::Decode>, +} + +/// Convert from a list of key-value pairs to a PipLangsResponse. impl From> for PipLangsResponse { fn from(value: Vec) -> Self { let mut langs = None; @@ -167,3 +221,18 @@ impl From> for PipLangsResponse { Self { langs } } } + +/// Deconstruct a PipWithGeometry into a geometry and a concise response. +impl From for (Option>, ConcisePipResponse) { + fn from(value: PipWithGeometry) -> Self { + ( + value.geom.geometry, + ConcisePipResponse { + source: value.source, + id: value.id, + class: value.class, + r#type: value.r#type, + }, + ) + } +} diff --git a/airmail_indexer/src/wof_tests.rs b/airmail_indexer/src/wof_tests.rs index e7536a9..c48bbae 100644 --- a/airmail_indexer/src/wof_tests.rs +++ b/airmail_indexer/src/wof_tests.rs @@ -4,20 +4,25 @@ use std::path::Path; use log::debug; -use crate::wof::WhosOnFirst; +const DEFAULT_LOG_LEVEL: &str = "debug,sqlx=info"; +const DEFAULT_WOF_DB: &str = "../data/whosonfirst-data-admin-latest.spatial.db"; +const DEFAULT_PIP_TREE: &str = "../data/pip_tree.bin"; +use crate::{ + pip_tree::PipTree, + wof::{ConcisePipResponse, PipLangsResponse, WhosOnFirst}, +}; + +/// Connect to WOF and perform some queries #[tokio::test] async fn wof_read() -> Result<()> { - let _ = env_logger::Builder::from_env(Env::default().default_filter_or("debug")) + let _ = env_logger::Builder::from_env(Env::default().default_filter_or(DEFAULT_LOG_LEVEL)) .is_test(true) .try_init(); // Connect to the WhosOnFirst database. - // Ensuring the database is present, and the mod_spatialite extension is loaded. - let wof = WhosOnFirst::new(Path::new( - "../data/whosonfirst-data-admin-latest.spatial.db", - )) - .await?; + // Ensuring the database is present, and mod_spatialite extension is loaded. + let wof = WhosOnFirst::new(Path::new(DEFAULT_WOF_DB)).await?; // Test point_in_polygon. // This should exist in Global and Australia. @@ -33,8 +38,51 @@ async fn wof_read() -> Result<()> { assert!(!place_name.is_empty()); // Lookup a country - let country = wof.properties_for_id(85632793).await?; + let country: PipLangsResponse = wof.properties_for_id(85632793).await?.into(); debug!("country: {:?}", country); + // Ensure eng is in the languages + assert!(country.langs.unwrap().contains("eng")); + + Ok(()) +} + +/// Connect to WOF, create a PipTree, and perform some queries +#[tokio::test] +async fn wof_pip_tree() -> Result<()> { + let _ = env_logger::Builder::from_env(Env::default().default_filter_or(DEFAULT_LOG_LEVEL)) + .is_test(true) + .try_init(); + + // Connect to the WhosOnFirst database. + // Ensuring the database is present, and mod_spatialite extension is loaded. + let wof = WhosOnFirst::new(Path::new(DEFAULT_WOF_DB)).await?; + + // Create a PipTree from the WhosOnFirst database. + let pip_tree = + PipTree::::new_or_load(&wof, Path::new(DEFAULT_PIP_TREE)).await?; + debug!("Tree size: {}", pip_tree.len()); + assert!(!pip_tree.is_empty()); + + // Should not match (coords are backwards) + let invalid_pip = pip_tree + .point_in_polygon(-33.88246407738443, 150.93800658307805) + .await?; + debug!("invalid_pip: {:?}", invalid_pip); + assert!(invalid_pip.is_empty()); + + let pip_from_tree = pip_tree + .point_in_polygon(150.93800658307805, -33.88246407738443) + .await?; + debug!("pip_from_tree: {:?}", pip_from_tree); + assert!(!pip_from_tree.is_empty()); + + // Ensure the PipTree matches the database + let pip_from_db = wof + .point_in_polygon(150.93800658307805, -33.88246407738443) + .await?; + debug!("pip_from_db: {:?}", pip_from_db); + assert_eq!(pip_from_tree.len(), pip_from_db.len()); + Ok(()) } diff --git a/airmail_service/Cargo.toml b/airmail_service/Cargo.toml index ebc9975..243f0c6 100644 --- a/airmail_service/Cargo.toml +++ b/airmail_service/Cargo.toml @@ -18,8 +18,9 @@ serde_json = "1" futures-util = "0.3.30" tower-http = { version = "0.5.1", features = ["cors"] } geo = "0.27.0" +anyhow = "1.0.86" +thiserror = "1.0.63" [features] -invasive_logging = ["airmail/invasive_logging"] remote_index = ["airmail/remote_index"] -default = ["remote_index"] \ No newline at end of file +default = ["remote_index"] diff --git a/airmail_service/src/api.rs b/airmail_service/src/api.rs new file mode 100644 index 0000000..e92b57b --- /dev/null +++ b/airmail_service/src/api.rs @@ -0,0 +1,93 @@ +use std::sync::Arc; + +use airmail::{index::AirmailIndex, poi::AirmailPoi}; +use anyhow::Result; +use axum::{ + extract::{Query, State}, + response::IntoResponse, + Json, +}; +use deunicode::deunicode; +use geo::{Coord, Rect}; +use log::debug; +use serde::{Deserialize, Serialize}; + +use crate::error::AirmailServiceError; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SearchQueryParams { + q: String, + + #[serde(default, skip_serializing_if = "Option::is_none")] + tags: Option, + + #[serde(default, skip_serializing_if = "Option::is_none")] + leniency: Option, + + #[serde(default, skip_serializing_if = "Option::is_none")] + bbox: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Response { + metadata: MetadataResponse, + features: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct MetadataResponse { + query: SearchQueryParams, +} + +fn parse_bbox(s: &str) -> Option { + let mut parts = s.split(','); + let min_lng: f64 = parts.next()?.parse().ok()?; + let min_lat: f64 = parts.next()?.parse().ok()?; + let max_lng: f64 = parts.next()?.parse().ok()?; + let max_lat: f64 = parts.next()?.parse().ok()?; + + Some(Rect::new( + Coord { + y: min_lat, + x: min_lng, + }, + Coord { + y: max_lat, + x: max_lng, + }, + )) +} + +pub async fn search( + Query(params): Query, + State(index): State>, +) -> Result { + let query = deunicode(params.q.trim()).to_lowercase(); + let tags: Option> = params + .tags + .clone() + .map(|s| s.split(',').map(std::string::ToString::to_string).collect()); + let leniency = params.leniency.unwrap_or_default(); + let bbox = params.bbox.clone().and_then(|s| parse_bbox(&s)); + + let start = std::time::Instant::now(); + + let results = index.search(&query, leniency, tags, bbox, &[]).await?; + + debug!( + "Query: {:?} produced: {} results found in {:?}", + params, + results.len(), + start.elapsed() + ); + + let response = Response { + metadata: MetadataResponse { query: params }, + features: results + .into_iter() + .map(|(results, _)| results) + .collect::>(), + }; + + Ok(Json(serde_json::to_value(response)?)) +} diff --git a/airmail_service/src/error.rs b/airmail_service/src/error.rs new file mode 100644 index 0000000..a9b2166 --- /dev/null +++ b/airmail_service/src/error.rs @@ -0,0 +1,38 @@ +use axum::{ + http::StatusCode, + response::{IntoResponse, Response}, +}; +use log::warn; +use thiserror::Error; + +#[derive(Error, Debug)] +#[allow(clippy::module_name_repetitions)] +pub enum AirmailServiceError { + #[error("general error: `{0}`")] + InternalAnyhowError(Box), + + #[error("failed to encode response")] + SerdeEncodeError(#[from] serde_json::Error), +} + +// Tell axum how to convert `AppError` into a response. +impl IntoResponse for AirmailServiceError { + fn into_response(self) -> Response { + match &self { + Self::InternalAnyhowError(e) => { + warn!("InternalAnyhowError: {:#}", self); + (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response() + } + Self::SerdeEncodeError(e) => { + warn!("SerdeEncodeError: {:#}", self); + (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response() + } // _ => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()).into_response(), + } + } +} + +impl From for AirmailServiceError { + fn from(e: anyhow::Error) -> Self { + Self::InternalAnyhowError(Box::new(e)) + } +} diff --git a/airmail_service/src/main.rs b/airmail_service/src/main.rs index 13e914f..4428157 100644 --- a/airmail_service/src/main.rs +++ b/airmail_service/src/main.rs @@ -1,114 +1,75 @@ -use std::{collections::HashMap, error::Error, sync::Arc}; +#![forbid(unsafe_code)] +#![warn(clippy::pedantic)] -use airmail::{index::AirmailIndex, poi::AirmailPoi}; -use axum::{ - extract::{Query, State}, - http::HeaderValue, - routing::get, - Json, Router, -}; +use std::future::IntoFuture; +use std::sync::Arc; + +use airmail::index::AirmailIndex; +use anyhow::Result; +use api::search; +use axum::{http::HeaderValue, routing::get, Router}; use clap::Parser; -use deunicode::deunicode; -use geo::{Coord, Rect}; -use serde::{Deserialize, Serialize}; -use serde_json::Value; -use tokio::task::spawn_blocking; +use env_logger::Env; +use log::{debug, info, warn}; +use tokio::net::TcpListener; +use tokio::select; use tower_http::cors::CorsLayer; +mod api; +mod error; + #[derive(Debug, Parser)] struct Args { + /// The path to the index to load #[arg(short, long, env = "AIRMAIL_INDEX")] index: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -struct Response { - metadata: HashMap, - features: Vec, -} -fn parse_bbox(s: &str) -> Option { - let mut parts = s.split(','); - let min_lng: f64 = parts.next()?.parse().ok()?; - let min_lat: f64 = parts.next()?.parse().ok()?; - let max_lng: f64 = parts.next()?.parse().ok()?; - let max_lat: f64 = parts.next()?.parse().ok()?; + /// The address to bind to + #[arg(short, long, env = "AIRMAIL_BIND", default_value = "127.0.0.1:3000")] + bind: String, - Some(Rect::new( - Coord { - y: min_lat, - x: min_lng, - }, - Coord { - y: max_lat, - x: max_lng, - }, - )) + /// Cors origins to allow + #[arg( + short, + long, + env = "AIRMAIL_CORS", + default_value = "http://localhost:5173" + )] + cors: Option>, } -async fn search( - Query(params): Query>, - State(index): State>, -) -> Json { - let query = params.get("q").unwrap(); - let query = deunicode(query.trim()).to_lowercase(); - let tags: Option> = params - .get("tags") - .map(|s| s.split(',').map(|s| s.to_string()).collect()); - let leniency = params - .get("lenient") - .map(|s| s.parse().unwrap()) - .unwrap_or(false); - - let bbox = params.get("bbox").map(|s| parse_bbox(s)).flatten(); - - let start = std::time::Instant::now(); - - let results = index - .search(&query, leniency, tags, bbox, &[]) - .await - .unwrap(); - - println!("{} results found in {:?}", results.len(), start.elapsed()); +#[tokio::main] +async fn main() -> Result<()> { + env_logger::Builder::from_env(Env::default().default_filter_or("info")).init(); + let args = Args::parse(); - let mut response = Response { - metadata: HashMap::new(), - features: results - .clone() - .into_iter() - .map(|(results, _)| results.clone()) - .collect::>(), + debug!("Loading index from {}", args.index); + let index = if args.index.starts_with("http") { + Arc::new(AirmailIndex::new_remote(&args.index)?) + } else { + Arc::new(AirmailIndex::new(&args.index)?) }; - response - .metadata - .insert("query".to_string(), Value::String(query)); + let mut cors = CorsLayer::new(); + for origin in args.cors.unwrap_or_default() { + cors = cors.allow_origin(origin.parse::()?); + } - Json(serde_json::to_value(response).unwrap()) -} + info!("Loaded {} docs from index", index.num_docs().await?); + let app = Router::new() + .route("/search", get(search).with_state(index)) + .layer(cors); -#[tokio::main] -async fn main() -> Result<(), Box> { - env_logger::init(); - let args = Args::parse(); - let index_path = args.index.clone(); + info!("Listening at: {}/search?q=query", args.bind); + let listener = TcpListener::bind(args.bind).await?; + let server = axum::serve(listener, app.into_make_service()).into_future(); - let index = spawn_blocking(move || { - if index_path.starts_with("http") { - Arc::new(AirmailIndex::new_remote(&index_path).unwrap()) - } else { - Arc::new(AirmailIndex::new(&index_path).unwrap()) + select! { + _ = server => {} + _ = tokio::signal::ctrl_c() => { + warn!("Received ctrl-c, shutting down"); } - }) - .await - .unwrap(); - println!("Have {} docs", index.num_docs().await?); - let app = Router::new() - .route("/search", get(search).with_state(index)) - .layer( - CorsLayer::new().allow_origin("http://localhost:5173".parse::().unwrap()), - ); - let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); - axum::serve(listener, app).await.unwrap(); + } + Ok(()) }