Skip to content

CUDA Out-of-Memory crash when inferencing more than 10 frames #20

@sipkode

Description

@sipkode

I can run the inference script on a single frame at a time, or on a small handful of frames, but when I try to run it on more than 10 frames (as in the Tiny examples) it makes it almost all the way through the process before crashing with an Out-of-Memory exception...

I am running on a single 4090 with 24GB of VRAM. Here's the output I get:

[2025-12-18 22:37:12,041][src.samplers.sliding_iterative_sampler][ERROR] - Failed on task 13/16: {'alt': 3, 'domain': 'spatial', 'domain_label': 'undistorted_Frame0013'}
Error executing job with overrides: ['exp=goprotest_tiny', 'data.scene_label=goprotest', 'data.data_dir=./data']
╭──────────────────────────────────────────────────────── Traceback (most recent call last) ─────────────────────────────────────────────────────────╮
│ F:\Users\sipkode\Dev\GitHub\Diffuman4D\inference.py:71 in <module>                                                                               │
│                                                                                                                                                    │
│   68                                                                                                                                               │
│   69                                                                                                                                               │
│   70 if __name__ == "__main__":                                                                                                                    │
│ ❱ 71 │   main()                                                                                                                                    │
│   72                                                                                                                                               │
│                                                                                                                                                    │
│ F:\Users\sipkode\miniconda3\envs\diffuman\Lib\site-packages\hydra\main.py:94 in decorated_main                                                   │
│                                                                                                                                                    │
│    91 │   │   │   │   else:                                                                                                                        │
│    92 │   │   │   │   │   # no return value from run_hydra() as it may sometime actually run the task_function                                     │
│    93 │   │   │   │   │   # multiple times (--multirun)                                                                                            │
│ ❱  94 │   │   │   │   │   _run_hydra(                                                                                                              │
│    95 │   │   │   │   │   │   args=args,                                                                                                           │
│    96 │   │   │   │   │   │   args_parser=args_parser,                                                                                             │
│    97 │   │   │   │   │   │   task_function=task_function,                                                                                         │
│                                                                                                                                                    │
│ F:\Users\sipkode\miniconda3\envs\diffuman\Lib\site-packages\hydra\_internal\utils.py:394 in _run_hydra                                           │
│                                                                                                                                                    │
│   391 │   │                                                                                                                                        │
│   392 │   │   if args.run or args.multirun:                                                                                                        │
│   393 │   │   │   run_mode = hydra.get_mode(config_name=config_name, overrides=overrides)                                                          │
│ ❱ 394 │   │   │   _run_app(                                                                                                                        │
│   395 │   │   │   │   run=args.run,                                                                                                                │
│   396 │   │   │   │   multirun=args.multirun,                                                                                                      │
│   397 │   │   │   │   mode=run_mode,                                                                                                               │
│                                                                                                                                                    │
│ F:\Users\sipkode\miniconda3\envs\diffuman\Lib\site-packages\hydra\_internal\utils.py:457 in _run_app                                             │
│                                                                                                                                                    │
│   454 │   │   │   overrides.extend(["hydra.mode=MULTIRUN"])                                                                                        │
│   455 │                                                                                                                                            │
│   456 │   if mode == RunMode.RUN:                                                                                                                  │
│ ❱ 457 │   │   run_and_report(                                                                                                                      │
│   458 │   │   │   lambda: hydra.run(                                                                                                               │
│   459 │   │   │   │   config_name=config_name,                                                                                                     │
│   460 │   │   │   │   task_function=task_function,                                                                                                 │
│                                                                                                                                                    │
│ F:\Users\sipkode\miniconda3\envs\diffuman\Lib\site-packages\hydra\_internal\utils.py:223 in run_and_report                                       │
│                                                                                                                                                    │
│   220 │   │   return func()                                                                                                                        │
│   221 │   except Exception as ex:                                                                                                                  │
│   222 │   │   if _is_env_set("HYDRA_FULL_ERROR") or is_under_debugger():                                                                           │
│ ❱ 223 │   │   │   raise ex                                                                                                                         │
│   224 │   │   else:                                                                                                                                │
│   225 │   │   │   try:                                                                                                                             │
│   226 │   │   │   │   if isinstance(ex, CompactHydraException):                                                                                    │
│                                                                                                                                                    │
│ F:\Users\sipkode\miniconda3\envs\diffuman\Lib\site-packages\hydra\_internal\utils.py:220 in run_and_report                                       │
│                                                                                                                                                    │
│   217                                                                                                                                              │
│   218 def run_and_report(func: Any) -> Any:                                                                                                        │
│   219 │   try:                                                                                                                                     │
│ ❱ 220 │   │   return func()                                                                                                                        │
│   221 │   except Exception as ex:                                                                                                                  │
│   222 │   │   if _is_env_set("HYDRA_FULL_ERROR") or is_under_debugger():                                                                           │
│   223 │   │   │   raise ex                                                                                                                         │
│                                                                                                                                                    │
│ F:\Users\sipkode\miniconda3\envs\diffuman\Lib\site-packages\hydra\_internal\utils.py:458 in <lambda>                                             │
│                                                                                                                                                    │
│   455 │                                                                                                                                            │
│   456 │   if mode == RunMode.RUN:                                                                                                                  │
│   457 │   │   run_and_report(                                                                                                                      │
│ ❱ 458 │   │   │   lambda: hydra.run(                                                                                                               │
│   459 │   │   │   │   config_name=config_name,                                                                                                     │
│   460 │   │   │   │   task_function=task_function,                                                                                                 │
│   461 │   │   │   │   overrides=overrides,                                                                                                         │
│                                                                                                                                                    │
│ F:\Users\sipkode\miniconda3\envs\diffuman\Lib\site-packages\hydra\_internal\hydra.py:132 in run                                                  │
│                                                                                                                                                    │
│   129 │   │   callbacks.on_run_end(config=cfg, config_name=config_name, job_return=ret)                                                            │
│   130 │   │                                                                                                                                        │
│   131 │   │   # access the result to trigger an exception in case the job failed.                                                                  │
│ ❱ 132 │   │   _ = ret.return_value                                                                                                                 │
│   133 │   │                                                                                                                                        │
│   134 │   │   return ret                                                                                                                           │
│   135                                                                                                                                              │
│                                                                                                                                                    │
│ F:\Users\sipkode\miniconda3\envs\diffuman\Lib\site-packages\hydra\core\utils.py:260 in return_value                                              │
│                                                                                                                                                    │
│   257 │   │   │   sys.stderr.write(                                                                                                                │
│   258 │   │   │   │   f"Error executing job with overrides: {self.overrides}" + os.linesep                                                         │
│   259 │   │   │   )                                                                                                                                │
│ ❱ 260 │   │   │   raise self._return_value                                                                                                         │
│   261 │                                                                                                                                            │
│   262 │   @return_value.setter                                                                                                                     │
│   263 │   def return_value(self, value: Any) -> None:                                                                                              │
│                                                                                                                                                    │
│ F:\Users\sipkode\miniconda3\envs\diffuman\Lib\site-packages\hydra\core\utils.py:186 in run_job                                                   │
│                                                                                                                                                    │
│   183 │   │   with env_override(hydra_cfg.hydra.job.env_set):                                                                                      │
│   184 │   │   │   callbacks.on_job_start(config=config, task_function=task_function)                                                               │
│   185 │   │   │   try:                                                                                                                             │
│ ❱ 186 │   │   │   │   ret.return_value = task_function(task_cfg)                                                                                   │
│   187 │   │   │   │   ret.status = JobStatus.COMPLETED                                                                                             │
│   188 │   │   │   except Exception as e:                                                                                                           │
│   189 │   │   │   │   ret.return_value = e                                                                                                         │
│                                                                                                                                                    │
│ F:\Users\sipkode\Dev\GitHub\Diffuman4D\inference.py:67 in main                                                                                   │
│                                                                                                                                                    │
│   64 │   """Main entry point"""                                                                                                                    │
│   65 │   print_config_tree(cfg, resolve=True, save_to_file=True)                                                                                   │
│   66 │                                                                                                                                             │
│ ❱ 67 │   inference(cfg)                                                                                                                            │
│   68                                                                                                                                               │
│   69                                                                                                                                               │
│   70 if __name__ == "__main__":                                                                                                                    │
│                                                                                                                                                    │
│ F:\Users\sipkode\Dev\GitHub\Diffuman4D\inference.py:51 in inference                                                                              │
│                                                                                                                                                    │
│   48 │                                                                                                                                             │
│   49 │   if cfg.sampling:                                                                                                                          │
│   50 │   │   log.info("Sampling...")                                                                                                               │
│ ❱ 51 │   │   runner.inference()                                                                                                                    │
│   52 │                                                                                                                                             │
│   53 │   if cfg.to_nerfstudio:                                                                                                                     │
│   54 │   │   log.info("Converting results to nerfstudio format...")                                                                                │
│                                                                                                                                                    │
│ F:\Users\sipkode\Dev\GitHub\Diffuman4D\src\samplers\sampling_runner.py:80 in inference                                                           │
│                                                                                                                                                    │
│    77 │   │   │   ):                                                                                                                               │
│    78 │   │   │   │   raise ValueError("Sampling failed.")                                                                                         │
│    79 │   │   else:                                                                                                                                │
│ ❱  80 │   │   │   self.sampler.execute_tasks()                                                                                                     │
│    81 │                                                                                                                                            │
│    82 │   def evaluate(self):                                                                                                                      │
│    83 │   │   evaluate_results(                                                                                                                    │
│                                                                                                                                                    │
│ F:\Users\sipkode\Dev\GitHub\Diffuman4D\src\samplers\sliding_iterative_sampler.py:275 in execute_tasks                                            │
│                                                                                                                                                    │
│   272 │   │   │   for j, task in enumerate(tasks):                                                                                                 │
│   273 │   │   │   │   try:                                                                                                                         │
│   274 │   │   │   │   │   log.debug(f"Executing task {j+1}/{len(tasks)}: {task}")                                                                  │
│ ❱ 275 │   │   │   │   │   self.execute_one_task(task)                                                                                              │
│   276 │   │   │   │   except Exception as e:                                                                                                       │
│   277 │   │   │   │   │   log.error(f"Failed on task {j+1}/{len(tasks)}: {task}")                                                                  │
│   278 │   │   │   │   │   log.error(f"Error: {str(e)}")                                                                                            │
│                                                                                                                                                    │
│ F:\Users\sipkode\Dev\GitHub\Diffuman4D\src\samplers\sliding_iterative_sampler.py:246 in execute_one_task                                         │
│                                                                                                                                                    │
│   243 │   │   │                                                                                                                                    │
│   244 │   │   │   sample = self.load_sample(**task)                                                                                                │
│   245 │   │   │   log.debug(f"Loaded sample for task: {task}")                                                                                     │
│ ❱ 246 │   │   │   sample = self.denoise(sample, pipe_idx=pipe_idx)                                                                                 │
│   247 │   │   │   log.debug(f"Denoising complete for task: {task}")                                                                                │
│   248 │   │   │   save_sampling_results(sample, output_dir=self.output_dir)                                                                        │
│   249 │   │   │   log.debug(f"Successfully completed task: {task}")                                                                                │
│                                                                                                                                                    │
│ F:\Users\sipkode\miniconda3\envs\diffuman\Lib\site-packages\torch\utils\_contextlib.py:120 in decorate_context                                   │
│                                                                                                                                                    │
│   117 │   @functools.wraps(func)                                                                                                                   │
│   118 │   def decorate_context(*args, **kwargs):                                                                                                   │
│   119 │   │   with ctx_factory():                                                                                                                  │
│ ❱ 120 │   │   │   return func(*args, **kwargs)                                                                                                     │
│   121 │                                                                                                                                            │
│   122 │   return decorate_context                                                                                                                  │
│   123                                                                                                                                              │
│                                                                                                                                                    │
│ F:\Users\sipkode\Dev\GitHub\Diffuman4D\src\samplers\sliding_iterative_sampler.py:187 in denoise                                                  │
│                                                                                                                                                    │
│   184 │   │   │   effective_window = self.window_size                                                                                              │
│   185 │   │                                                                                                                                        │
│   186 │   │   # denoise a spatial or temporal sample sequence                                                                                      │
│ ❱ 187 │   │   result = pipeline.sliding_iterative_denoise(                                                                                         │
│   188 │   │   │   pixel_values=sample["pixel_values"],                                                                                             │
│   189 │   │   │   plucker_embeds=sample["plucker_embeds"],                                                                                         │
│   190 │   │   │   skeletons=sample["skeletons"],                                                                                                   │
│                                                                                                                                                    │
│ F:\Users\sipkode\Dev\GitHub\Diffuman4D\src\diffusers\pipelines\diffuman4d\pipeline_diffuman4d.py:552 in sliding_iterative_denoise                │
│                                                                                                                                                    │
│   549 │   │   │   get_slice = lambda x: x[window] if x is not None else None                                                                       │
│   550 │   │   │                                                                                                                                    │
│   551 │   │   │   # few-step denoising for each window                                                                                             │
│ ❱ 552 │   │   │   latents_window = self(                                                                                                           │
│   553 │   │   │   │   pixel_values_latents=get_slice(pixel_values_latents),                                                                        │
│   554 │   │   │   │   plucker_embeds_latents=get_slice(plucker_embeds_latents),                                                                    │
│   555 │   │   │   │   skeletons_latents=get_slice(skeletons_latents),                                                                              │
│                                                                                                                                                    │
│ F:\Users\sipkode\miniconda3\envs\diffuman\Lib\site-packages\torch\utils\_contextlib.py:120 in decorate_context                                   │
│                                                                                                                                                    │
│   117 │   @functools.wraps(func)                                                                                                                   │
│   118 │   def decorate_context(*args, **kwargs):                                                                                                   │
│   119 │   │   with ctx_factory():                                                                                                                  │
│ ❱ 120 │   │   │   return func(*args, **kwargs)                                                                                                     │
│   121 │                                                                                                                                            │
│   122 │   return decorate_context                                                                                                                  │
│   123                                                                                                                                              │
│                                                                                                                                                    │
│ F:\Users\sipkode\Dev\GitHub\Diffuman4D\src\diffusers\pipelines\diffuman4d\pipeline_diffuman4d.py:413 in __call__                                 │
│                                                                                                                                                    │
│   410 │   │   │   │   latent_model_input = torch.cat(latent_model_input, dim=1)                                                                    │
│   411 │   │   │   │                                                                                                                                │
│   412 │   │   │   │   # predict the noise residual                                                                                                 │
│ ❱ 413 │   │   │   │   noise_pred = self.unet(                                                                                                      │
│   414 │   │   │   │   │   latent_model_input,                                                                                                      │
│   415 │   │   │   │   │   timestep=timestep,                                                                                                       │
│   416 │   │   │   │   │   skeletons=skeletons_latents,                                                                                             │
│                                                                                                                                                    │
│ F:\Users\sipkode\miniconda3\envs\diffuman\Lib\site-packages\torch\nn\modules\module.py:1773 in _wrapped_call_impl                                │
│                                                                                                                                                    │
│   1770 │   │   if self._compiled_call_impl is not None:                                                                                            │
│   1771 │   │   │   return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]                                                          │
│   1772 │   │   else:                                                                                                                               │
│ ❱ 1773 │   │   │   return self._call_impl(*args, **kwargs)                                                                                         │
│   1774 │                                                                                                                                           │
│   1775 │   # torchrec tests the code consistency with the following code                                                                           │
│   1776 │   # fmt: off                                                                                                                              │
│                                                                                                                                                    │
│ F:\Users\sipkode\miniconda3\envs\diffuman\Lib\site-packages\torch\nn\modules\module.py:1784 in _call_impl                                        │
│                                                                                                                                                    │
│   1781 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks                          │
│   1782 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                                                                     │
│   1783 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                                                                     │
│ ❱ 1784 │   │   │   return forward_call(*args, **kwargs)                                                                                            │
│   1785 │   │                                                                                                                                       │
│   1786 │   │   result = None                                                                                                                       │
│   1787 │   │   called_always_called_hooks = set()                                                                                                  │
│                                                                                                                                                    │
│ F:\Users\sipkode\Dev\GitHub\Diffuman4D\src\diffusers\models\unets\unet_multiview_condition.py:561 in forward                                     │
│                                                                                                                                                    │
│   558 │   │   for i, downsample_block in enumerate(self.down_blocks):                                                                              │
│   559 │   │   │   if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:                                    │
│   560 │   │   │   │   num_frames_block = num_frames if len(self.down_blocks) - i - 1 < self.config.num_3d_attn_blocks else 1                       │
│ ❱ 561 │   │   │   │   sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames_block)                          │
│   562 │   │   │   else:                                                                                                                            │
│   563 │   │   │   │   sample, res_samples = downsample_block(hidden_states=sample, temb=emb)                                                       │
│   564                                                                                                                                              │
│                                                                                                                                                    │
│ F:\Users\sipkode\miniconda3\envs\diffuman\Lib\site-packages\torch\nn\modules\module.py:1773 in _wrapped_call_impl                                │
│                                                                                                                                                    │
│   1770 │   │   if self._compiled_call_impl is not None:                                                                                            │
│   1771 │   │   │   return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]                                                          │
│   1772 │   │   else:                                                                                                                               │
│ ❱ 1773 │   │   │   return self._call_impl(*args, **kwargs)                                                                                         │
│   1774 │                                                                                                                                           │
│   1775 │   # torchrec tests the code consistency with the following code                                                                           │
│   1776 │   # fmt: off                                                                                                                              │
│                                                                                                                                                    │
│ F:\Users\sipkode\miniconda3\envs\diffuman\Lib\site-packages\torch\nn\modules\module.py:1784 in _call_impl                                        │
│                                                                                                                                                    │
│   1781 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks                          │
│   1782 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                                                                     │
│   1783 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                                                                     │
│ ❱ 1784 │   │   │   return forward_call(*args, **kwargs)                                                                                            │
│   1785 │   │                                                                                                                                       │
│   1786 │   │   result = None                                                                                                                       │
│   1787 │   │   called_always_called_hooks = set()                                                                                                  │
│                                                                                                                                                    │
│ F:\Users\sipkode\Dev\GitHub\Diffuman4D\src\diffusers\models\unets\unet_multiview_blocks.py:519 in forward                                        │
│                                                                                                                                                    │
│   516 │   │   │   │   )[0]                                                                                                                         │
│   517 │   │   │   else:                                                                                                                            │
│   518 │   │   │   │   hidden_states = resnet(hidden_states, temb)                                                                                  │
│ ❱ 519 │   │   │   │   hidden_states = attn(                                                                                                        │
│   520 │   │   │   │   │   hidden_states,                                                                                                           │
│   521 │   │   │   │   │   encoder_hidden_states=encoder_hidden_states,                                                                             │
│   522 │   │   │   │   │   cross_attention_kwargs=cross_attention_kwargs,                                                                           │
│                                                                                                                                                    │
│ F:\Users\sipkode\miniconda3\envs\diffuman\Lib\site-packages\torch\nn\modules\module.py:1773 in _wrapped_call_impl                                │
│                                                                                                                                                    │
│   1770 │   │   if self._compiled_call_impl is not None:                                                                                            │
│   1771 │   │   │   return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]                                                          │
│   1772 │   │   else:                                                                                                                               │
│ ❱ 1773 │   │   │   return self._call_impl(*args, **kwargs)                                                                                         │
│   1774 │                                                                                                                                           │
│   1775 │   # torchrec tests the code consistency with the following code                                                                           │
│   1776 │   # fmt: off                                                                                                                              │
│                                                                                                                                                    │
│ F:\Users\sipkode\miniconda3\envs\diffuman\Lib\site-packages\torch\nn\modules\module.py:1784 in _call_impl                                        │
│                                                                                                                                                    │
│   1781 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks                          │
│   1782 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                                                                     │
│   1783 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                                                                     │
│ ❱ 1784 │   │   │   return forward_call(*args, **kwargs)                                                                                            │
│   1785 │   │                                                                                                                                       │
│   1786 │   │   result = None                                                                                                                       │
│   1787 │   │   called_always_called_hooks = set()                                                                                                  │
│                                                                                                                                                    │
│ F:\Users\sipkode\Dev\GitHub\Diffuman4D\src\diffusers\models\transformers\transformer_multiview.py:196 in forward                                 │
│                                                                                                                                                    │
│   193 │   │   │   │   │   **ckpt_kwargs,                                                                                                           │
│   194 │   │   │   │   )                                                                                                                            │
│   195 │   │   │   else:                                                                                                                            │
│ ❱ 196 │   │   │   │   hidden_states = block(                                                                                                       │
│   197 │   │   │   │   │   hidden_states,                                                                                                           │
│   198 │   │   │   │   │   attention_mask=attention_mask,                                                                                           │
│   199 │   │   │   │   │   encoder_hidden_states=encoder_hidden_states,                                                                             │
│                                                                                                                                                    │
│ F:\Users\sipkode\miniconda3\envs\diffuman\Lib\site-packages\torch\nn\modules\module.py:1773 in _wrapped_call_impl                                │
│                                                                                                                                                    │
│   1770 │   │   if self._compiled_call_impl is not None:                                                                                            │
│   1771 │   │   │   return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]                                                          │
│   1772 │   │   else:                                                                                                                               │
│ ❱ 1773 │   │   │   return self._call_impl(*args, **kwargs)                                                                                         │
│   1774 │                                                                                                                                           │
│   1775 │   # torchrec tests the code consistency with the following code                                                                           │
│   1776 │   # fmt: off                                                                                                                              │
│                                                                                                                                                    │
│ F:\Users\sipkode\miniconda3\envs\diffuman\Lib\site-packages\torch\nn\modules\module.py:1784 in _call_impl                                        │
│                                                                                                                                                    │
│   1781 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks                          │
│   1782 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                                                                     │
│   1783 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                                                                     │
│ ❱ 1784 │   │   │   return forward_call(*args, **kwargs)                                                                                            │
│   1785 │   │                                                                                                                                       │
│   1786 │   │   result = None                                                                                                                       │
│   1787 │   │   called_always_called_hooks = set()                                                                                                  │
│                                                                                                                                                    │
│ F:\Users\sipkode\Dev\GitHub\Diffuman4D\src\diffusers\models\attention.py:142 in forward                                                          │
│                                                                                                                                                    │
│   139 │   │   │   # "feed_forward_chunk_size" can be used to save memory                                                                           │
│   140 │   │   │   ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)                                │
│   141 │   │   else:                                                                                                                                │
│ ❱ 142 │   │   │   ff_output = self.ff(norm_hidden_states)                                                                                          │
│   143 │   │                                                                                                                                        │
│   144 │   │   if self.norm_type == "ada_norm_zero":                                                                                                │
│   145 │   │   │   ff_output = gate_mlp.unsqueeze(1) * ff_output                                                                                    │
│                                                                                                                                                    │
│ F:\Users\sipkode\miniconda3\envs\diffuman\Lib\site-packages\torch\nn\modules\module.py:1773 in _wrapped_call_impl                                │
│                                                                                                                                                    │
│   1770 │   │   if self._compiled_call_impl is not None:                                                                                            │
│   1771 │   │   │   return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]                                                          │
│   1772 │   │   else:                                                                                                                               │
│ ❱ 1773 │   │   │   return self._call_impl(*args, **kwargs)                                                                                         │
│   1774 │                                                                                                                                           │
│   1775 │   # torchrec tests the code consistency with the following code                                                                           │
│   1776 │   # fmt: off                                                                                                                              │
│                                                                                                                                                    │
│ F:\Users\sipkode\miniconda3\envs\diffuman\Lib\site-packages\torch\nn\modules\module.py:1784 in _call_impl                                        │
│                                                                                                                                                    │
│   1781 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks                          │
│   1782 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                                                                     │
│   1783 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                                                                     │
│ ❱ 1784 │   │   │   return forward_call(*args, **kwargs)                                                                                            │
│   1785 │   │                                                                                                                                       │
│   1786 │   │   result = None                                                                                                                       │
│   1787 │   │   called_always_called_hooks = set()                                                                                                  │
│                                                                                                                                                    │
│ F:\Users\sipkode\miniconda3\envs\diffuman\Lib\site-packages\diffusers\models\attention.py:1250 in forward                                        │
│                                                                                                                                                    │
│   1247 │   │   │   deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an er │
│   1248 │   │   │   deprecate("scale", "1.0.0", deprecation_message)                                                                                │
│   1249 │   │   for module in self.net:                                                                                                             │
│ ❱ 1250 │   │   │   hidden_states = module(hidden_states)                                                                                           │
│   1251 │   │   return hidden_states                                                                                                                │
│   1252                                                                                                                                             │
│                                                                                                                                                    │
│ F:\Users\sipkode\miniconda3\envs\diffuman\Lib\site-packages\torch\nn\modules\module.py:1773 in _wrapped_call_impl                                │
│                                                                                                                                                    │
│   1770 │   │   if self._compiled_call_impl is not None:                                                                                            │
│   1771 │   │   │   return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]                                                          │
│   1772 │   │   else:                                                                                                                               │
│ ❱ 1773 │   │   │   return self._call_impl(*args, **kwargs)                                                                                         │
│   1774 │                                                                                                                                           │
│   1775 │   # torchrec tests the code consistency with the following code                                                                           │
│   1776 │   # fmt: off                                                                                                                              │
│                                                                                                                                                    │
│ F:\Users\sipkode\miniconda3\envs\diffuman\Lib\site-packages\torch\nn\modules\module.py:1784 in _call_impl                                        │
│                                                                                                                                                    │
│   1781 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks                          │
│   1782 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                                                                     │
│   1783 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                                                                     │
│ ❱ 1784 │   │   │   return forward_call(*args, **kwargs)                                                                                            │
│   1785 │   │                                                                                                                                       │
│   1786 │   │   result = None                                                                                                                       │
│   1787 │   │   called_always_called_hooks = set()                                                                                                  │
│                                                                                                                                                    │
│ F:\Users\sipkode\miniconda3\envs\diffuman\Lib\site-packages\diffusers\models\activations.py:123 in forward                                       │
│                                                                                                                                                    │
│   120 │   │   │   return torch_npu.npu_geglu(hidden_states, dim=-1, approximate=1)[0]                                                              │
│   121 │   │   else:                                                                                                                                │
│   122 │   │   │   hidden_states, gate = hidden_states.chunk(2, dim=-1)                                                                             │
│ ❱ 123 │   │   │   return hidden_states * self.gelu(gate)                                                                                           │
│   124                                                                                                                                              │
│   125                                                                                                                                              │
│   126 class SwiGLU(nn.Module):                                                                                                                     │
╰────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
OutOfMemoryError: CUDA out of memory. Tried to allocate 960.00 MiB. GPU 0 has a total capacity of 23.99 GiB of which 15.52 GiB is free. Of the allocated memory 6.78 GiB is allocated by PyTorch, and 76.36 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

I've already tried to follow the instructions by enabling the expandable_segments:True option, but with the same result. Additionally, I've tried

  • switching from bf16 to fp16
  • manually cleaning up in between denoising steps with torch.cuda.empty_cache() and gc.collect()
  • reducing the sampler window size

It definitely seems like it is a fragmentation issue, since if I watch the memory allocation pattern in perfmon, I can see that on the couple denoising steps prior to the crash the memory allocations/deallocations start getting crazy spikey!

Has anyone run into this before? Any other suggestions for how to fix this?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions