Skip to content

Commit

Permalink
Tensorflow Lite object detection implementation!
Browse files Browse the repository at this point in the history
  • Loading branch information
niccellular committed Jul 31, 2021
1 parent 52685b1 commit a80a402
Show file tree
Hide file tree
Showing 22 changed files with 1,480 additions and 167 deletions.
24 changes: 16 additions & 8 deletions app/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,25 @@ def ftsAPIKey = properties.getProperty('FTSAPIKEY')
def rtmpIP = properties.getProperty('RTMPIP')
def droneName = properties.getProperty('DRONENAME')

def versionMajor = 1
def versionMinor = 4
def versionPatch = 6

android {
compileSdkVersion 30
buildToolsVersion '30.0.2'
defaultConfig {
applicationId "org.FreeTak.FreeTAKUAS"
versionCode 1
versionName "1.4.6"
versionCode versionMajor * 10000
+ versionMinor * 100
+ versionPatch
versionName "${versionMajor}.${versionMinor}.${versionPatch}"
manifestPlaceholders = [djiKey: djiKey, hereKey: hereKey, hereToken: hereToken, hereLic: hereLic]
buildConfigField "String", "FTSIP", ftsIP
buildConfigField "String", "FTSAPIKEY", ftsAPIKey
buildConfigField "String", "RTMPIP", rtmpIP
buildConfigField "String", "DRONENAME", droneName
buildConfigField "boolean", "RELEASE", "false"
buildConfigField "boolean", "RELEASE", "false" // true when making Play Store APK
minSdkVersion 23
targetSdkVersion 30
multiDexEnabled true
Expand All @@ -54,6 +59,7 @@ android {
}
buildTypes {
release {
resValue "string", "app_version", "${defaultConfig.versionName}"
minifyEnabled false
proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro'
signingConfig signingConfigs.release
Expand Down Expand Up @@ -106,8 +112,14 @@ android {
sourceCompatibility JavaVersion.VERSION_1_8
targetCompatibility JavaVersion.VERSION_1_8
}

aaptOptions {
noCompress "tflite"
}
}

project.ext.ASSET_DIR = projectDir.toString() + '/src/main/assets'

dependencies {
implementation ('com.dji:dji-uxsdk:4.14', {
exclude module: 'dji-sdk'
Expand Down Expand Up @@ -142,11 +154,7 @@ dependencies {

//HERE maps
implementation files('libs/HERE-sdk-3.15.0.aar')
/*

// TensorFlow Lite
implementation 'org.tensorflow:tensorflow-lite-task-vision:0.2.0'
implementation 'org.tensorflow:tensorflow-lite-support:0.1.0-rc1'
implementation 'org.tensorflow:tensorflow-lite-metadata:0.1.0-rc1'
implementation 'org.tensorflow:tensorflow-lite-gpu:2.3.0'
*/
}
Binary file not shown.
Binary file not shown.
205 changes: 127 additions & 78 deletions app/src/main/java/org/FreeTak/FreeTAKUAS/CompleteWidgetActivity.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

import android.app.Activity;
import android.content.Context;
import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.graphics.Canvas;
import android.graphics.Point;
import android.graphics.Rect;
import android.graphics.RectF;
import android.os.Bundle;
import android.os.Handler;
import android.preference.PreferenceManager;
Expand All @@ -20,45 +21,46 @@
import android.view.animation.Animation;
import android.view.animation.Transformation;
import android.widget.Button;
import android.widget.CompoundButton;
import android.widget.FrameLayout;
import android.widget.ImageView;
import android.widget.LinearLayout;
import android.widget.NumberPicker;
import android.widget.PopupWindow;
import android.widget.RelativeLayout;
import android.widget.Toast;
import android.widget.ToggleButton;

import androidx.annotation.NonNull;
import androidx.annotation.Nullable;

import org.FreeTak.FreeTAKUAS.customview.OverlayView;
import org.FreeTak.FreeTAKUAS.detector.Detector;
import org.FreeTak.FreeTAKUAS.tracking.MultiBoxTracker;
import org.jetbrains.annotations.NotNull;

//import com.amap.api.maps.model.LatLng;
import com.dji.mapkit.core.maps.DJIMap;
import com.dji.mapkit.core.models.DJILatLng;
/*
import org.FreeTAKTeam.FreeTAKUAS.ml.ObjectDetectionMobileObjectLocalizerV11Metadata1;
import org.tensorflow.lite.DataType;
import org.tensorflow.lite.support.image.TensorImage;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
import org.tensorflow.lite.support.label.Category;
import org.tensorflow.lite.task.vision.core.BaseVisionTaskApi;
import org.tensorflow.lite.task.vision.detector.Detection;
import org.tensorflow.lite.task.vision.detector.ObjectDetector;
*/
import java.io.IOException;
import java.net.URL;

import java.io.InputStream;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Dictionary;
import java.util.Hashtable;
import java.util.Random;
import java.util.Iterator;
import java.util.List;
import java.util.ListIterator;
import java.util.concurrent.atomic.AtomicBoolean;

import dji.common.airlink.PhysicalSource;
import dji.common.flightcontroller.CompassState;
import dji.common.flightcontroller.FlightControllerState;
import dji.common.flightcontroller.LocationCoordinate3D;
import dji.common.gimbal.Attitude;
import dji.common.gimbal.GimbalState;
import dji.common.model.LocationCoordinate2D;
import dji.keysdk.CameraKey;
import dji.keysdk.KeyManager;
import dji.keysdk.ProductKey;
import dji.sdk.camera.VideoFeeder;
import dji.sdk.flightcontroller.FlightController;
import dji.sdk.gimbal.Gimbal;
Expand All @@ -71,7 +73,6 @@
import dji.ux.widget.FPVWidget;
import dji.ux.widget.MapWidget;
import dji.ux.widget.controls.CameraControlsWidget;
import dji.ux.widget.dashboard.CompassWidget;

/**
* Activity that shows all the UI elements together
Expand All @@ -86,7 +87,6 @@ public class CompleteWidgetActivity extends Activity {
private FPVOverlayWidget fpvOverlayWidget;
private RelativeLayout primaryVideoView;
private FrameLayout secondaryVideoView;
private CompassWidget compassWidget;
private boolean isMapMini = true;

private Button closePopupBtn, sendPopupBtn;
Expand Down Expand Up @@ -116,7 +116,7 @@ public class CompleteWidgetActivity extends Activity {
int stream_delay = 3000;

public String RTMP_URL = "";
public boolean rtmp_hd;
public boolean rtmp_hd, object_detect;
public String FTS_IP, FTS_APIKEY, drone_name, rtmp_ip;
public double droneLocationLat, droneLocationLng, droneDistance, droneHeading;
public double homeLocationLat, homeLocationLng;
Expand All @@ -126,49 +126,55 @@ public class CompleteWidgetActivity extends Activity {
public boolean stream_enabled = false;
// this holds the geoobj names count
private Dictionary names = new Hashtable();
//private ObjectDetectionMobileObjectLocalizerV11Metadata1 model = null;

// ML stuff
private ObjectDetector objectDetector = null;
private OverlayView trackingOverlay = null;
private MultiBoxTracker tracker;
private long timestamp = 0;

@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_default_widgets);

names.put("Alpha",0);
names.put("Bravo",0);
names.put("Charlie",0);
names.put("Delta",0);
names.put("Echo",0);
names.put("Foxtrot",0);
names.put("Golf",0);
names.put("Hotel",0);
names.put("India",0);
names.put("Juliett",0);
names.put("Kilo",0);
names.put("Lima",0);
names.put("Mike",0);
names.put("November",0);
names.put("Oscar",0);
names.put("Papa",0);
names.put("Quebec",0);
names.put("Romeo",0);
names.put("Sierra",0);
names.put("Tango",0);
names.put("Uniform",0);
names.put("Victor",0);
names.put("Whiskey",0);
names.put("X-ray",0);
names.put("Yankee",0);
names.put("Zulu",0);
names.put("Alpha", 0);
names.put("Bravo", 0);
names.put("Charlie", 0);
names.put("Delta", 0);
names.put("Echo", 0);
names.put("Foxtrot", 0);
names.put("Golf", 0);
names.put("Hotel", 0);
names.put("India", 0);
names.put("Juliett", 0);
names.put("Kilo", 0);
names.put("Lima", 0);
names.put("Mike", 0);
names.put("November", 0);
names.put("Oscar", 0);
names.put("Papa", 0);
names.put("Quebec", 0);
names.put("Romeo", 0);
names.put("Sierra", 0);
names.put("Tango", 0);
names.put("Uniform", 0);
names.put("Victor", 0);
names.put("Whiskey", 0);
names.put("X-ray", 0);
names.put("Yankee", 0);
names.put("Zulu", 0);

height = DensityUtil.dip2px(this, 100);
width = DensityUtil.dip2px(this, 150);
margin = DensityUtil.dip2px(this, 12);

FTS_IP = PreferenceManager.getDefaultSharedPreferences(this).getString("ftsip","");
FTS_APIKEY = PreferenceManager.getDefaultSharedPreferences(this).getString("ftsapikey","");
drone_name = PreferenceManager.getDefaultSharedPreferences(this).getString("drone_name","");
rtmp_ip = PreferenceManager.getDefaultSharedPreferences(this).getString("rtmp_ip","");
rtmp_hd = PreferenceManager.getDefaultSharedPreferences(this).getBoolean("rtmp_hd",false);
FTS_IP = PreferenceManager.getDefaultSharedPreferences(this).getString("ftsip", "");
FTS_APIKEY = PreferenceManager.getDefaultSharedPreferences(this).getString("ftsapikey", "");
drone_name = PreferenceManager.getDefaultSharedPreferences(this).getString("drone_name", "");
rtmp_ip = PreferenceManager.getDefaultSharedPreferences(this).getString("rtmp_ip", "");
rtmp_hd = PreferenceManager.getDefaultSharedPreferences(this).getBoolean("rtmp_hd", false);
object_detect = PreferenceManager.getDefaultSharedPreferences(this).getBoolean("object_detect", false);

String rtmp_path = "/live/UAS-" + drone_name;
RTMP_URL = "rtmp://" + rtmp_ip + rtmp_path;
Expand Down Expand Up @@ -208,44 +214,46 @@ public void onClick(View view) {
secondaryFPVWidget = findViewById(R.id.secondary_fpv_widget);
secondaryFPVWidget.setOnClickListener(view -> swapVideoSource());

tracker = new MultiBoxTracker(this);
trackingOverlay = (OverlayView) findViewById(R.id.tracking_overlay);
trackingOverlay.addCallback(
canvas -> tracker.draw(canvas));

if (VideoFeeder.getInstance() != null) {
//If secondary video feed is already initialized, get video source
updateSecondaryVideoVisibility(VideoFeeder.getInstance().getSecondaryVideoFeed().getVideoSource() != PhysicalSource.UNKNOWN);
//If secondary video feed is not yet initialized, wait for active status
VideoFeeder.getInstance().getSecondaryVideoFeed()
.addVideoActiveStatusListener(isActive ->
runOnUiThread(() -> updateSecondaryVideoVisibility(isActive)));
}
/*
try {
model = ObjectDetectionMobileObjectLocalizerV11Metadata1.newInstance(this.getApplicationContext());
} catch (IOException e) {
Log.i(TAG, String.format("Failed to create TFLite model: %s",e));
}

VideoFeeder.getInstance().getSecondaryVideoFeed().addVideoDataListener((videoBuffer, size) -> {
try {
Log.i(TAG, String.format("In addVideoDataListener: size=%d",size));
if (object_detect) {
try {
AssetManager am = this.getApplicationContext().getAssets();
//InputStream model = am.open("lite-model_object_detection_mobile_object_localizer_v1_1_metadata_2.tflite");
InputStream model = am.open("lite-model_ssd_mobilenet_v1_1_metadata_2.tflite");
ByteBuffer modelBytes = ByteBuffer.allocateDirect(model.available());
while (model.available() > 0) {
modelBytes.put((byte) model.read());
}

//Bitmap bitmap = fpvWidget.getBitmap();
Bitmap bitmap = BitmapFactory.decodeByteArray(videoBuffer,0, size);
// Creates inputs for reference.
TensorImage image = TensorImage.fromBitmap(bitmap);
List<String> allowedLabels = new ArrayList<>();
allowedLabels.add("car");
allowedLabels.add("truck");
allowedLabels.add("airplane");
allowedLabels.add("boat");
allowedLabels.add("person");

// Runs model inference and gets result.
ObjectDetectionMobileObjectLocalizerV11Metadata1.Outputs outputs = model.process(image);
TensorBuffer locations = outputs.getLocationsAsTensorBuffer();
TensorBuffer classes = outputs.getClassesAsTensorBuffer();
TensorBuffer scores = outputs.getScoresAsTensorBuffer();
TensorBuffer numberOfDetections = outputs.getNumberOfDetectionsAsTensorBuffer();
ObjectDetector.ObjectDetectorOptions options = ObjectDetector.ObjectDetectorOptions.builder().setNumThreads(4).setLabelAllowList(allowedLabels).setScoreThreshold(0.5f).build();

Log.i(TAG, String.format("TensorFlow data:\n\tLocations: %s\n\tClasses:%s\n\tScores:%s\n\tnumberOfDetections:%s", locations.toString(),classes.toString(),scores.toString(),numberOfDetections.toString()));
objectDetector = ObjectDetector.createFromBufferAndOptions(modelBytes, options);

} catch (Exception e) {
Log.i(TAG, String.format("Something bad happened doing TFLite: %s", e));
} catch (Exception e) {
Log.i(TAG, String.format("Failed to load TFLite model: %s", e));
}
}
});
*/

}
}

private void onViewClick(View view) {
Expand Down Expand Up @@ -360,6 +368,7 @@ protected void onResume() {
} else {
Log.i(TAG, "Streaming in 480p");
l.setLiveVideoResolution(LiveVideoResolution.VIDEO_RESOLUTION_480_360 );
VideoFeeder.getInstance().setTranscodingDataRate(0.3f);
}

l.setLiveVideoBitRate(LiveVideoBitRateMode.AUTO.getValue());
Expand Down Expand Up @@ -478,6 +487,46 @@ protected void onResume() {
return false;
});

if (object_detect) {
AtomicBoolean processingFrame = new AtomicBoolean(false);
VideoFeeder.getInstance().getSecondaryVideoFeed().addVideoDataListener((videoBuffer, size) -> {
try {
if (processingFrame.get())
return;
processingFrame.set(true);

Bitmap bitmap = fpvWidget.getBitmap();
if (bitmap == null)
return;

// Creates inputs for reference.
TensorImage image = TensorImage.fromBitmap(bitmap);
// Run inference
List<Detection> results = objectDetector.detect(image);

final List<Detector.Recognition> mappedRecognitions = new ArrayList<>();
for (Detection result : results) {
final RectF location = result.getBoundingBox();
List<Category> labels = result.getCategories();
for (Category label : labels) {
final float score = label.getScore();
final int id = label.getIndex();
final String title = label.getLabel();
mappedRecognitions.add(new Detector.Recognition(String.valueOf(id), title, score, location));
}
}

tracker.trackResults(mappedRecognitions, ++timestamp);
trackingOverlay.postInvalidate();

processingFrame.set(false);
} catch (Exception e) {
Log.i(TAG, String.format("Something bad happened doing TFLite: %s", e));
processingFrame.set(false);
}
});
}

boolean controller_status = initFlightController();
boolean gimbal_status = initGimbal();
if (controller_status) {
Expand Down
Loading

0 comments on commit a80a402

Please sign in to comment.