This repository has been archived by the owner on Jan 6, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 75
/
Copy pathdata_loaders.py
62 lines (56 loc) · 2.2 KB
/
data_loaders.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
import pandas as pd
import numpy as np
from PIL import Image
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
class Plain_Dataset(Dataset):
def __init__(self,csv_file,img_dir,datatype,transform):
'''
Pytorch Dataset class
params:-
csv_file : the path of the csv file (train, validation, test)
img_dir : the directory of the images (train, validation, test)
datatype : string for searching along the image_dir (train, val, test)
transform: pytorch transformation over the data
return :-
image, labels
'''
self.csv_file = pd.read_csv(csv_file)
self.lables = self.csv_file['emotion']
self.img_dir = img_dir
self.transform = transform
self.datatype = datatype
def __len__(self):
return len(self.csv_file)
def __getitem__(self,idx):
if torch.is_tensor(idx):
idx = idx.tolist()
img = Image.open(self.img_dir+self.datatype+str(idx)+'.jpg')
lables = np.array(self.lables[idx])
lables = torch.from_numpy(lables).long()
if self.transform :
img = self.transform(img)
return img,lables
#Helper function
def eval_data_dataloader(csv_file,img_dir,datatype,sample_number,transform= None):
'''
Helper function used to evaluate the Dataset class
params:-
csv_file : the path of the csv file (train, validation, test)
img_dir : the directory of the images (train, validation, test)
datatype : string for searching along the image_dir (train, val, test)
sample_number : any number from the data to be shown
'''
if transform is None :
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))])
dataset = Plain_Dataset(csv_file=csv_file,img_dir = img_dir,datatype = datatype,transform = transform)
label = dataset.__getitem__(sample_number)[1]
print(label)
imgg = dataset.__getitem__(sample_number)[0]
imgnumpy = imgg.numpy()
imgt = imgnumpy.squeeze()
plt.imshow(imgt)
plt.show()