Skip to content

Commit

Permalink
WIP: compute CLIP scores for another dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
bryant1410 committed Mar 17, 2023
1 parent ee25068 commit 477a294
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 2 deletions.
20 changes: 20 additions & 0 deletions argparse_with_defaults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import argparse
from typing import Any


# Copied from https://github.com/allenai/allennlp/blob/3aafb92/allennlp/commands/__init__.py
class ArgumentParserWithDefaults(argparse.ArgumentParser):
"""Custom argument parser that will display the default value for an argument in the help message. """

_action_defaults_to_ignore = {"help", "store_true", "store_false", "store_const"}

@staticmethod
def _is_empty_default(default: Any) -> bool:
return default is None or (isinstance(default, (str, list, tuple, set)) and not default)

def add_argument(self, *args, **kwargs) -> argparse.Action:
# Add default value to the help message when the default is meaningful.
default = kwargs.get("default")
if kwargs.get("action") not in self._action_defaults_to_ignore and not self._is_empty_default(default):
kwargs["help"] = f"{kwargs.get('help', '')} (default = {default})"
return super().add_argument(*args, **kwargs)
58 changes: 58 additions & 0 deletions compute_clip_scores.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#!/usr/bin/env python
import argparse
import os
import random

import numpy as np
import torch
from datasets import load_dataset
from tqdm.auto import tqdm
from transformers import AutoModel, AutoProcessor

from argparse_with_defaults import ArgumentParserWithDefaults


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def parse_args() -> argparse.Namespace:
parser = ArgumentParserWithDefaults()
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--dataset", default="red_caps",
help="See options at https://huggingface.co/datasets?"
"task_categories=task_categories:image-to-text")
parser.add_argument("--model-name-or-path", default="openai/clip-vit-large-patch14",
help="See options at https://huggingface.co/models?pipeline_tag=zero-shot-image-classification")
parser.add_argument("--output-path", default="output.pt")
return parser.parse_args()

def main() -> None:
args = parse_args()

print(args)

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

torch.use_deterministic_algorithms(True)
# https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

processor = AutoProcessor.from_pretrained(args.model_name_or_path)
model = AutoModel.from_pretrained(args.model_name_or_path).to(DEVICE).eval()

scores = []

with torch.inference_mode():
for batch in tqdm(load_dataset(args.dataset, split="train", streaming=True)):
batch = batch.to(DEVICE)
output = model(**batch)
scores.append(output.logits.cpu())

torch.save(scores, args.output_path)


if __name__ == "__main__":
main()
3 changes: 2 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from statsmodels.tools.tools import pinv_extended
from tqdm.auto import tqdm

from argparse_with_defaults import ArgumentParserWithDefaults
from features import VALID_LEVIN_RETURN_MODES, is_feature_binary, is_feature_multi_label, is_feature_string, \
load_features

Expand Down Expand Up @@ -252,7 +253,7 @@ def compute_mean_diff_and_corr(features: pd.DataFrame, dependent_variable: pd.Se


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser = ArgumentParserWithDefaults()
parser.add_argument("--model", default="mean-diff-and-corr", choices=MODELS)
parser.add_argument("--input-path", default="data/merged.csv")

Expand Down
4 changes: 3 additions & 1 deletion merge_csvs_and_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@

import pandas as pd

from argparse_with_defaults import ArgumentParserWithDefaults


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser = ArgumentParserWithDefaults()
parser.add_argument("--probes_path", default="data/svo_probes.csv")
parser.add_argument("--neg_path", default="data/neg_d.csv")
return parser.parse_args()
Expand Down

0 comments on commit 477a294

Please sign in to comment.