Skip to content

Commit

Permalink
* new logic to calculate the risk score where it can be a number betw…
Browse files Browse the repository at this point in the history
…een -1 (below threshold) and 1 (above threshold)
  • Loading branch information
asofter committed Aug 29, 2024
1 parent d6aa71e commit 4bf0f1f
Show file tree
Hide file tree
Showing 66 changed files with 321 additions and 286 deletions.
3 changes: 2 additions & 1 deletion docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-

### Changed
-
- Stop substrings moved to the variables instead of JSON files.
- **[BREAKING]** New logic to calculate the risk score ([#182](https://github.com/protectai/llm-guard/issues/182)).

### Removed
-
Expand Down
4 changes: 2 additions & 2 deletions llm_guard/input_scanners/anonymize.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def remove_single_quotes(text: str) -> str:
return text_without_single_quotes

def scan(self, prompt: str) -> tuple[str, bool, float]:
risk_score = 0.0
risk_score = -1.0
if prompt.strip() == "":
return prompt, True, risk_score

Expand Down Expand Up @@ -354,4 +354,4 @@ def scan(self, prompt: str) -> tuple[str, bool, float]:

LOGGER.debug("Prompt does not have sensitive data to replace", risk_score=risk_score)

return prompt, True, 0.0
return prompt, True, -1.0
4 changes: 2 additions & 2 deletions llm_guard/input_scanners/ban_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(

def scan(self, prompt: str) -> tuple[str, bool, float]:
if prompt.strip() == "":
return prompt, True, 0.0
return prompt, True, -1.0

# Hack: Improve accuracy
new_prompt = remove_markdown(prompt) # Remove markdown
Expand All @@ -91,4 +91,4 @@ def scan(self, prompt: str) -> tuple[str, bool, float]:
"No code detected in the text", score=score, threshold=self._threshold, text=new_prompt
)

return prompt, True, 0.0
return prompt, True, calculate_risk_score(score, self._threshold)
7 changes: 2 additions & 5 deletions llm_guard/input_scanners/ban_competitors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Sequence, cast
from typing import Sequence, cast

from presidio_anonymizer.core.text_replace_builder import TextReplaceBuilder

Expand All @@ -10,9 +10,6 @@

from .base import Scanner

if TYPE_CHECKING:
import transformers

LOGGER = get_logger()

MODEL_V1 = Model(
Expand Down Expand Up @@ -112,4 +109,4 @@ def scan(self, prompt: str) -> tuple[str, bool, float]:

LOGGER.debug("None of the competitors were detected")

return prompt, True, 0.0
return prompt, True, -1.0
33 changes: 25 additions & 8 deletions llm_guard/input_scanners/ban_substrings.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import os
import re
from enum import Enum

Expand All @@ -10,12 +9,30 @@

LOGGER = get_logger()

stop_file_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"..",
"resources",
"prompt_stop_substrings.json",
)
PROMPT_STOP_SUBSTRINGS = [
"arbitrary code execution",
"a backdoor",
"code injection",
"cross-site scripting",
"email spoofing",
"a keylogger",
"privilege escalation",
"a rootkit",
"a bootkit",
"screen scraping",
"corrupting a disk partition",
"cryptojacking",
"a botnet",
"adware",
"ransomware",
"Windows\\System32\\Config\\SAM",
"etc/shadow",
"escalate privilege",
"read the encrypted system password file",
"X5O!P%@AP[4\\PZX54(P^)7CC)7}$EICAR-STANDARD-ANTIVIRUS-TEST-FILE!$H+H*", # EICAR test file
"XJS*C4JDBQADN1.NSBN3*2IDNEN*GTUBE-STANDARD-ANTI-UBE-TEST-EMAIL*C.34X", # GTUBE test email
"XJS*C4JDBQADN1.NSBN3*2IDNEN*GTPHISH-STANDARD-ANTI-PHISH-TEST-EMAIL*C.34X", # GTPHISH test email
]


class MatchType(Enum):
Expand Down Expand Up @@ -119,4 +136,4 @@ def scan(self, prompt: str) -> tuple[str, bool, float]:

LOGGER.debug("No banned substrings found")

return sanitized_prompt, True, 0.0
return sanitized_prompt, True, -1.0
4 changes: 2 additions & 2 deletions llm_guard/input_scanners/ban_topics.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def __init__(

def scan(self, prompt: str) -> tuple[str, bool, float]:
if prompt.strip() == "":
return prompt, True, 0.0
return prompt, True, -1.0

output_model = self._classifier(prompt, self._topics, multi_label=False)
label_score = dict(zip(output_model["labels"], output_model["scores"]))
Expand All @@ -156,4 +156,4 @@ def scan(self, prompt: str) -> tuple[str, bool, float]:
scores=label_score,
)

return prompt, True, 0.0
return prompt, True, calculate_risk_score(max_score, self._threshold)
6 changes: 3 additions & 3 deletions llm_guard/input_scanners/code.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def _extract_code_blocks(self, markdown: str) -> list[str]:

def scan(self, prompt: str) -> tuple[str, bool, float]:
if prompt.strip() == "":
return prompt, True, 0.0
return prompt, True, -1.0

# Try to extract code snippets from Markdown
code_blocks = self._extract_code_blocks(prompt)
Expand Down Expand Up @@ -162,11 +162,11 @@ def scan(self, prompt: str) -> tuple[str, bool, float]:
LOGGER.debug(
"Language is allowed", language_name=language["label"], score=score
)
return prompt, True, 0.0
return prompt, True, calculate_risk_score(score, self._threshold)

if self._is_blocked:
LOGGER.debug("No blocked languages detected")
return prompt, True, 0.0
return prompt, True, -1.0

LOGGER.warning("No allowed languages detected")
return prompt, False, 1.0
4 changes: 2 additions & 2 deletions llm_guard/input_scanners/gibberish.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(

def scan(self, prompt: str) -> tuple[str, bool, float]:
if prompt.strip() == "":
return prompt, True, 0.0
return prompt, True, -1.0

highest_score = 0.0
results_all = self._classifier(self._match_type.get_inputs(prompt))
Expand All @@ -107,4 +107,4 @@ def scan(self, prompt: str) -> tuple[str, bool, float]:
"No gibberish in the text", highest_score=highest_score, threshold=self._threshold
)

return prompt, True, 0.0
return prompt, True, calculate_risk_score(highest_score, self._threshold)
2 changes: 1 addition & 1 deletion llm_guard/input_scanners/invisible_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def contains_unicode(text: str) -> bool:

def scan(self, prompt: str) -> tuple[str, bool, float]:
if not self.contains_unicode(prompt):
return prompt, True, 0.0
return prompt, True, -1.0

chars = []
for char in prompt:
Expand Down
4 changes: 2 additions & 2 deletions llm_guard/input_scanners/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(

def scan(self, prompt: str) -> tuple[str, bool, float]:
if prompt.strip() == "":
return prompt, True, 0.0
return prompt, True, -1.0

results_all = self._pipeline(self._match_type.get_inputs(prompt))
for result_chunk in results_all:
Expand All @@ -106,4 +106,4 @@ def scan(self, prompt: str) -> tuple[str, bool, float]:

LOGGER.debug("Only valid languages are found in the text.")

return prompt, True, 0.0
return prompt, True, -1.0
4 changes: 2 additions & 2 deletions llm_guard/input_scanners/prompt_injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def __init__(

def scan(self, prompt: str) -> tuple[str, bool, float]:
if prompt.strip() == "":
return prompt, True, 0.0
return prompt, True, -1.0

highest_score = 0.0
results_all = self._pipeline(self._match_type.get_inputs(prompt))
Expand All @@ -187,4 +187,4 @@ def scan(self, prompt: str) -> tuple[str, bool, float]:

LOGGER.debug("No prompt injection detected", highest_score=highest_score)

return prompt, True, 0.0
return prompt, True, calculate_risk_score(highest_score, self._threshold)
4 changes: 2 additions & 2 deletions llm_guard/input_scanners/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,11 @@ def scan(self, prompt: str) -> tuple[str, bool, float]:
return text_replace_builder.output_text, False, 1.0

LOGGER.debug("Pattern matched the text", pattern=pattern)
return text_replace_builder.output_text, True, 0.0
return text_replace_builder.output_text, True, -1.0

if self._is_blocked:
LOGGER.debug("None of the patterns were found in the text")
return text_replace_builder.output_text, True, 0.0
return text_replace_builder.output_text, True, -1.0

LOGGER.warning("None of the patterns matched the text")
return text_replace_builder.output_text, False, 1.0
4 changes: 2 additions & 2 deletions llm_guard/input_scanners/secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ def redact_value(value: str, mode: str) -> str:
def scan(self, prompt: str) -> tuple[str, bool, float]:
secrets = SecretsCollection()

risk_score = 0.0
risk_score = -1.0
if prompt.strip() == "":
return prompt, True, risk_score

Expand Down Expand Up @@ -499,4 +499,4 @@ def scan(self, prompt: str) -> tuple[str, bool, float]:

LOGGER.debug("No secrets detected in the prompt")

return prompt, True, 0.0
return prompt, True, -1.0
13 changes: 7 additions & 6 deletions llm_guard/input_scanners/sentiment.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from llm_guard.util import get_logger, lazy_load_dep
from llm_guard.util import calculate_risk_score, get_logger, lazy_load_dep

from .base import Scanner

Expand All @@ -12,12 +12,12 @@ class Sentiment(Scanner):
has a sentiment score lower than the threshold, indicating a negative sentiment.
"""

def __init__(self, *, threshold: float = -0.1, lexicon: str = _lexicon) -> None:
def __init__(self, *, threshold: float = -0.3, lexicon: str = _lexicon) -> None:
"""
Initializes Sentiment with a threshold and a chosen lexicon.
Parameters:
threshold (float): Threshold for the sentiment score (from -1 to 1). Default is -0.1.
threshold (float): Threshold for the sentiment score (from -1 to 1). Default is 0.3.
lexicon (str): Lexicon for the SentimentIntensityAnalyzer. Default is 'vader_lexicon'.
Raises:
Expand All @@ -32,6 +32,9 @@ def __init__(self, *, threshold: float = -0.1, lexicon: str = _lexicon) -> None:
self._threshold = threshold

def scan(self, prompt: str) -> tuple[str, bool, float]:
if not prompt:
return prompt, True, -1.0

sentiment_score = self._sentiment_analyzer.polarity_scores(prompt)
sentiment_score_compound = sentiment_score["compound"]
if sentiment_score_compound > self._threshold:
Expand All @@ -49,6 +52,4 @@ def scan(self, prompt: str) -> tuple[str, bool, float]:
threshold=self._threshold,
)

# Normalize such that -1 maps to 1 and threshold maps to 0
score = round((sentiment_score_compound - (-1)) / (self._threshold - (-1)), 2)
return prompt, False, score
return prompt, False, calculate_risk_score(abs(sentiment_score_compound), self._threshold)
4 changes: 2 additions & 2 deletions llm_guard/input_scanners/token_limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,14 @@ def _split_text_on_tokens(self, text: str) -> tuple[list[str], int]:

def scan(self, prompt: str) -> tuple[str, bool, float]:
if prompt.strip() == "":
return prompt, True, 0.0
return prompt, True, -1.0

chunks, num_tokens = self._split_text_on_tokens(text=prompt)
if num_tokens < self._limit:
LOGGER.debug(
"Prompt fits the maximum tokens", num_tokens=num_tokens, threshold=self._limit
)
return prompt, True, 0.0
return prompt, True, -1.0

LOGGER.warning(
"Prompt is too big. Splitting into chunks", num_tokens=num_tokens, chunks=chunks
Expand Down
4 changes: 2 additions & 2 deletions llm_guard/input_scanners/toxicity.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(

def scan(self, prompt: str) -> tuple[str, bool, float]:
if prompt.strip() == "":
return prompt, True, 0.0
return prompt, True, -1.0

inputs = self._match_type.get_inputs(prompt)

Expand All @@ -120,4 +120,4 @@ def scan(self, prompt: str) -> tuple[str, bool, float]:

LOGGER.debug("Not toxicity found in the text", results=results_all)

return prompt, True, 0.0
return prompt, True, calculate_risk_score(highest_toxicity_score, self._threshold)
23 changes: 16 additions & 7 deletions llm_guard/output_scanners/ban_substrings.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,28 @@
from __future__ import annotations

import logging
import os

from llm_guard.input_scanners.ban_substrings import BanSubstrings as InputBanSubstrings, MatchType

from .base import Scanner

log = logging.getLogger(__name__)
stop_file_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"..",
"resources",
"output_stop_substrings.json",
)

OUTPUT_STOP_SUBSTRINGS = [
"[DAN]",
"DAN:",
"DAN Mode enabled",
"ANTI-DAN:",
"[🔓JAILBREAK]",
"successfully jailbroken",
"STAN:",
"DUDE:",
"(🔓Developer Mode Output)",
"![http",
"X5O!P%@AP[4\\PZX54(P^)7CC)7}$EICAR-STANDARD-ANTIVIRUS-TEST-FILE!$H+H*", # EICAR test file
"XJS*C4JDBQADN1.NSBN3*2IDNEN*GTUBE-STANDARD-ANTI-UBE-TEST-EMAIL*C.34X", # GTUBE test file
"XJS*C4JDBQADN1.NSBN3*2IDNEN*GTPHISH-STANDARD-ANTI-PHISH-TEST-EMAIL*C.34X", # GTPHISH test file
]


class BanSubstrings(Scanner):
Expand Down
4 changes: 2 additions & 2 deletions llm_guard/output_scanners/bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(

def scan(self, prompt: str, output: str) -> tuple[str, bool, float]:
if output.strip() == "":
return output, True, 0.0
return output, True, -1.0

highest_score = 0.0
results_all = self._classifier(self._match_type.get_inputs(prompt + "\n" + output))
Expand All @@ -101,4 +101,4 @@ def scan(self, prompt: str, output: str) -> tuple[str, bool, float]:

LOGGER.debug("Not biased result", highest_score=highest_score, threshold=self._threshold)

return output, True, 0.0
return output, True, calculate_risk_score(highest_score, self._threshold)
2 changes: 1 addition & 1 deletion llm_guard/output_scanners/deanonymize.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,4 +149,4 @@ def scan(self, prompt: str, output: str) -> tuple[str, bool, float]:

output = self._matching_strategy.match(output, vault_items)

return output, True, 0.0
return output, True, -1.0
12 changes: 8 additions & 4 deletions llm_guard/output_scanners/factual_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from llm_guard.input_scanners.ban_topics import MODEL_DEBERTA_BASE_V2
from llm_guard.model import Model
from llm_guard.transformers_helpers import get_tokenizer_and_model_for_classification
from llm_guard.util import device, get_logger, lazy_load_dep
from llm_guard.util import calculate_risk_score, device, get_logger, lazy_load_dep

from .base import Scanner

Expand Down Expand Up @@ -55,7 +55,7 @@ def __init__(

def scan(self, prompt: str, output: str) -> tuple[str, bool, float]:
if prompt.strip() == "":
return output, True, 0.0
return output, True, -1.0

tokenized_input_seq_pair = self._tokenizer(
output, prompt, padding=True, truncation=True, return_tensors="pt"
Expand All @@ -77,8 +77,12 @@ def scan(self, prompt: str, output: str) -> tuple[str, bool, float]:
if entailment_score < self._minimum_score:
LOGGER.warning("Entailment score is below the threshold", prediction=prediction)

return output, False, prediction["not_entailment"]
return (
output,
False,
calculate_risk_score(prediction["not_entailment"], self._minimum_score),
)

LOGGER.debug("The output is factually consistent", prediction=prediction)

return output, True, 0.0
return output, True, calculate_risk_score(prediction["not_entailment"], self._minimum_score)
Loading

0 comments on commit 4bf0f1f

Please sign in to comment.