-
Notifications
You must be signed in to change notification settings - Fork 65
/
Copy pathdataloader_test.py
31 lines (26 loc) · 1.17 KB
/
dataloader_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import glob
import os
import time
import cv2
from torch.utils.data import DataLoader
from data_loader import create_training_datasets
import numpy as np
if __name__ == '__main__':
data_dir = '../../dataset/anime-seg/'
tra_fg_dir = 'fg/'
tra_bg_dir = 'bg/'
tra_img_dir = 'imgs/'
tra_mask_dir = 'masks/'
fg_ext = '.png'
bg_ext = '.*'
img_ext = '.jpg'
mask_ext = '.jpg'
train_dataset, val_dataset = create_training_datasets(data_dir, tra_fg_dir, tra_bg_dir, tra_img_dir, tra_mask_dir,
fg_ext, bg_ext, img_ext, mask_ext, 0.95, 640, True)
salobj_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=2, persistent_workers=True)
for data in salobj_dataloader:
cv2.imshow("a", np.concatenate([data['image'][0].permute(1, 2, 0).numpy()[:, :, ::-1],
cv2.cvtColor(data['label'][0].permute(1, 2, 0).numpy(), cv2.COLOR_GRAY2RGB),
cv2.cvtColor(data['trimap'][0].permute(1, 2, 0).numpy(), cv2.COLOR_GRAY2RGB)],
axis=1))
cv2.waitKey(1000)