Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training PointPillars on nuScenes mini dataset - "... variables needed for gradient computation has been modified by an inplace operation" #1057

Open
nout-kleef opened this issue Nov 17, 2021 · 14 comments

Comments

@nout-kleef
Copy link

nout-kleef commented Nov 17, 2021

Checklist

  1. I have searched related issues but cannot get the expected help.
  2. The issue has not been fixed in the latest version.

Describe the issue
I am trying to train PointPillars on nuscenes, but I get RuntimeError: CUDA out of memory. errors using the provided implementation.

Reproduction

./tools/dist_train.sh configs/pointpillars/hv_pointpillars_fpn_sbn-all_4x8_2x_nus-3d.py 1
  1. What config dir you run?
configs/pointpillars/hv_pointpillars_fpn_sbn-all_4x8_2x_nus-3d.py
  1. Did you make any modifications on the code or config? Did you understand what you have modified?

The OOM error occurs without any modification, i.e. when executing the reproduction command.
I tried to solve the issue by adjusting the batch size to a smaller number (I tried 1 and 2). However, this results in a different error:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1, 256, 100, 100]], which is output 0 of ReluBackward0, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

(I am not sure whether this is the correct way to decrease the batch size.)

  1. What dataset did you use?

NuScenes (configs/_base_/datasets/nus-3d.py)

Environment

  1. Please run python mmdet3d/utils/collect_env.py to collect necessary environment infomation and paste it here.
$ python mmdet3d/utils/collect_env.py
sys.platform: linux
Python: 3.7.11 (default, Jul 27 2021, 14:32:16) [GCC 7.5.0]
CUDA available: True
GPU 0: Tesla P4
CUDA_HOME: /usr/local/cuda-11
NVCC: Build cuda_11.5.r11.5/compiler.30411180_0
GCC: gcc (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
PyTorch: 1.10.0
PyTorch compiling details: PyTorch built with:
  - GCC 7.3
  - C++ Version: 201402
  - Intel(R) oneAPI Math Kernel Library Version 2021.4-Product Build 20210904 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v2.2.3 (Git Hash 7336ca9f055cf1bfa13efb658fe15dc9b41f0740)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: AVX512
  - CUDA Runtime 11.3
  - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_37,code=compute_37
  - CuDNN 8.2
  - Magma 2.5.2
  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.3, CUDNN_VERSION=8.2.0, CXX_COMPILER=/opt/rh/devtoolset-7/root/usr/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.10.0, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, 

TorchVision: 0.11.1
OpenCV: 4.5.4-dev
MMCV: 1.3.16
MMCV Compiler: GCC 9.3
MMCV CUDA Compiler: 11.5
MMDetection: 2.14.0
MMSegmentation: 0.14.1
MMDetection3D: 0.17.2+0cd000b

This is my first time doing deep learning and using CUDA, so I'm not sure what the best way forward is to solve this issue (other than upgrading my VM to use a bigger GPU).

@Tai-Wang
Copy link
Member

How do you change the batch_size? You just need to change the param samples_per_gpu in the config. (Only changing the batch size can not yield your reported error message.)

@nout-kleef
Copy link
Author

That's what happens on the most recent master commit when I update samples_per_gpu. I also don't see how that can affect the back propagation...

@nout-kleef
Copy link
Author

After upgrading to a GPU with more memory, the inplace operation error remains (with unmodified samples_per_gpu):

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [4, 256, 100, 100]], which is output 0 of ReluBackward0, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

This is happening on the most recent master commit, without any modifications. Does training start normally when you execute ./tools/dist_train.sh configs/pointpillars/hv_pointpillars_fpn_sbn-all_4x8_2x_nus-3d.py 1?

@nout-kleef
Copy link
Author

Using torch.autograd.set_detect_anomaly(True) gives the following stacktrace:

