-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
d8266ae
commit f9d27d3
Showing
25 changed files
with
417 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
env/** |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,129 @@ | ||
# offensive_open_clip | ||
A ready to use open clip implementation that aims to catch offensive image file uploads | ||
|
||
# Why | ||
|
||
This repository was inspired to be used as an image upload filtering mechanism by utilizing https://github.com/mlfoundations/open_clip to avoid inconvenient uploads by malicious users. This will also ease out the pain of reviewing all images (if this is something that you do in your daily lives or business). | ||
|
||
# How | ||
|
||
It is ready to use with the capability to scan a local image path, an image from a URL or a local directory. | ||
|
||
First install via requirements or follow the `instructions.txt` file (prefer this way): | ||
|
||
```bash | ||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124 | ||
pip install requests tabulate | ||
pip install open-clip-torch==2.26.1 | ||
``` | ||
|
||
The run the help dialogue: | ||
|
||
```bash | ||
$ ./offensive_open_clip.py -h | ||
usage: offensive_open_clip.py [-h] | ||
(--image-path IMAGE_PATH | --image-url IMAGE_URL | --image-dir IMAGE_DIR | --list-models | --runs-on) | ||
[--threshold THRESHOLD] [--model MODEL] | ||
[--dataset DATASET] | ||
|
||
options: | ||
-h, --help show this help message and exit | ||
--image-path IMAGE_PATH | ||
Local path to the image file to be processed. | ||
--image-url IMAGE_URL | ||
URL of the image to be downloaded and processed. | ||
--image-dir IMAGE_DIR | ||
Directory containing images to be processed. | ||
--list-models List all available pretrained models. | ||
--runs-on Show if running on cuda-GPU or CPU. GPU | ||
implementations are running faster. | ||
--threshold THRESHOLD | ||
Probability threshold for disallowing content (must be | ||
between 0.0 and 1.0). | ||
--model MODEL The model to use for predictions. Use '--list-models' | ||
for available models. Defaults to 'ViT-B-32'. | ||
--dataset DATASET The pretrained dataset to use for predictions. Use '-- | ||
list-models' for available datasets. Defaults to | ||
'laion2b_s34b_b79k'. | ||
``` | ||
# Models and datasets | ||
Use the following command to list: | ||
```bash | ||
$ ./offensive_open_clip.py --list-models | ||
+----------------------------+------------------------------------+ | ||
| Model | Pretrained Version | | ||
+============================+====================================+ | ||
| RN50 | openai | | ||
+----------------------------+------------------------------------+ | ||
| RN50 | yfcc15m | | ||
+----------------------------+------------------------------------+ | ||
| RN50 | cc12m | | ||
+----------------------------+------------------------------------+ | ||
| RN50-quickgelu | openai | | ||
+----------------------------+------------------------------------+ | ||
| RN50-quickgelu | yfcc15m | | ||
+----------------------------+------------------------------------+ | ||
| RN50-quickgelu | cc12m | | ||
+----------------------------+------------------------------------+ | ||
| RN101 | openai | | ||
+----------------------------+------------------------------------+ | ||
| RN101 | yfcc15m | | ||
+----------------------------+------------------------------------+ | ||
| RN101-quickgelu | openai | | ||
|
||
... | ||
|
||
+----------------------------+------------------------------------+ | ||
| ViTamin-L2-256 | datacomp1b | | ||
+----------------------------+------------------------------------+ | ||
| ViTamin-L2-336 | datacomp1b | | ||
+----------------------------+------------------------------------+ | ||
| ViTamin-L2-384 | datacomp1b | | ||
+----------------------------+------------------------------------+ | ||
| ViTamin-XL-256 | datacomp1b | | ||
+----------------------------+------------------------------------+ | ||
| ViTamin-XL-336 | datacomp1b | | ||
+----------------------------+------------------------------------+ | ||
| ViTamin-XL-384 | datacomp1b | | ||
+----------------------------+------------------------------------+ | ||
``` | ||
The default one does its job pretty well. You can also specify another model, dataset and threshold as shown: | ||
```bash | ||
$ ./offensive_open_clip.py --image-path=img/bank1.png --model=ViT-bigG-14 --dataset=laion2b_s39b_b160k --threshold=0.4 | ||
``` | ||
# Output | ||
The output of an image scan will have the following format: | ||
```bash | ||
$ ./offensive_open_clip.py --image-path=img/ass1.jpg | ||
{ | ||
"image": "img/ass1.jpg", | ||
"label": "ass", | ||
"probability": 0.9341, | ||
"prediction_time": 3.837, | ||
"state": "disallowed" | ||
} | ||
|
||
$ ./offensive_open_clip.py --image-path=img/product1.jpeg | ||
{ | ||
"image": "img/product1.jpeg", | ||
"label": "bomb", | ||
"probability": 0.2271, | ||
"prediction_time": 3.809, | ||
"state": "allowed" | ||
} | ||
``` | ||
If the state is `disallowed` then it means we are on to something. You can verify by running some samples in the `img/` directory. | ||
# Notes | ||
Larger models and datasets come with better results so you have to balance out speed with accuracy and tune out your prefered threshold. \ | ||
I do not own the models or datasets. This is simply a wrapper that utilizes the already existing `open_clip` tool with a few tweaks and enhancements. |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124 | ||
pip install requests tabulate | ||
pip install open-clip-torch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,255 @@ | ||
#!/usr/bin/env python | ||
import open_clip | ||
from PIL import Image | ||
import torch | ||
import json | ||
import argparse | ||
from pathlib import Path | ||
import sys | ||
import requests | ||
import time | ||
from tabulate import tabulate | ||
from urllib.parse import urlparse, unquote | ||
from io import BytesIO | ||
import warnings | ||
warnings.filterwarnings("ignore", category=FutureWarning) | ||
|
||
ALLOWED_IMG_FORMAT = ('.jpg', '.jpeg', '.png') | ||
|
||
VIOLENCE_TEXT_LABELS = [ | ||
"violence", "blood", "gore", "death", "murder", "assault", "fighting", "abuse", | ||
"gun", "weapon", "knife", "explosion", "war", "terrorism", "terrorist", "bomb", | ||
"mass shooting", "riot", "kidnapping", "hostage", "shooting", "execution", | ||
"war tank", "nuke", | ||
] | ||
|
||
SEXUAL_TEXT_LABELS = [ | ||
"nudity", "sexual content", "explicit content", "profanity", "pornographic content", "porn", | ||
"sex", "rape", "molestation", "incest", "sexual harassment", "fetish", "child pornography", | ||
"ass", "asshole", "vagina", "pussy", "genitals", "dick", "boobs", "dildo", "gag", "sex toy", "anal", | ||
"fuck", "fucking", "pussy fucking", "ass fucking", "semen", "poop", | ||
] | ||
|
||
DRUGS_TEXT_LABELS = [ | ||
"drugs", "drug use", "alcohol", "smoking", "addiction", | ||
"cocaine", "heroin", "marijuana", "opioids", "meth", "drug paraphernalia", | ||
"overdose", "pill", "syringe", "injection", | ||
] | ||
|
||
RACISM_TEXT_LABELS = [ | ||
"racism", "hate speech", "offensive language", "discrimination", | ||
"sexism", "homophobia", "transphobia", "slur", "xenophobia", "anti-semitism", | ||
"misogyny", "bigotry", "white supremacy", "neo-nazi", "alt-right", "extremism", | ||
"hate crime", "terror group", "cult", "propaganda", | ||
] | ||
|
||
HARASSMENT_TEXT_LABELS = [ | ||
"harassment", "bullying", "cyberbullying", "stalking", | ||
"child abuse", "animal abuse", "self-harm", "suicide", "sadism", "torture", | ||
"domestic violence", "spousal abuse", "verbal abuse", "psychological abuse", | ||
] | ||
|
||
ILLEGAL_ACTIVITIES_TEXT_LABELS = [ | ||
"theft", "vandalism", "fraud", "money laundering", "bribery", | ||
"extortion", "blackmail", "corruption", "scam", "hacking", "phishing", | ||
"identity theft", "counterfeiting", "illegal trade", "human trafficking", | ||
"organized crime", "drug trafficking", "arms trafficking", "terrorism financing", | ||
"arson", "looting", "tax evasion", "embezzlement", "racketeering", "prohibited goods", | ||
] | ||
|
||
PII_TEXT_LABELS = [ | ||
"passport", "ID card", "driver's license", "social security number", | ||
"credit card", "bank statement", "birth certificate", "visa", "medical record", | ||
"address", "phone number", "email address", "biometric data", "fingerprint", | ||
"face recognition", "personal document", "tax document", "insurance card", | ||
] | ||
|
||
OTHER_TEXT_LABELS = [ | ||
"disease", "epidemic", "plague", "biological weapon", "biohazard", | ||
"vomit", "feces", "urine", "defecation", "rotting", | ||
"decay", "severed", "decomposition", "corpse", "dead body", "autopsy", | ||
"abnormal", "demonic", "occult", "ritual", "witchcraft", "satanic", | ||
"cult", "blood", "sacrifice", "sacrilege", | ||
] | ||
|
||
DISALLOWED_TEXT_LABELS = VIOLENCE_TEXT_LABELS + SEXUAL_TEXT_LABELS + DRUGS_TEXT_LABELS + RACISM_TEXT_LABELS + HARASSMENT_TEXT_LABELS + ILLEGAL_ACTIVITIES_TEXT_LABELS + PII_TEXT_LABELS + OTHER_TEXT_LABELS | ||
|
||
def list_pretrained_models(): | ||
models = open_clip.list_pretrained() | ||
|
||
# Prepare data for the table | ||
table_data = [] | ||
for model_name, pretrained_model in models: | ||
table_data.append([model_name, pretrained_model]) | ||
|
||
# Print the table | ||
print(tabulate(table_data, headers=["Model", "Pretrained Version"], tablefmt="grid")) | ||
|
||
|
||
def runs_on(): | ||
# Run on GPU if available else on CPU | ||
device = 'cuda' if torch.cuda.is_available() else 'cpu' | ||
return device | ||
|
||
|
||
def predict_label(device, model, tokenizer, image, image_name): | ||
# Disallowed labels passed here | ||
text = tokenizer(DISALLOWED_TEXT_LABELS) | ||
|
||
start_time = time.time() | ||
|
||
# Calculate stuff | ||
with torch.no_grad(), torch.amp.autocast(device): | ||
image_features = model.encode_image(image) | ||
text_features = model.encode_text(text) | ||
image_features /= image_features.norm(dim=-1, keepdim=True) | ||
text_features /= text_features.norm(dim=-1, keepdim=True) | ||
|
||
text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) | ||
|
||
end_time = time.time() | ||
|
||
# Find the max index label | ||
max_prob_index = text_probs.argmax(dim=-1).item() | ||
|
||
# Get the label | ||
predicted_label = DISALLOWED_TEXT_LABELS[max_prob_index] | ||
|
||
# Humanize the probability | ||
predicted_probability = text_probs[0, max_prob_index].item() | ||
|
||
result = { | ||
"image": str(image_name), | ||
"label": predicted_label, | ||
"probability": round(predicted_probability, 4), | ||
"prediction_time": round(end_time - start_time, 4) | ||
} | ||
|
||
return json.dumps(result, indent=4) | ||
|
||
|
||
def evaluate_label(label, threshold): | ||
# Parse the JSON label result | ||
label_dict = json.loads(label) | ||
|
||
# Check if the probability exceeds the threshold | ||
if label_dict["probability"] > threshold: | ||
label_dict["state"] = 'disallowed' | ||
else: | ||
label_dict["state"] = 'allowed' | ||
|
||
# Return the updated label with the new state field as a JSON string | ||
return json.dumps(label_dict, indent=4) | ||
|
||
|
||
|
||
def main(selected_model, selected_dataset, image_path, image_url, image_dir, threshold, device): | ||
# Create the model and transformations | ||
model, _, preprocess = open_clip.create_model_and_transforms(selected_model, pretrained=selected_dataset) | ||
|
||
# This disables model training mode | ||
model.eval() | ||
|
||
# Use the corresponding tokenizer | ||
tokenizer = open_clip.get_tokenizer(selected_model) | ||
|
||
# Read the image | ||
if image_path: | ||
image = preprocess(Image.open(image_path)).unsqueeze(0) | ||
# Predict | ||
label = predict_label(device, model, tokenizer, image, image_path) | ||
# Evaluate | ||
print(evaluate_label(label, threshold)) | ||
elif image_url: | ||
# Fetch from URL | ||
response = requests.get(image_url) | ||
if response.status_code == 200: | ||
image = preprocess(Image.open(BytesIO(response.content))).unsqueeze(0) | ||
# Predict | ||
label = predict_label(device, model, tokenizer, image, image_url) | ||
# Evaluate | ||
print(evaluate_label(label, threshold)) | ||
else: | ||
print(f"Failed to download image from '{image_url}'.") | ||
sys.exit(1) | ||
elif image_dir: | ||
image_files = [file for file in Path(image_dir).glob('*') if file.suffix.lower() in ALLOWED_IMG_FORMAT] | ||
if not image_files: | ||
print(f"Error: No files with {ALLOWED_IMG_FORMAT} found in {image_dir}.") | ||
sys.exit(1) | ||
for image_file in image_files: | ||
image = preprocess(Image.open(image_file)).unsqueeze(0) | ||
# Predict | ||
label = predict_label(device, model, tokenizer, image, image_file) | ||
# Evaluate | ||
print(evaluate_label(label, threshold), flush=True) | ||
|
||
|
||
def valid_threshold(value): | ||
# Convert the input value to float and check if it's between 0.0 and 1.0 | ||
try: | ||
fvalue = float(value) | ||
except ValueError: | ||
raise argparse.ArgumentTypeError(f"Invalid threshold value: {value}. Must be a float.") | ||
|
||
if fvalue < 0.0 or fvalue > 1.0: | ||
raise argparse.ArgumentTypeError(f"Threshold must be between 0.0 and 1.0. Got {value}.") | ||
|
||
return fvalue | ||
|
||
|
||
def is_valid_image_url(url): | ||
parsed_url = urlparse(url) | ||
|
||
# Check if the URL starts with http or https | ||
if parsed_url.scheme not in ('http', 'https'): | ||
return False | ||
|
||
path = unquote(parsed_url.path) | ||
return path.lower().endswith((ALLOWED_IMG_FORMAT)) | ||
|
||
|
||
if __name__ == "__main__": | ||
# Parse the arguments | ||
parser = argparse.ArgumentParser() | ||
# Mutually exclusive group | ||
group = parser.add_mutually_exclusive_group(required=True) | ||
group.add_argument("--image-path", type=str, help="Local path to the image file to be processed.") | ||
group.add_argument("--image-url", type=str, help="URL of the image to be downloaded and processed.") | ||
group.add_argument("--image-dir", type=str, help="Directory containing images to be processed.") | ||
group.add_argument("--list-models", action="store_true", help="List all available pretrained models.") | ||
group.add_argument("--runs-on", action="store_true", help="Show if running on cuda-GPU or CPU. GPU implementations are running faster.") | ||
parser.add_argument("--threshold", type=valid_threshold, default=0.45, help="Probability threshold for disallowing content (must be between 0.0 and 1.0).") | ||
parser.add_argument("--model", type=str, default="ViT-B-32", help="The model to use for predictions. Use '--list-models' for available models. Defaults to 'ViT-B-32'.") | ||
parser.add_argument("--dataset", type=str, default="laion2b_s34b_b79k", help="The pretrained dataset to use for predictions. Use '--list-models' for available datasets. Defaults to 'laion2b_s34b_b79k'.") | ||
|
||
args = parser.parse_args() | ||
|
||
if args.list_models: | ||
list_pretrained_models() | ||
sys.exit(0) | ||
|
||
if args.runs_on: | ||
device = runs_on() | ||
print(f"Running on '{device}'.") | ||
sys.exit(0) | ||
|
||
if args.image_path: | ||
if not Path(args.image_path).exists(): | ||
print(f"No such image file '{args.image_path}'") | ||
sys.exit(1) | ||
if not args.image_path.lower().endswith(ALLOWED_IMG_FORMAT): | ||
print(f"Error: Only {ALLOWED_IMG_FORMAT} files allowed.") | ||
sys.exit(1) | ||
|
||
if args.image_url: | ||
if not is_valid_image_url(args.image_url): | ||
print(f"Error: Only {ALLOWED_IMG_FORMAT} files allowed.") | ||
sys.exit(1) | ||
|
||
if args.image_dir: | ||
if not Path(args.image_dir).is_dir(): | ||
print(f"No such directory '{args.image_dir}'.") | ||
sys.exit(1) | ||
|
||
main(args.model, args.dataset, args.image_path, args.image_url, args.image_dir, args.threshold, runs_on()) |
Oops, something went wrong.