Skip to content

Commit

Permalink
* v2 model
Browse files Browse the repository at this point in the history
* ONNX hack cleanup
  • Loading branch information
asofter committed Apr 22, 2024
1 parent bcdb2e4 commit 6fed22f
Show file tree
Hide file tree
Showing 17 changed files with 33 additions and 60 deletions.
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `BanTopics`, `FactualConsistency`: support of the [new zero-shot-classification models](https://huggingface.co/collections/MoritzLaurer/zeroshot-classifiers-6548b4ff407bb19ff5c3ad6f).
- `PromptInjection` can support more match types for better accuracy.
- `API` relies on the lighter models for faster inference but with a bit lower accuracy. You can remove the change and build from source to use the full models.
- `PromptInjection` scanned uses the [new v2 model](https://huggingface.co/protectai/deberta-v3-base-prompt-injection-v2) for better accuracy.

### Removed
- `model_kwargs` and `pipeline_kwargs` as they are part of the `Model` object.
Expand Down
2 changes: 1 addition & 1 deletion docs/input_scanners/prompt_injection.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ Prompt injection attacks are particularly potent in the following scenarios:

Choose models you would like to validate against:

[ProtectAI/deberta-v3-base-prompt-injection](https://huggingface.co/ProtectAI/deberta-v3-base-prompt-injection).
[ProtectAI/deberta-v3-base-prompt-injection-v2](https://huggingface.co/ProtectAI/deberta-v3-base-prompt-injection-v2).
This model is a fine-tuned version of the `microsoft/deberta-v3-base` on multiple dataset of prompt injections and
normal prompts to classify text.
It aims to identify prompt injections, classifying inputs into two categories: `0` for no injection and `1` for
Expand Down
8 changes: 4 additions & 4 deletions docs/tutorials/notebooks/local_models.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"outputs": [],
"source": [
"!git lfs install\n",
"!git clone [email protected]:protectai/deberta-v3-base-prompt-injection\n",
"!git clone [email protected]:protectai/deberta-v3-base-prompt-injection-v2\n",
"!git clone [email protected]:MoritzLaurer/deberta-v3-base-zeroshot-v1.1-all-33\n",
"!git clone [email protected]:tomaarsen/span-marker-bert-base-orgs\n",
"!git clone [email protected]:unitary/unbiased-toxic-roberta\n",
Expand Down Expand Up @@ -159,8 +159,8 @@
"from llm_guard import scan_prompt\n",
"from llm_guard.input_scanners import PromptInjection, Anonymize, BanTopics, BanCompetitors, Toxicity, Code, Gibberish, Language\n",
"from llm_guard.vault import Vault\n",
"from llm_guard.input_scanners.prompt_injection import DEFAULT_MODEL as PROMPT_INJECTION_MODEL\n",
"from llm_guard.input_scanners.ban_topics import MODEL_BASE as BAN_TOPICS_MODEL\n",
"from llm_guard.input_scanners.prompt_injection import V2_MODEL as PROMPT_INJECTION_MODEL\n",
"from llm_guard.input_scanners.ban_topics import MODEL_DEBERTA_BASE_V2 as BAN_TOPICS_MODEL\n",
"from llm_guard.input_scanners.ban_competitors import MODEL_BASE as BAN_COMPETITORS_MODEL\n",
"from llm_guard.input_scanners.toxicity import DEFAULT_MODEL as TOXICITY_MODEL\n",
"from llm_guard.input_scanners.anonymize_helpers import DEBERTA_AI4PRIVACY_v2_CONF\n",
Expand All @@ -169,7 +169,7 @@
"from llm_guard.input_scanners.language import DEFAULT_MODEL as LANGUAGE_MODEL\n",
"\n",
"PROMPT_INJECTION_MODEL.kwargs[\"local_files_only\"] = True\n",
"PROMPT_INJECTION_MODEL.path = \"./deberta-v3-base-prompt-injection\"\n",
"PROMPT_INJECTION_MODEL.path = \"./deberta-v3-base-prompt-injection-v2\"\n",
"\n",
"DEBERTA_AI4PRIVACY_v2_CONF[\"DEFAULT_MODEL\"].path = \"./deberta-v3-base_finetuned_ai4privacy_v2\"\n",
"DEBERTA_AI4PRIVACY_v2_CONF[\"DEFAULT_MODEL\"].kwargs[\"local_files_only\"] = True\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,6 @@ def _load_pipeline(
"optimum[onnxruntime]" if device().type != "cuda" else "optimum[onnxruntime-gpu]",
)

if self.model.onnx_enable_hack:
tf_tokenizer.model_input_names = ["input_ids", "attention_mask"]

tf_model = optimum_onnxruntime.ORTModelForTokenClassification.from_pretrained(
self.model.onnx_path,
export=False,
Expand Down
6 changes: 2 additions & 4 deletions llm_guard/input_scanners/ban_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,15 @@
revision="caa3d167fd262c76c7da23cd72c1d24cfdcafd0f",
onnx_path="protectai/vishnun-codenlbert-sm-onnx",
onnx_revision="2b1d298410bd98832e41e3da82e20f6d8dff1bc7",
pipeline_kwargs={"truncation": True, "max_length": 128},
onnx_enable_hack=False,
pipeline_kwargs={"max_length": 128, "return_token_type_ids": True},
)

MODEL_TINY = Model(
path="vishnun/codenlbert-tiny",
revision="2caf5a621b29c50038ee081479a82f192e9a5e69",
onnx_path="protectai/vishnun-codenlbert-tiny-onnx",
onnx_revision="84148cb4b3f08fe44705e2d8ed81505450ae8abd",
pipeline_kwargs={"truncation": True, "max_length": 128},
onnx_enable_hack=False,
pipeline_kwargs={"max_length": 128, "return_token_type_ids": True},
)


Expand Down
20 changes: 0 additions & 20 deletions llm_guard/input_scanners/ban_topics.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@
onnx_path="MoritzLaurer/deberta-v3-large-zeroshot-v2.0",
onnx_revision="cf44676c28ba7312e5c5f8f8d2c22b3e0c9cdae2",
onnx_subfolder="onnx",
pipeline_kwargs={
"max_length": 512,
"truncation": True,
},
)

# The most performant base model. 0.18 B parameters, 369 MB.
Expand All @@ -31,10 +27,6 @@
onnx_path="MoritzLaurer/deberta-v3-base-zeroshot-v2.0",
onnx_subfolder="onnx",
onnx_revision="8e7e5af5983a0ddb1a5b45a38b129ab69e2258e8",
pipeline_kwargs={
"max_length": 512,
"truncation": True,
},
)

# The most performance multilingual model. 0.57 B parameters, 1.14 GB.
Expand All @@ -45,10 +37,6 @@
onnx_path="MoritzLaurer/bge-m3-zeroshot-v2.0",
onnx_subfolder="onnx",
onnx_revision="cd3f8598c7359a3b5cbce164d7fcdafb83a36484",
pipeline_kwargs={
"max_length": 8192,
"truncation": True,
},
)

# Less performant than deberta-v3 variants, but a bit faster and compatible with flash attention and TEI containers.
Expand All @@ -60,10 +48,6 @@
onnx_path="MoritzLaurer/roberta-large-zeroshot-v2.0-c",
onnx_subfolder="onnx",
onnx_revision="4c24ed4bba5af8d3162604abc2a141b9d2183ecc",
pipeline_kwargs={
"max_length": 512,
"truncation": True,
},
)