[W python_anomaly_mode.cpp:104] Warning: Error detected in ReluBackward0. Traceback of forward call that caused the error:
  File "/home/noutkleef1/Downloads/pycharm-community-2021.2.3/plugins/python-ce/helpers/pydev/pydevd.py", line 2173, in <module>
    main()
  File "/home/noutkleef1/Downloads/pycharm-community-2021.2.3/plugins/python-ce/helpers/pydev/pydevd.py", line 2164, in main
    globals = debugger.run(setup['file'], None, None, is_module)
  File "/home/noutkleef1/Downloads/pycharm-community-2021.2.3/plugins/python-ce/helpers/pydev/pydevd.py", line 1476, in run
    return self._exec(is_module, entry_point_fn, module_name, file, globals, locals)
  File "/home/noutkleef1/Downloads/pycharm-community-2021.2.3/plugins/python-ce/helpers/pydev/pydevd.py", line 1483, in _exec
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "/home/noutkleef1/Downloads/pycharm-community-2021.2.3/plugins/python-ce/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "tools/train.py", line 225, in <module>
    main()
  File "tools/train.py", line 221, in main
    meta=meta)
  File "/home/noutkleef1/mmdetection3d/mmdet3d/apis/train.py", line 35, in train_model
    meta=meta)
  File "/home/noutkleef1/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/mmdet/apis/train.py", line 170, in train_detector
    runner.run(data_loaders, cfg.workflow)
  File "/home/noutkleef1/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/mmcv/runner/epoch_based_runner.py", line 127, in run
    epoch_runner(data_loaders[i], **kwargs)
  File "/home/noutkleef1/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/mmcv/runner/epoch_based_runner.py", line 50, in train
    self.run_iter(data_batch, train_mode=True, **kwargs)
  File "/home/noutkleef1/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/mmcv/runner/epoch_based_runner.py", line 30, in run_iter
    **kwargs)
  File "/home/noutkleef1/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/mmcv/parallel/distributed.py", line 52, in train_step
    output = self.module.train_step(*inputs[0], **kwargs[0])
  File "/home/noutkleef1/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/mmdet/models/detectors/base.py", line 237, in train_step
    losses = self(**data)
  File "/home/noutkleef1/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/noutkleef1/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/mmcv/runner/fp16_utils.py", line 98, in new_func
    return old_func(*args, **kwargs)
  File "/home/noutkleef1/mmdetection3d/mmdet3d/models/detectors/base.py", line 59, in forward
    return self.forward_train(**kwargs)
  File "/home/noutkleef1/mmdetection3d/mmdet3d/models/detectors/mvx_two_stage.py", line 274, in forward_train
    points, img=img, img_metas=img_metas)
  File "/home/noutkleef1/mmdetection3d/mmdet3d/models/detectors/mvx_two_stage.py", line 208, in extract_feat
    pts_feats = self.extract_pts_feat(points, img_feats, img_metas)
  File "/home/noutkleef1/mmdetection3d/mmdet3d/models/detectors/mvx_two_stage.py", line 202, in extract_pts_feat
    x = self.pts_neck(x)
  File "/home/noutkleef1/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/noutkleef1/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/mmcv/runner/fp16_utils.py", line 98, in new_func
    return old_func(*args, **kwargs)
  File "/home/noutkleef1/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/mmdet/models/necks/fpn.py", line 158, in forward
    for i, lateral_conv in enumerate(self.lateral_convs)
  File "/home/noutkleef1/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/mmdet/models/necks/fpn.py", line 158, in <listcomp>
    for i, lateral_conv in enumerate(self.lateral_convs)
  File "/home/noutkleef1/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/noutkleef1/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/mmcv/cnn/bricks/conv_module.py", line 205, in forward
    x = self.activate(x)
  File "/home/noutkleef1/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/noutkleef1/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/nn/modules/activation.py", line 98, in forward
    return F.relu(input, inplace=self.inplace)
  File "/home/noutkleef1/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/nn/functional.py", line 1299, in relu
    result = torch.relu(input)
 (function _print_stack)
Traceback (most recent call last):
  File "tools/train.py", line 221, in main
    meta=meta)
  File "/home/noutkleef1/mmdetection3d/mmdet3d/apis/train.py", line 35, in train_model
    meta=meta)
  File "/home/noutkleef1/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/mmdet/apis/train.py", line 170, in train_detector
    runner.run(data_loaders, cfg.workflow)
  File "/home/noutkleef1/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/mmcv/runner/epoch_based_runner.py", line 127, in run
    epoch_runner(data_loaders[i], **kwargs)
  File "/home/noutkleef1/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/mmcv/runner/epoch_based_runner.py", line 51, in train
    self.call_hook('after_train_iter')
  File "/home/noutkleef1/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/mmcv/runner/base_runner.py", line 307, in call_hook
    getattr(hook, fn_name)(self)
  File "/home/noutkleef1/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/mmcv/runner/hooks/optimizer.py", line 35, in after_train_iter
    runner.outputs['loss'].backward()
  File "/home/noutkleef1/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/_tensor.py", line 307, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/home/noutkleef1/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/autograd/__init__.py", line 156, in backward
    allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [4, 256, 100, 100]], which is output 0 of ReluBackward0, is at version 1; expected version 0 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

