This repository has been archived by the owner on Oct 26, 2024. It is now read-only.
-
-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathpycoral_image_classification.py
81 lines (65 loc) · 2.42 KB
/
pycoral_image_classification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# based on https://github.com/google-coral/pycoral/blob/master/examples/classify_image.py
from imutils.video import VideoStream, FPS
import argparse
import time
import cv2
from PIL import Image
import numpy as np
from pycoral.adapters import classify
from pycoral.adapters import common
from pycoral.utils.dataset import read_label_file
from pycoral.utils.edgetpu import make_interpreter
def draw_image(image, classes, labels):
image_np = np.asarray(image)
if len(classes) > 0:
c = classes[0]
cv2.putText(image_np, labels.get(c.id, c.id), (10, 35),
cv2.FONT_HERSHEY_SIMPLEX, 0.9, (36, 255, 12), 2)
cv2.imshow('Live Inference', image_np)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
'--model', help='File path of Tflite model.', required=True)
parser.add_argument(
'--labels', help='File path of label file.', required=True)
parser.add_argument('--picamera',
action='store_true',
help="Use PiCamera for image capture",
default=False)
parser.add_argument(
'-t', '--threshold', type=float, default=0.5,
help='Classification score threshold')
args = parser.parse_args()
print('Loading {} with {} labels.'.format(args.model, args.labels))
interpreter = make_interpreter(args.model)
interpreter.allocate_tensors()
labels = read_label_file(args.labels)
size = common.input_size(interpreter)
# Initialize video stream
vs = VideoStream(usePiCamera=args.picamera, resolution=(640, 480)).start()
time.sleep(1)
fps = FPS().start()
while True:
try:
# Read frame from video
screenshot = vs.read()
image = Image.fromarray(screenshot)
image_pred = image.resize(size, Image.ANTIALIAS)
common.set_input(interpreter, image_pred)
interpreter.invoke()
classes = classify.get_classes(interpreter, 1, args.threshold)
draw_image(image, classes, labels)
if(cv2.waitKey(5) & 0xFF == ord('q')):
fps.stop()
break
fps.update()
except KeyboardInterrupt:
fps.stop()
break
print("Elapsed time: " + str(fps.elapsed()))
print("Approx FPS: :" + str(fps.fps()))
cv2.destroyAllWindows()
vs.stop()
time.sleep(2)
if __name__ == '__main__':
main()