-
Notifications
You must be signed in to change notification settings - Fork 4
/
dataset.py
103 lines (86 loc) · 3.67 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
from PIL import Image
import scipy.io
import numpy as np
import glob
class NYUDDataset(Dataset):
"""
The dataset is downloaded from http://dl.caffe.berkeleyvision.org/nyud.tar.gz
"""
def __init__(self, img_paths, seg_paths, depth_paths, transform=None):
super().__init__()
self.img_paths = img_paths
self.seg_paths = seg_paths
self.depth_paths = depth_paths
self.transform = transform
self.mask_names = ("depth", "segm")
def __len__(self):
return len(self.img_paths)
def __getitem__(self, idx):
sample = {"image": np.array(Image.open(self.img_paths[idx])),
"segm": np.array(scipy.io.loadmat(self.seg_paths[idx])["segmentation"]),
"depth": np.array(Image.open(self.depth_paths[idx])),
"names":self.mask_names}
if self.transform:
sample = self.transform(sample)
# if "names" in sample:
# del sample["names"]
return sample
class CityscapesDataset(Dataset):
"""
The dataset should be downloaded from
1. https://www.cityscapes-dataset.com/file-handling/?packageID=1
2. https://www.cityscapes-dataset.com/file-handling/?packageID=3
3. https://www.cityscapes-dataset.com/file-handling/?packageID=7
and placed in a directory named cityscapes
"""
def __init__(self, img_paths, seg_paths, ins_paths, depth_paths, transform=None):
super().__init__()
self.img_paths = img_paths
self.seg_paths = seg_paths
self.ins_paths = ins_paths
self.depth_paths = depth_paths
self.transform = transform
self.mask_names = ("depth", "segm", "ins")
def __len__(self):
return len(self.img_paths)
def __getitem__(self, idx):
disparity = np.array(Image.open(self.depth_paths[idx])).astype(np.float32)
disparity[disparity > 0] = (disparity[disparity > 0] - 1)/256.
disparity[disparity > 0] = (0.209313*2262.52)/disparity[disparity > 0]
sample = {"image": np.array(Image.open(self.img_paths[idx])),
"segm": np.array(Image.open(self.seg_paths[idx])),
"ins": np.array(Image.open(self.ins_paths[idx])),
"depth": disparity,
"names":self.mask_names}
if self.transform:
sample = self.transform(sample)
# if "names" in sample:
# del sample["names"]
return sample
if __name__ == "__main__":
# For NYUD dataset
# img_paths = sorted(glob.glob("./nyud/data/images/*"))
# seg_paths = sorted(glob.glob("./nyud/segmentation/*"))
# depth_paths = sorted(glob.glob("./nyud/data/depth/*"))
# dataset = NYUDDataset(img_paths, seg_paths, depth_paths)
# sample = dataset[5]
# f, ax = plt.subplots(1,3)
# ax[0].imshow(sample["image"])
# ax[1].imshow(sample["segm"])
# ax[2].imshow(sample["depth"])
# plt.show()
# For Cityscapes dataset
img_paths = sorted(glob.glob("./cityscapes/leftImg8bit_trainvaltest/leftImg8bit/train/*/*"))
seg_paths = sorted(glob.glob("./cityscapes/gtFine_trainvaltest/gtFine/train/*/*labelIds.png"))
ins_paths = sorted(glob.glob("./cityscapes/gtFine_trainvaltest/gtFine/train/*/*instanceIds.png"))
depth_paths = sorted(glob.glob("./cityscapes/disparity_trainvaltest/disparity/train/*/*"))
dataset = CityscapesDataset(img_paths, seg_paths, ins_paths, depth_paths)
sample = dataset[0]
f, ax = plt.subplots(1,4)
ax[0].imshow(sample["image"])
ax[1].imshow(sample["segm"])
ax[2].imshow(sample["ins"])
ax[3].imshow(sample["depth"])
plt.show()