It seems that for some reason torch.relu fails to compute the gradient

@Tai-Wang
Copy link
Member

Tai-Wang commented Nov 23, 2021

If you only train with a single GPU, please replace SyncBN with a general BN for nuScenes models although we do not recommend it. You can try with experiments on KITTI at first. Its size is more friendly for limited computing resources.

@nout-kleef
Copy link
Author

nout-kleef commented Nov 24, 2021

Thanks for the tips. I've replaced SyncBN with general BN:

configs/dissertation/diss.py:

_base_ = '../pointpillars/hv_pointpillars_fpn_sbn-all_4x8_2x_nus-3d.py'

# use normal BN instead of SyncBN
model = dict(
    pts_voxel_encoder=dict(
        norm_cfg=dict(type='BN1d')),
    pts_backbone=dict(
        norm_cfg=dict(type='BN2d')),
    pts_neck=dict(
        norm_cfg=dict(type='BN2d'))
)

However, the error remains:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [4, 256, 100, 100]], which is output 0 of ReluBackward0, is at version 1; expected version 0 instead.

(Now using python tools/train.py configs/dissertation/diss.py)

I have been using nuScenes-mini dataset thus far, but I will see if the problem goes away when using KITTI.

@nout-kleef nout-kleef changed the title How to solve "CUDA out of memory" error Training PointPillars on nuScenes mini dataset - "... variables needed for gradient computation has been modified by an inplace operation" Nov 24, 2021
@nout-kleef
Copy link
Author

Training works fine on kitti.
Could it be that using the mini version of nuScenes is causing these issues? I was hoping to delay using the full nuScenes dataset since it is so big...

@Tai-Wang
Copy link
Member

Then I recommend you first train models on KITTI, because using nuScenes-mini usually can not achieve decent performance. We will try to reproduce your bug ASAP.

@wHao-Wu
Copy link
Contributor

wHao-Wu commented Nov 25, 2021

Hi, @nout-kleef

I cannot reproduce your error on my machine. Would you please run python mmdet3d/utils/collect_env.py to collect your environment and place it here?

@nout-kleef
Copy link
Author

Hi @wHao-Wu ,
I just tried it with the new environment that I created to make KITTI training work (downgraded to CUDA 10.1, pytorch 1.8.0, gcc 8.4.0).

Now, training commences without problems (./tools/dist_train.sh configs/pointpillars/hv_pointpillars_fpn_sbn-all_4x8_2x_nus-3d.py 1 with samples_per_gpu == 2 to prevent OOM error).

New env:

$ python mmdet3d/utils/collect_env.py 
sys.platform: linux
Python: 3.7.11 (default, Jul 27 2021, 14:32:16) [GCC 7.5.0]
CUDA available: True
GPU 0: Tesla T4
CUDA_HOME: /usr/local/cuda
NVCC: Build cuda_11.5.r11.5/compiler.30411180_0
GCC: gcc (Ubuntu 8.4.0-3ubuntu2) 8.4.0
PyTorch: 1.8.0
PyTorch compiling details: PyTorch built with:
  - GCC 7.3
  - C++ Version: 201402
  - Intel(R) oneAPI Math Kernel Library Version 2021.4-Product Build 20210904 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v1.7.0 (Git Hash 7aed236906b1f7a05c0917e5257a1af05e9ff683)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - CUDA Runtime 10.1
  - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_37,code=compute_37
  - CuDNN 7.6.3
  - Magma 2.5.2
  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=10.1, CUDNN_VERSION=7.6.3, CXX_COMPILER=/opt/rh/devtoolset-7/root/usr/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.8.0, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, 

TorchVision: 0.9.0
OpenCV: 4.5.4
MMCV: 1.3.18
MMCV Compiler: GCC 7.3
MMCV CUDA Compiler: 10.1
MMDetection: 2.18.1
MMSegmentation: 0.19.0
MMDetection3D: 0.17.2+00f4e95

Here is a diff against the env I posted initially:

