Skip to content

Commit 6389cac

Browse files
authored
Gaussian splatting support for Aria (#2785)
* Gaussian splatting support for Aria * Respect masks in splatting loss function
1 parent c27b5cb commit 6389cac

File tree

5 files changed

+216
-126
lines changed

5 files changed

+216
-126
lines changed

nerfstudio/cameras/camera_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -720,7 +720,7 @@ def fisheye624_unproject_helper(uv, params, max_iters: int = 5):
720720
function so this solves an optimization problem using Newton's method to get
721721
the inverse.
722722
Inputs:
723-
uv: BxNx3 tensor of 2D pixels to be projected
723+
uv: BxNx2 tensor of 2D pixels to be unprojected
724724
params: Bx16 tensor of Fisheye624 parameters formatted like this:
725725
[f_u f_v c_u c_v {k_0 ... k_5} {p_0 p_1} {s_0 s_1 s_2 s_3}]
726726
or Bx15 tensor of Fisheye624 parameters formatted like this:

nerfstudio/cameras/cameras.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -864,7 +864,7 @@ def _compute_rays_for_vr180(
864864

865865
assert distortion_params is not None
866866
masked_coords = pcoord_stack[coord_mask, :]
867-
# The fisheye unprojection does not rely on planar/pinhold unprojection, thus the method needs
867+
# The fisheye unprojection does not rely on planar/pinhole unprojection, thus the method needs
868868
# to access the focal length and principle points directly.
869869
camera_params = torch.cat(
870870
[

nerfstudio/data/datamanagers/full_images_datamanager.py

Lines changed: 166 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from torch.nn import Parameter
3535
from tqdm import tqdm
3636

37+
from nerfstudio.cameras.camera_utils import fisheye624_project, fisheye624_unproject_helper
3738
from nerfstudio.cameras.cameras import Cameras, CameraType
3839
from nerfstudio.configs.dataparser_configs import AnnotatedDataParserUnion
3940
from nerfstudio.data.datamanagers.base_datamanager import DataManager, DataManagerConfig, TDataset
@@ -135,70 +136,20 @@ def cache_images(self, cache_images_option):
135136
continue
136137
distortion_params = camera.distortion_params.numpy()
137138
image = data["image"].numpy()
138-
if camera.camera_type.item() == CameraType.PERSPECTIVE.value:
139-
distortion_params = np.array(
140-
[
141-
distortion_params[0],
142-
distortion_params[1],
143-
distortion_params[4],
144-
distortion_params[5],
145-
distortion_params[2],
146-
distortion_params[3],
147-
0,
148-
0,
149-
]
150-
)
151-
if np.any(distortion_params):
152-
newK, roi = cv2.getOptimalNewCameraMatrix(K, distortion_params, (image.shape[1], image.shape[0]), 0)
153-
image = cv2.undistort(image, K, distortion_params, None, newK) # type: ignore
154-
else:
155-
newK = K
156-
roi = 0, 0, image.shape[1], image.shape[0]
157-
# crop the image and update the intrinsics accordingly
158-
x, y, w, h = roi
159-
image = image[y : y + h, x : x + w]
160-
if "depth_image" in data:
161-
data["depth_image"] = data["depth_image"][y : y + h, x : x + w]
162-
# update the width, height
163-
self.train_dataset.cameras.width[i] = w
164-
self.train_dataset.cameras.height[i] = h
165-
if "mask" in data:
166-
mask = data["mask"].numpy()
167-
mask = mask.astype(np.uint8) * 255
168-
if np.any(distortion_params):
169-
mask = cv2.undistort(mask, K, distortion_params, None, newK) # type: ignore
170-
mask = mask[y : y + h, x : x + w]
171-
data["mask"] = torch.from_numpy(mask).bool()
172-
K = newK
173-
174-
elif camera.camera_type.item() == CameraType.FISHEYE.value:
175-
distortion_params = np.array(
176-
[distortion_params[0], distortion_params[1], distortion_params[2], distortion_params[3]]
177-
)
178-
newK = cv2.fisheye.estimateNewCameraMatrixForUndistortRectify(
179-
K, distortion_params, (image.shape[1], image.shape[0]), np.eye(3), balance=0
180-
)
181-
map1, map2 = cv2.fisheye.initUndistortRectifyMap(
182-
K, distortion_params, np.eye(3), newK, (image.shape[1], image.shape[0]), cv2.CV_32FC1
183-
)
184-
# and then remap:
185-
image = cv2.remap(image, map1, map2, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT)
186-
if "mask" in data:
187-
mask = data["mask"].numpy()
188-
mask = mask.astype(np.uint8) * 255
189-
mask = cv2.fisheye.undistortImage(mask, K, distortion_params, None, newK)
190-
data["mask"] = torch.from_numpy(mask).bool()
191-
K = newK
192-
else:
193-
raise NotImplementedError("Only perspective and fisheye cameras are supported")
139+
140+
K, image, mask = _undistort_image(camera, distortion_params, data, image, K)
194141
data["image"] = torch.from_numpy(image)
142+
if mask is not None:
143+
data["mask"] = mask
195144

196145
cached_train.append(data)
197146

198147
self.train_dataset.cameras.fx[i] = float(K[0, 0])
199148
self.train_dataset.cameras.fy[i] = float(K[1, 1])
200149
self.train_dataset.cameras.cx[i] = float(K[0, 2])
201150
self.train_dataset.cameras.cy[i] = float(K[1, 2])
151+
self.train_dataset.cameras.width[i] = image.shape[1]
152+
self.train_dataset.cameras.height[i] = image.shape[0]
202153

203154
CONSOLE.log("Caching / undistorting eval images")
204155
for i in tqdm(range(len(self.eval_dataset)), leave=False):
@@ -210,68 +161,20 @@ def cache_images(self, cache_images_option):
210161
continue
211162
distortion_params = camera.distortion_params.numpy()
212163
image = data["image"].numpy()
213-
if camera.camera_type.item() == CameraType.PERSPECTIVE.value:
214-
distortion_params = np.array(
215-
[
216-
distortion_params[0],
217-
distortion_params[1],
218-
distortion_params[4],
219-
distortion_params[5],
220-
distortion_params[2],
221-
distortion_params[3],
222-
0,
223-
0,
224-
]
225-
)
226-
if np.any(distortion_params):
227-
newK, roi = cv2.getOptimalNewCameraMatrix(K, distortion_params, (image.shape[1], image.shape[0]), 0)
228-
image = cv2.undistort(image, K, distortion_params, None, newK) # type: ignore
229-
else:
230-
newK = K
231-
roi = 0, 0, image.shape[1], image.shape[0]
232-
# crop the image and update the intrinsics accordingly
233-
x, y, w, h = roi
234-
image = image[y : y + h, x : x + w]
235-
# update the width, height
236-
self.eval_dataset.cameras.width[i] = w
237-
self.eval_dataset.cameras.height[i] = h
238-
if "mask" in data:
239-
mask = data["mask"].numpy()
240-
mask = mask.astype(np.uint8) * 255
241-
if np.any(distortion_params):
242-
mask = cv2.undistort(mask, K, distortion_params, None, newK) # type: ignore
243-
mask = mask[y : y + h, x : x + w]
244-
data["mask"] = torch.from_numpy(mask).bool()
245-
K = newK
246-
247-
elif camera.camera_type.item() == CameraType.FISHEYE.value:
248-
distortion_params = np.array(
249-
[distortion_params[0], distortion_params[1], distortion_params[2], distortion_params[3]]
250-
)
251-
newK = cv2.fisheye.estimateNewCameraMatrixForUndistortRectify(
252-
K, distortion_params, (image.shape[1], image.shape[0]), np.eye(3), balance=0
253-
)
254-
map1, map2 = cv2.fisheye.initUndistortRectifyMap(
255-
K, distortion_params, np.eye(3), newK, (image.shape[1], image.shape[0]), cv2.CV_32FC1
256-
)
257-
# and then remap:
258-
image = cv2.remap(image, map1, map2, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT)
259-
if "mask" in data:
260-
mask = data["mask"].numpy()
261-
mask = mask.astype(np.uint8) * 255
262-
mask = cv2.fisheye.undistortImage(mask, K, distortion_params, None, newK)
263-
data["mask"] = torch.from_numpy(mask).bool()
264-
K = newK
265-
else:
266-
raise NotImplementedError("Only perspective and fisheye cameras are supported")
164+
165+
K, image, mask = _undistort_image(camera, distortion_params, data, image, K)
267166
data["image"] = torch.from_numpy(image)
167+
if mask is not None:
168+
data["mask"] = mask
268169

269170
cached_eval.append(data)
270171

271172
self.eval_dataset.cameras.fx[i] = float(K[0, 0])
272173
self.eval_dataset.cameras.fy[i] = float(K[1, 1])
273174
self.eval_dataset.cameras.cx[i] = float(K[0, 2])
274175
self.eval_dataset.cameras.cy[i] = float(K[1, 2])
176+
self.eval_dataset.cameras.width[i] = image.shape[1]
177+
self.eval_dataset.cameras.height[i] = image.shape[0]
275178

276179
if cache_images_option == "gpu":
277180
for cache in cached_train:
@@ -416,3 +319,156 @@ def next_eval_image(self, step: int) -> Tuple[Cameras, Dict]:
416319
assert len(self.eval_dataset.cameras.shape) == 1, "Assumes single batch dimension"
417320
camera = self.eval_dataset.cameras[image_idx : image_idx + 1].to(self.device)
418321
return camera, data
322+
323+
324+
def _undistort_image(
325+
camera: Cameras, distortion_params: np.ndarray, data: dict, image: np.ndarray, K: np.ndarray
326+
) -> Tuple[np.ndarray, np.ndarray, Optional[torch.Tensor]]:
327+
mask = None
328+
if camera.camera_type.item() == CameraType.PERSPECTIVE.value:
329+
distortion_params = np.array(
330+
[
331+
distortion_params[0],
332+
distortion_params[1],
333+
distortion_params[4],
334+
distortion_params[5],
335+
distortion_params[2],
336+
distortion_params[3],
337+
0,
338+
0,
339+
]
340+
)
341+
if np.any(distortion_params):
342+
newK, roi = cv2.getOptimalNewCameraMatrix(K, distortion_params, (image.shape[1], image.shape[0]), 0)
343+
image = cv2.undistort(image, K, distortion_params, None, newK) # type: ignore
344+
else:
345+
newK = K
346+
roi = 0, 0, image.shape[1], image.shape[0]
347+
# crop the image and update the intrinsics accordingly
348+
x, y, w, h = roi
349+
image = image[y : y + h, x : x + w]
350+
if "depth_image" in data:
351+
data["depth_image"] = data["depth_image"][y : y + h, x : x + w]
352+
if "mask" in data:
353+
mask = data["mask"].numpy()
354+
mask = mask.astype(np.uint8) * 255
355+
if np.any(distortion_params):
356+
mask = cv2.undistort(mask, K, distortion_params, None, newK) # type: ignore
357+
mask = mask[y : y + h, x : x + w]
358+
mask = torch.from_numpy(mask).bool()
359+
K = newK
360+
361+
elif camera.camera_type.item() == CameraType.FISHEYE.value:
362+
distortion_params = np.array(
363+
[distortion_params[0], distortion_params[1], distortion_params[2], distortion_params[3]]
364+
)
365+
newK = cv2.fisheye.estimateNewCameraMatrixForUndistortRectify(
366+
K, distortion_params, (image.shape[1], image.shape[0]), np.eye(3), balance=0
367+
)
368+
map1, map2 = cv2.fisheye.initUndistortRectifyMap(
369+
K, distortion_params, np.eye(3), newK, (image.shape[1], image.shape[0]), cv2.CV_32FC1
370+
)
371+
# and then remap:
372+
image = cv2.remap(image, map1, map2, interpolation=cv2.INTER_LINEAR)
373+
if "mask" in data:
374+
mask = data["mask"].numpy()
375+
mask = mask.astype(np.uint8) * 255
376+
mask = cv2.fisheye.undistortImage(mask, K, distortion_params, None, newK)
377+
mask = torch.from_numpy(mask).bool()
378+
K = newK
379+
elif camera.camera_type.item() == CameraType.FISHEYE624.value:
380+
fisheye624_params = torch.cat(
381+
[camera.fx, camera.fy, camera.cx, camera.cy, torch.from_numpy(distortion_params)], dim=0
382+
)
383+
assert fisheye624_params.shape == (16,)
384+
assert (
385+
"mask" not in data
386+
and camera.metadata is not None
387+
and "fisheye_crop_radius" in camera.metadata
388+
and isinstance(camera.metadata["fisheye_crop_radius"], float)
389+
)
390+
fisheye_crop_radius = camera.metadata["fisheye_crop_radius"]
391+
392+
# Approximate the FOV of the unmasked region of the camera.
393+
upper, lower, left, right = fisheye624_unproject_helper(
394+
torch.tensor(
395+
[
396+
[camera.cx, camera.cy - fisheye_crop_radius],
397+
[camera.cx, camera.cy + fisheye_crop_radius],
398+
[camera.cx - fisheye_crop_radius, camera.cy],
399+
[camera.cx + fisheye_crop_radius, camera.cy],
400+
],
401+
dtype=torch.float32,
402+
)[None],
403+
params=fisheye624_params[None],
404+
).squeeze(dim=0)
405+
fov_radians = torch.max(
406+
torch.acos(torch.sum(upper * lower / torch.linalg.norm(upper) / torch.linalg.norm(lower))),
407+
torch.acos(torch.sum(left * right / torch.linalg.norm(left) / torch.linalg.norm(right))),
408+
)
409+
410+
# Heuristics to determine parameters of an undistorted image.
411+
undist_h = int(fisheye_crop_radius * 2)
412+
undist_w = int(fisheye_crop_radius * 2)
413+
undistort_focal = undist_h / (2 * torch.tan(fov_radians / 2.0))
414+
undist_K = torch.eye(3)
415+
undist_K[0, 0] = undistort_focal # fx
416+
undist_K[1, 1] = undistort_focal # fy
417+
undist_K[0, 2] = (undist_w - 1) / 2.0 # cx; for a 1x1 image, center should be at (0, 0).
418+
undist_K[1, 2] = (undist_h - 1) / 2.0 # cy
419+
420+
# Undistorted 2D coordinates -> rays -> reproject to distorted UV coordinates.
421+
undist_uv_homog = torch.stack(
422+
[
423+
*torch.meshgrid(
424+
torch.arange(undist_w, dtype=torch.float32),
425+
torch.arange(undist_h, dtype=torch.float32),
426+
),
427+
torch.ones((undist_w, undist_h), dtype=torch.float32),
428+
],
429+
dim=-1,
430+
)
431+
assert undist_uv_homog.shape == (undist_w, undist_h, 3)
432+
dist_uv = (
433+
fisheye624_project(
434+
xyz=(
435+
torch.einsum(
436+
"ij,bj->bi",
437+
torch.linalg.inv(undist_K),
438+
undist_uv_homog.reshape((undist_w * undist_h, 3)),
439+
)[None]
440+
),
441+
params=fisheye624_params[None, :],
442+
)
443+
.reshape((undist_w, undist_h, 2))
444+
.numpy()
445+
)
446+
map1 = dist_uv[..., 1]
447+
map2 = dist_uv[..., 0]
448+
449+
# Use correspondence to undistort image.
450+
image = cv2.remap(image, map1, map2, interpolation=cv2.INTER_LINEAR)
451+
452+
# Compute undistorted mask as well.
453+
dist_h = camera.height.item()
454+
dist_w = camera.width.item()
455+
mask = np.mgrid[:dist_h, :dist_w]
456+
mask[0, ...] -= dist_h // 2
457+
mask[1, ...] -= dist_w // 2
458+
mask = np.linalg.norm(mask, axis=0) < fisheye_crop_radius
459+
mask = torch.from_numpy(
460+
cv2.remap(
461+
mask.astype(np.uint8) * 255,
462+
map1,
463+
map2,
464+
interpolation=cv2.INTER_LINEAR,
465+
borderMode=cv2.BORDER_CONSTANT,
466+
borderValue=0,
467+
)
468+
/ 255.0
469+
).bool()
470+
K = undist_K.numpy()
471+
else:
472+
raise NotImplementedError("Only perspective and fisheye cameras are supported")
473+
474+
return K, image, mask

0 commit comments

Comments
 (0)