|
45 | 45 | from nerfstudio.data.datasets.base_dataset import InputDataset |
46 | 46 | from nerfstudio.data.utils.data_utils import identity_collate |
47 | 47 | from nerfstudio.data.utils.dataloaders import ImageBatchStream, _undistort_image |
| 48 | +from nerfstudio.data.utils.nerfstudio_collate import nerfstudio_collate |
48 | 49 | from nerfstudio.utils.misc import get_orig_class |
49 | 50 | from nerfstudio.utils.rich_utils import CONSOLE |
50 | 51 |
|
@@ -150,7 +151,7 @@ def __init__( |
150 | 151 | assert len(self.train_unseen_cameras) > 0, "No data found in dataset" |
151 | 152 | super().__init__() |
152 | 153 |
|
153 | | - def sample_train_cameras(self): |
| 154 | + def sample_train_cameras(self) -> List[int]: |
154 | 155 | """Return a list of camera indices sampled using the strategy specified by |
155 | 156 | self.config.train_cameras_sampling_strategy""" |
156 | 157 | num_train_cameras = len(self.train_dataset) |
@@ -326,7 +327,7 @@ def setup_train(self): |
326 | 327 | self.train_imagebatch_stream, |
327 | 328 | batch_size=self.config.batch_size, |
328 | 329 | num_workers=self.config.dataloader_num_workers, |
329 | | - collate_fn=identity_collate, |
| 330 | + collate_fn=nerfstudio_collate, |
330 | 331 | ) |
331 | 332 | self.iter_train_image_dataloader = iter(self.train_image_dataloader) |
332 | 333 |
|
@@ -382,53 +383,30 @@ def get_param_groups(self) -> Dict[str, List[Parameter]]: |
382 | 383 | def get_train_rays_per_batch(self) -> int: |
383 | 384 | """Returns resolution of the image returned from datamanager.""" |
384 | 385 | camera = self.train_dataset.cameras[0].reshape(()) |
385 | | - return int(camera.width[0].item() * camera.height[0].item()) |
| 386 | + return int(camera.width[0].item() * camera.height[0].item()) * self.config.batch_size |
386 | 387 |
|
387 | 388 | def next_train(self, step: int) -> Tuple[Cameras, Dict]: |
388 | 389 | """Returns the next training batch |
389 | 390 | Returns a Camera instead of raybundle""" |
390 | 391 |
|
391 | 392 | self.train_count += 1 |
392 | 393 | if self.config.cache_images == "disk": |
393 | | - output = next(self.iter_train_image_dataloader) |
394 | | - print("Alex", output) |
395 | | - camera, data = output[0] |
396 | | - return camera, data |
| 394 | + cameras, data = next(self.iter_train_image_dataloader) |
| 395 | + return cameras, data |
397 | 396 |
|
398 | | - image_indices = [] |
| 397 | + camera_indices = [] |
399 | 398 | for _ in range(self.config.batch_size): |
400 | 399 | # Make sure to re-populate the unseen cameras list if we have exhausted it |
401 | 400 | if len(self.train_unseen_cameras) == 0: |
402 | 401 | self.train_unseen_cameras = self.sample_train_cameras() |
403 | | - image_indices.append(self.train_unseen_cameras.pop(0)) |
404 | | - |
405 | | - all_keys = self.cached_train[0].keys() |
406 | | - |
407 | | - data = {} |
408 | | - for key in all_keys: |
409 | | - if key == "image": |
410 | | - data[key] = torch.stack([self.cached_train[i][key] for i in image_indices]).to(self.device) |
411 | | - else: |
412 | | - data[key] = [self.cached_train[i][key] for i in image_indices] |
413 | | - |
414 | | - cameras = Cameras( |
415 | | - camera_to_worlds=self.train_cameras.camera_to_worlds[image_indices], |
416 | | - fx=self.train_cameras.fx[image_indices], |
417 | | - fy=self.train_cameras.fy[image_indices], |
418 | | - cx=self.train_cameras.cx[image_indices], |
419 | | - cy=self.train_cameras.cy[image_indices], |
420 | | - width=self.train_cameras.width[image_indices], |
421 | | - height=self.train_cameras.height[image_indices], |
422 | | - camera_type=self.train_cameras.camera_type[image_indices], |
423 | | - ).to(self.device) |
424 | | - |
425 | | - if self.train_cameras.distortion_params is not None: |
426 | | - cameras.distortion_params = self.train_cameras.distortion_params[image_indices] |
427 | | - |
428 | | - if cameras.metadata is None: |
429 | | - cameras.metadata = {} |
430 | | - |
431 | | - cameras.metadata["cam_idx"] = image_indices |
| 402 | + camera_indices.append(self.train_unseen_cameras.pop(0)) |
| 403 | + |
| 404 | + # NOTE: We're going to copy the data to make sure we don't mutate the cached dictionary. |
| 405 | + # This can cause a memory leak: https://github.com/nerfstudio-project/nerfstudio/issues/3335 |
| 406 | + data = nerfstudio_collate( |
| 407 | + [self.cached_train[i].copy() for i in camera_indices] |
| 408 | + ) # Note that this must happen before indexing cameras, as it can modify the cameras in the dataset during undistortion |
| 409 | + cameras = nerfstudio_collate([self.train_dataset.cameras[i : i + 1].to(self.device) for i in camera_indices]) |
432 | 410 |
|
433 | 411 | return cameras, data |
434 | 412 |
|
|
0 commit comments