Skip to content

Commit

Permalink
feature: add hallucination detection to server
Browse files Browse the repository at this point in the history
  • Loading branch information
densumesh authored and skeptrunedev committed Dec 13, 2024
1 parent c732c7e commit 598b134
Show file tree
Hide file tree
Showing 21 changed files with 1,165 additions and 143 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ export const SingleRAGQuery = (props: SingleRAGQueryProps) => {
"M/d/yy h:mm a",
)}
</span>
<dl class="m-auto mt-5 grid grid-cols-1 divide-y divide-gray-200 overflow-hidden rounded-lg bg-white shadow md:grid-cols-4 md:divide-x md:divide-y-0">
<dl class="m-auto mt-5 grid grid-cols-1 divide-y divide-gray-200 overflow-hidden rounded-lg bg-white shadow md:grid-cols-5 md:divide-x md:divide-y-0">
<DataSquare label="RAG Type" value={props.rag_data.rag_type} />
<DataSquare
label="Dataset"
Expand All @@ -105,6 +105,14 @@ export const SingleRAGQuery = (props: SingleRAGQueryProps) => {
value={props.search_data?.top_score.toPrecision(4) ?? "N/A"}
/>
</Show>
<Show when={props.rag_data && props.rag_data.hallucination_score}>
<DataSquare
label="Hallucination Score"
value={
props.rag_data.hallucination_score?.toPrecision(4) ?? "N/A"
}
/>
</Show>
<Show
when={
props.rag_data.query_rating &&
Expand All @@ -126,6 +134,18 @@ export const SingleRAGQuery = (props: SingleRAGQueryProps) => {
</ul>
</Card>
</Show>
<Show
when={
props.rag_data.detected_hallucinations &&
props.rag_data.detected_hallucinations.length > 0
}
>
<Card title="Detected Hallucinations">
<ul>
<li>{props.rag_data.detected_hallucinations?.join(",")}</li>
</ul>
</Card>
</Show>
<Show
when={
(props.search_data?.results && props.search_data.results[0]) ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ export const RAGAnalyticsPage = () => {
);
},
},
{
accessorKey: "hallucination_score",
header: "Hallucination Score",
},
{
accessorKey: "query_rating",
header: "Query Rating",
Expand Down
2 changes: 2 additions & 0 deletions frontends/shared/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,8 @@ export interface RagQueryEvent {
note?: string;
rating: number;
};
hallucination_score?: number;
detected_hallucinations?: string[];
}

export interface EventData {
Expand Down
2 changes: 1 addition & 1 deletion hallucination-detection/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 8 additions & 8 deletions hallucination-detection/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "hallucination-detection"
version = "0.1.3"
version = "0.1.5"
edition = "2021"
license = "MIT"
repository = "https://github.com/devflowinc/trieve"
Expand All @@ -11,9 +11,10 @@ description = "Extremely fast Hallucination Detection for RAG using BERT NER, no
inherits = "release"

[features]
default = []
default = ["ner"]
ner = ["rust-bert"]
onnx = ["ort"]
download-onnx = ["ort?/download-binaries"]
onnx = ["ort", "rust-bert?/onnx"]

[dependencies]
# Core dependencies
Expand All @@ -22,15 +23,14 @@ regex = "1.11.1"
serde = { version = "1.0.215", features = ["derive"] }
tokio = { version = "1.42.0", features = ["full"] }
once_cell = "1.18"

# Optional dependencies for NER feature
rust-bert = { version = "0.23.0", features = ["onnx"], optional = true }
rust-bert = { version = "0.23.0", optional = true }
ort = { version = "1.16.3", features = [
"download-binaries",
"load-dynamic",
], optional = true }
], optional = true, default-features = false }


[dev-dependencies]
ort = { version = "1.16.3", features = ["download-binaries", "load-dynamic"] }
csv = "1.3.1"
dotenvy = "0.15.7"
openai_dive = "0.7.0"
Expand Down
23 changes: 22 additions & 1 deletion hallucination-detection/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,32 @@ Add this to your `Cargo.toml`:
hallucination-detection = "^0.1.3"
```

If you want to use NER and ONNX features:
If you want to use NER:

1. Download `libtorch` from <https://pytorch.org/get-started/locally/>. This package requires `v2.4`: if this version is no longer available on the "get started" page, the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-2.4.0%2Bcpu.zip` for a Linux version with CPU.
2. Extract the library to a location of your choice
3. Set the following environment variables
##### Linux:
```bash
export LIBTORCH=/path/to/libtorch
export LD_LIBRARY_PATH=${LIBTORCH}/lib:$LD_LIBRARY_PATH
```
##### Windows
```powershell
$Env:LIBTORCH = "X:\path\to\libtorch"
$Env:Path += ";X:\path\to\libtorch\lib"
```

```toml
[dependencies]
hallucination-detection = { version = "^0.1.3", features = ["ner"] }
```

If you want to use ONNX for the NER models, you need to either [install the ort runtime](https://docs.rs/ort/1.16.3/ort/#how-to-get-binaries) or include it in your dependencies:

```toml
hallucination-detection = { version = "^0.1.3", features = ["ner", "onnx"] }
ort = { version = "...", features = [ "download-binaries" ] }
```

## Quick Start
Expand Down
2 changes: 1 addition & 1 deletion hallucination-detection/examples/rag_truth_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ async fn run_hallucination_test() -> Result<(), Box<dyn Error>> {
let start = std::time::Instant::now();
let hallucination_score = detector
.detect_hallucinations(&record.response, &[source_info.clone()])
.await;
.await.unwrap();
let elapsed = start.elapsed();
println!("Hallucination detection took: {:?}", elapsed);

Expand Down
3 changes: 2 additions & 1 deletion hallucination-detection/examples/vectara_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ async fn run_hallucination_test() -> Result<(), Box<dyn Error>> {
let start = std::time::Instant::now();
let hallucination_score = detector
.detect_hallucinations(&record.og_sum, &references)
.await;
.await
.unwrap();
let elapsed = start.elapsed();
println!("Hallucination detection took: {:?}", elapsed);

Expand Down
83 changes: 62 additions & 21 deletions hallucination-detection/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,32 @@ use std::{
};
use tokio::sync::OnceCell;

#[cfg(not(feature = "ner"))]
#[cfg(feature = "onnx")]
compile_error!("NER feature must be enabled to use ONNX model");

#[cfg(feature = "ner")]
use {
rust_bert::{
pipelines::{
common::{ModelResource, ModelType, ONNXModelResources},
ner::{Entity, NERModel},
token_classification::{LabelAggregationOption, TokenClassificationConfig},
},
resources::RemoteResource,
pipelines::ner::{Entity, NERModel},
pipelines::token_classification::TokenClassificationConfig,
RustBertError,
},
std::error::Error,
std::sync::mpsc,
tokio::{sync::oneshot, task::JoinHandle},
};

#[cfg(feature = "onnx")]
use rust_bert::{
pipelines::{
common::ONNXModelResources,
common::{ModelResource, ModelType},
token_classification::LabelAggregationOption,
},
resources::RemoteResource,
};

const WORDS_URL: &str =
"https://raw.githubusercontent.com/dwyl/english-words/refs/heads/master/words.txt";
const CACHE_FILE: &str = "~/.cache/hallucination-detection/english_words_cache.txt";
Expand Down Expand Up @@ -72,6 +83,18 @@ pub struct HallucinationScore {
pub detected_hallucinations: Vec<String>,
}

#[derive(Debug)]
#[allow(dead_code)]
pub struct DetectorError {
message: String,
}

impl std::fmt::Display for DetectorError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "Detector Error: {}", self.message)
}
}

#[derive(Debug, Clone)]
pub struct ScoreWeights {
pub proper_noun_weight: f64,
Expand Down Expand Up @@ -201,13 +224,20 @@ impl HallucinationDetector {
&self,
llm_output: &String,
references: &[String],
) -> HallucinationScore {
) -> Result<HallucinationScore, DetectorError> {
let mut all_texts = vec![llm_output.to_string()];
all_texts.extend(references.iter().cloned());

let all_analyses = self.analyze_text(&all_texts).await;
let all_analyses = self.analyze_text(&all_texts).await?;

let (output_analysis, ref_analyses) = all_analyses.split_first().unwrap();
let (output_analysis, ref_analyses) = match all_analyses.split_first() {
Some((output_analysis, ref_analyses)) => (output_analysis, ref_analyses),
None => {
return Err(DetectorError {
message: "Failed to analyze text".to_string(),
});
}
};

let all_ref_proper_nouns: HashSet<_> = ref_analyses
.iter()
Expand Down Expand Up @@ -250,7 +280,7 @@ impl HallucinationDetector {
+ number_mismatch_score * self.options.weights.number_mismatch_weight)
.clamp(0.0, 1.0);

HallucinationScore {
Ok(HallucinationScore {
proper_noun_score,
unknown_word_score,
number_mismatch_score,
Expand All @@ -261,14 +291,19 @@ impl HallucinationDetector {
number_diff.iter().map(|n| n.to_string()).collect(),
]
.concat(),
}
})
}

#[allow(unused_variables)]
async fn analyze_text(&self, texts: &[String]) -> Vec<TextAnalysis> {
async fn analyze_text(&self, texts: &[String]) -> Result<Vec<TextAnalysis>, DetectorError> {
#[cfg(feature = "ner")]
let entities = if let Some(ner_model) = &self.ner_model {
ner_model.predict(texts.to_vec()).await.unwrap()
ner_model
.predict(texts.to_vec())
.await
.map_err(|e| DetectorError {
message: format!("Failed to predict entities: {:?}", e),
})?
} else {
vec![Vec::new(); texts.len()]
};
Expand Down Expand Up @@ -322,11 +357,11 @@ impl HallucinationDetector {
true
});
}
TextAnalysis {
Ok(TextAnalysis {
proper_nouns,
unknown_words,
numbers,
}
})
})
.collect()
}
Expand Down Expand Up @@ -381,7 +416,8 @@ mod tests {

let score = detector
.detect_hallucinations(&llm_output, &references)
.await;
.await
.unwrap();
println!("Zero Hallucination Score: {:?}", score);

assert_eq!(score.proper_noun_score, 0.0);
Expand All @@ -402,7 +438,8 @@ mod tests {

let score = detector
.detect_hallucinations(&llm_output, &references)
.await;
.await
.unwrap();
println!("Multiple References Score: {:?}", score);
assert_eq!(score.proper_noun_score, 0.0); // Both companies are in references
assert_eq!(score.number_mismatch_score, 0.0); // Number matches reference
Expand All @@ -415,13 +452,15 @@ mod tests {
// Empty input
let score_empty = detector
.detect_hallucinations(&String::from(""), &[String::from("")])
.await;
.await
.unwrap();
assert_eq!(score_empty.total_score, 0.0);

// Only numbers
let score_numbers = detector
.detect_hallucinations(&String::from("123 456.789"), &[String::from("123 456.789")])
.await;
.await
.unwrap();
assert_eq!(score_numbers.number_mismatch_score, 0.0);

// Only proper nouns
Expand All @@ -430,7 +469,8 @@ mod tests {
&String::from("John Smith"),
&[String::from("Different Person")],
)
.await;
.await
.unwrap();
assert!(score_nouns.proper_noun_score > 0.0);
}

Expand Down Expand Up @@ -508,7 +548,8 @@ mod tests {
&String::from(llm_output),
&references.into_iter().map(String::from).collect::<Vec<_>>(),
)
.await;
.await
.unwrap();

println!("Test '{}' Score: {:?}", test_name, score);

Expand Down
8 changes: 4 additions & 4 deletions pdf2md/server/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 598b134

Please sign in to comment.