-
Notifications
You must be signed in to change notification settings - Fork 1
/
dataset.py
73 lines (59 loc) · 2.66 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
import cv2
import os
import numpy as np
from options import opt
import torchvision
import torchvision.transforms.functional as F
import numbers
import random
from PIL import Image
import glob
class ToTensor(object):
def __call__(self, sample):
hazy_image, clean_image = sample['hazy'], sample['clean']
hazy_image = torch.from_numpy(np.array(hazy_image).astype(np.float32))
hazy_image = torch.transpose(torch.transpose(hazy_image, 2, 0), 1, 2)
clean_image = torch.from_numpy(np.array(clean_image).astype(np.float32))
clean_image = torch.transpose(torch.transpose(clean_image, 2, 0), 1, 2)
return {'hazy': hazy_image,
'clean': clean_image}
class Dataset_Load(Dataset):
def __init__(self, data_path, transform=None):
self.data_path = data_path
self.filesA, self.filesB = self.get_file_paths(self.data_path, 'UIEB')
self.len = min(len(self.filesA), len(self.filesB))
self.transform = transform
def __len__(self):
return self.len
def __getitem__(self, index):
hazy_im = cv2.resize(cv2.imread(self.filesA[index % self.len]), (256,256),
interpolation=cv2.INTER_AREA)
hazy_im = hazy_im[:, :, ::-1] ## BGR to RGB
hazy_im = np.float32(hazy_im) / 255.0
clean_im = cv2.resize(cv2.imread(self.filesB[index % self.len]), (256,256),
interpolation=cv2.INTER_AREA)
clean_im = clean_im[:, :, ::-1] ## BGR to RGB
clean_im = np.float32(clean_im) / 255.0
sample = {'hazy': hazy_im,
'clean': clean_im}
if self.transform != None:
sample = self.transform(sample)
return sample
def get_file_paths(self, root, dataset_name):
if dataset_name=='EUVP':
filesA, filesB = [], []
sub_dirs = ['underwater_imagenet', 'underwater_dark', 'underwater_scenes']
for sd in sub_dirs:
filesA += sorted(glob.glob(os.path.join(root, sd, 'trainA') + "/*.*"))
filesB += sorted(glob.glob(os.path.join(root, sd, 'trainB') + "/*.*"))
elif dataset_name=='SUIM':
filesA = sorted(glob.glob(os.path.join(root, 'inp') + "/*.*"))
filesB = sorted(glob.glob(os.path.join(root, 'gt') + "/*.*"))
elif dataset_name=='UIEB':
filesA = sorted(glob.glob(os.path.join(root, 'inp') + "/*.*"))
filesB = sorted(glob.glob(os.path.join(root, 'gt') + "/*.*"))
return filesA, filesB