-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinference.py
More file actions
125 lines (86 loc) · 3.29 KB
/
Copy pathinference.py
File metadata and controls
125 lines (86 loc) · 3.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#Dependencies
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import os
import logging
import sys
import time
import json
import base64
import argparse
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
def input_fn(request_body, request_content_type):
"""
Deserialize and prepare the prediction input
"""
if request_content_type == "application/json":
deserialized_data = json.loads(request_body)
plt.imsave("image.png",deserialized_data['arr'])
data = Image.open("image.png").convert('RGB')
test_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean = [0.485, 0.456, 0.406],
std = [0.229, 0.224, 0.225])
])
train_inputs = test_transform(data)
return train_inputs
def predict_fn(input_data, model):
"""
Apply model to the incoming request
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model['model'].to(device)
input_data = torch.unsqueeze(input_data,0).to(device)
model['model'].eval()
with torch.no_grad():
return {"prediction": model['model'](input_data),"class": model['class']}
def output_fn(prediction_output, response_content_type):
if response_content_type == "application/json":
result = nn.functional.softmax(prediction_output['prediction'],dim=1)
prob = torch.topk(result, 5)[0][0].tolist()
indices = torch.topk(result, 5)[1][0].tolist()
for i in range(len(indices)):
for key, val in prediction_output['class'].items():
if indices[i] == val:
indices[i] = key
temp = {"prob":prob,"indices":indices,"class":prediction_output['class']}
data = {'body': temp}
# Serialize the data using the JSONSerializer
serialized_data = json.dumps(data)
return serialized_data
def net():
'''
This function takes zero parameters and returns a Network
Parameters:
None
Returns:
Untrained Image Classification Model
'''
pretrained_model = models.resnet18(pretrained=True)
# Freezing Pretrained Weights
for param in pretrained_model.parameters():
param.requires_grad = False
# Append Fully_Connected layer
num_ftrs = pretrained_model.fc.in_features
pretrained_model.fc = nn.Linear(num_ftrs, 133)
model_ft = pretrained_model.to(device)
return model_ft
def model_fn(model_dir):
model = net()
with open(os.path.join(model_dir, 'model.pth'), 'rb') as f:
checkpoint = torch.load(f)
model.load_state_dict(checkpoint['model_state_dict'])
class_to_idx = checkpoint['class_to_idx']
return {"model":model,"class":class_to_idx}