-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit e2177b5
Showing
21 changed files
with
2,226 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
__pycache__/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import torch | ||
import faiss | ||
import numpy as np | ||
import time | ||
import cv2 | ||
import argparse | ||
import os | ||
from Matcher import Dinov2Matcher | ||
|
||
# #Define a function that normalizes embeddings and add them to the index | ||
def add_vector_to_index(embedding, index): | ||
#Convert to float32 numpy | ||
vector = np.float32(embedding) | ||
#Normalize vector: important to avoid wrong results when searching | ||
faiss.normalize_L2(vector) | ||
#Add to index | ||
index.add(vector) | ||
|
||
|
||
#Create Faiss index using FlatL2 type with 768 dimensions (DINOv2-base) as this | ||
|
||
def build_database(base_path, database_save_path): | ||
|
||
index = faiss.IndexFlatL2(768) | ||
# Extract features | ||
t0 = time.time() | ||
image_paths = [] | ||
for root, dirs, files in os.walk(base_path): | ||
for file in files: | ||
# 检查文件是否是图像类型 | ||
if file.lower().endswith('.jpg'): | ||
file_path = os.path.join(root, file) | ||
image_paths.append(file_path) | ||
|
||
for image_path in image_paths: | ||
image = cv2.cvtColor(cv2.imread(image_path, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) | ||
with torch.no_grad(): | ||
image_tensor, _, _ = dm.prepare_image(image) | ||
global_features = dm.extract_global_features(image_tensor) | ||
add_vector_to_index(global_features, index) | ||
|
||
# Store the index locally | ||
print('Extraction done in :', time.time()-t0) | ||
faiss.write_index(index, database_save_path) | ||
|
||
|
||
######### Retrieve the image to search | ||
def retrieval_from_database(image_path, database_path): | ||
#Extract the features | ||
with torch.no_grad(): | ||
image = cv2.cvtColor(cv2.imread(image_path, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) | ||
image_tensor, _, _ = dm.prepare_image(image) | ||
global_features = dm.extract_global_features(image_tensor) | ||
|
||
vector = np.float32(global_features) | ||
faiss.normalize_L2(vector) | ||
|
||
#Read the index file and perform search of top-3 images | ||
index = faiss.read_index(database_path) | ||
d,i = index.search(vector, 5) | ||
print('distances:', d, 'indexes:', i) | ||
|
||
if __name__=='__main__': | ||
dm = Dinov2Matcher(half_precision=False) | ||
|
||
parser = argparse.ArgumentParser(description='Global_trerieval') | ||
parser.add_argument('--database_path', type=str, required=True, help='Path to save the database') | ||
parser.add_argument('--data_path', type=str, required=True, help='Path of non-polyp reference images') | ||
parser.add_argument('--image_path', type=str, required=True, help='Path of Query image') | ||
args = parser.parse_args() | ||
|
||
build_database(args.data_path, args.database_path) | ||
retrieval_from_database(args.image_path, args.database_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
from PIL import Image | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
import cv2 | ||
from sklearn.neighbors import NearestNeighbors | ||
import argparse | ||
import matplotlib.pyplot as plt | ||
from matplotlib.patches import ConnectionPatch | ||
from sklearn.cluster import DBSCAN | ||
from Matcher import Dinov2Matcher | ||
|
||
patch_size = 14 | ||
|
||
|
||
def plot_matching_figure(image1, image2, xyA_list, xyB_list, save_path): | ||
fig = plt.figure(figsize=(10,10)) | ||
ax1 = fig.add_subplot(121) | ||
ax2 = fig.add_subplot(122) | ||
|
||
ax1.imshow(image1) | ||
ax2.imshow(image2) | ||
|
||
for xyA, xyB in zip(xyA_list, xyB_list): | ||
con = ConnectionPatch(xyA=xyB, xyB=xyA, coordsA="data", coordsB="data", | ||
axesA=ax2, axesB=ax1, color="red") | ||
ax2.add_artist(con) | ||
|
||
fig.tight_layout() | ||
fig.show() | ||
fig.savefig(save_path) | ||
|
||
|
||
|
||
def MaskProposer(origin_image, origin_mask, target_image, target_mask, matching_figure_save_path=None): | ||
# Init Dinov2Matcher | ||
dm = Dinov2Matcher(half_precision=False) | ||
# Extract image1 features | ||
image1 = cv2.cvtColor(cv2.imread(origin_image, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) | ||
mask1 = cv2.imread(origin_mask, cv2.IMREAD_COLOR)[:,:,0] > 127 | ||
image_tensor1, grid_size1, resize_scale1 = dm.prepare_image(image1) | ||
features1 = dm.extract_local_features(image_tensor1) | ||
print(features1.shape) | ||
# Extract image2 features | ||
image2 = cv2.cvtColor(cv2.imread(target_image, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) | ||
image_tensor2, grid_size2, resize_scale2 = dm.prepare_image(image2) | ||
features2 = dm.extract_local_features(image_tensor2) | ||
|
||
# Build knn using features from image1, and query all features from image2 | ||
knn = NearestNeighbors(n_neighbors=1) | ||
knn.fit(features1) | ||
distances, match2to1 = knn.kneighbors(features2) | ||
match2to1 = np.array(match2to1) | ||
|
||
xyA_list = [] | ||
xyB_list = [] | ||
distances_list = [] | ||
|
||
for idx2, (dist, idx1) in enumerate(zip(distances, match2to1)): | ||
row, col = dm.idx_to_source_position(idx1, grid_size1, resize_scale1) | ||
xyA = (col, row) | ||
if not mask1[int(row), int(col)]: continue # skip if feature is not on the object | ||
row, col = dm.idx_to_source_position(idx2, grid_size2, resize_scale2) | ||
xyB = (col, row) | ||
xyB_list.append(xyB) | ||
xyA_list.append(xyA) | ||
distances_list.append(dist[0]) | ||
|
||
#Filter by distance | ||
if len(xyA_list) > 30: | ||
zip_list = list(zip(distances_list, xyA_list, xyB_list)) | ||
zip_list.sort(key=lambda x: x[0]) | ||
distances_list, xyA_list, xyB_list = zip(*zip_list) | ||
xyA_list = xyA_list[:30] | ||
xyB_list = xyB_list[:30] | ||
|
||
|
||
if matching_figure_save_path is not None: | ||
plot_matching_figure(image1, image2, xyA_list, xyB_list, matching_figure_save_path) | ||
|
||
# DBSCAN clustering | ||
X = np.array(xyB_list) | ||
clustering = DBSCAN(eps=2*patch_size+1 , min_samples=1).fit(X) | ||
labels = clustering.labels_ | ||
|
||
# find the cluster with the most number of points | ||
unique_labels, counts = np.unique(labels, return_counts=True) | ||
max_label = unique_labels[np.argmax(counts)] | ||
new_list = [xyB for i, xyB in enumerate(xyB_list) if labels[i] == max_label] | ||
|
||
#find the min-col and max-col of the cluster | ||
min_col = np.min([xy[0] for xy in new_list]) - patch_size//2 | ||
max_col = np.max([xy[0] for xy in new_list]) + patch_size//2 | ||
#find the min-row and max-row of the cluster | ||
min_row = np.min([xy[1] for xy in new_list]) - patch_size//2 | ||
max_row = np.max([xy[1] for xy in new_list]) + patch_size//2 | ||
|
||
mask = np.zeros((image2.shape[0], image2.shape[1])) | ||
mask[int(min_row):int(max_row), int(min_col):int(max_col)] = 255 | ||
mask = mask.astype(np.uint8) | ||
mask = Image.fromarray(mask).convert('L') | ||
mask.save(target_mask) | ||
return mask | ||
|
||
|
||
if __name__ == "__main__": | ||
|
||
parser = argparse.ArgumentParser(description='LocalMatching') | ||
parser.add_argument('--ref_image', type=str, required=True, help='Path of Reference image') | ||
parser.add_argument('--ref_mask', type=str, required=True, help='Path of Reference mask') | ||
parser.add_argument('--query_image', type=str, required=True, help='Path of Query image') | ||
parser.add_argument('--mask_proposal', type=str, required=True, help='Save Path of Mask proposal') | ||
parser.add_argument('--save_fig', type=str, default=None, help='Save the Matching Figure') | ||
|
||
args = parser.parse_args() | ||
|
||
mask = MaskProposer(origin_image=args.ref_image, | ||
origin_mask=args.ref_mask, | ||
target_image=args.query_image, | ||
target_mask=args.mask_proposal, | ||
matching_figure_save_path=args.save_fig | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import torch | ||
from transformers import AutoImageProcessor, AutoModel | ||
from PIL import Image | ||
import numpy as np | ||
import torchvision.transforms as transforms | ||
import torch.nn as nn | ||
|
||
|
||
class GeM(nn.Module): | ||
def __init__(self, p=3, eps=1e-6): | ||
super(GeM, self).__init__() | ||
self.p = nn.Parameter(torch.ones(1) * p) | ||
self.eps = eps | ||
|
||
def forward(self, x): | ||
# 假设x的形状是 (batch_size, sequence_length, feature_dim) | ||
# 对每个样本的所有序列特征进行GeM pooling | ||
x = x.clamp(min=self.eps).pow(self.p) | ||
x = torch.mean(x, dim=1) # 沿着sequence_length维度取平均 | ||
x = x.pow(1. / self.p) | ||
return x | ||
|
||
def __repr__(self): | ||
return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')' | ||
|
||
|
||
class Dinov2Matcher: | ||
|
||
def __init__(self, repo_name="facebookresearch/dinov2", model_name="dinov2_vitb14", smaller_edge_size=512, half_precision=False, device="cuda"): | ||
self.repo_name = repo_name | ||
self.model_name = model_name | ||
# self.smaller_edge_size = smaller_edge_size | ||
self.half_precision = half_precision | ||
self.device = device | ||
|
||
if self.half_precision: | ||
self.model = torch.hub.load(repo_or_dir=repo_name, model=model_name).half().to(self.device) | ||
else: | ||
self.model = torch.hub.load(repo_or_dir=repo_name, model=model_name).to(self.device) | ||
|
||
# self.model = AutoModel.from_pretrained('/data/lsy/workspace/hf_ckp/models--facebook--dinov2-base').to(device) | ||
|
||
self.model.eval() | ||
|
||
self.transform = transforms.Compose([ | ||
# transforms.Resize(size=smaller_edge_size, interpolation=transforms.InterpolationMode.BICUBIC, antialias=True), | ||
transforms.ToTensor(), | ||
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), # imagenet defaults | ||
]) | ||
self.gem_loss = GeM() | ||
|
||
# https://github.com/facebookresearch/dinov2/blob/255861375864acdd830f99fdae3d9db65623dafe/notebooks/features.ipynb | ||
def prepare_image(self, rgb_image_numpy): | ||
image = Image.fromarray(rgb_image_numpy) | ||
image_tensor = self.transform(image) | ||
resize_scale = image.width / image_tensor.shape[2] | ||
|
||
# Crop image to dimensions that are a multiple of the patch size | ||
height, width = image_tensor.shape[1:] # C x H x W | ||
cropped_width, cropped_height = width - width % self.model.patch_size, height - height % self.model.patch_size # crop a bit from right and bottom parts | ||
image_tensor = image_tensor[:, :cropped_height, :cropped_width] | ||
|
||
grid_size = (cropped_height // self.model.patch_size, cropped_width // self.model.patch_size) | ||
return image_tensor, grid_size, resize_scale | ||
|
||
def prepare_mask(self, mask_image_numpy, grid_size, resize_scale): | ||
cropped_mask_image_numpy = mask_image_numpy[:int(grid_size[0]*self.model.patch_size*resize_scale), :int(grid_size[1]*self.model.patch_size*resize_scale)] | ||
image = Image.fromarray(cropped_mask_image_numpy) | ||
resized_mask = image.resize((grid_size[1], grid_size[0]), resample=Image.Resampling.NEAREST) | ||
resized_mask = np.asarray(resized_mask).flatten() | ||
return resized_mask | ||
|
||
def extract_global_features(self, image_tensor): | ||
with torch.inference_mode(): | ||
if self.half_precision: | ||
image_batch = image_tensor.unsqueeze(0).half().to(self.device) | ||
else: | ||
image_batch = image_tensor.unsqueeze(0).to(self.device) | ||
|
||
tokens = self.model.get_intermediate_layers(image_batch)[0].mean(dim=1).detach().cpu() | ||
|
||
return tokens.numpy() | ||
|
||
def extract_local_features(self, image_tensor): | ||
with torch.inference_mode(): | ||
if self.half_precision: | ||
image_batch = image_tensor.unsqueeze(0).half().to(self.device) | ||
else: | ||
image_batch = image_tensor.unsqueeze(0).to(self.device) | ||
|
||
tokens = self.model.get_intermediate_layers(image_batch)[0].squeeze() | ||
return tokens.cpu().numpy() | ||
|
||
|
||
def idx_to_source_position(self, idx, grid_size, resize_scale): | ||
row = (idx // grid_size[1])*self.model.patch_size*resize_scale + self.model.patch_size / 2 | ||
col = (idx % grid_size[1])*self.model.patch_size*resize_scale + self.model.patch_size / 2 | ||
return row, col | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# Polyp-Gen | ||
|
||
# 🛠Setup | ||
|
||
```bash | ||
git clone https://github.com/Saint-lsy/Polyp-Gen.git | ||
cd Polyp-Gen | ||
conda create -n PolypGen python=3.10 | ||
conda activate PolypGen | ||
pip install -r requirements.txt | ||
``` | ||
|
||
## Data Preparation | ||
This model was trained by [LDPolypVideo](https://github.com/dashishi/LDPolypVideo-Benchmark) dataset. | ||
|
||
We filtered out some low-quality images with blurry, reflective, and ghosting effects, and finally select 55,883 samples including 29,640 polyp frames and 26,243 non-polyp frames. | ||
## Checkpoint | ||
You can download the chekpoints of our Polyp_Gen on [HuggingFace](https://huggingface.co/Saint-lsy/Polyp-Gen-sd2-inpainting/tree/main) | ||
|
||
## Sampling with Specified Mask | ||
``` | ||
python sample_one_image.py | ||
``` | ||
|
||
## Sampling with Mask Proposer | ||
The first step is building database and Global Retrieval. | ||
```bash | ||
python GlobalRetrieval.py --data_path /path/of/non-polyp/images --database_path /path/to/build/database --image_path /path/of/query/image/ | ||
``` | ||
The second step is Local Matching for query image. | ||
```bash | ||
python LocalMatching.py --ref_image /path/ref/image --ref_mask /path/ref/mask --query_image /path/query/image --mask_proposal /path/to/save/mask | ||
``` | ||
One Demo of LocalMatching | ||
```bash | ||
python LocalMatching.py --ref_image demos/img_1513_neg.jpg --ref_mask demos/mask_1513.jpg --query_image demos/img_1592_neg.jpg --mask_proposal gen_mask.jpg | ||
``` | ||
|
||
The third step is using the generated Mask to sample. |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# Download Modified LDPolypVideo |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Oops, something went wrong.