-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathdataset.py
249 lines (189 loc) · 7.51 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
import os
import json
from pathlib import Path
import random
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torchvision.transforms.functional import to_tensor, adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation, resize
from torchvision.transforms import RandomResizedCrop, Compose, RandomHorizontalFlip, RandomVerticalFlip, RandomRotation, CenterCrop, Resize, ColorJitter
from einops import rearrange
class RepeatDataset(Dataset):
def __init__(self, dataset, factor=10):
self.dataset = dataset
self.factor = factor
def __getitem__(self, idx):
idx = idx % len(self.dataset)
return self.dataset[idx]
def __len__(self):
return self.factor * len(self.dataset)
class TrainingFramesDataset(Dataset):
def __init__(self, root):
self.root = Path(root)
self.frame_names = sorted(os.listdir(self.root))
def transform_frame(self, frame):
# Crop to H and W multiple of 8 (for RAFT)
# n.b. technically this is only necessary for training
c, h, w = frame.shape
frame = frame[:,:(h//8)*8,:(w//8)*8]
return frame
def __getitem__(self, idx):
fname = self.frame_names[idx]
frame = Image.open(self.root / fname)
frame = to_tensor(frame)
# # Transform tensors
frame = self.transform_frame(frame)
return frame
def __len__(self):
return len(self.frame_names)
class FramesDataset(Dataset):
def __init__(self, root):
self.root = Path(root)
self.frame_names = sorted(os.listdir(self.root))
def transform_frame(self, frame):
# Crop to H and W multiple of 8 (for RAFT)
# n.b. technically this is only necessary for training
c, h, w = frame.shape
frame = frame[:,:(h//8)*8,:(w//8)*8]
return frame
def __getitem__(self, idx):
fname = self.frame_names[idx]
frame = Image.open(self.root / fname)
frame = to_tensor(frame)
return frame
def __len__(self):
return len(self.frame_names)
class TestTimeAdaptDataset(Dataset):
def __init__(self, root, mode='first', length=None):
'''
args:
root: (string) path to directory of frames
mode: ['first', 'random'] how to sample frames
first: always samples first frame + idx^th frame
random: randomly samples two frames
'''
self.root = Path(root)
self.frame_names = sorted(os.listdir(self.root))
self.mode = mode
self.im_size = 512
self.scale = 1.1 # stretching scale
self.geom_transform = Compose([
RandomRotation(5),
])
self.color_transform = ColorJitter(brightness=.5, contrast=.5, saturation=.5, hue=.3)
if length is not None:
self.length = length
else:
self.length = len(self.frame_names)
def transform_frames(self, frames):
c, t, h, w = frames.shape
# Apply geometric transforms on all frames
frames = rearrange(frames, 'c t h w -> (c t) h w')
frames = self.geom_transform(frames)
frames = rearrange(frames, '(c t) h w -> c t h w', c=c, t=t)
# scale
new_h = h * self.scale**np.random.uniform(-1, 1) # rescale for stretch aug
new_w = w * self.scale**np.random.uniform(-1, 1)
new_h = int(new_h)
new_w = int(new_w)
frames = resize(frames, (new_h, new_w))
# Pad out so edges are at least im_size.
to_pad_h = max(self.im_size - new_h, 0)
to_pad_w = max(self.im_size - new_w, 0)
pad_l_h = to_pad_h // 2
pad_r_h = to_pad_h - pad_l_h
pad_l_w = to_pad_w // 2
pad_r_w = to_pad_w - pad_l_w
frames = F.pad(frames, (pad_l_w, pad_r_w, pad_l_h, pad_r_h))
_, _, padded_h, padded_w = frames.shape
# crop to (im_size, im_size)
ch = self.im_size
cw = self.im_size
ct = np.random.randint(padded_h-ch+1)
cl = np.random.randint(padded_w-cw+1)
frames = frames[:,:,ct:ct+ch,cl:cl+cw]
# Apply color transforms (must have c=3, treating t as batch)
frames = rearrange(frames, 'c t h w -> t c h w')
frames = self.color_transform(frames)
frames = rearrange(frames, 't c h w -> c t h w')
return frames
def __getitem__(self, idx):
idx = idx % len(self.frame_names)
if self.mode == 'first':
idx0 = 0
idx1 = idx
elif self.mode == 'random':
idx0 = np.random.randint(self.length)
idx1 = np.random.randint(self.length)
else:
raise NotImplementedError
frame0 = Image.open(self.root / self.frame_names[idx0])
frame1 = Image.open(self.root / self.frame_names[idx1])
frame0 = to_tensor(frame0)
frame1 = to_tensor(frame1)
frames = torch.stack([frame0, frame1], dim=1)
# Transform tensors
frames = self.transform_frames(frames)
return frames
def __len__(self):
return self.length
class FlowMagDataset(Dataset):
def __init__(self, data_root, split, aug=False, img_size=256):
self.data_root = Path(data_root)
self.split = split
self.img_size = img_size
# There is not valid set, only a test set
if self.split == 'valid':
self.split = 'test'
# Get frame metadata
self.frameA_dir = self.data_root / self.split / 'frameA'
self.frameB_dir = self.data_root / self.split / 'frameB'
with open(self.data_root / f'{self.split}_fn.json', 'r') as f:
self.fnames = json.load(f)
# Make augmentations
if aug:
self.transform = Compose([
RandomResizedCrop(img_size, scale=(0.7, 1.0)),
RandomHorizontalFlip(.5),
RandomVerticalFlip(.5),
RandomRotation(15),
])
self.color_transform = ColorJitter(brightness=.5, contrast=.5, saturation=.5, hue=.3)
else:
self.transform = Compose([
Resize(img_size),
CenterCrop(img_size),
])
self.color_transform = nn.Identity()
def transform_frames(self, frames):
c, t, h, w = frames.shape
# Apply geometric transforms on all frames
frames = rearrange(frames, 'c t h w -> (c t) h w')
frames = self.transform(frames)
frames = rearrange(frames, '(c t) h w -> c t h w', c=c, t=t)
# Apply color transforms (must have c=3, treating t as batch)
frames = rearrange(frames, 'c t h w -> t c h w')
frames = self.color_transform(frames)
frames = rearrange(frames, 't c h w -> c t h w')
return frames
def __getitem__(self, idx):
# Load both frameA and frameB
image_paths = [self.frameA_dir / self.fnames[idx], self.frameB_dir / self.fnames[idx]]
images = [Image.open(path) for path in image_paths]
images = [to_tensor(im) for im in images]
frames = torch.stack(images, dim=1)
frames = self.transform_frames(frames)
info = {'fname': self.fnames[idx]}
return frames, info
def __len__(self):
return len(self.fnames)
def get_dataloader(config, split):
if split == 'train':
aug = config.data.aug
else:
aug = False
dataset = FlowMagDataset(config.data.dataroot, split=split, aug=aug, img_size=config.data.im_size)
return dataset