Skip to content

Commit

Permalink
tf-idf query selector and document heap
Browse files Browse the repository at this point in the history
  • Loading branch information
tomfran committed Dec 24, 2023
1 parent c7e5e30 commit 214ab2e
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 8 deletions.
14 changes: 10 additions & 4 deletions src/index/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,16 @@ pub struct Index {
tokenizer: Tokenizer,
}

#[derive(Debug)]
pub struct PostingList {
documents: Vec<PostingEntry>,
collection_frequency: u32,
pub documents: Vec<PostingEntry>,
pub collection_frequency: u32,
}

#[derive(Debug)]
pub struct PostingEntry {
document_id: u32,
document_frequency: u32,
pub document_id: u32,
pub document_frequency: u32,
}

impl Index {
Expand All @@ -45,6 +47,10 @@ impl Index {
}
}

pub fn get_num_documents(&self) -> u32 {
self.doc_lenghts.len() as u32
}

pub fn get_term(&mut self, term: &str) -> Option<PostingList> {
let offset = self.term_offset_map.get(term)?;

Expand Down
87 changes: 87 additions & 0 deletions src/query/document_selector.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
use std::{cmp::Ordering, collections::BinaryHeap};

#[derive(Debug)]
struct DocumentScore {
id: u32,
score: f32,
}

impl PartialEq for DocumentScore {
fn eq(&self, other: &Self) -> bool {
self.score == other.score
}
}

impl Eq for DocumentScore {}

impl PartialOrd for DocumentScore {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
other.score.partial_cmp(&self.score)
}
}

impl Ord for DocumentScore {
fn cmp(&self, other: &Self) -> Ordering {
other.partial_cmp(self).unwrap()
}
}

pub struct DocumentSelector {
heap: BinaryHeap<DocumentScore>,
capacity: usize,
}

impl DocumentSelector {
pub fn new(capacity: usize) -> DocumentSelector {
DocumentSelector {
heap: BinaryHeap::new(),
capacity,
}
}

pub fn push(&mut self, id: u32, score: f32) {
self.heap.push(DocumentScore { id, score });

if self.heap.len() > self.capacity {
self.heap.pop();
}
}

pub fn get_sorted_ids(&mut self) -> Vec<u32> {
let mut res: Vec<u32> = (0..self.capacity)
.map(|_| self.heap.pop())
.filter(|e| e.is_some())
.map(|e: Option<DocumentScore>| e.unwrap().id)
.collect();

res.reverse();
res
}
}

#[cfg(test)]
mod test {
use super::*;

#[test]
fn test_top_k() {
let mut selector = DocumentSelector::new(2);

selector.push(2, 0.4);
selector.push(3, 0.3);
selector.push(1, 0.5);
selector.push(4, 0.2);

assert_eq!(selector.get_sorted_ids(), [1, 2]);
}

#[test]
fn test_top_less_than_k() {
let mut selector = DocumentSelector::new(3);

selector.push(1, 0.5);
selector.push(2, 0.4);

assert_eq!(selector.get_sorted_ids(), [1, 2]);
}
}
58 changes: 54 additions & 4 deletions src/query/mod.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,70 @@
use std::collections::HashMap;

use crate::index::Index;

struct QueryProcessor {
use self::document_selector::DocumentSelector;

mod document_selector;
pub struct QueryProcessor {
index: Index,
num_documents: u32,
}

impl QueryProcessor {
pub fn build_query_processor(
index_input_path: &str,
index_tokenizer_path: &str,
) -> QueryProcessor {
let index = Index::load_index(index_input_path, index_tokenizer_path);
let num_documents = index.get_num_documents();

QueryProcessor {
index: Index::load_index(index_input_path, index_tokenizer_path),
index,
num_documents,
}
}

pub fn query(query: &str) -> Vec<u32> {
todo!()
pub fn query(&mut self, query: &str) -> Vec<u32> {
println!("\nQuery: {:?}", query);

let mut scores: HashMap<u32, f32> = HashMap::new();

for token in self.index.tokenize_query(query) {
if let Some(postings) = self.index.get_term(&token) {
let idf = (self.num_documents as f32 / postings.collection_frequency as f32).log2();

for doc_posting in &postings.documents {
let doc_score = doc_posting.document_frequency as f32 * idf;
scores
.entry(doc_posting.document_id)
.and_modify(|s| *s += doc_score)
.or_insert(doc_score);
}
}
}

let mut selector = DocumentSelector::new(3);
scores.iter().for_each(|(id, score)| {
println!("- document: {:?}, score: {:?}", id, score);
selector.push(*id, *score)
});

selector.get_sorted_ids()
}
}

// #[cfg(test)]
// mod test {
// use super::*;

// #[test]
// fn test_build() {
// let mut q = QueryProcessor::build_query_processor(
// "data/small/index/small",
// "data/small/bert-base-uncased",
// );
// q.query("google");
// q.query("apple");
// q.query("microsoft");
// }
// }

0 comments on commit 214ab2e

Please sign in to comment.