Skip to content

Commit

Permalink
Android MNIST example (#28)
Browse files Browse the repository at this point in the history
* download MNIST data set instead of bundling in iOS

* fix backend not picking up coreml param

* android load MNIST;bug: TFLite shape mismatch

* fix android mnist training by changing Y_SHAPE

* get Float from loss not Double

* fix iOS data loading;iOS MSE&classification accuracy;bug: predict 0s

* fix android mnist data loading

* no debug print when loading data

* drop last relu in model;ios client directly mutate model on disk;clean up logging

* fix try_make_layers_updatable marking updatable layers

* correct mnist core ml layers

* improve printout

* change CIFAR10 names to MNIST for flutter example
  • Loading branch information
SichangHe committed Oct 1, 2023
1 parent f5ae4c4 commit eb60d46
Show file tree
Hide file tree
Showing 19 changed files with 252 additions and 79 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package org.eu.fedcampus.fed_kit_examples.mnist

import android.util.Log
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import org.eu.fedcampus.fed_kit_examples.cifar10.Float3DArray
import org.eu.fedcampus.fed_kit_train.FlowerClient
import org.eu.fedcampus.fed_kit_train.SampleSpec
import org.eu.fedcampus.fed_kit_train.helpers.classifierAccuracy
import org.eu.fedcampus.fed_kit_train.helpers.maxSquaredErrorLoss
import java.io.File

fun sampleSpec() = SampleSpec<Float3DArray, FloatArray>(
{ it.toTypedArray() },
{ it.toTypedArray() },
{ Array(it) { FloatArray(N_CLASSES) } },
::maxSquaredErrorLoss,
{ samples, logits ->
samples.zip(logits).forEach { (sample, logit) ->
Log.d(
TAG,
"actual: ${sample.label.contentToString()}, predicted: ${logit.contentToString()}"
)
}
classifierAccuracy(samples, logits)
},
)

private suspend fun processSet(dataSetDir: String, call: suspend (Int, String) -> Unit) {
withContext(Dispatchers.IO) {
File(dataSetDir).useLines {
it.forEachIndexed { i, l -> launch { call(i, l) } }
}
}
}

suspend fun loadData(
dataDir: String, flowerClient: FlowerClient<Float3DArray, FloatArray>, partitionId: Int
) {
processSet("$dataDir/MNIST_train.csv") { index, line ->
if (index / 6000 + 1 != partitionId) {
return@processSet
}
if (index % 1000 == 999) {
Log.i(TAG, "Prepared $index training data points.")
}
addSample(flowerClient, line, true)
}
processSet("$dataDir/MNIST_test.csv") { index, line ->
if (index % 1000 == 999) {
Log.i(TAG, "Prepared $index test data points.")
}
addSample(flowerClient, line, false)
}
}

private fun addSample(
flowerClient: FlowerClient<Float3DArray, FloatArray>, line: String, isTraining: Boolean
) {
val splits = line.split(",")
val feature = Array(IMAGE_SIZE) { Array(IMAGE_SIZE) { FloatArray(1) } }
val label = FloatArray(N_CLASSES)
for (i in 0 until LENGTH_ENTRY) {
feature[i / IMAGE_SIZE][i % IMAGE_SIZE][0] = splits[i + 1].toFloat() / NORMALIZATION
}
label[splits.first().toInt()] = 1f
flowerClient.addSample(feature, label, isTraining)
}

private const val TAG = "MNIST Data Loader"
private const val IMAGE_SIZE = 28
private const val N_CLASSES = 10
private const val LENGTH_ENTRY = IMAGE_SIZE * IMAGE_SIZE
private const val NORMALIZATION = 255f
2 changes: 1 addition & 1 deletion backend/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ cat seed.py | python3 manage.py shell

The above script adds the CIFAR10 training data type and model to database.

## Prepare MNIST model for iOS example
## Prepare MNIST model for cross platform Flutter example

Run at `../`:

Expand Down
5 changes: 3 additions & 2 deletions convert_model/coreml.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,9 @@ def try_make_layers_updatable(builder: NeuralNetworkBuilder, limit_last: int = 0
made_updatable = updatable_layer_names[-limit_last:]
builder.make_updatable(made_updatable)
print(f"Made {made_updatable} updatable.")
for updatable_layer in updatable_layers[-limit_last:]:
updatable_layer["updatable"] = True
for updatable_layer in updatable_layers:
if updatable_layer["name"] in made_updatable:
updatable_layer["updatable"] = True
print(f"All updatable layers:\n\t{red(updatable_layers)}")
return updatable_layers

Expand Down
2 changes: 1 addition & 1 deletion convert_model/mnist_eg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def mnist_model():
pool_layer(),
k.layers.Flatten(),
k.layers.Dense(500, activation="relu", kernel_initializer="he_uniform"),
k.layers.Dense(n_classes, activation="relu"),
k.layers.Dense(n_classes),
]
)
model.compile(optimizer="adam", loss="mse", metrics=["accuracy"])
Expand Down
1 change: 1 addition & 0 deletions convert_model/mnist_eg/cm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def main():
try_make_layers_updatable(builder, 2)
builder.inspect_layers()
save_builder(builder, COREML_FILE)
print(f"Successfully converted to Core ML model at {COREML_FILE}.")


