34
34
from torch .nn import Parameter
35
35
from tqdm import tqdm
36
36
37
+ from nerfstudio .cameras .camera_utils import fisheye624_project , fisheye624_unproject_helper
37
38
from nerfstudio .cameras .cameras import Cameras , CameraType
38
39
from nerfstudio .configs .dataparser_configs import AnnotatedDataParserUnion
39
40
from nerfstudio .data .datamanagers .base_datamanager import DataManager , DataManagerConfig , TDataset
@@ -135,70 +136,20 @@ def cache_images(self, cache_images_option):
135
136
continue
136
137
distortion_params = camera .distortion_params .numpy ()
137
138
image = data ["image" ].numpy ()
138
- if camera .camera_type .item () == CameraType .PERSPECTIVE .value :
139
- distortion_params = np .array (
140
- [
141
- distortion_params [0 ],
142
- distortion_params [1 ],
143
- distortion_params [4 ],
144
- distortion_params [5 ],
145
- distortion_params [2 ],
146
- distortion_params [3 ],
147
- 0 ,
148
- 0 ,
149
- ]
150
- )
151
- if np .any (distortion_params ):
152
- newK , roi = cv2 .getOptimalNewCameraMatrix (K , distortion_params , (image .shape [1 ], image .shape [0 ]), 0 )
153
- image = cv2 .undistort (image , K , distortion_params , None , newK ) # type: ignore
154
- else :
155
- newK = K
156
- roi = 0 , 0 , image .shape [1 ], image .shape [0 ]
157
- # crop the image and update the intrinsics accordingly
158
- x , y , w , h = roi
159
- image = image [y : y + h , x : x + w ]
160
- if "depth_image" in data :
161
- data ["depth_image" ] = data ["depth_image" ][y : y + h , x : x + w ]
162
- # update the width, height
163
- self .train_dataset .cameras .width [i ] = w
164
- self .train_dataset .cameras .height [i ] = h
165
- if "mask" in data :
166
- mask = data ["mask" ].numpy ()
167
- mask = mask .astype (np .uint8 ) * 255
168
- if np .any (distortion_params ):
169
- mask = cv2 .undistort (mask , K , distortion_params , None , newK ) # type: ignore
170
- mask = mask [y : y + h , x : x + w ]
171
- data ["mask" ] = torch .from_numpy (mask ).bool ()
172
- K = newK
173
-
174
- elif camera .camera_type .item () == CameraType .FISHEYE .value :
175
- distortion_params = np .array (
176
- [distortion_params [0 ], distortion_params [1 ], distortion_params [2 ], distortion_params [3 ]]
177
- )
178
- newK = cv2 .fisheye .estimateNewCameraMatrixForUndistortRectify (
179
- K , distortion_params , (image .shape [1 ], image .shape [0 ]), np .eye (3 ), balance = 0
180
- )
181
- map1 , map2 = cv2 .fisheye .initUndistortRectifyMap (
182
- K , distortion_params , np .eye (3 ), newK , (image .shape [1 ], image .shape [0 ]), cv2 .CV_32FC1
183
- )
184
- # and then remap:
185
- image = cv2 .remap (image , map1 , map2 , interpolation = cv2 .INTER_LINEAR , borderMode = cv2 .BORDER_CONSTANT )
186
- if "mask" in data :
187
- mask = data ["mask" ].numpy ()
188
- mask = mask .astype (np .uint8 ) * 255
189
- mask = cv2 .fisheye .undistortImage (mask , K , distortion_params , None , newK )
190
- data ["mask" ] = torch .from_numpy (mask ).bool ()
191
- K = newK
192
- else :
193
- raise NotImplementedError ("Only perspective and fisheye cameras are supported" )
139
+
140
+ K , image , mask = _undistort_image (camera , distortion_params , data , image , K )
194
141
data ["image" ] = torch .from_numpy (image )
142
+ if mask is not None :
143
+ data ["mask" ] = mask
195
144
196
145
cached_train .append (data )
197
146
198
147
self .train_dataset .cameras .fx [i ] = float (K [0 , 0 ])
199
148
self .train_dataset .cameras .fy [i ] = float (K [1 , 1 ])
200
149
self .train_dataset .cameras .cx [i ] = float (K [0 , 2 ])
201
150
self .train_dataset .cameras .cy [i ] = float (K [1 , 2 ])
151
+ self .train_dataset .cameras .width [i ] = image .shape [1 ]
152
+ self .train_dataset .cameras .height [i ] = image .shape [0 ]
202
153
203
154
CONSOLE .log ("Caching / undistorting eval images" )
204
155
for i in tqdm (range (len (self .eval_dataset )), leave = False ):
@@ -210,68 +161,20 @@ def cache_images(self, cache_images_option):
210
161
continue
211
162
distortion_params = camera .distortion_params .numpy ()
212
163
image = data ["image" ].numpy ()
213
- if camera .camera_type .item () == CameraType .PERSPECTIVE .value :
214
- distortion_params = np .array (
215
- [
216
- distortion_params [0 ],
217
- distortion_params [1 ],
218
- distortion_params [4 ],
219
- distortion_params [5 ],
220
- distortion_params [2 ],
221
- distortion_params [3 ],
222
- 0 ,
223
- 0 ,
224
- ]
225
- )
226
- if np .any (distortion_params ):
227
- newK , roi = cv2 .getOptimalNewCameraMatrix (K , distortion_params , (image .shape [1 ], image .shape [0 ]), 0 )
228
- image = cv2 .undistort (image , K , distortion_params , None , newK ) # type: ignore
229
- else :
230
- newK = K
231
- roi = 0 , 0 , image .shape [1 ], image .shape [0 ]
232
- # crop the image and update the intrinsics accordingly
233
- x , y , w , h = roi
234
- image = image [y : y + h , x : x + w ]
235
- # update the width, height
236
- self .eval_dataset .cameras .width [i ] = w
237
- self .eval_dataset .cameras .height [i ] = h
238
- if "mask" in data :
239
- mask = data ["mask" ].numpy ()
240
- mask = mask .astype (np .uint8 ) * 255
241
- if np .any (distortion_params ):
242
- mask = cv2 .undistort (mask , K , distortion_params , None , newK ) # type: ignore
243
- mask = mask [y : y + h , x : x + w ]
244
- data ["mask" ] = torch .from_numpy (mask ).bool ()
245
- K = newK
246
-
247
- elif camera .camera_type .item () == CameraType .FISHEYE .value :
248
- distortion_params = np .array (
249
- [distortion_params [0 ], distortion_params [1 ], distortion_params [2 ], distortion_params [3 ]]
250
- )
251
- newK = cv2 .fisheye .estimateNewCameraMatrixForUndistortRectify (
252
- K , distortion_params , (image .shape [1 ], image .shape [0 ]), np .eye (3 ), balance = 0
253
- )
254
- map1 , map2 = cv2 .fisheye .initUndistortRectifyMap (
255
- K , distortion_params , np .eye (3 ), newK , (image .shape [1 ], image .shape [0 ]), cv2 .CV_32FC1
256
- )
257
- # and then remap:
258
- image = cv2 .remap (image , map1 , map2 , interpolation = cv2 .INTER_LINEAR , borderMode = cv2 .BORDER_CONSTANT )
259
- if "mask" in data :
260
- mask = data ["mask" ].numpy ()
261
- mask = mask .astype (np .uint8 ) * 255
262
- mask = cv2 .fisheye .undistortImage (mask , K , distortion_params , None , newK )
263
- data ["mask" ] = torch .from_numpy (mask ).bool ()
264
- K = newK
265
- else :
266
- raise NotImplementedError ("Only perspective and fisheye cameras are supported" )
164
+
165
+ K , image , mask = _undistort_image (camera , distortion_params , data , image , K )
267
166
data ["image" ] = torch .from_numpy (image )
167
+ if mask is not None :
168
+ data ["mask" ] = mask
268
169
269
170
cached_eval .append (data )
270
171
271
172
self .eval_dataset .cameras .fx [i ] = float (K [0 , 0 ])
272
173
self .eval_dataset .cameras .fy [i ] = float (K [1 , 1 ])
273
174
self .eval_dataset .cameras .cx [i ] = float (K [0 , 2 ])
274
175
self .eval_dataset .cameras .cy [i ] = float (K [1 , 2 ])
176
+ self .eval_dataset .cameras .width [i ] = image .shape [1 ]
177
+ self .eval_dataset .cameras .height [i ] = image .shape [0 ]
275
178
276
179
if cache_images_option == "gpu" :
277
180
for cache in cached_train :
@@ -416,3 +319,156 @@ def next_eval_image(self, step: int) -> Tuple[Cameras, Dict]:
416
319
assert len (self .eval_dataset .cameras .shape ) == 1 , "Assumes single batch dimension"
417
320
camera = self .eval_dataset .cameras [image_idx : image_idx + 1 ].to (self .device )
418
321
return camera , data
322
+
323
+
324
+ def _undistort_image (
325
+ camera : Cameras , distortion_params : np .ndarray , data : dict , image : np .ndarray , K : np .ndarray
326
+ ) -> Tuple [np .ndarray , np .ndarray , Optional [torch .Tensor ]]:
327
+ mask = None
328
+ if camera .camera_type .item () == CameraType .PERSPECTIVE .value :
329
+ distortion_params = np .array (
330
+ [
331
+ distortion_params [0 ],
332
+ distortion_params [1 ],
333
+ distortion_params [4 ],
334
+ distortion_params [5 ],
335
+ distortion_params [2 ],
336
+ distortion_params [3 ],
337
+ 0 ,
338
+ 0 ,
339
+ ]
340
+ )
341
+ if np .any (distortion_params ):
342
+ newK , roi = cv2 .getOptimalNewCameraMatrix (K , distortion_params , (image .shape [1 ], image .shape [0 ]), 0 )
343
+ image = cv2 .undistort (image , K , distortion_params , None , newK ) # type: ignore
344
+ else :
345
+ newK = K
346
+ roi = 0 , 0 , image .shape [1 ], image .shape [0 ]
347
+ # crop the image and update the intrinsics accordingly
348
+ x , y , w , h = roi
349
+ image = image [y : y + h , x : x + w ]
350
+ if "depth_image" in data :
351
+ data ["depth_image" ] = data ["depth_image" ][y : y + h , x : x + w ]
352
+ if "mask" in data :
353
+ mask = data ["mask" ].numpy ()
354
+ mask = mask .astype (np .uint8 ) * 255
355
+ if np .any (distortion_params ):
356
+ mask = cv2 .undistort (mask , K , distortion_params , None , newK ) # type: ignore
357
+ mask = mask [y : y + h , x : x + w ]
358
+ mask = torch .from_numpy (mask ).bool ()
359
+ K = newK
360
+
361
+ elif camera .camera_type .item () == CameraType .FISHEYE .value :
362
+ distortion_params = np .array (
363
+ [distortion_params [0 ], distortion_params [1 ], distortion_params [2 ], distortion_params [3 ]]
364
+ )
365
+ newK = cv2 .fisheye .estimateNewCameraMatrixForUndistortRectify (
366
+ K , distortion_params , (image .shape [1 ], image .shape [0 ]), np .eye (3 ), balance = 0
367
+ )
368
+ map1 , map2 = cv2 .fisheye .initUndistortRectifyMap (
369
+ K , distortion_params , np .eye (3 ), newK , (image .shape [1 ], image .shape [0 ]), cv2 .CV_32FC1
370
+ )
371
+ # and then remap:
372
+ image = cv2 .remap (image , map1 , map2 , interpolation = cv2 .INTER_LINEAR )
373
+ if "mask" in data :
374
+ mask = data ["mask" ].numpy ()
375
+ mask = mask .astype (np .uint8 ) * 255
376
+ mask = cv2 .fisheye .undistortImage (mask , K , distortion_params , None , newK )
377
+ mask = torch .from_numpy (mask ).bool ()
378
+ K = newK
379
+ elif camera .camera_type .item () == CameraType .FISHEYE624 .value :
380
+ fisheye624_params = torch .cat (
381
+ [camera .fx , camera .fy , camera .cx , camera .cy , torch .from_numpy (distortion_params )], dim = 0
382
+ )
383
+ assert fisheye624_params .shape == (16 ,)
384
+ assert (
385
+ "mask" not in data
386
+ and camera .metadata is not None
387
+ and "fisheye_crop_radius" in camera .metadata
388
+ and isinstance (camera .metadata ["fisheye_crop_radius" ], float )
389
+ )
390
+ fisheye_crop_radius = camera .metadata ["fisheye_crop_radius" ]
391
+
392
+ # Approximate the FOV of the unmasked region of the camera.
393
+ upper , lower , left , right = fisheye624_unproject_helper (
394
+ torch .tensor (
395
+ [
396
+ [camera .cx , camera .cy - fisheye_crop_radius ],
397
+ [camera .cx , camera .cy + fisheye_crop_radius ],
398
+ [camera .cx - fisheye_crop_radius , camera .cy ],
399
+ [camera .cx + fisheye_crop_radius , camera .cy ],
400
+ ],
401
+ dtype = torch .float32 ,
402
+ )[None ],
403
+ params = fisheye624_params [None ],
404
+ ).squeeze (dim = 0 )
405
+ fov_radians = torch .max (
406
+ torch .acos (torch .sum (upper * lower / torch .linalg .norm (upper ) / torch .linalg .norm (lower ))),
407
+ torch .acos (torch .sum (left * right / torch .linalg .norm (left ) / torch .linalg .norm (right ))),
408
+ )
409
+
410
+ # Heuristics to determine parameters of an undistorted image.
411
+ undist_h = int (fisheye_crop_radius * 2 )
412
+ undist_w = int (fisheye_crop_radius * 2 )
413
+ undistort_focal = undist_h / (2 * torch .tan (fov_radians / 2.0 ))
414
+ undist_K = torch .eye (3 )
415
+ undist_K [0 , 0 ] = undistort_focal # fx
416
+ undist_K [1 , 1 ] = undistort_focal # fy
417
+ undist_K [0 , 2 ] = (undist_w - 1 ) / 2.0 # cx; for a 1x1 image, center should be at (0, 0).
418
+ undist_K [1 , 2 ] = (undist_h - 1 ) / 2.0 # cy
419
+
420
+ # Undistorted 2D coordinates -> rays -> reproject to distorted UV coordinates.
421
+ undist_uv_homog = torch .stack (
422
+ [
423
+ * torch .meshgrid (
424
+ torch .arange (undist_w , dtype = torch .float32 ),
425
+ torch .arange (undist_h , dtype = torch .float32 ),
426
+ ),
427
+ torch .ones ((undist_w , undist_h ), dtype = torch .float32 ),
428
+ ],
429
+ dim = - 1 ,
430
+ )
431
+ assert undist_uv_homog .shape == (undist_w , undist_h , 3 )
432
+ dist_uv = (
433
+ fisheye624_project (
434
+ xyz = (
435
+ torch .einsum (
436
+ "ij,bj->bi" ,
437
+ torch .linalg .inv (undist_K ),
438
+ undist_uv_homog .reshape ((undist_w * undist_h , 3 )),
439
+ )[None ]
440
+ ),
441
+ params = fisheye624_params [None , :],
442
+ )
443
+ .reshape ((undist_w , undist_h , 2 ))
444
+ .numpy ()
445
+ )
446
+ map1 = dist_uv [..., 1 ]
447
+ map2 = dist_uv [..., 0 ]
448
+
449
+ # Use correspondence to undistort image.
450
+ image = cv2 .remap (image , map1 , map2 , interpolation = cv2 .INTER_LINEAR )
451
+
452
+ # Compute undistorted mask as well.
453
+ dist_h = camera .height .item ()
454
+ dist_w = camera .width .item ()
455
+ mask = np .mgrid [:dist_h , :dist_w ]
456
+ mask [0 , ...] -= dist_h // 2
457
+ mask [1 , ...] -= dist_w // 2
458
+ mask = np .linalg .norm (mask , axis = 0 ) < fisheye_crop_radius
459
+ mask = torch .from_numpy (
460
+ cv2 .remap (
461
+ mask .astype (np .uint8 ) * 255 ,
462
+ map1 ,
463
+ map2 ,
464
+ interpolation = cv2 .INTER_LINEAR ,
465
+ borderMode = cv2 .BORDER_CONSTANT ,
466
+ borderValue = 0 ,
467
+ )
468
+ / 255.0
469
+ ).bool ()
470
+ K = undist_K .numpy ()
471
+ else :
472
+ raise NotImplementedError ("Only perspective and fisheye cameras are supported" )
473
+
474
+ return K , image , mask
0 commit comments