-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
347 additions
and
245 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
package ch.bildspur.vision; | ||
|
||
import ch.bildspur.vision.util.CvProcessingUtils; | ||
import org.bytedeco.opencv.opencv_core.Mat; | ||
import processing.core.PImage; | ||
|
||
import static org.bytedeco.opencv.global.opencv_core.CV_8UC4; | ||
import static org.bytedeco.opencv.global.opencv_imgproc.COLOR_RGBA2RGB; | ||
import static org.bytedeco.opencv.global.opencv_imgproc.cvtColor; | ||
|
||
public abstract class DeepNeuralNetwork<R> { | ||
|
||
abstract boolean setup(); | ||
|
||
public R run(PImage image) { | ||
// prepare frame | ||
Mat frame = new Mat(image.height, image.width, CV_8UC4); | ||
CvProcessingUtils.toCv(image, frame); | ||
cvtColor(frame, frame, COLOR_RGBA2RGB); | ||
|
||
// inference | ||
return run(frame); | ||
} | ||
|
||
abstract R run(Mat frame); | ||
} |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
34 changes: 34 additions & 0 deletions
34
src/main/java/ch/bildspur/vision/ObjectDetectionNetwork.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
package ch.bildspur.vision; | ||
|
||
import ch.bildspur.vision.result.ObjectDetectionResult; | ||
|
||
import java.io.IOException; | ||
import java.nio.file.Files; | ||
import java.nio.file.Paths; | ||
import java.util.ArrayList; | ||
import java.util.List; | ||
|
||
public abstract class ObjectDetectionNetwork extends DeepNeuralNetwork<List<ObjectDetectionResult>> { | ||
public static final String COCONamesFile = "data/darknet/coco.names"; | ||
|
||
private List<String> names = new ArrayList<>(); | ||
|
||
@Override | ||
boolean setup() { | ||
|
||
return false; | ||
} | ||
|
||
public void loadNames(String namesFile) { | ||
try { | ||
names.clear(); | ||
names.addAll(Files.readAllLines(Paths.get(namesFile))); | ||
} catch (IOException e) { | ||
e.printStackTrace(); | ||
} | ||
} | ||
|
||
public List<String> getNames() { | ||
return names; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,211 @@ | ||
package ch.bildspur.vision; | ||
|
||
import ch.bildspur.vision.config.YoloConfig; | ||
import ch.bildspur.vision.result.ObjectDetectionResult; | ||
import org.bytedeco.javacpp.DoublePointer; | ||
import org.bytedeco.javacpp.FloatPointer; | ||
import org.bytedeco.javacpp.IntPointer; | ||
import org.bytedeco.opencv.opencv_core.*; | ||
import org.bytedeco.opencv.opencv_dnn.Net; | ||
import org.bytedeco.opencv.opencv_text.FloatVector; | ||
import org.bytedeco.opencv.opencv_text.IntVector; | ||
|
||
import java.util.ArrayList; | ||
import java.util.List; | ||
|
||
import static org.bytedeco.opencv.global.opencv_core.CV_32F; | ||
import static org.bytedeco.opencv.global.opencv_core.minMaxLoc; | ||
import static org.bytedeco.opencv.global.opencv_dnn.*; | ||
|
||
public class YoloNetwork extends ObjectDetectionNetwork { | ||
|
||
private String configPath; | ||
private String weightsPath; | ||
private String namesPath; | ||
private int width; | ||
private int height; | ||
|
||
private float confidenceThreshold = 0.5f; | ||
private float nonMaximumSuppressionThreshold = 0.4f; | ||
private boolean skipNonMaximumSuppression = false; | ||
|
||
private Net net; | ||
|
||
public YoloNetwork(YoloConfig config) { | ||
this.configPath = config.getConfigPath(); | ||
this.weightsPath = config.getWeightsPath(); | ||
this.namesPath = config.getNamesPath(); | ||
this.width = config.getWidth(); | ||
this.height = config.getHeight(); | ||
} | ||
|
||
public YoloNetwork(String configPath, String weightsPath, String namesPath, int width, int height) { | ||
this.configPath = configPath; | ||
this.weightsPath = weightsPath; | ||
this.namesPath = namesPath; | ||
this.width = width; | ||
this.height = height; | ||
} | ||
|
||
public boolean setup() { | ||
net = readNetFromDarknet(configPath, weightsPath); | ||
|
||
// load names | ||
// todo: do this in object detection or factory | ||
loadNames(namesPath); | ||
|
||
if (net.empty()) { | ||
System.out.println("Can't load network!"); | ||
return false; | ||
} | ||
|
||
return true; | ||
} | ||
|
||
public List<ObjectDetectionResult> run(Mat frame) { | ||
// convert image into batch of images | ||
Mat inputBlob = blobFromImage(frame, | ||
1 / 255.0, | ||
new Size(width, height), | ||
new Scalar(0.0), | ||
true, false, CV_32F); | ||
|
||
// set input | ||
net.setInput(inputBlob); | ||
|
||
// create output layers | ||
StringVector outNames = net.getUnconnectedOutLayersNames(); | ||
MatVector outs = new MatVector(outNames.size()); | ||
|
||
// run detection | ||
net.forward(outs, outNames); | ||
|
||
// evaluate result | ||
return postprocess(frame, outs); | ||
} | ||
|
||
/** | ||
* Remove the bounding boxes with low confidence using non-maxima suppression | ||
* | ||
* @param frame Input frame | ||
* @param outs Network outputs | ||
* @return List of objects | ||
*/ | ||
private List<ObjectDetectionResult> postprocess(Mat frame, MatVector outs) { | ||
IntVector classIds = new IntVector(); | ||
FloatVector confidences = new FloatVector(); | ||
RectVector boxes = new RectVector(); | ||
|
||
for (int i = 0; i < outs.size(); ++i) { | ||
// Scan through all the bounding boxes output from the network and keep only the | ||
// ones with high confidence scores. Assign the box's class label as the class | ||
// with the highest score for the box. | ||
Mat result = outs.get(i); | ||
|
||
for (int j = 0; j < result.rows(); j++) { | ||
FloatPointer data = new FloatPointer(result.row(j).data()); | ||
Mat scores = result.row(j).colRange(5, result.cols()); | ||
|
||
Point classIdPoint = new Point(1); | ||
DoublePointer confidence = new DoublePointer(1); | ||
|
||
// Get the value and location of the maximum score | ||
minMaxLoc(scores, null, confidence, null, classIdPoint, null); | ||
if (confidence.get() > confidenceThreshold) { | ||
// todo: maybe round instead of floor | ||
int centerX = (int) (data.get(0) * frame.cols()); | ||
int centerY = (int) (data.get(1) * frame.rows()); | ||
int width = (int) (data.get(2) * frame.cols()); | ||
int height = (int) (data.get(3) * frame.rows()); | ||
int left = centerX - width / 2; | ||
int top = centerY - height / 2; | ||
|
||
classIds.push_back(classIdPoint.x()); | ||
confidences.push_back((float) confidence.get()); | ||
boxes.push_back(new Rect(left, top, width, height)); | ||
} | ||
} | ||
} | ||
|
||
// skip nms | ||
if (skipNonMaximumSuppression) { | ||
List<ObjectDetectionResult> detections = new ArrayList<>(); | ||
for (int i = 0; i < confidences.size(); ++i) { | ||
Rect box = boxes.get(i); | ||
|
||
int classId = classIds.get(i); | ||
detections.add(new ObjectDetectionResult(classId, getNames().get(classId), confidences.get(i), | ||
box.x(), box.y(), box.width(), box.height())); | ||
} | ||
return detections; | ||
} | ||
|
||
// Perform non maximum suppression to eliminate redundant overlapping boxes with | ||
// lower confidences | ||
IntPointer indices = new IntPointer(confidences.size()); | ||
FloatPointer confidencesPointer = new FloatPointer(confidences.size()); | ||
confidencesPointer.put(confidences.get()); | ||
|
||
NMSBoxes(boxes, confidencesPointer, confidenceThreshold, nonMaximumSuppressionThreshold, indices, 1.f, 0); | ||
|
||
List<ObjectDetectionResult> detections = new ArrayList<>(); | ||
for (int i = 0; i < indices.limit(); ++i) { | ||
int idx = indices.get(i); | ||
Rect box = boxes.get(idx); | ||
|
||
int classId = classIds.get(idx); | ||
detections.add(new ObjectDetectionResult(classId, getNames().get(classId), confidences.get(idx), | ||
box.x(), box.y(), box.width(), box.height())); | ||
} | ||
|
||
return detections; | ||
} | ||
|
||
public float getConfidenceThreshold() { | ||
return confidenceThreshold; | ||
} | ||
|
||
public void setConfidenceThreshold(float confidenceThreshold) { | ||
this.confidenceThreshold = confidenceThreshold; | ||
} | ||
|
||
public float getNonMaximumSuppressionThreshold() { | ||
return nonMaximumSuppressionThreshold; | ||
} | ||
|
||
public void setNonMaximumSuppressionThreshold(float nonMaximumSuppressionThreshold) { | ||
this.nonMaximumSuppressionThreshold = nonMaximumSuppressionThreshold; | ||
} | ||
|
||
public boolean isSkipNonMaximumSuppression() { | ||
return skipNonMaximumSuppression; | ||
} | ||
|
||
public void setSkipNonMaximumSuppression(boolean skipNonMaximumSuppression) { | ||
this.skipNonMaximumSuppression = skipNonMaximumSuppression; | ||
} | ||
|
||
public String getConfigPath() { | ||
return configPath; | ||
} | ||
|
||
public String getWeightsPath() { | ||
return weightsPath; | ||
} | ||
|
||
public String getNamesPath() { | ||
return namesPath; | ||
} | ||
|
||
public int getWidth() { | ||
return width; | ||
} | ||
|
||
public int getHeight() { | ||
return height; | ||
} | ||
|
||
public Net getNet() { | ||
return net; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
package ch.bildspur.vision.config; | ||
|
||
import ch.bildspur.vision.ObjectDetectionNetwork; | ||
|
||
// todo: maybe replace through factory | ||
public enum YoloConfig { | ||
YOLOv3_Tiny("data/darknet/yolov3-tiny.cfg", | ||
"data/darknet/yolov3-tiny.weights", | ||
ObjectDetectionNetwork.COCONamesFile, | ||
416, 416), | ||
|
||
YOLOv3_608("data/darknet/yolov3-608.cfg", | ||
"data/darknet/yolov3.weights", | ||
ObjectDetectionNetwork.COCONamesFile, | ||
608, 608); | ||
|
||
private String configPath; | ||
private String weightsPath; | ||
private String namesPath; | ||
private int width; | ||
private int height; | ||
|
||
YoloConfig(String configPath, String weightsPath, String namesPath, int width, int height) { | ||
this.configPath = configPath; | ||
this.weightsPath = weightsPath; | ||
this.namesPath = namesPath; | ||
this.width = width; | ||
this.height = height; | ||
} | ||
|
||
public String getConfigPath() { | ||
return configPath; | ||
} | ||
|
||
public String getWeightsPath() { | ||
return weightsPath; | ||
} | ||
|
||
public String getNamesPath() { | ||
return namesPath; | ||
} | ||
|
||
public int getWidth() { | ||
return width; | ||
} | ||
|
||
public int getHeight() { | ||
return height; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
package ch.bildspur.vision.deps; | ||
|
||
public class NetworkDependency { | ||
private String name; | ||
private String url; | ||
|
||
// todo: implement dependency management system | ||
} |
5 changes: 0 additions & 5 deletions
5
src/main/java/ch/bildspur/vision/network/DeepNeuralNetwork.java
This file was deleted.
Oops, something went wrong.
5 changes: 0 additions & 5 deletions
5
src/main/java/ch/bildspur/vision/network/DeepNeuralNetworkFactory.java
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.