if __name__ == "__main__":
Expand Down
3 changes: 2 additions & 1 deletion convert_model/mnist_eg/tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
@tflite_model_class
class MnistTFModel(BaseTFLiteModel):
X_SHAPE = list(in_shape)
Y_SHAPE = [1]
Y_SHAPE = [10]

def __init__(self):
self.model = mnist_model()
Expand All @@ -25,6 +25,7 @@ def tflite():
save_model(model, SAVED_MODEL_DIR)
tflite_model = convert_saved_model(SAVED_MODEL_DIR)
save_tflite_model(tflite_model, TFLITE_FILE)
print(f"Successfully converted to TFLite model at {TFLITE_FILE}.")


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion convert_model/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def save_model(model, saved_model_dir):
init_params = parameters()
restore = model.restore.get_concrete_function(**init_params)
restore_test = restore(**init_params)
print(f"Restore test result: {restore_test}.")
print(f"{len(restore_test)} restore test results.")
tf.saved_model.save(
model,
saved_model_dir,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.MainScope
import kotlinx.coroutines.launch
import org.eu.fedcampus.fed_kit_examples.cifar10.Float3DArray
import org.eu.fedcampus.fed_kit_examples.cifar10.loadData
import org.eu.fedcampus.fed_kit_examples.cifar10.sampleSpec
import org.eu.fedcampus.fed_kit_examples.mnist.loadData
import org.eu.fedcampus.fed_kit_examples.mnist.sampleSpec
import org.eu.fedcampus.fed_kit_train.FlowerClient
import org.eu.fedcampus.fed_kit_train.helpers.loadMappedFile
import java.io.File
Expand All @@ -28,8 +28,8 @@ class MainActivity : FlutterActivity() {
override fun configureFlutterEngine(flutterEngine: FlutterEngine) {
super.configureFlutterEngine(flutterEngine)
val messenger = flutterEngine.dartExecutor.binaryMessenger
MethodChannel(messenger, "fed_kit_client_cifar10_ml_client").setMethodCallHandler(::handle)
EventChannel(messenger, "fed_kit_client_cifar10_ml_client_log").setStreamHandler(object :
MethodChannel(messenger, "fed_kit_client_mnist_ml_client").setMethodCallHandler(::handle)
EventChannel(messenger, "fed_kit_client_mnist_ml_client_log").setStreamHandler(object :
EventChannel.StreamHandler {
override fun onListen(arguments: Any?, eventSink: EventSink?) {
if (eventSink === null) {
Expand Down Expand Up @@ -91,12 +91,13 @@ class MainActivity : FlutterActivity() {
}

fun initML(call: MethodCall, result: Result) = scope.launch(Dispatchers.Default) {
val dataDir = call.argument<String>("dataDir")!!
val modelDir = call.argument<String>("modelDir")!!
val layersSizes = call.argument<List<Int>>("layersSizes")!!.toIntArray()
val partitionId = call.argument<Int>("partitionId")!!
val buffer = loadMappedFile(File(modelDir))
flowerClient = FlowerClient(buffer, layersSizes, sampleSpec())
loadData(this@MainActivity, flowerClient, partitionId)
loadData(dataDir, flowerClient, partitionId)
runOnUiThread { result.success(null) }
}

Expand Down
16 changes: 0 additions & 16 deletions fed_kit_client/ios/Runner.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
3B3967161E833CAA004F5970 /* AppFrameworkInfo.plist in Resources */ = {isa = PBXBuildFile; fileRef = 3B3967151E833CAA004F5970 /* AppFrameworkInfo.plist */; };
3F3379472A91D7490084B7C8 /* MLHelper.swift in Sources */ = {isa = PBXBuildFile; fileRef = 3F3379462A91D7490084B7C8 /* MLHelper.swift */; };
3F3379492A91E9C80084B7C8 /* Helpers.swift in Sources */ = {isa = PBXBuildFile; fileRef = 3F3379482A91E9C80084B7C8 /* Helpers.swift */; };
3F33794D2A91F6EF0084B7C8 /* MNIST_test.csv.lzfse in Resources */ = {isa = PBXBuildFile; fileRef = 3F33794A2A91F6EF0084B7C8 /* MNIST_test.csv.lzfse */; };
3F33794E2A91F6EF0084B7C8 /* MNIST_train.csv.lzfse in Resources */ = {isa = PBXBuildFile; fileRef = 3F33794B2A91F6EF0084B7C8 /* MNIST_train.csv.lzfse */; };
3F4AFE3E2A90656F004B55C8 /* DataLoader.swift in Sources */ = {isa = PBXBuildFile; fileRef = 3F4AFE3D2A90656F004B55C8 /* DataLoader.swift */; };
3F4AFE432A906BAF004B55C8 /* NIOCore in Frameworks */ = {isa = PBXBuildFile; productRef = 3F4AFE422A906BAF004B55C8 /* NIOCore */; };
3F4AFE452A906BAF004B55C8 /* NIOPosix in Frameworks */ = {isa = PBXBuildFile; productRef = 3F4AFE442A906BAF004B55C8 /* NIOPosix */; };
Expand Down Expand Up @@ -96,8 +94,6 @@
3B3967151E833CAA004F5970 /* AppFrameworkInfo.plist */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text.plist.xml; name = AppFrameworkInfo.plist; path = Flutter/AppFrameworkInfo.plist; sourceTree = "<group>"; };
3F3379462A91D7490084B7C8 /* MLHelper.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MLHelper.swift; sourceTree = "<group>"; };
3F3379482A91E9C80084B7C8 /* Helpers.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Helpers.swift; sourceTree = "<group>"; };
3F33794A2A91F6EF0084B7C8 /* MNIST_test.csv.lzfse */ = {isa = PBXFileReference; lastKnownFileType = file; path = MNIST_test.csv.lzfse; sourceTree = "<group>"; };
3F33794B2A91F6EF0084B7C8 /* MNIST_train.csv.lzfse */ = {isa = PBXFileReference; lastKnownFileType = file; path = MNIST_train.csv.lzfse; sourceTree = "<group>"; };
3F4AFE3D2A90656F004B55C8 /* DataLoader.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = DataLoader.swift; sourceTree = "<group>"; };
3F4AFE552A909778004B55C8 /* Scaler.pb.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Scaler.pb.swift; sourceTree = "<group>"; };
3F4AFE562A909778004B55C8 /* AudioFeaturePrint.pb.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = AudioFeaturePrint.pb.swift; sourceTree = "<group>"; };
Expand Down Expand Up @@ -200,21 +196,11 @@
3F3379482A91E9C80084B7C8 /* Helpers.swift */,
3F8854222A90A52700B37487 /* Flwr */,
3F4AFE542A909769004B55C8 /* CoreMLProto */,
3F4AFE4D2A90974E004B55C8 /* MNIST */,
3F3379462A91D7490084B7C8 /* MLHelper.swift */,
);
path = Utils;
sourceTree = "<group>";
};
3F4AFE4D2A90974E004B55C8 /* MNIST */ = {
isa = PBXGroup;
children = (
3F33794A2A91F6EF0084B7C8 /* MNIST_test.csv.lzfse */,
3F33794B2A91F6EF0084B7C8 /* MNIST_train.csv.lzfse */,
);
path = MNIST;
sourceTree = "<group>";
};
3F4AFE542A909769004B55C8 /* CoreMLProto */ = {
isa = PBXGroup;
children = (
Expand Down Expand Up @@ -426,13 +412,11 @@
buildActionMask = 2147483647;
files = (
97C147011CF9000F007C117D /* LaunchScreen.storyboard in Resources */,
3F33794E2A91F6EF0084B7C8 /* MNIST_train.csv.lzfse in Resources */,
3F4AFE912A90977B004B55C8 /* README.rst in Resources */,
3F4AFE832A90977B004B55C8 /* LICENSE.txt in Resources */,
3B3967161E833CAA004F5970 /* AppFrameworkInfo.plist in Resources */,
97C146FE1CF9000F007C117D /* Assets.xcassets in Resources */,
97C146FC1CF9000F007C117D /* Main.storyboard in Resources */,
3F33794D2A91F6EF0084B7C8 /* MNIST_test.csv.lzfse in Resources */,
);
runOnlyForDeploymentPostprocessing = 0;
};
Expand Down
18 changes: 9 additions & 9 deletions fed_kit_client/ios/Runner/AppDelegate.swift
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ let log = logger(String(describing: AppDelegate.self))
let controller = window?.rootViewController as! FlutterViewController
let messenger = controller.binaryMessenger
FlutterMethodChannel(
name: "fed_kit_client_cifar10_ml_client", binaryMessenger: messenger
name: "fed_kit_client_mnist_ml_client", binaryMessenger: messenger
).setMethodCallHandler(handle)
FlutterEventChannel(
name: "fed_kit_client_cifar10_ml_client_log", binaryMessenger: messenger
name: "fed_kit_client_mnist_ml_client_log", binaryMessenger: messenger
).setStreamHandler(self)
}

Expand All @@ -63,7 +63,7 @@ let log = logger(String(describing: AppDelegate.self))
func evaluate(_ result: @escaping FlutterResult) {
runAsync(result) {
let (loss, accuracy) = try await self.mlClient!.evaluate()
let lossAccuracy = [Float(loss), Float(accuracy)]
let lossAccuracy = [loss, accuracy]
return FlutterStandardTypedData(float32: Data(fromArray: lossAccuracy))
}
}
Expand Down Expand Up @@ -106,17 +106,17 @@ let log = logger(String(describing: AppDelegate.self))
func initML(_ call: FlutterMethodCall, _ result: @escaping FlutterResult) {
runAsync(result) {
let args = call.arguments as! [String: Any]
let dataDir = args["dataDir"] as! String
let modelDir = args["modelDir"] as! String
let layers = try (args["layersSizes"] as! [[String: Any]]).map(Layer.init)
log.error("Model layers: \(layers)")
let partitionId = (args["partitionId"] as! NSNumber).intValue
let url = URL(fileURLWithPath: modelDir)
log.error("Accessing: \(url.startAccessingSecurityScopedResource())")
log.error("Model URL: \(url).")
let content = try Data(contentsOf: url)
let modelProto = try ModelProto(data: content)
let loader = try await self.dataLoader(
partitionId, modelProto.input, modelProto.target
dataDir, partitionId, modelProto.input, modelProto.target
)
self.mlClient = try MLClient(layers, loader, url, modelProto)
self.ready = true
Expand All @@ -139,15 +139,15 @@ let log = logger(String(describing: AppDelegate.self))
}

private func dataLoader(
_ partitionId: Int, _ inputName: String, _ outputName: String
_ dataDir: String, _ partitionId: Int, _ inputName: String, _ outputName: String
) async throws -> MLDataLoader {
if dataLoader != nil && self.partitionId == partitionId &&
if dataLoader != nil && partitionId == partitionId &&
self.inputName == inputName && self.outputName == outputName
{
return dataLoader!
}
let trainBatchProvider = try await trainBatchProvider(
partitionId, inputName: inputName, outputName: outputName
dataDir, partitionId, inputName: inputName, outputName: outputName
) { count in
if count % 1000 == 999 {
log.error("Prepared \(count) training data points.")
Expand All @@ -156,7 +156,7 @@ let log = logger(String(describing: AppDelegate.self))
log.error("trainBatchProvider: \(trainBatchProvider.count)")

let testBatchProvider = try await testBatchProvider(
inputName: inputName, outputName: outputName
dataDir, inputName: inputName, outputName: outputName
) { count in
if count % 1000 == 999 {
log.error("Prepared \(count) test data points.")
Expand Down
18 changes: 12 additions & 6 deletions fed_kit_client/ios/Runner/DataLoader.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@ private let nClasses = 10
private let shapeTarget: [NSNumber] = [nClasses as NSNumber]

func trainBatchProvider(
_ dataDir: String,
_ partitionId: Int,
inputName: String,
outputName: String,
progressHandler: @escaping (Int) -> Void
) async throws -> MLBatchProvider {
return try await prepareMLBatchProvider(
filePath: extract("_train"),
filePath: "\(dataDir)/MNIST_train.csv",
inputName: inputName,
outputName: outputName,
progressHandler: progressHandler
Expand All @@ -39,12 +40,13 @@ func trainBatchProvider(
}

func testBatchProvider(
_ dataDir: String,
inputName: String,
outputName: String,
progressHandler: @escaping (Int) -> Void
) async throws -> MLBatchProvider {
return try await prepareMLBatchProvider(
filePath: extract("_test"),
filePath: "\(dataDir)/MNIST_test.csv",
inputName: inputName,
outputName: outputName,
progressHandler: progressHandler
Expand Down Expand Up @@ -104,15 +106,16 @@ private func extractFile(from sourceURL: URL, to destinationURL: URL) throws ->
}

private func prepareMLBatchProvider(
filePath: URL,
filePath: String,
inputName: String,
outputName: String,
progressHandler: @escaping (Int) -> Void,
indexFilter: ((Int) -> Bool)? = nil
) async throws -> MLBatchProvider {
var count = 0
let featureProviders = try await withThrowingTaskGroup(of: MLDictionaryFeatureProvider.self) { group in
let (bytes, _) = try await URLSession.shared.bytes(from: filePath)
let (bytes, _) =
try await URLSession.shared.bytes(from: URL(fileURLWithPath: filePath))
for try await line in bytes.lines {
count += 1
if indexFilter.map({ !$0(count) }) ?? false { continue }
Expand All @@ -122,9 +125,12 @@ private func prepareMLBatchProvider(
let imageMultiArr = try! MLMultiArray(shape: shapeData, dataType: .float32)
let outputMultiArr = try! MLMultiArray(shape: shapeTarget, dataType: .int32)
for i in 0 ..< lengthEntry {
imageMultiArr[i] = (Float(String(splits[i]))! / normalization) as NSNumber
imageMultiArr[i] = (Float(String(splits[i + 1]))! / normalization) as NSNumber
}
outputMultiArr[Int(String(splits.last!))!] = 1
for i in 0 ..< outputMultiArr.count {
outputMultiArr[i] = 0
}
outputMultiArr[Int(String(splits[0]))!] = 1
let imageValue = MLFeatureValue(multiArray: imageMultiArr)
let outputValue = MLFeatureValue(multiArray: outputMultiArr)
let dataPointFeatures: [String: MLFeatureValue] =
Expand Down
Loading

0 comments on commit eb60d46

Please sign in to comment.