# Same model but smaller, more efficient version.
Expand All @@ -72,10 +56,6 @@
revision="d825e740e0c59881cf0b0b1481ccf726b6d65341",
onnx_path="protectai/MoritzLaurer-roberta-base-zeroshot-v2.0-c-onnx",
onnx_revision="fde5343dbad32f1a5470890505c72ec656db6dbe",
pipeline_kwargs={
"max_length": 512,
"truncation": True,
},
)


Expand Down
2 changes: 1 addition & 1 deletion llm_guard/input_scanners/code.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
onnx_path="philomath-1209/programming-language-identification",
onnx_revision="9090d38e7333a2c6ff00f154ab981a549842c20f",
onnx_subfolder="onnx",
pipeline_kwargs={"truncation": True, "top_k": None},
pipeline_kwargs={"top_k": None},
)

SUPPORTED_LANGUAGES = [
Expand Down
1 change: 0 additions & 1 deletion llm_guard/input_scanners/gibberish.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
onnx_path="madhurjindal/autonlp-Gibberish-Detector-492513457",
onnx_revision="fddf42c3008ad61cc481f90d02dd0712ba1ee2d8",
onnx_subfolder="onnx",
pipeline_kwargs={"truncation": True},
)


Expand Down
2 changes: 0 additions & 2 deletions llm_guard/input_scanners/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
onnx_path="ProtectAI/xlm-roberta-base-language-detection-onnx",
onnx_revision="dce2fa14a0dc61b6f889537e9ad4fccf083b22bd",
pipeline_kwargs={
"max_length": 512,
"truncation": True,
"top_k": None,
},
)
Expand Down
19 changes: 12 additions & 7 deletions llm_guard/input_scanners/prompt_injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,25 @@

