-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
WIP: compute CLIP scores for another dataset
- Loading branch information
1 parent
ee25068
commit 477a294
Showing
4 changed files
with
83 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import argparse | ||
from typing import Any | ||
|
||
|
||
# Copied from https://github.com/allenai/allennlp/blob/3aafb92/allennlp/commands/__init__.py | ||
class ArgumentParserWithDefaults(argparse.ArgumentParser): | ||
"""Custom argument parser that will display the default value for an argument in the help message. """ | ||
|
||
_action_defaults_to_ignore = {"help", "store_true", "store_false", "store_const"} | ||
|
||
@staticmethod | ||
def _is_empty_default(default: Any) -> bool: | ||
return default is None or (isinstance(default, (str, list, tuple, set)) and not default) | ||
|
||
def add_argument(self, *args, **kwargs) -> argparse.Action: | ||
# Add default value to the help message when the default is meaningful. | ||
default = kwargs.get("default") | ||
if kwargs.get("action") not in self._action_defaults_to_ignore and not self._is_empty_default(default): | ||
kwargs["help"] = f"{kwargs.get('help', '')} (default = {default})" | ||
return super().add_argument(*args, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
#!/usr/bin/env python | ||
import argparse | ||
import os | ||
import random | ||
|
||
import numpy as np | ||
import torch | ||
from datasets import load_dataset | ||
from tqdm.auto import tqdm | ||
from transformers import AutoModel, AutoProcessor | ||
|
||
from argparse_with_defaults import ArgumentParserWithDefaults | ||
|
||
|
||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
|
||
def parse_args() -> argparse.Namespace: | ||
parser = ArgumentParserWithDefaults() | ||
parser.add_argument("--seed", type=int, default=42) | ||
parser.add_argument("--dataset", default="red_caps", | ||
help="See options at https://huggingface.co/datasets?" | ||
"task_categories=task_categories:image-to-text") | ||
parser.add_argument("--model-name-or-path", default="openai/clip-vit-large-patch14", | ||
help="See options at https://huggingface.co/models?pipeline_tag=zero-shot-image-classification") | ||
parser.add_argument("--output-path", default="output.pt") | ||
return parser.parse_args() | ||
|
||
def main() -> None: | ||
args = parse_args() | ||
|
||
print(args) | ||
|
||
random.seed(args.seed) | ||
np.random.seed(args.seed) | ||
torch.manual_seed(args.seed) | ||
torch.cuda.manual_seed_all(args.seed) | ||
|
||
torch.use_deterministic_algorithms(True) | ||
# https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility | ||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" | ||
|
||
processor = AutoProcessor.from_pretrained(args.model_name_or_path) | ||
model = AutoModel.from_pretrained(args.model_name_or_path).to(DEVICE).eval() | ||
|
||
scores = [] | ||
|
||
with torch.inference_mode(): | ||
for batch in tqdm(load_dataset(args.dataset, split="train", streaming=True)): | ||
batch = batch.to(DEVICE) | ||
output = model(**batch) | ||
scores.append(output.logits.cpu()) | ||
|
||
torch.save(scores, args.output_path) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters