Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

[On hold] Add digit recognition example #64

Draft
wants to merge 19 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file not shown.
Binary file not shown.
82 changes: 82 additions & 0 deletions scripts/run_digit_recognition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#!/usr/bin/env python
"""
Real time recognition of digits drawn in the air by hand.

Usage:
run_digit_recognition.py [--camera_id=CAMERA_ID]
[--path_in=FILENAME]
[--path_out=FILENAME]
[--title=TITLE]
[--use_gpu]
run_digit_recognition.py (-h | --help)

Options:
--camera_id=CAMERA_ID Index of the camera to be used as input. Defaults to 0.
--path_in=FILENAME Video file to stream from
--path_out=FILENAME Video file to stream to
--title=TITLE This adds a title to the window display
--use_gpu Use GPU for inference
"""
from docopt import docopt

import sense.display
from sense import engine
from sense import feature_extractors
from sense.controller import Controller
from sense.downstream_tasks.digit_recognition import INT2LAB, LAB2THRESHOLD
from sense.downstream_tasks.nn_utils import Pipe, LogisticRegression
from sense.downstream_tasks.postprocess import PostprocessClassificationOutput


if __name__ == "__main__":
# Parse arguments
args = docopt(__doc__)
camera_id = int(args['--camera_id'] or 0)
path_in = args['--path_in'] or None
path_out = args['--path_out'] or None
title = args['--title'] or None
use_gpu = args['--use_gpu']

# Load feature extractor
feature_extractor = feature_extractors.StridedInflatedEfficientNet()
checkpoint = engine.load_weights('resources/backbone/strided_inflated_efficientnet.ckpt')
feature_extractor.load_state_dict(checkpoint)
feature_extractor.eval()

# Load a logistic regression classifier
digit_classifier = LogisticRegression(num_in=feature_extractor.feature_dim, num_out=len(INT2LAB))
checkpoint = engine.load_weights('resources/digit_recognition/efficientnet_logistic_regression.ckpt')
digit_classifier.load_state_dict(checkpoint)
digit_classifier.eval()

# Concatenate feature extractor and digit classifier
net = Pipe(feature_extractor, digit_classifier)

postprocessor = [
PostprocessClassificationOutput(INT2LAB, smoothing=4)
]

border_size = 50 # Increase border size for showing top 2 predictions

display_ops = [
sense.display.DisplayTopKClassificationOutputs(top_k=2, threshold=0),
sense.display.DisplayClassnameOverlay(thresholds=LAB2THRESHOLD,
duration=2,
font_scale=20,
thickness=10,
border_size=border_size if not title else border_size + 50),
]
display_results = sense.display.DisplayResults(title=title, display_ops=display_ops, border_size=border_size)

# Run live inference
controller = Controller(
neural_network=net,
post_processors=postprocessor,
results_display=display_results,
callbacks=[],
camera_id=camera_id,
path_in=path_in,
path_out=path_out,
use_gpu=use_gpu
)
controller.run_inference()
3 changes: 2 additions & 1 deletion sense/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ def run_inference(self):
if not all(callback(prediction_postprocessed) for callback in self.callbacks):
break

except Exception as runtime_error:
except Exception as e:
runtime_error = e
break

# Press escape to exit
Expand Down
89 changes: 83 additions & 6 deletions sense/display.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,14 @@
FONT = cv2.FONT_HERSHEY_PLAIN


def put_text(img: np.ndarray, text: str, position: Tuple[int, int],
color: Tuple[int, int, int] = (255, 255, 255)) -> np.ndarray:
def put_text(
img: np.ndarray,
text: str,
position: Tuple[int, int],
font_scale: float = 1.,
color: Tuple[int, int, int] = (255, 255, 255),
thickness: int = 1
) -> np.ndarray:
"""
Draw a white text string on an image at a specified position and return the image.

Expand All @@ -21,13 +27,17 @@ def put_text(img: np.ndarray, text: str, position: Tuple[int, int],
The text to be written.
:param position:
A tuple of x and y coordinates of the bottom-left corner of the text in the image.
:param font_scale:
Font scale factor for modifying the font size.
:param color:
A tuple for font color. For BGR, eg: (0, 255, 0) for green color.
:param thickness:
Thickness of the lines used to draw the text.

:return:
The image with the text string drawn.
"""
cv2.putText(img, text, position, FONT, 1, color, 1, cv2.LINE_AA)
cv2.putText(img, text, position, FONT, font_scale, color, thickness, cv2.LINE_AA)
return img


