Skip to content

Commit

Permalink
Execute tantivy quries in a blocking context, tweaks to remote dir.
Browse files Browse the repository at this point in the history
  • Loading branch information
ellenhp committed Feb 11, 2024
1 parent c73c555 commit 09743d7
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 104 deletions.
108 changes: 11 additions & 97 deletions airmail/src/directory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::{
ops::Range,
path::{Path, PathBuf},
sync::{Arc, Mutex, OnceLock},
time::Duration,
};

use log::{error, info, warn};
Expand All @@ -17,7 +18,6 @@ use tantivy::{
Directory,
};
use tantivy_common::{file_slice::FileHandle, AntiCallToken, HasLen, OwnedBytes, TerminatingWrite};
use tokio::spawn;

thread_local! {
static BLOCKING_HTTP_CLIENT: reqwest::blocking::Client = reqwest::blocking::Client::new();
Expand Down Expand Up @@ -89,6 +89,9 @@ impl FileHandle for HttpFileHandle {
let response = BLOCKING_HTTP_CLIENT.with(|client| {
client
.get(&self.url)
.timeout(Duration::from_millis(
500 + (range.end - range.start) as u64 / 1024,
))
.header(
"Range",
dbg!(format!(
Expand Down Expand Up @@ -145,96 +148,6 @@ impl FileHandle for HttpFileHandle {
.to_vec(),
))
}

async fn read_bytes_async(&self, range: Range<usize>) -> io::Result<OwnedBytes> {
let chunk_start = range.start / CHUNK_SIZE;
let chunk_end = range.end / CHUNK_SIZE;
let cache =
LRU_CACHE.get_or_init(|| Mutex::new(LruCache::new(NonZeroUsize::new(40_000).unwrap())));
let mut accumulated_chunks = vec![0u8; (chunk_end - chunk_start + 1) * CHUNK_SIZE];
info!(
"Reading bytes: {:?} in chunks from {} to {}",
range, chunk_start, chunk_end
);
let mut handles = Vec::new();
for chunk in chunk_start..=chunk_end {
let key = CacheKey {
base_url: self.url.clone(),
path: self.url.clone(),
chunk,
};
{
let mut cache = cache.lock().unwrap();
if let Some(data) = cache.get(&key) {
accumulated_chunks[chunk * CHUNK_SIZE..(chunk + 1) * CHUNK_SIZE]
.copy_from_slice(data);
continue;
}
}
let url = self.url.clone();
let handle = spawn(async move {
let response = HTTP_CLIENT.with(|client| {
client
.get(&url)
.header(
"Range",
format!("{}-{}", chunk * CHUNK_SIZE, (chunk + 1) * CHUNK_SIZE),
)
.send()
});
let response = match response.await {
Ok(response) => response,
Err(e) => {
error!("Error: {:?}", e);
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"Error fetching chunk",
));
}
};
if response.status() != 200 {
error!("Response: {:?}", response);
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"Error fetching chunk: non-200 status",
));
} else {
let data = response.bytes().await.unwrap();
let data = data.to_vec();
{
let mut cache = cache.lock().unwrap();
cache.put(key, data.to_vec());
}
if data.len() < CHUNK_SIZE && chunk != chunk_end {
warn!("Short chunk: {}", data.len());
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"Error fetching chunk: short response length",
));
}
Ok((chunk, data))
}
});
handles.push(handle);
}
for handle in handles {
if let Ok(Ok((chunk, data))) = handle.await {
accumulated_chunks[chunk * CHUNK_SIZE..(chunk + 1) * CHUNK_SIZE]
.copy_from_slice(&data);
} else {
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"Error fetching chunk",
));
}
}
info!("Accumulated chunks: {}", accumulated_chunks.len());
let chunk_start_offset = range.start % CHUNK_SIZE;
let chunk_end_offset = (chunk_end - chunk_start) * CHUNK_SIZE + range.end % CHUNK_SIZE;
Ok(OwnedBytes::new(
accumulated_chunks[chunk_start_offset..chunk_end_offset].to_vec(),
))
}
}

impl HasLen for HttpFileHandle {
Expand All @@ -249,24 +162,25 @@ impl HasLen for HttpFileHandle {

let url = format!("{}", self.url);
info!("Fetching length from: {}", url);
let response = BLOCKING_HTTP_CLIENT.with(|client| client.head(&url).send());
let response = BLOCKING_HTTP_CLIENT
.with(|client| client.head(&url).timeout(Duration::from_millis(500)).send());
if let Err(e) = response {
error!("Error: {:?}", e);
return 0;
error!("Error fetching length: {:?}", e);
panic!();
}
let response = response.unwrap();
if response.status() != 200 {
error!("Response: {:?}", response);
return 0;
panic!();
} else {
let length = response
.headers()
.get("Content-Length")
.unwrap()
.to_str()
.unwrap_or_default()
.unwrap()
.parse()
.unwrap_or_default();
.unwrap();
info!("Length: {}", length);
let mut lengths = lengths.lock().unwrap();
lengths.insert(PathBuf::from(&self.url), length);
Expand Down
2 changes: 1 addition & 1 deletion airmail_parser/src/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub trait TriviallyConstructibleComponent: QueryComponent {
fn new(text: String) -> Self;
}

pub trait QueryComponent {
pub trait QueryComponent: Send + Sync {
fn text(&self) -> &str;

fn penalty_mult(&self) -> f32;
Expand Down
24 changes: 18 additions & 6 deletions airmail_service/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use deunicode::deunicode;
use log::trace;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio::task::spawn_blocking;

#[derive(Debug, Parser)]
struct Args {
Expand Down Expand Up @@ -41,7 +42,13 @@ async fn search(
if all_results.len() > 20 {
break;
}
let results = index.search(scenario).unwrap();
let results = {
let scenario = scenario.clone();
let index = index.clone();
spawn_blocking(move || index.search(&scenario).unwrap())
.await
.unwrap()
};
if results.is_empty() {
continue;
} else {
Expand Down Expand Up @@ -89,11 +96,16 @@ async fn main() -> Result<(), Box<dyn Error>> {
env_logger::init();
let args = Args::parse();
let index_path = args.index.clone();
let index = if index_path.starts_with("http") {
Arc::new(AirmailIndex::new_remote(&index_path)?)
} else {
Arc::new(AirmailIndex::new(&index_path)?)
};

let index = spawn_blocking(move || {
if index_path.starts_with("http") {
Arc::new(AirmailIndex::new_remote(&index_path).unwrap())
} else {
Arc::new(AirmailIndex::new(&index_path).unwrap())
}
})
.await
.unwrap();
let app = Router::new().route("/search", get(search).with_state(index));
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, app).await.unwrap();
Expand Down

0 comments on commit 09743d7

Please sign in to comment.