Skip to content

Commit

Permalink
support different color spaces
Browse files Browse the repository at this point in the history
  • Loading branch information
elxy authored and ernestcao committed Oct 13, 2022
1 parent ce5c55a commit 4013631
Show file tree
Hide file tree
Showing 16 changed files with 167 additions and 61 deletions.
4 changes: 2 additions & 2 deletions basicsr/data/data_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from basicsr.utils import img2tensor, scandir


def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False):
def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False, color_space='rgb'):
"""Read a sequence of images from a given folder path.
Args:
Expand All @@ -30,7 +30,7 @@ def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False):

if require_mod_crop:
imgs = [mod_crop(img, scale) for img in imgs]
imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
imgs = img2tensor(imgs, color_space=color_space, float32=True)
imgs = torch.stack(imgs, dim=0)

if return_imgname:
Expand Down
6 changes: 4 additions & 2 deletions basicsr/data/ffhq_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torchvision.transforms.functional import normalize

from basicsr.data.transforms import augment
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
from basicsr.utils import ColorSpace, FileClient, get_root_logger, imfrombytes, img2tensor
from basicsr.utils.registry import DATASET_REGISTRY


Expand Down Expand Up @@ -70,8 +70,10 @@ def __getitem__(self, index):

# random horizontal flip
img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False)
# color space transform
color_space = ColorSpace.RGB if 'color' not in self.opt else self.opt['color']
# BGR to RGB, HWC to CHW, numpy to tensor
img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True)
img_gt = img2tensor(img_gt, color_space=color_space, float32=True)
# normalize
normalize(img_gt, self.mean, self.std, inplace=True)
return {'gt': img_gt, 'gt_path': gt_path}
Expand Down
12 changes: 5 additions & 7 deletions basicsr/data/paired_image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file
from basicsr.data.transforms import augment, paired_random_crop
from basicsr.utils import FileClient, bgr2ycbcr, imfrombytes, img2tensor
from basicsr.utils import ColorSpace, FileClient, imfrombytes, img2tensor
from basicsr.utils.registry import DATASET_REGISTRY


Expand Down Expand Up @@ -83,18 +83,16 @@ def __getitem__(self, index):
# flip, rotation
img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot'])

# color space transform
if 'color' in self.opt and self.opt['color'] == 'y':
img_gt = bgr2ycbcr(img_gt, y_only=True)[..., None]
img_lq = bgr2ycbcr(img_lq, y_only=True)[..., None]

# crop the unmatched GT images during validation or testing, especially for SR benchmark datasets
# TODO: It is better to update the datasets, rather than force to crop
if self.opt['phase'] != 'train':
img_gt = img_gt[0:img_lq.shape[0] * scale, 0:img_lq.shape[1] * scale, :]

# color space transform
color_space = ColorSpace.RGB if 'color' not in self.opt else self.opt['color']

# BGR to RGB, HWC to CHW, numpy to tensor
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
img_gt, img_lq = img2tensor([img_gt, img_lq], color_space=color_space, float32=True)
# normalize
if self.mean is not None or self.std is not None:
normalize(img_lq, self.mean, self.std, inplace=True)
Expand Down
6 changes: 4 additions & 2 deletions basicsr/data/realesrgan_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
from basicsr.data.transforms import augment
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
from basicsr.utils import ColorSpace, FileClient, get_root_logger, imfrombytes, img2tensor
from basicsr.utils.registry import DATASET_REGISTRY


Expand Down Expand Up @@ -181,8 +181,10 @@ def __getitem__(self, index):
else:
sinc_kernel = self.pulse_tensor

# color space transform
color_space = ColorSpace.RGB if 'color' not in self.opt else self.opt['color']
# BGR to RGB, HWC to CHW, numpy to tensor
img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0]
img_gt = img2tensor([img_gt], color_space=color_space, float32=True)[0]
kernel = torch.FloatTensor(kernel)
kernel2 = torch.FloatTensor(kernel2)

Expand Down
7 changes: 5 additions & 2 deletions basicsr/data/realesrgan_paired_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb
from basicsr.data.transforms import augment, paired_random_crop
from basicsr.utils import FileClient, imfrombytes, img2tensor
from basicsr.utils import ColorSpace, FileClient, imfrombytes, img2tensor
from basicsr.utils.registry import DATASET_REGISTRY


Expand Down Expand Up @@ -93,8 +93,11 @@ def __getitem__(self, index):
# flip, rotation
img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot'])

# color space transform
color_space = ColorSpace.RGB if 'color' not in self.opt else self.opt['color']

# BGR to RGB, HWC to CHW, numpy to tensor
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
img_gt, img_lq = img2tensor([img_gt, img_lq], color_space=color_space, float32=True)
# normalize
if self.mean is not None or self.std is not None:
normalize(img_lq, self.mean, self.std, inplace=True)
Expand Down
16 changes: 12 additions & 4 deletions basicsr/data/reds_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch.utils import data as data

from basicsr.data.transforms import augment, paired_random_crop
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
from basicsr.utils import ColorSpace, FileClient, get_root_logger, imfrombytes, img2tensor
from basicsr.utils.flow_util import dequantize_flow
from basicsr.utils.registry import DATASET_REGISTRY

Expand Down Expand Up @@ -182,12 +182,16 @@ def __getitem__(self, index):
else:
img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])

img_results = img2tensor(img_results)
# color space transform
color_space = ColorSpace.RGB if 'color' not in self.opt else self.opt['color']

# BGR to RGB, HWC to CHW, numpy to tensor
img_results = img2tensor(img_results, color_space=color_space)
img_lqs = torch.stack(img_results[0:-1], dim=0)
img_gt = img_results[-1]

if self.flow_root is not None:
img_flows = img2tensor(img_flows)
img_flows = img2tensor(img_flows, color_space=ColorSpace.RAW)
# add the zero center flow
img_flows.insert(self.num_half_frames, torch.zeros_like(img_flows[0]))
img_flows = torch.stack(img_flows, dim=0)
Expand Down Expand Up @@ -339,7 +343,11 @@ def __getitem__(self, index):
img_lqs.extend(img_gts)
img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])

img_results = img2tensor(img_results)
# color space transform
color_space = ColorSpace.RGB if 'color' not in self.opt else self.opt['color']

# BGR to RGB, HWC to CHW, numpy to tensor
img_results = img2tensor(img_results, color_space=color_space)
img_gts = torch.stack(img_results[len(img_lqs) // 2:], dim=0)
img_lqs = torch.stack(img_results[:len(img_lqs) // 2], dim=0)

Expand Down
7 changes: 3 additions & 4 deletions basicsr/data/single_image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torchvision.transforms.functional import normalize

from basicsr.data.data_util import paths_from_lmdb
from basicsr.utils import FileClient, imfrombytes, img2tensor, rgb2ycbcr, scandir
from basicsr.utils import ColorSpace, FileClient, imfrombytes, img2tensor, scandir
from basicsr.utils.registry import DATASET_REGISTRY


Expand Down Expand Up @@ -54,11 +54,10 @@ def __getitem__(self, index):
img_lq = imfrombytes(img_bytes, float32=True)

# color space transform
if 'color' in self.opt and self.opt['color'] == 'y':
img_lq = rgb2ycbcr(img_lq, y_only=True)[..., None]
color_space = ColorSpace.RGB if 'color' not in self.opt else self.opt['color']

# BGR to RGB, HWC to CHW, numpy to tensor
img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True)
img_lq = img2tensor(img_lq, color_space=color_space, float32=True)
# normalize
if self.mean is not None or self.std is not None:
normalize(img_lq, self.mean, self.std, inplace=True)
Expand Down
14 changes: 11 additions & 3 deletions basicsr/data/vimeo90k_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch.utils import data as data

from basicsr.data.transforms import augment, paired_random_crop
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
from basicsr.utils import ColorSpace, FileClient, get_root_logger, imfrombytes, img2tensor
from basicsr.utils.registry import DATASET_REGISTRY


Expand Down Expand Up @@ -120,7 +120,11 @@ def __getitem__(self, index):
img_lqs.append(img_gt)
img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])

img_results = img2tensor(img_results)
# color space transform
color_space = ColorSpace.RGB if 'color' not in self.opt else self.opt['color']

# BGR to RGB, HWC to CHW, numpy to tensor
img_results = img2tensor(img_results, color_space=color_space)
img_lqs = torch.stack(img_results[0:-1], dim=0)
img_gt = img_results[-1]

Expand Down Expand Up @@ -182,7 +186,11 @@ def __getitem__(self, index):
img_lqs.extend(img_gts)
img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])

img_results = img2tensor(img_results)
# color space transform
color_space = ColorSpace.RGB if 'color' not in self.opt else self.opt['color']

# BGR to RGB, HWC to CHW, numpy to tensor
img_results = img2tensor(img_results, color_space=color_space)
img_lqs = torch.stack(img_results[:7], dim=0)
img_gts = torch.stack(img_results[7:], dim=0)

Expand Down
6 changes: 3 additions & 3 deletions basicsr/metrics/test_metrics/test_psnr_ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from basicsr.metrics import calculate_psnr, calculate_ssim
from basicsr.metrics.psnr_ssim import calculate_psnr_pt, calculate_ssim_pt
from basicsr.utils import img2tensor
from basicsr.utils import ColorSpace, img2tensor


def test(img_path, img_path2, crop_border, test_y_channel=False):
Expand All @@ -16,8 +16,8 @@ def test(img_path, img_path2, crop_border, test_y_channel=False):
print(f'\tNumpy\tPSNR: {psnr:.6f} dB, \tSSIM: {ssim:.6f}')

# --------------------- PyTorch (CPU) ---------------------
img = img2tensor(img / 255., bgr2rgb=True, float32=True).unsqueeze_(0)
img2 = img2tensor(img2 / 255., bgr2rgb=True, float32=True).unsqueeze_(0)
img = img2tensor(img / 255., color_space=ColorSpace.RGB, float32=True).unsqueeze_(0)
img2 = img2tensor(img2 / 255., color_space=ColorSpace.RGB, float32=True).unsqueeze_(0)

psnr_pth = calculate_psnr_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel)
ssim_pth = calculate_ssim_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel)
Expand Down
5 changes: 3 additions & 2 deletions basicsr/models/sr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def dist_validation(self, dataloader, current_iter, tb_logger, save_img):

def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
dataset_name = dataloader.dataset.opt['name']
color_space = dataloader.dataset.opt['color'] if 'color' in dataloader.dataset.opt else 'rgb'
with_metrics = self.opt['val'].get('metrics') is not None
use_pbar = self.opt['val'].get('pbar', False)

Expand All @@ -205,10 +206,10 @@ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
self.test()

visuals = self.get_current_visuals()
sr_img = tensor2img([visuals['result']])
sr_img = tensor2img([visuals['result']], color_space=color_space)
metric_data['img'] = sr_img
if 'gt' in visuals:
gt_img = tensor2img([visuals['gt']])
gt_img = tensor2img([visuals['gt']], color_space=color_space)
metric_data['img2'] = gt_img
del self.gt

Expand Down
3 changes: 2 additions & 1 deletion basicsr/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .diffjpeg import DiffJPEG
from .file_client import FileClient
from .img_process_util import USMSharp, usm_sharp
from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img
from .img_util import ColorSpace, crop_border, imfrombytes, img2tensor, imwrite, tensor2img
from .logger import AvgTimer, MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger
from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt
from .options import yaml_load
Expand All @@ -17,6 +17,7 @@
# file_client.py
'FileClient',
# img_util.py
'ColorSpace',
'img2tensor',
'tensor2img',
'imfrombytes',
Expand Down
4 changes: 4 additions & 0 deletions basicsr/utils/color_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def ycbcr2rgb(img):
"""
img_type = img.dtype
img = _convert_input_type_range(img) * 255
if img.shape[2] == 1: # only y channel
img = np.pad(img, ((0, 0), (0, 0), (0, 2)), 'constant')
out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
[0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] # noqa: E126
out_img = _convert_output_type_range(out_img, img_type)
Expand Down Expand Up @@ -120,6 +122,8 @@ def ycbcr2bgr(img):
"""
img_type = img.dtype
img = _convert_input_type_range(img) * 255
if img.shape[2] == 1: # only y channel
img = np.pad(img, ((0, 0), (0, 0), (0, 2)), 'constant')
out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0.00791071, -0.00153632, 0],
[0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921] # noqa: E126
out_img = _convert_output_type_range(out_img, img_type)
Expand Down
Loading

0 comments on commit 4013631

Please sign in to comment.