Expand Down Expand Up @@ -99,7 +109,7 @@ def __init__(self, top_k=1, threshold=0.2, **kwargs):
:param top_k:
Number of the top classification labels to be displayed.
:param threshold:
Threshhold for the output to be displayed.
Threshold for the output to be displayed.
"""
super().__init__(**kwargs)
self.top_k = top_k
Expand Down Expand Up @@ -181,8 +191,75 @@ def display(self, img: np.ndarray, display_data: dict) -> np.ndarray:
text_color = self.default_text_color

# Show FPS on the video screen
put_text(img, "Camera FPS: {:.1f}".format(camera_fps), (5, img.shape[0] - self.y_offset - 20), text_color)
put_text(img, "Model FPS: {:.1f}".format(inference_engine_fps), (5, img.shape[0] - self.y_offset), text_color)
put_text(img, "Camera FPS: {:.1f}".format(camera_fps), (5, img.shape[0] - self.y_offset - 20),
color=text_color)
put_text(img, "Model FPS: {:.1f}".format(inference_engine_fps), (5, img.shape[0] - self.y_offset),
color=text_color)

return img


class DisplayClassnameOverlay(BaseDisplay):
"""
Display recognized class name as a large video overlay. Once the probability for a class passes the threshold,
the name is shown and stays visible for a certain duration.
"""

def __init__(self, thresholds, duration=2, font_scale=3, thickness=2, border_size=50, **kwargs):
"""
:param thresholds:
Dictionary of thresholds for all classes.
:param duration:
Duration in seconds how long the class name should be displayed after it has been recognized.
:param font_scale:
Font scale factor for modifying the font size.
:param thickness:
Thickness of the lines used to draw the text.
:param border_size:
Height of the border on top of the video display. Used for correctly centering the displayed class name
on the video.
"""
super().__init__(**kwargs)
self.thresholds = thresholds
self.duration = duration
self.font_scale = font_scale
self.thickness = thickness
self.border_size = border_size

self._current_class_name = None
self._start_time = None

def _get_center_coordinates(self, img, text):
textsize = cv2.getTextSize(text, FONT, self.font_scale, self.thickness)[0]

height, width, _ = img.shape
height -= self.border_size

x = int((width - textsize[0]) / 2)
y = int((height + textsize[1]) / 2) + self.border_size

return x, y

def _display_class_name(self, img, class_name):
pos = self._get_center_coordinates(img, class_name)
put_text(img, class_name, position=pos, font_scale=self.font_scale, thickness=self.thickness)

def display(self, img, display_data):
now = time.perf_counter()

if self._current_class_name and now - self._start_time < self.duration:
# Keep displaying the same class name
self._display_class_name(img, self._current_class_name)
else:
self._current_class_name = None
for class_name, proba in display_data['sorted_predictions']:
if class_name in self.thresholds and proba > self.thresholds[class_name]:
# Display new class name
self._display_class_name(img, class_name)
self._current_class_name = class_name
self._start_time = now

break

return img

Expand Down
30 changes: 30 additions & 0 deletions sense/downstream_tasks/digit_recognition/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
LAB2INT = {
"Doing nothing": 0,
"Doing other things": 1,
"0": 2,
"1": 3,
"2": 4,
"3": 5,
"4": 6,
"5": 7,
"6": 8,
"7": 9,
"8": 10,
"9": 11,
"Drawing": 12,
}

INT2LAB = {value: key for key, value in LAB2INT.items()}

LAB2THRESHOLD = {
"0": 0.7,
"1": 0.7,
"2": 0.7,
"3": 0.4,
"4": 0.3,
"5": 0.4,
"6": 0.4,
"7": 0.5,
"8": 0.4,
"9": 0.3,
}