Skip to content

Update README.md #21

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,23 +58,36 @@ MetaCLIP uses 500,000 queries as [metadata](metadata.json) to align the training
We change OpenCLIP to match training in the default CLIP model setup (w/ [ViT-B-16-quickgelu](src/open_clip/model_configs/ViT-B-16-quickgelu.json), [ViT-L-14-quickgelu](src/open_clip/model_configs/ViT-L-14-quickgelu.json) and [ViT-H-14-quickgelu](src/open_clip/model_configs/ViT-H-14-quickgelu.json)). Most OpenCLIP models use `nn.GELU` not `quickgelu` used by vanilla CLIP. We hope this helps research w/ controlled experiments in the "CLIP era of ImageNet".

```python
# Import necessary libraries and modules
import torch
from PIL import Image
import open_clip

# Create an OpenCLIP model, specify the architecture, and load a pretrained model checkpoint
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32-quickgelu', pretrained='metaclip/b32_400m.pt')

# Load an image and preprocess it using the defined transformations
image = preprocess(Image.open("CLIP.png")).unsqueeze(0)

# Tokenize a list of textual descriptions
text = open_clip.tokenize(["a diagram", "a dog", "a cat"])

# Perform image and text encoding using the model
with torch.no_grad():
# Encode the image
image_features = model.encode_image(image)

# Encode the text
text_features = model.encode_text(text)

# Normalize the image and text features
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)

# Calculate label probabilities by computing the dot product between image and text features
text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)

# Print the label probabilities
print("Label probs:", text_probs)
```

Expand All @@ -98,24 +111,36 @@ We have a [demo notebook](demo.ipynb) to show how the proposed algorithm works.
CLIP curation can still help as online balancing (Table 6 in the paper). We wrap CLIP curation in two key functions: [substring matching](metaclip/substr_matching.py) (recommended to run offline) and [balancing](metaclip/balancing.py) (either offline or online, please check `metaclip.balancing:main`).

```python
# Import necessary libraries and modules
import json
import numpy as np
from metaclip.substr_matching import substr_matching
from metaclip.balancing import balance_sampling

# Load metadata from a JSON file
with open("metadata.json") as f:
metadata = json.load(f)
# entry counts for our 1.6B(pool) -> 400M(curated); please check balance_sampling:main and substr match and count on your own data.

# Load entry counts for 400M curated data from a JSON file
with open("metaclip/entry_counts_400m.json") as f:
entry_count_json = json.load(f)
entry_count = np.array([entry_count_json[entry] for entry in metadata], dtype=np.uint64) # uint64 to be safe for scaling.

# Convert entry counts to a NumPy array with a safe data type (uint64)
entry_count = np.array([entry_count_json[entry] for entry in metadata], dtype=np.uint64)

# Set a threshold value 't' for entry counts
t = 20000
entry_count[entry_count < t] = t

# Calculate entry probabilities based on the threshold value
entry_prob = t / entry_count

# Iterate through a list of texts
for text in ["jacksons chameleon", "battery plate"]:
# Use substr_matching to find matching entry IDs for the text
matched_entry_ids = substr_matching(text, metadata)

# Perform balance sampling using entry probabilities
if balance_sampling(matched_entry_ids, entry_prob):
print(f"'{text}' curated")
```
Expand Down