Skip to content

Commit

Permalink
Fix clip score OOM
Browse files Browse the repository at this point in the history
  • Loading branch information
lixiang007666 committed Jul 25, 2024
1 parent 1c2e690 commit e2d3858
Show file tree
Hide file tree
Showing 8 changed files with 294 additions and 47 deletions.
20 changes: 7 additions & 13 deletions T2IBenchmark/loaders.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Callable, Any
from typing import List, Optional, Callable, Any, Dict
from abc import ABC, abstractmethod
import os
from PIL import Image
Expand Down Expand Up @@ -118,29 +118,23 @@ def __str__(self) -> str:


class CaptionImageDataset(Dataset):

def __init__(
self,
images_paths: List[str],
captions: List[str],
preprocess_fn: Optional[Callable[[Image.Image], Any]] = None,
):
assert len(images_paths) == len(captions)
def __init__(self, images_paths: List[str], captions_mapping: Dict[str, str], preprocess_fn: Optional[Callable[[Image.Image], Any]] = None):
self.images_paths = images_paths
self.captions = captions
self.captions_mapping = captions_mapping
self.preprocess_fn = preprocess_fn if preprocess_fn else lambda x: x

def __len__(self) -> int:
return len(self.images_paths)

def __getitem__(self, idx: int) -> tuple:
image = Image.open(self.images_paths[idx])
return self.preprocess_fn(image), self.captions[idx]
image_path = self.images_paths[idx]
image = Image.open(image_path)
caption = self.captions_mapping[image_path]
return self.preprocess_fn(image), caption

def __str__(self) -> str:
return f"CaptionImageDataset({self.__len__()} items)"


def get_images_from_folder(folder_path: str) -> ImageDataset:
filepaths = []
for root, dirs, files in os.walk(folder_path):
Expand Down
40 changes: 19 additions & 21 deletions T2IBenchmark/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,17 +180,14 @@ def calculate_clip_score(
seed: Optional[int] = 42,
batch_size: int = 128,
dataloader_workers: int = 16,
verbose: bool = True,
verbose: bool = False,
):
if seed:
set_all_seeds(seed)

model, preprocess = clip.load("ViT-B/32", device=device)
dataset = CaptionImageDataset(
images_paths=image_paths,
captions=list(map(lambda x: captions_mapping[x], image_paths)),
preprocess_fn=preprocess,
)
dataset = CaptionImageDataset(images_paths=image_paths, captions_mapping=captions_mapping, preprocess_fn=preprocess)

dataloader = DataLoader(
dataset,
batch_size=batch_size,
Expand All @@ -200,24 +197,25 @@ def calculate_clip_score(
)

score_acc = 0.0
num_samples = 0.0
num_samples = 0

for images, captions in tqdm(dataloader):
images = images.to(device)
captions = [clip.tokenize(caption).to(device) for caption in captions]

for image, caption in tqdm(dataloader):
image_embedding = model.encode_image(image.to(device))
caption_embedding = model.encode_text(clip.tokenize(caption).to(device))
with torch.no_grad():
image_embeddings = model.encode_image(images)
caption_embeddings = model.encode_text(torch.cat(captions))

image_features = image_embedding / image_embedding.norm(dim=1, keepdim=True).to(
torch.float32
)
caption_features = caption_embedding / caption_embedding.norm(
dim=1, keepdim=True
).to(torch.float32)
image_features = image_embeddings / image_embeddings.norm(dim=1, keepdim=True)
caption_features = caption_embeddings / caption_embeddings.norm(dim=1, keepdim=True)

score = (image_features * caption_features).sum()
score_acc += score
num_samples += image.shape[0]
score = (image_features * caption_features).sum(dim=1).mean().item()
score_acc += score * images.size(0)
num_samples += images.size(0)

clip_score = score_acc / num_samples
dprint(verbose, f"CLIP score is {clip_score}")
if verbose:
print(f"CLIP score is {clip_score}")

return clip_score
return clip_score
2 changes: 1 addition & 1 deletion metrics/aesthetic_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def normalized(a, axis=-1, order=2):

def evaluate_images(folder_path):
path = Path(folder_path)
images = list(path.rglob("*.jpg"))
images = list(path.rglob("*.png"))
scores = []

for img_path in images:
Expand Down
28 changes: 28 additions & 0 deletions metrics/clip_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import pandas as pd
from glob import glob
from T2IBenchmark import calculate_clip_score

cat_paths = sorted(glob('/home/lixiang/data/fid_kolors_nexfort/*.png'), key=lambda x: int(x.split('_')[-1].split('.')[0]))

print(f"Number of image files: {len(cat_paths)}")

csv_path = '/home/lixiang/odeval/MS-COCO_val2014_30k_captions.csv'
try:
captions_df = pd.read_csv(csv_path)
print(f"Number of captions read: {len(captions_df)}")

if len(cat_paths) != len(captions_df):
print("Error: The number of images does not match the number of captions.")
else:
captions_mapping = {cat_paths[i]: captions_df.iloc[i, 1] for i in range(len(cat_paths))}
print("Captions mapping created successfully.")

clip_score = calculate_clip_score(cat_paths, captions_mapping=captions_mapping)
print(f"CLIP Score: {clip_score}")

except FileNotFoundError:
print(f"Error: The file {csv_path} was not found.")
except pd.errors.EmptyDataError:
print("Error: No data found in the CSV file.")
except Exception as e:
print(f"An error occurred: {e}")
2 changes: 1 addition & 1 deletion metrics/fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
from T2IBenchmark.datasets import get_coco_fid_stats

fid, _ = calculate_fid(
'/home/lixiang/data/fid_kolors_nexfort',
'/home/lixiang/data/fid_kolors_torch',
get_coco_fid_stats()
)
22 changes: 11 additions & 11 deletions metrics/structural_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,17 @@ def compare_images(standard_output, current_output):

def average_metrics(folder1, folder2):
metrics = {'ssim': [], 'mse': [], 'mae': []}
for subfolder in os.listdir(folder1):
subfolder_path1 = os.path.join(folder1, subfolder)
subfolder_path2 = os.path.join(folder2, subfolder)
images1 = [os.path.join(subfolder_path1, f) for f in os.listdir(subfolder_path1) if f.endswith(('.png', '.jpg', '.jpeg'))]
images2 = [os.path.join(subfolder_path2, f) for f in os.listdir(subfolder_path2) if f.endswith(('.png', '.jpg', '.jpeg'))]
for img1, img2 in zip(images1, images2):
image1 = imread(img1)
image2 = imread(img2)
results = compare_images(image1, image2)
for key in metrics:
metrics[key].append(results[key])
# for subfolder in os.listdir(folder1):
# subfolder_path1 = os.path.join(folder1, subfolder)
# subfolder_path2 = os.path.join(folder2, subfolder)
images1 = [os.path.join(folder1, f) for f in os.listdir(folder1) if f.endswith(('.png', '.jpg', '.jpeg'))]
images2 = [os.path.join(folder2, f) for f in os.listdir(folder2) if f.endswith(('.png', '.jpg', '.jpeg'))]
for img1, img2 in zip(images1, images2):
image1 = imread(img1)
image2 = imread(img2)
results = compare_images(image1, image2)
for key in metrics:
metrics[key].append(results[key])
average_results = {k: np.mean(v) for k, v in metrics.items()}
return average_results

Expand Down
Loading

0 comments on commit e2d3858

Please sign in to comment.