-
Notifications
You must be signed in to change notification settings - Fork 1
/
handler.py
173 lines (145 loc) · 5.23 KB
/
handler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
# this import statement is needed if you want to use the AWS Lambda Layer called "pytorch-v1-py36"
# it unzips all of the pytorch & dependency packages when the script is loaded to avoid the 250 MB unpacked limit in AWS Lambda
try:
import unzip_requirements
except ImportError:
pass
import os
import io
import json
import tarfile
import glob
import time
import logging
import boto3
import requests
import PIL
import torch
import torch.nn.functional as F
from torchvision import models, transforms
# load the S3 client when lambda execution context is created
s3 = boto3.client('s3')
# classes for the image classification
classes = []
logger = logging.getLogger()
logger.setLevel(logging.INFO)
# get bucket name from ENV variable
MODEL_BUCKET=os.environ.get('MODEL_BUCKET')
logger.info(f'Model Bucket is {MODEL_BUCKET}')
# get bucket prefix from ENV variable
MODEL_KEY=os.environ.get('MODEL_KEY')
logger.info(f'Model Prefix is {MODEL_KEY}')
# processing pipeline to resize, normalize and create tensor object
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
def load_model():
"""Loads the PyTorch model into memory from a file on S3.
Returns
------
Vision model: Module
Returns the vision PyTorch model to use for inference.
"""
global classes
logger.info('Loading model from S3')
model_dir = '/tmp/model'
local_model=f'{model_dir}/model.tar.gz'
# download the model tar.gz file from S3 and extract
logger.info(f'Downloading model from S3 to {local_model}')
if not os.path.exists(model_dir):
os.mkdir(model_dir)
s3.download_file(
MODEL_BUCKET, MODEL_KEY, local_model)
logger.info('Extracting model tarfile')
tarfile.open(local_model).extractall(model_dir)
os.remove(local_model)
logger.info('Getting classes from file')
# get the classes from saved 'classes.txt' file
with open(f'{model_dir}/classes.txt', 'r') as f:
classes = f.read().splitlines()
logger.info(f'Classes are {classes}')
model_path = glob.glob(f'{model_dir}/*_jit.pth')[0]
logger.info(f'Model path is {model_path}')
model = torch.jit.load(model_path, map_location=torch.device('cpu'))
return model.eval()
# load the model when lambda execution context is created
model = load_model()
def predict(input_object, model):
"""Predicts the class from an input image.
Parameters
----------
input_object: Tensor, required
The tensor object containing the image pixels reshaped and normalized.
Returns
------
Response object: dict
Returns the predicted class and confidence score.
"""
logger.info("Calling prediction on model")
start_time = time.time()
predict_values = model(input_object)
logger.info("--- Inference time: %s seconds ---" % (time.time() - start_time))
preds = F.softmax(predict_values, dim=1)
conf_score, indx = torch.max(preds, dim=1)
predict_class = classes[indx]
logger.info(f'Predicted class is {predict_class}')
logger.info(f'Softmax confidence score is {conf_score.item()}')
response = {}
response['class'] = str(predict_class)
response['confidence'] = conf_score.item()
return response
def input_fn(request_body):
"""Pre-processes the input data from JSON to PyTorch Tensor.
Parameters
----------
request_body: dict, required
The request body submitted by the client. Expect an entry 'url' containing a URL of an image to classify.
Returns
------
PyTorch Tensor object: Tensor
"""
logger.info("Getting input URL to a image Tensor object")
if isinstance(request_body, str):
request_body = json.loads(request_body)
logger.info(request_body)
img_request = requests.get(request_body['url'], stream=True)
img = PIL.Image.open(io.BytesIO(img_request.content))
img_tensor = preprocess(img)
img_tensor = img_tensor.unsqueeze(0)
return img_tensor
def lambda_handler(event, context):
"""Lambda handler function
Parameters
----------
event: dict, required
API Gateway Lambda Proxy Input Format
Event doc: https://docs.aws.amazon.com/apigateway/latest/developerguide/set-up-lambda-proxy-integrations.html#api-gateway-simple-proxy-for-lambda-input-format
context: object, required
Lambda Context runtime methods and attributes
Context doc: https://docs.aws.amazon.com/lambda/latest/dg/python-context-object.html
Returns
------
API Gateway Lambda Proxy Output Format: dict
Return doc: https://docs.aws.amazon.com/apigateway/latest/developerguide/set-up-lambda-proxy-integrations.html
"""
print("Starting event")
logger.info(event)
print("Getting input object")
input_object = input_fn(event['body'])
print("Calling prediction")
response = predict(input_object, model)
print("Returning response")
return {
"statusCode": 200,
"headers": {
'Content-Type': 'application/json',
'Access-Control-Allow-Origin': '*'
},
"body": json.dumps(response)
}