Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support different color spaces #574

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions basicsr/data/data_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from basicsr.utils import img2tensor, scandir


def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False):
def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False, color_space='rgb'):
"""Read a sequence of images from a given folder path.

Args:
Expand All @@ -30,7 +30,7 @@ def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False):

if require_mod_crop:
imgs = [mod_crop(img, scale) for img in imgs]
imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
imgs = img2tensor(imgs, color_space=color_space, float32=True)
imgs = torch.stack(imgs, dim=0)

if return_imgname:
Expand Down
6 changes: 4 additions & 2 deletions basicsr/data/ffhq_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torchvision.transforms.functional import normalize

from basicsr.data.transforms import augment
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
from basicsr.utils import ColorSpace, FileClient, get_root_logger, imfrombytes, img2tensor
from basicsr.utils.registry import DATASET_REGISTRY


Expand Down Expand Up @@ -70,8 +70,10 @@ def __getitem__(self, index):

# random horizontal flip
img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False)
# color space transform
color_space = ColorSpace.RGB if 'color' not in self.opt else self.opt['color']
# BGR to RGB, HWC to CHW, numpy to tensor
img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True)
img_gt = img2tensor(img_gt, color_space=color_space, float32=True)
# normalize
normalize(img_gt, self.mean, self.std, inplace=True)
return {'gt': img_gt, 'gt_path': gt_path}
Expand Down
12 changes: 5 additions & 7 deletions basicsr/data/paired_image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file
from basicsr.data.transforms import augment, paired_random_crop
from basicsr.utils import FileClient, bgr2ycbcr, imfrombytes, img2tensor
from basicsr.utils import ColorSpace, FileClient, imfrombytes, img2tensor
from basicsr.utils.registry import DATASET_REGISTRY


Expand Down Expand Up @@ -83,18 +83,16 @@ def __getitem__(self, index):
# flip, rotation
img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot'])

# color space transform
if 'color' in self.opt and self.opt['color'] == 'y':
img_gt = bgr2ycbcr(img_gt, y_only=True)[..., None]
img_lq = bgr2ycbcr(img_lq, y_only=True)[..., None]

# crop the unmatched GT images during validation or testing, especially for SR benchmark datasets
# TODO: It is better to update the datasets, rather than force to crop
if self.opt['phase'] != 'train':
img_gt = img_gt[0:img_lq.shape[0] * scale, 0:img_lq.shape[1] * scale, :]

# color space transform
color_space = ColorSpace.RGB if 'color' not in self.opt else self.opt['color']

# BGR to RGB, HWC to CHW, numpy to tensor
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
img_gt, img_lq = img2tensor([img_gt, img_lq], color_space=color_space, float32=True)
# normalize
if self.mean is not None or self.std is not None:
normalize(img_lq, self.mean, self.std, inplace=True)
Expand Down
6 changes: 4 additions & 2 deletions basicsr/data/realesrgan_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
from basicsr.data.transforms import augment
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
from basicsr.utils import ColorSpace, FileClient, get_root_logger, imfrombytes, img2tensor
from basicsr.utils.registry import DATASET_REGISTRY


Expand Down Expand Up @@ -181,8 +181,10 @@ def __getitem__(self, index):
else:
sinc_kernel = self.pulse_tensor

# color space transform
color_space = ColorSpace.RGB if 'color' not in self.opt else self.opt['color']
# BGR to RGB, HWC to CHW, numpy to tensor
img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0]
img_gt = img2tensor([img_gt], color_space=color_space, float32=True)[0]
kernel = torch.FloatTensor(kernel)
kernel2 = torch.FloatTensor(kernel2)

Expand Down
7 changes: 5 additions & 2 deletions basicsr/data/realesrgan_paired_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb
from basicsr.data.transforms import augment, paired_random_crop
from basicsr.utils import FileClient, imfrombytes, img2tensor
from basicsr.utils import ColorSpace, FileClient, imfrombytes, img2tensor
from basicsr.utils.registry import DATASET_REGISTRY


Expand Down Expand Up @@ -93,8 +93,11 @@ def __getitem__(self, index):
# flip, rotation
img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot'])

# color space transform
color_space = ColorSpace.RGB if 'color' not in self.opt else self.opt['color']

# BGR to RGB, HWC to CHW, numpy to tensor
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
img_gt, img_lq = img2tensor([img_gt, img_lq], color_space=color_space, float32=True)
# normalize
if self.mean is not None or self.std is not None:
normalize(img_lq, self.mean, self.std, inplace=True)
Expand Down
16 changes: 12 additions & 4 deletions basicsr/data/reds_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch.utils import data as data

from basicsr.data.transforms import augment, paired_random_crop
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
from basicsr.utils import ColorSpace, FileClient, get_root_logger, imfrombytes, img2tensor
from basicsr.utils.flow_util import dequantize_flow
from basicsr.utils.registry import DATASET_REGISTRY

Expand Down Expand Up @@ -182,12 +182,16 @@ def __getitem__(self, index):
else:
img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])

img_results = img2tensor(img_results)
# color space transform
color_space = ColorSpace.RGB if 'color' not in self.opt else self.opt['color']

# BGR to RGB, HWC to CHW, numpy to tensor
img_results = img2tensor(img_results, color_space=color_space)
img_lqs = torch.stack(img_results[0:-1], dim=0)
img_gt = img_results[-1]

if self.flow_root is not None:
img_flows = img2tensor(img_flows)
img_flows = img2tensor(img_flows, color_space=ColorSpace.RAW)
# add the zero center flow
img_flows.insert(self.num_half_frames, torch.zeros_like(img_flows[0]))
img_flows = torch.stack(img_flows, dim=0)
Expand Down Expand Up @@ -339,7 +343,11 @@ def __getitem__(self, index):
img_lqs.extend(img_gts)
img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])

img_results = img2tensor(img_results)
# color space transform
color_space = ColorSpace.RGB if 'color' not in self.opt else self.opt['color']

# BGR to RGB, HWC to CHW, numpy to tensor
img_results = img2tensor(img_results, color_space=color_space)
img_gts = torch.stack(img_results[len(img_lqs) // 2:], dim=0)
img_lqs = torch.stack(img_results[:len(img_lqs) // 2], dim=0)

Expand Down
7 changes: 3 additions & 4 deletions basicsr/data/single_image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torchvision.transforms.functional import normalize

from basicsr.data.data_util import paths_from_lmdb
from basicsr.utils import FileClient, imfrombytes, img2tensor, rgb2ycbcr, scandir
from basicsr.utils import ColorSpace, FileClient, imfrombytes, img2tensor, scandir
from basicsr.utils.registry import DATASET_REGISTRY


Expand Down Expand Up @@ -54,11 +54,10 @@ def __getitem__(self, index):
img_lq = imfrombytes(img_bytes, float32=True)

# color space transform
if 'color' in self.opt and self.opt['color'] == 'y':
img_lq = rgb2ycbcr(img_lq, y_only=True)[..., None]
color_space = ColorSpace.RGB if 'color' not in self.opt else self.opt['color']

# BGR to RGB, HWC to CHW, numpy to tensor
img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True)
img_lq = img2tensor(img_lq, color_space=color_space, float32=True)
# normalize
if self.mean is not None or self.std is not None:
normalize(img_lq, self.mean, self.std, inplace=True)
Expand Down
14 changes: 11 additions & 3 deletions basicsr/data/vimeo90k_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch.utils import data as data

from basicsr.data.transforms import augment, paired_random_crop
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
from basicsr.utils import ColorSpace, FileClient, get_root_logger, imfrombytes, img2tensor
from basicsr.utils.registry import DATASET_REGISTRY


Expand Down Expand Up @@ -120,7 +120,11 @@ def __getitem__(self, index):
img_lqs.append(img_gt)
img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])

img_results = img2tensor(img_results)
# color space transform
color_space = ColorSpace.RGB if 'color' not in self.opt else self.opt['color']

# BGR to RGB, HWC to CHW, numpy to tensor
img_results = img2tensor(img_results, color_space=color_space)
img_lqs = torch.stack(img_results[0:-1], dim=0)
img_gt = img_results[-1]

Expand Down Expand Up @@ -182,7 +186,11 @@ def __getitem__(self, index):
img_lqs.extend(img_gts)
img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])

img_results = img2tensor(img_results)
# color space transform
color_space = ColorSpace.RGB if 'color' not in self.opt else self.opt['color']

# BGR to RGB, HWC to CHW, numpy to tensor
img_results = img2tensor(img_results, color_space=color_space)
img_lqs = torch.stack(img_results[:7], dim=0)
img_gts = torch.stack(img_results[7:], dim=0)

Expand Down
6 changes: 3 additions & 3 deletions basicsr/metrics/test_metrics/test_psnr_ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from basicsr.metrics import calculate_psnr, calculate_ssim
from basicsr.metrics.psnr_ssim import calculate_psnr_pt, calculate_ssim_pt
from basicsr.utils import img2tensor
from basicsr.utils import ColorSpace, img2tensor


def test(img_path, img_path2, crop_border, test_y_channel=False):
Expand All @@ -16,8 +16,8 @@ def test(img_path, img_path2, crop_border, test_y_channel=False):
print(f'\tNumpy\tPSNR: {psnr:.6f} dB, \tSSIM: {ssim:.6f}')

# --------------------- PyTorch (CPU) ---------------------
img = img2tensor(img / 255., bgr2rgb=True, float32=True).unsqueeze_(0)
img2 = img2tensor(img2 / 255., bgr2rgb=True, float32=True).unsqueeze_(0)
img = img2tensor(img / 255., color_space=ColorSpace.RGB, float32=True).unsqueeze_(0)
img2 = img2tensor(img2 / 255., color_space=ColorSpace.RGB, float32=True).unsqueeze_(0)

psnr_pth = calculate_psnr_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel)
ssim_pth = calculate_ssim_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel)
Expand Down
5 changes: 3 additions & 2 deletions basicsr/models/sr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def dist_validation(self, dataloader, current_iter, tb_logger, save_img):

def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
dataset_name = dataloader.dataset.opt['name']
color_space = dataloader.dataset.opt['color'] if 'color' in dataloader.dataset.opt else 'rgb'
with_metrics = self.opt['val'].get('metrics') is not None
use_pbar = self.opt['val'].get('pbar', False)

Expand All @@ -205,10 +206,10 @@ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
self.test()

visuals = self.get_current_visuals()
sr_img = tensor2img([visuals['result']])
sr_img = tensor2img([visuals['result']], color_space=color_space)
metric_data['img'] = sr_img
if 'gt' in visuals:
gt_img = tensor2img([visuals['gt']])
gt_img = tensor2img([visuals['gt']], color_space=color_space)
metric_data['img2'] = gt_img
del self.gt

Expand Down
3 changes: 2 additions & 1 deletion basicsr/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .diffjpeg import DiffJPEG
from .file_client import FileClient
from .img_process_util import USMSharp, usm_sharp
from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img
from .img_util import ColorSpace, crop_border, imfrombytes, img2tensor, imwrite, tensor2img
from .logger import AvgTimer, MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger
from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt
from .options import yaml_load
Expand All @@ -17,6 +17,7 @@
# file_client.py
'FileClient',
# img_util.py
'ColorSpace',
'img2tensor',
'tensor2img',
'imfrombytes',
Expand Down
4 changes: 4 additions & 0 deletions basicsr/utils/color_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def ycbcr2rgb(img):
"""
img_type = img.dtype
img = _convert_input_type_range(img) * 255
if img.shape[2] == 1: # only y channel
img = np.pad(img, ((0, 0), (0, 0), (0, 2)), 'constant')
out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
[0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] # noqa: E126
out_img = _convert_output_type_range(out_img, img_type)
Expand Down Expand Up @@ -120,6 +122,8 @@ def ycbcr2bgr(img):
"""
img_type = img.dtype
img = _convert_input_type_range(img) * 255
if img.shape[2] == 1: # only y channel
img = np.pad(img, ((0, 0), (0, 0), (0, 2)), 'constant')
out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0.00791071, -0.00153632, 0],
[0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921] # noqa: E126
out_img = _convert_output_type_range(out_img, img_type)
Expand Down
Loading