-
Notifications
You must be signed in to change notification settings - Fork 1
/
Predict_from_model.py
37 lines (29 loc) · 1.17 KB
/
Predict_from_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
print("so far so good")
import torch
from torchvision import transforms
from PIL import Image
import pandas as pd
import numpy as np
class GarbagePredict():
def __init__(self, labels_file='index_to_label.csv'):
self.model = torch.load('ResNet_Transfer_model.pkl', map_location=torch.device('cpu'))
self.model.eval()
self.labels = pd.read_csv(labels_file)
print('Model Loaded')
def predict(self, image):
img_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
])
image = img_transform(image).unsqueeze(dim=0)
pred_tensor = self.model(image)
pred_tensor_np = pred_tensor.detach().numpy()[0]
#plt = sns.barplot(x=self.labels.iloc[:,1], y=pred_tensor_np)
prediction = torch.max(pred_tensor, dim=1)[1].item()
label, suggestions = (self.labels.iloc[prediction, 1], self.labels.iloc[prediction, 2:])
return label, suggestions
if __name__ == '__main__':
test = GarbagePredict()
testimg = Image.open('cardboard10.jpg').convert('RGB')
print(test.predict(testimg)[0])