$ diff old.txt new.txt -y
sys.platform: linux                                             sys.platform: linux
Python: 3.7.11 (default, Jul 27 2021, 14:32:16) [GCC 7.5.0]     Python: 3.7.11 (default, Jul 27 2021, 14:32:16) [GCC 7.5.0]
CUDA available: True                                            CUDA available: True
GPU 0: Tesla P4                                               | GPU 0: Tesla T4
CUDA_HOME: /usr/local/cuda-11                                 | CUDA_HOME: /usr/local/cuda
NVCC: Build cuda_11.5.r11.5/compiler.30411180_0                 NVCC: Build cuda_11.5.r11.5/compiler.30411180_0
GCC: gcc (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0                 | GCC: gcc (Ubuntu 8.4.0-3ubuntu2) 8.4.0
PyTorch: 1.10.0                                               | PyTorch: 1.8.0
PyTorch compiling details: PyTorch built with:                  PyTorch compiling details: PyTorch built with:
  - GCC 7.3                                                       - GCC 7.3
  - C++ Version: 201402                                           - C++ Version: 201402
  - Intel(R) oneAPI Math Kernel Library Version 2021.4-Produc     - Intel(R) oneAPI Math Kernel Library Version 2021.4-Produc
  - Intel(R) MKL-DNN v2.2.3 (Git Hash 7336ca9f055cf1bfa13efb6 |   - Intel(R) MKL-DNN v1.7.0 (Git Hash 7aed236906b1f7a05c0917e
  - OpenMP 201511 (a.k.a. OpenMP 4.5)                             - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)               <
  - NNPACK is enabled                                             - NNPACK is enabled
  - CPU capability usage: AVX512                              |   - CPU capability usage: AVX2
  - CUDA Runtime 11.3                                         |   - CUDA Runtime 10.1
  - NVCC architecture flags: -gencode;arch=compute_37,code=sm |   - NVCC architecture flags: -gencode;arch=compute_37,code=sm
  - CuDNN 8.2                                                 |   - CuDNN 7.6.3
  - Magma 2.5.2                                                   - Magma 2.5.2
  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_V |   - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_V

TorchVision: 0.11.1                                           | TorchVision: 0.9.0
OpenCV: 4.5.4-dev                                             | OpenCV: 4.5.4
MMCV: 1.3.16                                                  | MMCV: 1.3.18
MMCV Compiler: GCC 9.3                                        | MMCV Compiler: GCC 7.3
MMCV CUDA Compiler: 11.5                                      | MMCV CUDA Compiler: 10.1
MMDetection: 2.14.0                                           | MMDetection: 2.18.1
MMSegmentation: 0.14.1                                        | MMSegmentation: 0.19.0
MMDetection3D: 0.17.2+0cd000b                                 | MMDetection3D: 0.17.2+00f4e95

LD_LIBRARY_PATH: /usr/local/cuda-10.1/lib64:/usr/local/cuda-10.1/extras/CUPTI/lib64:/lib/nccl/cuda-10:/usr/lib/mesa-diverted/x86_64-linux-gnu:/usr/lib/x86_64-linux-gnu/mesa:/usr/lib/x86_64-linux-gnu/dri:/usr/lib/x86_64-linux-gnu/gallium-pipe:/usr/local/cuda-11/lib64:

My guess is that the issue was caused by pytorch 1.10.0?

@wHao-Wu
Copy link
Contributor

wHao-Wu commented Nov 26, 2021

The environment in my machine is PyTorch 1.5.0. We will try to reproduce this error in the environment of CUDA 11.3 and PyTorch 1.10.0. Thanks a lot for your feedback and contribution.

@fcakyon
Copy link

fcakyon commented Dec 27, 2021

@wHao-Wu we are having the same error when start a training on pointpillars. We cannot lower the cuda and torch versions since our gpu is new. What causes this error in the up-to-date torch version?

@TheGreatGalaxy
Copy link

@wHao-Wu we are having the same error when start a training on pointpillars. We cannot lower the cuda and torch versions since our gpu is new. What causes this error in the up-to-date torch version?
I have countered the same error in pytorch1.10. Seems caused by new version pytorch. But i can't fix it. Do you fix it now?

@TheGreatGalaxy
Copy link

Found the reason finally. Should use "= +" instead of " += ". https://github.com/open-mmlab/mmdetection/blob/56e42e72cdf516bebb676e586f408b98f854d84c/mmdet/models/necks/fpn.py#L169. New version mmdetect is ok. Reference: https://discuss.pytorch.org/t/element-0-of-tensors-does-not-require-grad-and-does-not-have-a-grad-fn/32908/112

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants