diff --git a/src/main/java/ch/bildspur/vision/MediaPipeBlazeFaceNetwork.java b/src/main/java/ch/bildspur/vision/MediaPipeBlazeFaceNetwork.java new file mode 100644 index 0000000..4be28fc --- /dev/null +++ b/src/main/java/ch/bildspur/vision/MediaPipeBlazeFaceNetwork.java @@ -0,0 +1,90 @@ +package ch.bildspur.vision; + +import ch.bildspur.vision.network.ObjectDetectionNetwork; +import ch.bildspur.vision.result.ObjectDetectionResult; +import ch.bildspur.vision.result.ResultList; +import ch.bildspur.vision.util.MathUtils; +import org.bytedeco.javacpp.FloatPointer; +import org.bytedeco.javacpp.IntPointer; +import org.bytedeco.javacpp.indexer.FloatIndexer; +import org.bytedeco.opencv.global.opencv_dnn; +import org.bytedeco.opencv.opencv_core.*; +import org.bytedeco.opencv.opencv_dnn.Net; +import org.bytedeco.opencv.opencv_text.FloatVector; + +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; + +import static org.bytedeco.opencv.global.opencv_core.CV_32F; +import static org.bytedeco.opencv.global.opencv_dnn.*; + +/** + * Based on https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB/blob/master/caffe/ultra_face_opencvdnn_inference.py + * Adapted and improved a lot. + */ +public class MediaPipeBlazeFaceNetwork extends ObjectDetectionNetwork { + private Path modelPath; + protected Net net; + + private int width; + private int height; + + private Scalar imageMean = Scalar.all(127); + private float imageStd = 128.0f; + + public MediaPipeBlazeFaceNetwork(Path modelPath, int width, int height) { + this.modelPath = modelPath; + this.width = width; + this.height = height; + } + + @Override + public boolean setup() { + net = readNetFromONNX(modelPath.toAbsolutePath().toString()); + + if (DeepVision.ENABLE_CUDA_BACKEND) { + net.setPreferableBackend(opencv_dnn.DNN_BACKEND_CUDA); + net.setPreferableTarget(opencv_dnn.DNN_TARGET_CUDA); + } + + if (net.empty()) { + System.out.println("Can't load network!"); + return false; + } + + return true; + } + + @Override + public ResultList run(Mat frame) { + // convert image into batch of images + Mat inputBlob = blobFromImage(frame, + 1 / imageStd, + new Size(width, height), + imageMean, + false, false, CV_32F); + + // set input + net.setInput(inputBlob); + + // create output layers + StringVector outNames = net.getUnconnectedOutLayersNames(); + MatVector outs = new MatVector(outNames.size()); + + // run detection + net.forward(outs, outNames); + + // extract boxes and scores + Mat boxesOut = outs.get(0); + Mat confidencesOut = outs.get(1); + + // boxes + Mat boxes = boxesOut.reshape(0, boxesOut.size(1)); + + // class confidences (BACKGROUND, face) + Mat confidences = confidencesOut.reshape(0, confidencesOut.size(1)); + + return new ResultList<>(); + } +} diff --git a/src/test/java/ch/bildspur/vision/test/BlazeFaceTest.java b/src/test/java/ch/bildspur/vision/test/BlazeFaceTest.java new file mode 100644 index 0000000..6832650 --- /dev/null +++ b/src/test/java/ch/bildspur/vision/test/BlazeFaceTest.java @@ -0,0 +1,70 @@ +package ch.bildspur.vision.test; + + +import ch.bildspur.vision.DeepVisionPreview; +import ch.bildspur.vision.MediaPipeBlazeFaceNetwork; +import ch.bildspur.vision.TextBoxesNetwork; +import ch.bildspur.vision.result.ObjectDetectionResult; +import processing.core.PApplet; +import processing.core.PImage; + +import java.nio.file.Paths; +import java.util.List; + +public class BlazeFaceTest extends PApplet { + + public static void main(String... args) { + BlazeFaceTest sketch = new BlazeFaceTest(); + sketch.runSketch(); + } + + public void settings() { + size(640, 480, FX2D); + } + + PImage testImage; + + DeepVisionPreview vision = new DeepVisionPreview(this); + MediaPipeBlazeFaceNetwork network; + List detections; + + public void setup() { + colorMode(HSB, 360, 100, 100); + + testImage = loadImage(sketchPath("data/faces.png")); + + println("creating network..."); + network = new MediaPipeBlazeFaceNetwork(Paths.get("networks/face_detection_back_256x256_barracuda.onnx"), 256, 256); + + println("loading model..."); + network.setup(); + + //network.setConfidenceThreshold(0.2f); + + println("inferencing..."); + detections = network.run(testImage); + println("done!"); + + for (ObjectDetectionResult detection : detections) { + System.out.println(detection.getClassName() + "\t[" + detection.getConfidence() + "]"); + } + + println("found " + detections.size() + " texts!"); + } + + public void draw() { + background(55); + + image(testImage, 0, 0); + + noFill(); + strokeWeight(2f); + + stroke(200, 80, 100); + for (ObjectDetectionResult detection : detections) { + rect(detection.getX(), detection.getY(), detection.getWidth(), detection.getHeight()); + } + + surface.setTitle("BlazeFace Test - FPS: " + Math.round(frameRate)); + } +}