From 2e450b159f161c12eec8ce2b17dc9d5a6c446567 Mon Sep 17 00:00:00 2001 From: SashaXser <24498484+SashaXser@users.noreply.github.com> Date: Sat, 15 Feb 2025 23:03:19 +0400 Subject: [PATCH] Refactor hash calculation and token emission logic for improved clarity and performance --- src/lib.rs | 123 ++++++++++++++++++++++++----------------------------- 1 file changed, 55 insertions(+), 68 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index a497e76..c385312 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -35,12 +35,11 @@ fn murmurhash2(mut k: u32, seed: u32) -> u32 { } impl Feature { - #[inline(always)] pub fn to_hash(&self) -> u32 { - match self { - Feature::AsciiNGram(ngram) => murmurhash2(*ngram, SEED), - Feature::Unicode(chr) => murmurhash2(*chr as u32 / 128, SEED ^ 2), - Feature::UnicodeClass(chr) => murmurhash2(classify_codepoint(*chr), SEED ^ 4), + match *self { + Feature::AsciiNGram(ngram) => murmurhash2(ngram, SEED), + Feature::Unicode(chr) => murmurhash2((chr as u32) >> 7, SEED ^ 2), + Feature::UnicodeClass(chr) => murmurhash2(classify_codepoint(chr), SEED ^ 4), } } } @@ -48,77 +47,70 @@ impl Feature { pub fn detect_language(text: &str) -> Lang { let mut scores: [f32; NUM_LANGUAGES] = Default::default(); let mut num_features: u32 = 0; - emit_tokens( - text, - #[inline(always)] - |token| { - num_features += 1u32; - let bucket = token.to_hash() % DIMENSION as u32; - let idx = bucket as usize * NUM_LANGUAGES; - let per_language_scores = &weights::WEIGHTS[idx..idx + NUM_LANGUAGES]; - for i in 0..NUM_LANGUAGES { - scores[i] += per_language_scores[i]; - } - }, - ); + emit_tokens(text, |token| { + num_features += 1; + let bucket = token.to_hash() % DIMENSION as u32; + let idx = bucket as usize * NUM_LANGUAGES; + scores + .iter_mut() + .zip(&weights::WEIGHTS[idx..idx + NUM_LANGUAGES]) + .for_each(|(score, weight)| *score += weight); + }); if num_features == 0 { // By default, we return English return Lang::Eng; } - - let sqrt_inv_num_features = 1.0f32 / (num_features as f32).sqrt(); - #[allow(clippy::needless_range_loop)] - for i in 0..NUM_LANGUAGES { - // Ok so the sqrt(num_features) is not really the norm, but whatever. - scores[i] = scores[i] * sqrt_inv_num_features + weights::INTERCEPTS[i]; + let sqrt_inv = 1.0 / (num_features as f32).sqrt(); + scores + .iter_mut() + .zip(weights::INTERCEPTS.iter()) + .for_each(|(score, intercept)| *score = *score * sqrt_inv + intercept); + let mut lang_id = 0; + let mut max_score = scores[0]; + for (i, &score) in scores.iter().enumerate().skip(1) { + if score > max_score { + max_score = score; + lang_id = i; + } } - - let lang_id = scores - .iter() - .enumerate() - .max_by(|(_, &score_left), (_, &score_right)| score_left.partial_cmp(&score_right).unwrap()) - .map(|(pos, _val)| pos) - .unwrap(); weights::LANGUAGES[lang_id] } #[doc(hidden)] pub fn emit_tokens(text: &str, mut listener: impl FnMut(Feature)) { let mut prev = ' ' as u32; - let mut num_previous_ascii_chr = 1; + let mut num_prev_ascii = 1; for chr in text.chars() { - let code = chr.to_ascii_lowercase() as u32; if !chr.is_ascii() { listener(Feature::Unicode(chr)); listener(Feature::UnicodeClass(chr)); - num_previous_ascii_chr = 0; - continue; - } - prev = prev << 8 | code; - match num_previous_ascii_chr { - 0 => { - num_previous_ascii_chr = 1; + num_prev_ascii = 0; + } else { + let code = chr.to_ascii_lowercase() as u32; + prev = prev << 8 | code; + match num_prev_ascii { + 0 => { + num_prev_ascii = 1; + } + 1 => { + listener(Feature::AsciiNGram(prev & BIGRAM_MASK)); + num_prev_ascii = 2; + } + 2 => { + listener(Feature::AsciiNGram(prev & BIGRAM_MASK)); + listener(Feature::AsciiNGram(prev & TRIGRAM_MASK)); + num_prev_ascii = 3; + } + 3 => { + listener(Feature::AsciiNGram(prev & BIGRAM_MASK)); + listener(Feature::AsciiNGram(prev & TRIGRAM_MASK)); + listener(Feature::AsciiNGram(prev)); + } + _ => unreachable!(), } - 1 => { - listener(Feature::AsciiNGram(prev & BIGRAM_MASK)); - num_previous_ascii_chr = 2; + if !chr.is_alphanumeric() { + prev = ' ' as u32; } - 2 => { - listener(Feature::AsciiNGram(prev & BIGRAM_MASK)); - listener(Feature::AsciiNGram(prev & TRIGRAM_MASK)); - num_previous_ascii_chr = 3; - } - 3 => { - listener(Feature::AsciiNGram(prev & BIGRAM_MASK)); - listener(Feature::AsciiNGram(prev & TRIGRAM_MASK)); - listener(Feature::AsciiNGram(prev)); - } - _ => { - unreachable!(); - } - } - if !chr.is_alphanumeric() { - prev = ' ' as u32; } } } @@ -195,15 +187,10 @@ fn classify_codepoint(chr: char) -> u32 { #[cfg(test)] mod tests { - use crate::detect_language; - use crate::emit_tokens; - use crate::Feature; - use crate::Lang; + use crate::{detect_language, emit_tokens, Feature, Lang}; fn ascii_ngram_feature(text: &str) -> Feature { - assert!(text.is_ascii()); - let mut bytes: [u8; 4] = [0u8; 4]; - assert!(text.len() <= 4); + let mut bytes = [0; 4]; bytes[4 - text.len()..].copy_from_slice(text.as_bytes()); Feature::AsciiNGram(u32::from_be_bytes(bytes)) } @@ -213,8 +200,8 @@ mod tests { let mut tokens = Vec::new(); emit_tokens("hello こん!", |token| tokens.push(token)); assert_eq!( - &tokens, - &[ + tokens, + vec![ ascii_ngram_feature(" h"), ascii_ngram_feature("he"), ascii_ngram_feature(" he"),