Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] CID for other dataset? #3067

Open
2 tasks done
wux024 opened this issue Jun 12, 2024 · 0 comments
Open
2 tasks done

[Bug] CID for other dataset? #3067

wux024 opened this issue Jun 12, 2024 · 0 comments
Assignees

Comments

@wux024
Copy link

wux024 commented Jun 12, 2024

Prerequisite

Environment

OrderedDict([('sys.platform', 'linux'), ('Python', '3.10.14 (main, Mar 21 2024, 16:24:04) [GCC 11.2.0]'), ('CUDA available', True), ('MUSA available', False), ('numpy_random_seed', 2147483648), ('GPU 0,1', 'NVIDIA GeForce RTX 3090'), ('CUDA_HOME', ':/usr/local/cuda'), ('GCC', 'gcc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0'), ('PyTorch', '2.1.2+cu121'), ('PyTorch compiling details', 'PyTorch built with:\n - GCC 9.3\n - C++ Version: 201703\n - Intel(R) oneAPI Math Kernel Library Version 2022.2-Product Build 20220804 for Intel(R) 64 architecture applications\n - Intel(R) MKL-DNN v3.1.1 (Git Hash 64f6bcbcbab628e96f33a62c3e975f8535a7bde4)\n - OpenMP 201511 (a.k.a. OpenMP 4.5)\n - LAPACK is enabled (usually provided by MKL)\n - NNPACK is enabled\n - CPU capability usage: AVX512\n - CUDA Runtime 12.1\n - NVCC architecture flags: -gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_90,code=sm_90\n - CuDNN 8.9.2\n - Magma 2.6.1\n - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=12.1, CUDNN_VERSION=8.9.2, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11 -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=old-style-cast -Wno-invalid-partial-specialization -Wno-unused-private-field -Wno-aligned-allocation-unavailable -Wno-missing-braces -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_DISABLE_GPU_ASSERTS=ON, TORCH_VERSION=2.1.2, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=1, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, \n'), ('TorchVision', '0.16.2+cu121'), ('OpenCV', '4.9.0'), ('MMEngine', '0.10.3'), ('MMPose', '1.3.1+0a8a104')])

mmcv 2.1.0
mmdet 3.3.0
mmengine 0.10.3
mmpose 1.3.1 /mnt/data1/wux024/mmpose
mmpretrain 1.2.0

Reproduces the problem - code sample

base = ['../../../base/default_runtime.py']

runtime

train_cfg = dict(max_epochs=140, val_interval=10)

optimizer

optim_wrapper = dict(optimizer=dict(
type='Adam',
lr=1e-3,
))

learning policy

param_scheduler = [
dict(
type='MultiStepLR',
begin=0,
end=140,
milestones=[90, 120],
gamma=0.1,
by_epoch=True)
]

automatically scaling LR based on the actual training batch size

auto_scale_lr = dict(base_batch_size=160)

hooks

default_hooks = dict(checkpoint=dict(save_best='coco/AP', rule='greater'))

codec settings

codec = dict(
type='DecoupledHeatmap', input_size=(512, 512), heatmap_size=(128, 128))

model settings

model = dict(
type='BottomupPoseEstimator',
data_preprocessor=dict(
type='PoseDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True),
backbone=dict(
type='HRNet',
in_channels=3,
extra=dict(
stage1=dict(
num_modules=1,
num_branches=1,
block='BOTTLENECK',
num_blocks=(4, ),
num_channels=(64, )),
stage2=dict(
num_modules=1,
num_branches=2,
block='BASIC',
num_blocks=(4, 4),
num_channels=(32, 64)),
stage3=dict(
num_modules=4,
num_branches=3,
block='BASIC',
num_blocks=(4, 4, 4),
num_channels=(32, 64, 128)),
stage4=dict(
num_modules=3,
num_branches=4,
block='BASIC',
num_blocks=(4, 4, 4, 4),
num_channels=(32, 64, 128, 256),
multiscale_output=True)),
init_cfg=dict(
type='Pretrained',
checkpoint='https://download.openmmlab.com/mmpose/'
'pretrain_models/hrnet_w32-36af842e.pth'),
),
neck=dict(
type='FeatureMapProcessor',
concat=True,
),
head=dict(
type='CIDHead',
in_channels=480,
num_keypoints=5,
gfd_channels=32,
coupled_heatmap_loss=dict(type='FocalHeatmapLoss', loss_weight=1.0),
decoupled_heatmap_loss=dict(type='FocalHeatmapLoss', loss_weight=4.0),
contrastive_loss=dict(
type='InfoNCELoss', temperature=0.05, loss_weight=1.0),
decoder=codec,
),
train_cfg=dict(max_train_instances=200),
test_cfg=dict(
multiscale_test=False,
flip_test=True,
shift_heatmap=False,
align_corners=False))

base dataset settings

dataset_type = 'FishDataset'
data_mode = 'bottomup'
data_root = 'data/fish/'

pipelines

train_pipeline = [
dict(type='LoadImage'),
dict(type='BottomupRandomAffine', input_size=codec['input_size']),
dict(type='RandomFlip', direction='horizontal'),
dict(type='GenerateTarget', encoder=codec),
dict(type='BottomupGetHeatmapMask'),
dict(type='PackPoseInputs'),
]
val_pipeline = [
dict(type='LoadImage'),
dict(
type='BottomupResize',
input_size=codec['input_size'],
size_factor=64,
resize_mode='expand'),
dict(
type='PackPoseInputs',
meta_keys=('id', 'img_id', 'img_path', 'crowd_index', 'ori_shape',
'img_shape', 'input_size', 'input_center', 'input_scale',
'flip', 'flip_direction', 'flip_indices', 'raw_ann_info',
'skeleton_links'))
]

data loaders

train_dataloader = dict(
batch_size=32,
num_workers=8,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_mode=data_mode,
ann_file='annotations/train.json',
data_prefix=dict(img='images/train/'),
pipeline=train_pipeline,
))
val_dataloader = dict(
batch_size=32,
num_workers=8,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_mode=data_mode,
ann_file='annotations/val.json',
data_prefix=dict(img='images/val/'),
test_mode=True,
pipeline=val_pipeline,
))
test_dataloader = dict(
batch_size=32,
num_workers=8,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_mode=data_mode,
ann_file='annotations/test.json',
data_prefix=dict(img='images/test/'),
test_mode=True,
pipeline=val_pipeline,
))

