-
Notifications
You must be signed in to change notification settings - Fork 0
/
set.py
30 lines (24 loc) · 858 Bytes
/
set.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
from torchvision.datasets import ImageFolder
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy("file_system")
class HistoDataset(ImageFolder):
def __init__(
self,
root,
transform=None,
transform_list=None,
):
super(HistoDataset, self).__init__(root=root, transform=transform)
self.transform_list = transform_list
def __getitem__(self, index):
path, target = self.imgs[index]
img = self.loader(path)
if self.transform_list is not None:
img_transformed = []
for transform in self.transform_list:
img_transformed.append(transform(img.copy()))
img = torch.stack(img_transformed)
elif self.transform is not None:
img = self.transform(img)
return img, target