From afd7f4958583b565d51ff56a1fdabe193d710131 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Steven=20H=C3=A9=20=28S=C4=ABch=C3=A0ng=29?= Date: Mon, 25 Sep 2023 22:13:44 +0800 Subject: [PATCH] sync android mnist binary loading;fix ios data type --- .../org/eu/fedcampus/fed_kit_examples/mnist/mnist.kt | 9 +++++---- backend/train/run.py | 2 +- fed_kit_client/ios/Runner/DataLoader.swift | 6 +++--- fed_kit_client/ios/Runner/Utils/Flwr/MLClient.swift | 4 ++-- fed_kit_client/ios/Runner/Utils/MLHelper.swift | 2 +- 5 files changed, 12 insertions(+), 11 deletions(-) diff --git a/android/fed_kit_examples/src/main/java/org/eu/fedcampus/fed_kit_examples/mnist/mnist.kt b/android/fed_kit_examples/src/main/java/org/eu/fedcampus/fed_kit_examples/mnist/mnist.kt index 1c83a2a..dae0e5b 100644 --- a/android/fed_kit_examples/src/main/java/org/eu/fedcampus/fed_kit_examples/mnist/mnist.kt +++ b/android/fed_kit_examples/src/main/java/org/eu/fedcampus/fed_kit_examples/mnist/mnist.kt @@ -14,7 +14,7 @@ import java.io.File fun sampleSpec() = SampleSpec( { it.toTypedArray() }, { it.toTypedArray() }, - { Array(it) { FloatArray(N_CLASSES) } }, + { Array(it) { FloatArray(1) } }, ::maxSquaredErrorLoss, { samples, logits -> samples.zip(logits).forEach { (sample, logit) -> @@ -60,16 +60,17 @@ private fun addSample( ) { val splits = line.split(",") val feature = Array(IMAGE_SIZE) { Array(IMAGE_SIZE) { FloatArray(1) } } - val label = FloatArray(N_CLASSES) + val label = FloatArray(1) for (i in 0 until LENGTH_ENTRY) { feature[i / IMAGE_SIZE][i % IMAGE_SIZE][0] = splits[i + 1].toFloat() / NORMALIZATION } - label[splits.first().toInt()] = 1f + if (splits.first().toInt() == 1) { + label[0] = 1f + } flowerClient.addSample(feature, label, isTraining) } private const val TAG = "MNIST Data Loader" private const val IMAGE_SIZE = 28 -private const val N_CLASSES = 10 private const val LENGTH_ENTRY = IMAGE_SIZE * IMAGE_SIZE private const val NORMALIZATION = 255f diff --git a/backend/train/run.py b/backend/train/run.py index 08f0f00..da9aa39 100644 --- a/backend/train/run.py +++ b/backend/train/run.py @@ -69,7 +69,7 @@ def fit_config(_: int): """ config: dict[str, Scalar] = { "batch_size": 32, - "local_epochs": 2, + "local_epochs": 10, } return config diff --git a/fed_kit_client/ios/Runner/DataLoader.swift b/fed_kit_client/ios/Runner/DataLoader.swift index a063477..3d9539d 100644 --- a/fed_kit_client/ios/Runner/DataLoader.swift +++ b/fed_kit_client/ios/Runner/DataLoader.swift @@ -122,14 +122,14 @@ private func prepareMLBatchProvider( group.addTask { let splits = line.split(separator: ",") let imageMultiArr = try! MLMultiArray(shape: shapeData, dataType: .float32) - let outputMultiArr = try! MLMultiArray(shape: shapeTarget, dataType: .int32) + let outputMultiArr = try! MLMultiArray(shape: shapeTarget, dataType: .double) for i in 0 ..< lengthEntry { imageMultiArr[i] = (Float(String(splits[i + 1]))! / normalization) as NSNumber } if Int(String(splits[0]))! == 1 { - outputMultiArr[0] = 1 + outputMultiArr[0] = 1.0 } else { - outputMultiArr[0] = 0 + outputMultiArr[0] = 0.0 } let imageValue = MLFeatureValue(multiArray: imageMultiArr) let outputValue = MLFeatureValue(multiArray: outputMultiArr) diff --git a/fed_kit_client/ios/Runner/Utils/Flwr/MLClient.swift b/fed_kit_client/ios/Runner/Utils/Flwr/MLClient.swift index 3995cc7..cd302e1 100644 --- a/fed_kit_client/ios/Runner/Utils/Flwr/MLClient.swift +++ b/fed_kit_client/ios/Runner/Utils/Flwr/MLClient.swift @@ -103,9 +103,9 @@ public class MLClient { } let actl = batch.features(at: index).featureValue(for: modelProto.target)! let prediction = try pred.multiArrayValue!.toArray(type: Float.self) - let actual = try actl.multiArrayValue!.toArray(type: Int32.self) + let actual = try actl.multiArrayValue!.toArray(type: Double.self) totalLoss += meanSquareErrors(prediction, actual) - if actual.argmax() == prediction.argmax() { + if (actual[0] == 0 && prediction[0] < 0.5) || (actual[0] != 0 && prediction[0] >= 0.5) { nCorrect += 1 } } diff --git a/fed_kit_client/ios/Runner/Utils/MLHelper.swift b/fed_kit_client/ios/Runner/Utils/MLHelper.swift index 50d7f39..cd7fef1 100644 --- a/fed_kit_client/ios/Runner/Utils/MLHelper.swift +++ b/fed_kit_client/ios/Runner/Utils/MLHelper.swift @@ -70,7 +70,7 @@ struct Layer { } /// Assuming the two have the same length. -func meanSquareErrors(_ a: [Float], _ b: [Int32]) -> Float { +func meanSquareErrors(_ a: [Float], _ b: [Double]) -> Float { var sum: Float = 0.0 for (index, value) in a.enumerated() { let diff = value - Float(b[index])