LOGGER = get_logger()

PROMPT_CHARACTERS_LIMIT = 512
PROMPT_CHARACTERS_LIMIT = 256

# This model is proprietary but open source.
DEFAULT_MODEL = Model(
V1_MODEL = Model(
path="protectai/deberta-v3-base-prompt-injection",
revision="f51c3b2a5216ae1af467b511bc7e3b78dc4a99c9",
onnx_path="ProtectAI/deberta-v3-base-prompt-injection",
onnx_revision="f51c3b2a5216ae1af467b511bc7e3b78dc4a99c9",
onnx_subfolder="onnx",
onnx_filename="model.onnx",
pipeline_kwargs={
"max_length": 512,
"truncation": True,
},
)

V2_MODEL = Model(
path="protectai/deberta-v3-base-prompt-injection-v2-2024-04-20-16-52",
revision="69d3788a68e9d6c64b43e78284380a31182651d5",
onnx_path="ProtectAI/deberta-v3-base-prompt-injection",
onnx_revision="69d3788a68e9d6c64b43e78284380a31182651d5",
onnx_subfolder="onnx",
onnx_filename="model.onnx",
)


Expand Down Expand Up @@ -106,7 +111,7 @@ def __init__(
ValueError: If non-existent models were provided.
"""
if model is None:
model = DEFAULT_MODEL
model = V2_MODEL

if isinstance(match_type, str):
match_type = MatchType(match_type)
Expand Down
1 change: 0 additions & 1 deletion llm_guard/input_scanners/toxicity.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
"padding": "max_length",
"top_k": None,
"function_to_apply": "sigmoid",
"truncation": True,
},
)

Expand Down
15 changes: 11 additions & 4 deletions llm_guard/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,18 @@ class Model:
onnx_revision: Optional[str] = None
onnx_subfolder: str = ""
onnx_filename: str = "model.onnx"
onnx_enable_hack: bool = True # Enable hack to remove token_type_ids from input
kwargs: Dict = dataclasses.field(default_factory=dict)
pipeline_kwargs: Dict = dataclasses.field(
default_factory=lambda: {"batch_size": 1, "device": device()}
)
pipeline_kwargs: Dict = dataclasses.field(default_factory=dict)

def __post_init__(self):
default_pipeline_kwargs = {
"max_length": 512,
"truncation": True,
"batch_size": 1,
"device": device(),
"return_token_type_ids": False,
}
self.pipeline_kwargs = {**default_pipeline_kwargs, **self.pipeline_kwargs}

def __str__(self):
return self.path
1 change: 0 additions & 1 deletion llm_guard/output_scanners/bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
revision="c1e4a2773522c3acc929a7b2c9af2b7e4137b96d",
onnx_path="ProtectAI/distilroberta-bias-onnx",
onnx_revision="3e64d057d20d7ef43fa4f831b992bad28d72640e",
pipeline_kwargs={"truncation": True},
)


Expand Down
2 changes: 0 additions & 2 deletions llm_guard/output_scanners/malicious_urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
onnx_path="ProtectAI/codebert-base-Malicious_URLs-onnx",
onnx_revision="7bc4fa926eeae5e752d0790cc42faa24eb32fa64",
pipeline_kwargs={
"max_length": 512,
"truncation": True,
"top_k": None,
},
)
Expand Down
4 changes: 0 additions & 4 deletions llm_guard/output_scanners/no_refusal.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,6 @@
onnx_path="ProtectAI/distilroberta-base-rejection-v1",
onnx_revision="65584967c3f22ff7723e5370c65e0e76791e6055",
onnx_subfolder="onnx",
pipeline_kwargs={
"max_length": 512,
"truncation": True,
},
)


Expand Down
4 changes: 0 additions & 4 deletions llm_guard/transformers_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,6 @@ def get_tokenizer_and_model_for_classification(

return tf_tokenizer, tf_model

# Hack for some models
if model.onnx_enable_hack:
tf_tokenizer.model_input_names = ["input_ids", "attention_mask"]

tf_model = _ort_model_for_sequence_classification(model)

return tf_tokenizer, tf_model
Expand Down
2 changes: 1 addition & 1 deletion tests/input_scanners/test_prompt_injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@
),
],
)
def test_scan_model_default(
def test_scan(
match_type: MatchType,
prompt: str,
expected_prompt: str,
Expand Down

0 comments on commit 6fed22f

Please sign in to comment.