evaluators

val_evaluator = [dict(
type='CocoMetric',
ann_file=data_root + 'annotations/val.json',
nms_thr=0.8,
score_mode='keypoint')
]
test_evaluator = [dict(
type='CocoMetric',
ann_file=data_root + 'annotations/test.json',
nms_thr=0.8,
score_mode='keypoint')
]

Reproduces the problem - command or script

python tools/train.py configs/...

Reproduces the problem - error message

06/12 11:17:08 - mmengine - WARNING - "FileClient" will be deprecated in future. Please use io functions in https://mmengine.readthedocs.io/en/latest/api/fileio.html#file-io
06/12 11:17:08 - mmengine - WARNING - "HardDiskBackend" is the alias of "LocalBackend" and the former will be deprecated in future.
06/12 11:17:08 - mmengine - INFO - Checkpoints will be saved to /mnt/data1/wux024/mmpose/work_dirs/bottomup/fish/cid_hrnet-w32_8xb32-140e_fish-512x512.
/mnt/data1/wux024/mmpose/mmpose/codecs/utils/gaussian_heatmap.py:182: RuntimeWarning: invalid value encountered in divide
gaussian = np.exp(-((x - x0)**2 + (y - y0)**2) / (2 * sigma[n]**2))
/mnt/data1/wux024/mmpose/mmpose/codecs/utils/gaussian_heatmap.py:182: RuntimeWarning: invalid value encountered in divide
gaussian = np.exp(-((x - x0)**2 + (y - y0)**2) / (2 * sigma[n]**2))
Traceback (most recent call last):
File "/mnt/data1/wux024/mmpose/tools/train.py", line 162, in
main()
File "/mnt/data1/wux024/mmpose/tools/train.py", line 158, in main
runner.train()
File "/home/wux024/.conda/envs/openmmlab/lib/python3.10/site-packages/mmengine/runner/runner.py", line 1777, in train
model = self.train_loop.run() # type: ignore
File "/home/wux024/.conda/envs/openmmlab/lib/python3.10/site-packages/mmengine/runner/loops.py", line 96, in run
self.run_epoch()
File "/home/wux024/.conda/envs/openmmlab/lib/python3.10/site-packages/mmengine/runner/loops.py", line 112, in run_epoch
self.run_iter(idx, data_batch)
File "/home/wux024/.conda/envs/openmmlab/lib/python3.10/site-packages/mmengine/runner/loops.py", line 128, in run_iter
outputs = self.runner.model.train_step(
File "/home/wux024/.conda/envs/openmmlab/lib/python3.10/site-packages/mmengine/model/base_model/base_model.py", line 114, in train_step
losses = self._run_forward(data, mode='loss') # type: ignore
File "/home/wux024/.conda/envs/openmmlab/lib/python3.10/site-packages/mmengine/model/base_model/base_model.py", line 361, in _run_forward
results = self(**data, mode=mode)
File "/home/wux024/.conda/envs/openmmlab/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/wux024/.conda/envs/openmmlab/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/mnt/data1/wux024/mmpose/mmpose/models/pose_estimators/base.py", line 155, in forward
return self.loss(inputs, data_samples)
File "/mnt/data1/wux024/mmpose/mmpose/models/pose_estimators/bottomup.py", line 70, in loss
self.head.loss(feats, data_samples, train_cfg=self.train_cfg))
File "/mnt/data1/wux024/mmpose/mmpose/models/heads/heatmap_heads/cid_head.py", line 641, in loss
gt_heatmaps = torch.stack(gt_heatmaps)
RuntimeError: stack expects each tensor to be equal size, but got [6, 128, 128] at entry 0 and [11, 128, 128] at entry 2

Additional information

No response

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants