Skip to content

Commit

Permalink
First commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Saint-lsy committed Sep 6, 2024
0 parents commit e2177b5
Show file tree
Hide file tree
Showing 21 changed files with 2,226 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__pycache__/
73 changes: 73 additions & 0 deletions GlobalRetrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import torch
import faiss
import numpy as np
import time
import cv2
import argparse
import os
from Matcher import Dinov2Matcher

# #Define a function that normalizes embeddings and add them to the index
def add_vector_to_index(embedding, index):
#Convert to float32 numpy
vector = np.float32(embedding)
#Normalize vector: important to avoid wrong results when searching
faiss.normalize_L2(vector)
#Add to index
index.add(vector)


#Create Faiss index using FlatL2 type with 768 dimensions (DINOv2-base) as this

def build_database(base_path, database_save_path):

index = faiss.IndexFlatL2(768)
# Extract features
t0 = time.time()
image_paths = []
for root, dirs, files in os.walk(base_path):
for file in files:
# 检查文件是否是图像类型
if file.lower().endswith('.jpg'):
file_path = os.path.join(root, file)
image_paths.append(file_path)

for image_path in image_paths:
image = cv2.cvtColor(cv2.imread(image_path, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
with torch.no_grad():
image_tensor, _, _ = dm.prepare_image(image)
global_features = dm.extract_global_features(image_tensor)
add_vector_to_index(global_features, index)

# Store the index locally
print('Extraction done in :', time.time()-t0)
faiss.write_index(index, database_save_path)


######### Retrieve the image to search
def retrieval_from_database(image_path, database_path):
#Extract the features
with torch.no_grad():
image = cv2.cvtColor(cv2.imread(image_path, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
image_tensor, _, _ = dm.prepare_image(image)
global_features = dm.extract_global_features(image_tensor)

vector = np.float32(global_features)
faiss.normalize_L2(vector)

#Read the index file and perform search of top-3 images
index = faiss.read_index(database_path)
d,i = index.search(vector, 5)
print('distances:', d, 'indexes:', i)

if __name__=='__main__':
dm = Dinov2Matcher(half_precision=False)

parser = argparse.ArgumentParser(description='Global_trerieval')
parser.add_argument('--database_path', type=str, required=True, help='Path to save the database')
parser.add_argument('--data_path', type=str, required=True, help='Path of non-polyp reference images')
parser.add_argument('--image_path', type=str, required=True, help='Path of Query image')
args = parser.parse_args()

build_database(args.data_path, args.database_path)
retrieval_from_database(args.image_path, args.database_path)
121 changes: 121 additions & 0 deletions LocalMatching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import cv2
from sklearn.neighbors import NearestNeighbors
import argparse
import matplotlib.pyplot as plt
from matplotlib.patches import ConnectionPatch
from sklearn.cluster import DBSCAN
from Matcher import Dinov2Matcher

patch_size = 14


def plot_matching_figure(image1, image2, xyA_list, xyB_list, save_path):
fig = plt.figure(figsize=(10,10))
ax1 = fig.add_subplot(121)
ax2 = fig.add_subplot(122)

ax1.imshow(image1)
ax2.imshow(image2)

for xyA, xyB in zip(xyA_list, xyB_list):
con = ConnectionPatch(xyA=xyB, xyB=xyA, coordsA="data", coordsB="data",
axesA=ax2, axesB=ax1, color="red")
ax2.add_artist(con)

fig.tight_layout()
fig.show()
fig.savefig(save_path)



def MaskProposer(origin_image, origin_mask, target_image, target_mask, matching_figure_save_path=None):
# Init Dinov2Matcher
dm = Dinov2Matcher(half_precision=False)
# Extract image1 features
image1 = cv2.cvtColor(cv2.imread(origin_image, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
mask1 = cv2.imread(origin_mask, cv2.IMREAD_COLOR)[:,:,0] > 127
image_tensor1, grid_size1, resize_scale1 = dm.prepare_image(image1)
features1 = dm.extract_local_features(image_tensor1)
print(features1.shape)
# Extract image2 features
image2 = cv2.cvtColor(cv2.imread(target_image, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
image_tensor2, grid_size2, resize_scale2 = dm.prepare_image(image2)
features2 = dm.extract_local_features(image_tensor2)

# Build knn using features from image1, and query all features from image2
knn = NearestNeighbors(n_neighbors=1)
knn.fit(features1)
distances, match2to1 = knn.kneighbors(features2)
match2to1 = np.array(match2to1)

xyA_list = []
xyB_list = []
distances_list = []

for idx2, (dist, idx1) in enumerate(zip(distances, match2to1)):
row, col = dm.idx_to_source_position(idx1, grid_size1, resize_scale1)
xyA = (col, row)
if not mask1[int(row), int(col)]: continue # skip if feature is not on the object
row, col = dm.idx_to_source_position(idx2, grid_size2, resize_scale2)
xyB = (col, row)
xyB_list.append(xyB)
xyA_list.append(xyA)
distances_list.append(dist[0])

#Filter by distance
if len(xyA_list) > 30:
zip_list = list(zip(distances_list, xyA_list, xyB_list))
zip_list.sort(key=lambda x: x[0])
distances_list, xyA_list, xyB_list = zip(*zip_list)
xyA_list = xyA_list[:30]
xyB_list = xyB_list[:30]


if matching_figure_save_path is not None:
plot_matching_figure(image1, image2, xyA_list, xyB_list, matching_figure_save_path)

# DBSCAN clustering
X = np.array(xyB_list)
clustering = DBSCAN(eps=2*patch_size+1 , min_samples=1).fit(X)
labels = clustering.labels_

# find the cluster with the most number of points
unique_labels, counts = np.unique(labels, return_counts=True)
max_label = unique_labels[np.argmax(counts)]
new_list = [xyB for i, xyB in enumerate(xyB_list) if labels[i] == max_label]

#find the min-col and max-col of the cluster
min_col = np.min([xy[0] for xy in new_list]) - patch_size//2
max_col = np.max([xy[0] for xy in new_list]) + patch_size//2
#find the min-row and max-row of the cluster
min_row = np.min([xy[1] for xy in new_list]) - patch_size//2
max_row = np.max([xy[1] for xy in new_list]) + patch_size//2

mask = np.zeros((image2.shape[0], image2.shape[1]))
mask[int(min_row):int(max_row), int(min_col):int(max_col)] = 255
mask = mask.astype(np.uint8)
mask = Image.fromarray(mask).convert('L')
mask.save(target_mask)
return mask


if __name__ == "__main__":

parser = argparse.ArgumentParser(description='LocalMatching')
parser.add_argument('--ref_image', type=str, required=True, help='Path of Reference image')
parser.add_argument('--ref_mask', type=str, required=True, help='Path of Reference mask')
parser.add_argument('--query_image', type=str, required=True, help='Path of Query image')
parser.add_argument('--mask_proposal', type=str, required=True, help='Save Path of Mask proposal')
parser.add_argument('--save_fig', type=str, default=None, help='Save the Matching Figure')

args = parser.parse_args()

mask = MaskProposer(origin_image=args.ref_image,
origin_mask=args.ref_mask,
target_image=args.query_image,
target_mask=args.mask_proposal,
matching_figure_save_path=args.save_fig
)
99 changes: 99 additions & 0 deletions Matcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import torch
from transformers import AutoImageProcessor, AutoModel
from PIL import Image
import numpy as np
import torchvision.transforms as transforms
import torch.nn as nn


class GeM(nn.Module):
def __init__(self, p=3, eps=1e-6):
super(GeM, self).__init__()
self.p = nn.Parameter(torch.ones(1) * p)
self.eps = eps

def forward(self, x):
# 假设x的形状是 (batch_size, sequence_length, feature_dim)
# 对每个样本的所有序列特征进行GeM pooling
x = x.clamp(min=self.eps).pow(self.p)
x = torch.mean(x, dim=1) # 沿着sequence_length维度取平均
x = x.pow(1. / self.p)
return x

def __repr__(self):
return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')'


class Dinov2Matcher:

def __init__(self, repo_name="facebookresearch/dinov2", model_name="dinov2_vitb14", smaller_edge_size=512, half_precision=False, device="cuda"):
self.repo_name = repo_name
self.model_name = model_name
# self.smaller_edge_size = smaller_edge_size
self.half_precision = half_precision
self.device = device

if self.half_precision:
self.model = torch.hub.load(repo_or_dir=repo_name, model=model_name).half().to(self.device)
else:
self.model = torch.hub.load(repo_or_dir=repo_name, model=model_name).to(self.device)

# self.model = AutoModel.from_pretrained('/data/lsy/workspace/hf_ckp/models--facebook--dinov2-base').to(device)

self.model.eval()

self.transform = transforms.Compose([
# transforms.Resize(size=smaller_edge_size, interpolation=transforms.InterpolationMode.BICUBIC, antialias=True),
transforms.ToTensor(),
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), # imagenet defaults
])
self.gem_loss = GeM()

# https://github.com/facebookresearch/dinov2/blob/255861375864acdd830f99fdae3d9db65623dafe/notebooks/features.ipynb
def prepare_image(self, rgb_image_numpy):
image = Image.fromarray(rgb_image_numpy)
image_tensor = self.transform(image)
resize_scale = image.width / image_tensor.shape[2]

# Crop image to dimensions that are a multiple of the patch size
height, width = image_tensor.shape[1:] # C x H x W
cropped_width, cropped_height = width - width % self.model.patch_size, height - height % self.model.patch_size # crop a bit from right and bottom parts
image_tensor = image_tensor[:, :cropped_height, :cropped_width]

grid_size = (cropped_height // self.model.patch_size, cropped_width // self.model.patch_size)
return image_tensor, grid_size, resize_scale

def prepare_mask(self, mask_image_numpy, grid_size, resize_scale):
cropped_mask_image_numpy = mask_image_numpy[:int(grid_size[0]*self.model.patch_size*resize_scale), :int(grid_size[1]*self.model.patch_size*resize_scale)]
image = Image.fromarray(cropped_mask_image_numpy)
resized_mask = image.resize((grid_size[1], grid_size[0]), resample=Image.Resampling.NEAREST)
resized_mask = np.asarray(resized_mask).flatten()
return resized_mask

def extract_global_features(self, image_tensor):
with torch.inference_mode():
if self.half_precision:
image_batch = image_tensor.unsqueeze(0).half().to(self.device)
else:
image_batch = image_tensor.unsqueeze(0).to(self.device)

tokens = self.model.get_intermediate_layers(image_batch)[0].mean(dim=1).detach().cpu()

return tokens.numpy()

def extract_local_features(self, image_tensor):
with torch.inference_mode():
if self.half_precision:
image_batch = image_tensor.unsqueeze(0).half().to(self.device)
else:
image_batch = image_tensor.unsqueeze(0).to(self.device)

tokens = self.model.get_intermediate_layers(image_batch)[0].squeeze()
return tokens.cpu().numpy()


def idx_to_source_position(self, idx, grid_size, resize_scale):
row = (idx // grid_size[1])*self.model.patch_size*resize_scale + self.model.patch_size / 2
col = (idx % grid_size[1])*self.model.patch_size*resize_scale + self.model.patch_size / 2
return row, col

39 changes: 39 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Polyp-Gen

# 🛠Setup

```bash
git clone https://github.com/Saint-lsy/Polyp-Gen.git
cd Polyp-Gen
conda create -n PolypGen python=3.10
conda activate PolypGen
pip install -r requirements.txt
```

## Data Preparation
This model was trained by [LDPolypVideo](https://github.com/dashishi/LDPolypVideo-Benchmark) dataset.

We filtered out some low-quality images with blurry, reflective, and ghosting effects, and finally select 55,883 samples including 29,640 polyp frames and 26,243 non-polyp frames.
## Checkpoint
You can download the chekpoints of our Polyp_Gen on [HuggingFace](https://huggingface.co/Saint-lsy/Polyp-Gen-sd2-inpainting/tree/main)

## Sampling with Specified Mask
```
python sample_one_image.py
```

## Sampling with Mask Proposer
The first step is building database and Global Retrieval.
```bash
python GlobalRetrieval.py --data_path /path/of/non-polyp/images --database_path /path/to/build/database --image_path /path/of/query/image/
```
The second step is Local Matching for query image.
```bash
python LocalMatching.py --ref_image /path/ref/image --ref_mask /path/ref/mask --query_image /path/query/image --mask_proposal /path/to/save/mask
```
One Demo of LocalMatching
```bash
python LocalMatching.py --ref_image demos/img_1513_neg.jpg --ref_mask demos/mask_1513.jpg --query_image demos/img_1592_neg.jpg --mask_proposal gen_mask.jpg
```

The third step is using the generated Mask to sample.
Empty file.
1 change: 1 addition & 0 deletions data/data.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Download Modified LDPolypVideo
Binary file added demos/demo_img.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demos/demo_mask.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demos/img_1513.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demos/img_1513_neg.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demos/img_1592.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demos/img_1592_neg.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demos/mask_1513.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit e2177b5

Please sign in to comment.