-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* download MNIST data set instead of bundling in iOS * fix backend not picking up coreml param * android load MNIST;bug: TFLite shape mismatch * fix android mnist training by changing Y_SHAPE * get Float from loss not Double * fix iOS data loading;iOS MSE&classification accuracy;bug: predict 0s * fix android mnist data loading * no debug print when loading data * drop last relu in model;ios client directly mutate model on disk;clean up logging * fix try_make_layers_updatable marking updatable layers * correct mnist core ml layers * improve printout * change CIFAR10 names to MNIST for flutter example
- Loading branch information
Showing
19 changed files
with
252 additions
and
79 deletions.
There are no files selected for viewing
75 changes: 75 additions & 0 deletions
75
android/fed_kit_examples/src/main/java/org/eu/fedcampus/fed_kit_examples/mnist/mnist.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
package org.eu.fedcampus.fed_kit_examples.mnist | ||
|
||
import android.util.Log | ||
import kotlinx.coroutines.Dispatchers | ||
import kotlinx.coroutines.launch | ||
import kotlinx.coroutines.withContext | ||
import org.eu.fedcampus.fed_kit_examples.cifar10.Float3DArray | ||
import org.eu.fedcampus.fed_kit_train.FlowerClient | ||
import org.eu.fedcampus.fed_kit_train.SampleSpec | ||
import org.eu.fedcampus.fed_kit_train.helpers.classifierAccuracy | ||
import org.eu.fedcampus.fed_kit_train.helpers.maxSquaredErrorLoss | ||
import java.io.File | ||
|
||
fun sampleSpec() = SampleSpec<Float3DArray, FloatArray>( | ||
{ it.toTypedArray() }, | ||
{ it.toTypedArray() }, | ||
{ Array(it) { FloatArray(N_CLASSES) } }, | ||
::maxSquaredErrorLoss, | ||
{ samples, logits -> | ||
samples.zip(logits).forEach { (sample, logit) -> | ||
Log.d( | ||
TAG, | ||
"actual: ${sample.label.contentToString()}, predicted: ${logit.contentToString()}" | ||
) | ||
} | ||
classifierAccuracy(samples, logits) | ||
}, | ||
) | ||
|
||
private suspend fun processSet(dataSetDir: String, call: suspend (Int, String) -> Unit) { | ||
withContext(Dispatchers.IO) { | ||
File(dataSetDir).useLines { | ||
it.forEachIndexed { i, l -> launch { call(i, l) } } | ||
} | ||
} | ||
} | ||
|
||
suspend fun loadData( | ||
dataDir: String, flowerClient: FlowerClient<Float3DArray, FloatArray>, partitionId: Int | ||
) { | ||
processSet("$dataDir/MNIST_train.csv") { index, line -> | ||
if (index / 6000 + 1 != partitionId) { | ||
return@processSet | ||
} | ||
if (index % 1000 == 999) { | ||
Log.i(TAG, "Prepared $index training data points.") | ||
} | ||
addSample(flowerClient, line, true) | ||
} | ||
processSet("$dataDir/MNIST_test.csv") { index, line -> | ||
if (index % 1000 == 999) { | ||
Log.i(TAG, "Prepared $index test data points.") | ||
} | ||
addSample(flowerClient, line, false) | ||
} | ||
} | ||
|
||
private fun addSample( | ||
flowerClient: FlowerClient<Float3DArray, FloatArray>, line: String, isTraining: Boolean | ||
) { | ||
val splits = line.split(",") | ||
val feature = Array(IMAGE_SIZE) { Array(IMAGE_SIZE) { FloatArray(1) } } | ||
val label = FloatArray(N_CLASSES) | ||
for (i in 0 until LENGTH_ENTRY) { | ||
feature[i / IMAGE_SIZE][i % IMAGE_SIZE][0] = splits[i + 1].toFloat() / NORMALIZATION | ||
} | ||
label[splits.first().toInt()] = 1f | ||
flowerClient.addSample(feature, label, isTraining) | ||
} | ||
|
||
private const val TAG = "MNIST Data Loader" | ||
private const val IMAGE_SIZE = 28 | ||
private const val N_CLASSES = 10 | ||
private const val LENGTH_ENTRY = IMAGE_SIZE * IMAGE_SIZE | ||
private const val NORMALIZATION = 255f |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.