forked from hhj1897/face_detection
-
Notifications
You must be signed in to change notification settings - Fork 3
/
face_detection_test.py
170 lines (155 loc) · 8.24 KB
/
face_detection_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import os
import cv2
import time
import torch
from argparse import ArgumentParser
from ibug.face_detection import RetinaFacePredictor, S3FDPredictor
from ibug.face_detection.utils import SimpleFaceTracker, HeadPoseEstimator
def main() -> None:
# Parse command-line arguments
parser = ArgumentParser()
parser.add_argument('--input', '-i', help='Input video path or webcam index (default=0)', default=0)
parser.add_argument('--output', '-o', help='Output file path', default=None)
parser.add_argument('--fourcc', '-f', help='FourCC of the output video (default=mp4v)',
type=str, default='mp4v')
parser.add_argument('--benchmark', '-b', help='Enable benchmark mode for CUDNN',
action='store_true', default=False)
parser.add_argument('--no-display', '-n', help='No display if processing a video file',
action='store_true', default=False)
parser.add_argument('--threshold', '-t', help='Confidence threshold (default=0.8)',
type=float, default=0.8)
parser.add_argument('--method', '-m', help='Method to use, can be either RatinaFace or S3FD (default=RatinaFace)',
default='retinaface')
parser.add_argument('--weights', '-w',
help='Weights to load, can be either resnet50 or mobilenet0.25 when using RetinaFace',
default=None)
parser.add_argument('--alternative-pth', '-p', help='Alternative pth file to load', default=None)
parser.add_argument('--device', '-d', help='Device to be used by the model (default=cuda:0)',
default='cuda:0')
parser.add_argument('--iou-threshold', '-iou',
help='IOU threshold used by the simple face tracker (default=0.4)',
type=float, default=0.4)
parser.add_argument('--minimum-face-size', '-min',
help='Minimum face size used by the simple face tracker (default=0.0)',
type=float, default=0.0)
parser.add_argument('--head-pose-preference', '-hp',
help='Head pose output preference (default=0)',
type=int, default=0)
args = parser.parse_args()
# Set benchmark mode flag for CUDNN
torch.backends.cudnn.benchmark = args.benchmark
vid = None
out_vid = None
has_window = False
try:
# Create the face detector
args.method = args.method.lower().strip()
if args.method == 'retinaface':
face_detector_class = (RetinaFacePredictor, 'RetinaFace')
elif args.method == 's3fd':
face_detector_class = (S3FDPredictor, 'S3FD')
else:
raise ValueError('method must be set to either RetinaFace or S3FD')
if args.weights is None:
fd_model = face_detector_class[0].get_model()
else:
fd_model = face_detector_class[0].get_model(args.weights)
if args.alternative_pth is not None:
fd_model.weights = args.alternative_pth
face_detector = face_detector_class[0](threshold=args.threshold, device=args.device, model=fd_model)
print(f"Face detector created using {face_detector_class[1]} ({fd_model.weights}).")
# Create the simple face tracker
face_tracker = SimpleFaceTracker(iou_threshold=args.iou_threshold,
minimum_face_size=args.minimum_face_size)
print('Simple face tracker created.')
# Create the head pose estimator
head_pose_estimator = HeadPoseEstimator()
print('Head pose estimator created.')
# Open the input video
using_webcam = not os.path.exists(args.input)
vid = cv2.VideoCapture(int(args.input) if using_webcam else args.input)
assert vid.isOpened()
if using_webcam:
print(f'Webcam #{int(args.input)} opened.')
else:
print(f'Input video "{args.input}" opened.')
# Open the output video (if a path is given)
if args.output is not None:
out_vid = cv2.VideoWriter(args.output, fps=vid.get(cv2.CAP_PROP_FPS),
frameSize=(int(vid.get(cv2.CAP_PROP_FRAME_WIDTH)),
int(vid.get(cv2.CAP_PROP_FRAME_HEIGHT))),
fourcc=cv2.VideoWriter_fourcc(*args.fourcc))
assert out_vid.isOpened()
# Process the frames
frame_number = 0
window_title = os.path.splitext(os.path.basename(__file__))[0]
colours = [(0, 0, 255), (0, 255, 0), (255, 0, 0), (0, 255, 255), (255, 0, 255), (255, 255, 0),
(0, 128, 255), (128, 255, 0), (255, 0, 128), (128, 0, 255), (0, 255, 128), (255, 128, 0)]
print('Processing started, press \'Q\' to quit or \'R\' to reset the tracker.')
while True:
# Get a new frame
_, frame = vid.read()
if frame is None:
break
else:
# Detect and track faces, also estimate head pose if landmarks are available
start_time = time.time()
faces = face_detector(frame, rgb=False)
tids = face_tracker(faces)
if faces.shape[1] >= 15:
head_poses = [head_pose_estimator(face[5:15].reshape((-1, 2)), *frame.shape[1::-1],
output_preference=args.head_pose_preference)
for face in faces]
else:
head_poses = [None] * faces.shape[0]
elapsed_time = time.time() - start_time
# Textural output
print(f'Frame #{frame_number} processed in {elapsed_time * 1000.0:.04f} ms: ' +
f'{len(faces)} faces detected.')
# Rendering
for face, tid, head_pose in zip(faces, tids, head_poses):
bbox = face[:4].astype(int)
if tid is None:
colour = (128, 128, 128)
else:
colour = colours[(tid - 1) % len(colours)]
cv2.rectangle(frame, (bbox[0], bbox[1]), (bbox[2], bbox[3]), color=colour, thickness=2)
if len(face) > 5:
for pts in face[5:].reshape((-1, 2)):
cv2.circle(frame, tuple(pts.astype(int).tolist()), 3, colour, -1)
if tid is not None:
cv2.putText(frame, f'Face {tid}', (bbox[0], bbox[1] - 10),
cv2.FONT_HERSHEY_DUPLEX, 0.6, colour, lineType=cv2.LINE_AA)
if head_pose is not None:
pitch, yaw, roll = head_pose
cv2.putText(frame, f'Pitch: {pitch:.1f}', (bbox[2] + 5, bbox[1] + 10),
cv2.FONT_HERSHEY_DUPLEX, 0.5, colour, lineType=cv2.LINE_AA)
cv2.putText(frame, f'Yaw: {yaw:.1f}', (bbox[2] + 5, bbox[1] + 30),
cv2.FONT_HERSHEY_DUPLEX, 0.5, colour, lineType=cv2.LINE_AA)
cv2.putText(frame, f'Roll: {roll:.1f}', (bbox[2] + 5, bbox[1] + 50),
cv2.FONT_HERSHEY_DUPLEX, 0.5, colour, lineType=cv2.LINE_AA)
# Write the frame to output video (if recording)
if out_vid is not None:
out_vid.write(frame)
# Display the frame
if using_webcam or not args.no_display:
has_window = True
cv2.imshow(window_title, frame)
key = cv2.waitKey(1) % 2 ** 16
if key == ord('q') or key == ord('Q'):
print('\'Q\' pressed, we are done here.')
break
elif key == ord('r') or key == ord('R'):
print('\'R\' pressed, reset the tracker.')
face_tracker.reset()
frame_number += 1
finally:
if has_window:
cv2.destroyAllWindows()
if out_vid is not None:
out_vid.release()
if vid is not None:
vid.release()
print('All done.')
if __name__ == '__main__':
main()