From b12a85913a067b08bdcd633a40cf7e3529faec4d Mon Sep 17 00:00:00 2001 From: Giovanni Cavallin <37183651+mawanda-jun@users.noreply.github.com> Date: Mon, 30 Sep 2024 12:04:37 +0200 Subject: [PATCH] Enable float32 image dtype in (Cached)MixUp The old method was forcing the `np.uint8` dtype on the output image, thus not permitting the float32 dtype (case of images in (0, 1) interval). --- mmdet/datasets/transforms/transforms.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mmdet/datasets/transforms/transforms.py b/mmdet/datasets/transforms/transforms.py index c50b987db33..acc8bd32583 100644 --- a/mmdet/datasets/transforms/transforms.py +++ b/mmdet/datasets/transforms/transforms.py @@ -2668,7 +2668,7 @@ def transform(self, results: dict) -> dict: target_h, target_w = ori_img.shape[:2] padded_img = np.ones((max(origin_h, target_h), max( origin_w, target_w), 3)) * self.pad_val - padded_img = padded_img.astype(np.uint8) + padded_img = padded_img.astype(retrieve_img.dtype) padded_img[:origin_h, :origin_w] = out_img x_offset, y_offset = 0, 0 @@ -2715,7 +2715,7 @@ def transform(self, results: dict) -> dict: mixup_gt_bboxes_labels = mixup_gt_bboxes_labels[inside_inds] mixup_gt_ignore_flags = mixup_gt_ignore_flags[inside_inds] - results['img'] = mixup_img.astype(np.uint8) + results['img'] = mixup_img.astype(retrieve_img.dtype) results['img_shape'] = mixup_img.shape[:2] results['gt_bboxes'] = mixup_gt_bboxes results['gt_bboxes_labels'] = mixup_gt_bboxes_labels @@ -3765,7 +3765,7 @@ def transform(self, results: dict) -> dict: target_h, target_w = ori_img.shape[:2] padded_img = np.ones((max(origin_h, target_h), max( origin_w, target_w), 3)) * self.pad_val - padded_img = padded_img.astype(np.uint8) + padded_img = padded_img.astype(retrieve_img.dtype) padded_img[:origin_h, :origin_w] = out_img x_offset, y_offset = 0, 0 @@ -3833,7 +3833,7 @@ def transform(self, results: dict) -> dict: if with_mask: mixup_gt_masks = mixup_gt_masks[inside_inds] - results['img'] = mixup_img.astype(np.uint8) + results['img'] = mixup_img.astype(retrieve_img.dtype) results['img_shape'] = mixup_img.shape[:2] results['gt_bboxes'] = mixup_gt_bboxes results['gt_bboxes_labels'] = mixup_gt_bboxes_labels