-
Notifications
You must be signed in to change notification settings - Fork 0
/
keypoint_proposal.py
178 lines (167 loc) · 9.32 KB
/
keypoint_proposal.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import numpy as np
import torch
import cv2
from torch.nn.functional import interpolate
from kmeans_pytorch import kmeans
from utils import filter_points_by_bounds
from sklearn.cluster import MeanShift
class KeypointProposer:
def __init__(self, config):
self.config = config
self.device = torch.device(self.config['device'])
self.dinov2 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14').eval().to(self.device)
# local_model_path = '/omnigibson-src/ReKep/dinov2_vits14_pretrain.pth'
# checkpoint = torch.load(local_model_path)
# self.dinov2 = checkpoint
self.bounds_min = np.array(self.config['bounds_min'])
self.bounds_max = np.array(self.config['bounds_max'])
self.mean_shift = MeanShift(bandwidth=self.config['min_dist_bt_keypoints'], bin_seeding=True, n_jobs=32)
self.patch_size = 14 # dinov2
np.random.seed(self.config['seed'])
torch.manual_seed(self.config['seed'])
torch.cuda.manual_seed(self.config['seed'])
def get_keypoints(self, rgb, points, masks):
# preprocessing
# breakpoint()
transformed_rgb, rgb, points, masks, shape_info = self._preprocess(rgb, points, masks)
# get features
features_flat = self._get_features(transformed_rgb, shape_info)
# for each mask, cluster in feature space to get meaningful regions, and use their centers as keypoint candidates
candidate_keypoints, candidate_pixels, candidate_rigid_group_ids = self._cluster_features(points, features_flat, masks)
# exclude keypoints that are outside of the workspace
within_space = filter_points_by_bounds(candidate_keypoints, self.bounds_min, self.bounds_max, strict=True)
candidate_keypoints = candidate_keypoints[within_space]
candidate_pixels = candidate_pixels[within_space]
candidate_rigid_group_ids = candidate_rigid_group_ids[within_space]
# merge close points by clustering in cartesian space
merged_indices = self._merge_clusters(candidate_keypoints)
candidate_keypoints = candidate_keypoints[merged_indices]
candidate_pixels = candidate_pixels[merged_indices]
candidate_rigid_group_ids = candidate_rigid_group_ids[merged_indices]
# sort candidates by locations
sort_idx = np.lexsort((candidate_pixels[:, 0], candidate_pixels[:, 1]))
candidate_keypoints = candidate_keypoints[sort_idx]
candidate_pixels = candidate_pixels[sort_idx]
candidate_rigid_group_ids = candidate_rigid_group_ids[sort_idx]
# project keypoints to image space
projected = self._project_keypoints_to_img(rgb, candidate_pixels, candidate_rigid_group_ids, masks, features_flat)
return candidate_keypoints, projected
def _preprocess(self, rgb, points, masks):
if masks.is_cuda:
masks = masks.cpu()
# print("***masks", masks)
rgb = rgb.cpu() # move to CPU if on GPU
rgb = rgb.numpy()
# print("***rgb", rgb)
# convert masks to binary masks
masks = [masks == uid for uid in np.unique(masks.numpy())]
# print("***masks2", masks)
# ensure input shape is compatible with dinov2
H, W, _ = rgb.shape
patch_h = int(H // self.patch_size)
patch_w = int(W // self.patch_size)
new_H = patch_h * self.patch_size
new_W = patch_w * self.patch_size
# print("***rgb2", rgb)
transformed_rgb = cv2.resize(rgb, (new_W, new_H))
transformed_rgb = transformed_rgb.astype(np.float32) / 255.0 # float32 [H, W, 3]
# shape info
shape_info = {
'img_h': H,
'img_w': W,
'patch_h': patch_h,
'patch_w': patch_w,
}
return transformed_rgb, rgb, points, masks, shape_info
def _project_keypoints_to_img(self, rgb, candidate_pixels, candidate_rigid_group_ids, masks, features_flat):
projected = rgb.copy()
# overlay keypoints on the image
for keypoint_count, pixel in enumerate(candidate_pixels):
displayed_text = f"{keypoint_count}"
text_length = len(displayed_text)
# draw a box
box_width = 30 + 10 * (text_length - 1)
box_height = 30
cv2.rectangle(projected, (pixel[1] - box_width // 2, pixel[0] - box_height // 2), (pixel[1] + box_width // 2, pixel[0] + box_height // 2), (255, 255, 255), -1)
cv2.rectangle(projected, (pixel[1] - box_width // 2, pixel[0] - box_height // 2), (pixel[1] + box_width // 2, pixel[0] + box_height // 2), (0, 0, 0), 2)
# draw text
org = (pixel[1] - 7 * (text_length), pixel[0] + 7)
color = (255, 0, 0)
cv2.putText(projected, str(keypoint_count), org, cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)
keypoint_count += 1
return projected
@torch.inference_mode()
@torch.amp.autocast('cuda')
def _get_features(self, transformed_rgb, shape_info):
img_h = shape_info['img_h']
img_w = shape_info['img_w']
patch_h = shape_info['patch_h']
patch_w = shape_info['patch_w']
# get features
img_tensors = torch.from_numpy(transformed_rgb).permute(2, 0, 1).unsqueeze(0).to(self.device) # float32 [1, 3, H, W]
assert img_tensors.shape[1] == 3, "unexpected image shape"
# breakpoint()
features_dict = self.dinov2.forward_features(img_tensors)
raw_feature_grid = features_dict['x_norm_patchtokens'] # float32 [num_cams, patch_h*patch_w, feature_dim]
raw_feature_grid = raw_feature_grid.reshape(1, patch_h, patch_w, -1) # float32 [num_cams, patch_h, patch_w, feature_dim]
# compute per-point feature using bilinear interpolation
interpolated_feature_grid = interpolate(raw_feature_grid.permute(0, 3, 1, 2), # float32 [num_cams, feature_dim, patch_h, patch_w]
size=(img_h, img_w),
mode='bilinear').permute(0, 2, 3, 1).squeeze(0) # float32 [H, W, feature_dim]
features_flat = interpolated_feature_grid.reshape(-1, interpolated_feature_grid.shape[-1]) # float32 [H*W, feature_dim]
return features_flat
def _cluster_features(self, points, features_flat, masks):
candidate_keypoints = []
candidate_pixels = []
candidate_rigid_group_ids = []
for rigid_group_id, binary_mask in enumerate(masks):
# ignore mask that is too large
# print("***binary_mask", binary_mask)
binary_mask = binary_mask.cpu().numpy()
if np.mean(binary_mask) > self.config['max_mask_ratio']:
continue
# consider only foreground features
obj_features_flat = features_flat[binary_mask.reshape(-1)]
feature_pixels = np.argwhere(binary_mask)
feature_points = points[binary_mask]
# reduce dimensionality to be less sensitive to noise and texture
obj_features_flat = obj_features_flat.double()
(u, s, v) = torch.pca_lowrank(obj_features_flat, center=False)
features_pca = torch.mm(obj_features_flat, v[:, :3])
features_pca = (features_pca - features_pca.min(0)[0]) / (features_pca.max(0)[0] - features_pca.min(0)[0])
X = features_pca
# add feature_pixels as extra dimensions
feature_points_torch = torch.tensor(feature_points, dtype=features_pca.dtype, device=features_pca.device)
feature_points_torch = (feature_points_torch - feature_points_torch.min(0)[0]) / (feature_points_torch.max(0)[0] - feature_points_torch.min(0)[0])
X = torch.cat([X, feature_points_torch], dim=-1)
# cluster features to get meaningful regions
cluster_ids_x, cluster_centers = kmeans(
X=X,
num_clusters=self.config['num_candidates_per_mask'],
distance='euclidean',
device=self.device,
)
cluster_centers = cluster_centers.to(self.device)
for cluster_id in range(self.config['num_candidates_per_mask']):
cluster_center = cluster_centers[cluster_id][:3]
member_idx = cluster_ids_x == cluster_id
member_points = feature_points[member_idx]
member_pixels = feature_pixels[member_idx]
member_features = features_pca[member_idx]
dist = torch.norm(member_features - cluster_center, dim=-1)
closest_idx = torch.argmin(dist)
candidate_keypoints.append(member_points[closest_idx])
candidate_pixels.append(member_pixels[closest_idx])
candidate_rigid_group_ids.append(rigid_group_id)
candidate_keypoints = np.array(candidate_keypoints)
candidate_pixels = np.array(candidate_pixels)
candidate_rigid_group_ids = np.array(candidate_rigid_group_ids)
return candidate_keypoints, candidate_pixels, candidate_rigid_group_ids
def _merge_clusters(self, candidate_keypoints):
self.mean_shift.fit(candidate_keypoints)
cluster_centers = self.mean_shift.cluster_centers_
merged_indices = []
for center in cluster_centers:
dist = np.linalg.norm(candidate_keypoints - center, axis=-1)
merged_indices.append(np.argmin(dist))
return merged_indices