Skip to content

Commit

Permalink
sync android mnist binary loading;fix ios data type
Browse files Browse the repository at this point in the history
  • Loading branch information
SichangHe committed Sep 25, 2023
1 parent cf28ea2 commit afd7f49
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import java.io.File
fun sampleSpec() = SampleSpec<Float3DArray, FloatArray>(
{ it.toTypedArray() },
{ it.toTypedArray() },
{ Array(it) { FloatArray(N_CLASSES) } },
{ Array(it) { FloatArray(1) } },
::maxSquaredErrorLoss,
{ samples, logits ->
samples.zip(logits).forEach { (sample, logit) ->
Expand Down Expand Up @@ -60,16 +60,17 @@ private fun addSample(
) {
val splits = line.split(",")
val feature = Array(IMAGE_SIZE) { Array(IMAGE_SIZE) { FloatArray(1) } }
val label = FloatArray(N_CLASSES)
val label = FloatArray(1)
for (i in 0 until LENGTH_ENTRY) {
feature[i / IMAGE_SIZE][i % IMAGE_SIZE][0] = splits[i + 1].toFloat() / NORMALIZATION
}
label[splits.first().toInt()] = 1f
if (splits.first().toInt() == 1) {
label[0] = 1f
}
flowerClient.addSample(feature, label, isTraining)
}

private const val TAG = "MNIST Data Loader"
private const val IMAGE_SIZE = 28
private const val N_CLASSES = 10
private const val LENGTH_ENTRY = IMAGE_SIZE * IMAGE_SIZE
private const val NORMALIZATION = 255f
2 changes: 1 addition & 1 deletion backend/train/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def fit_config(_: int):
"""
config: dict[str, Scalar] = {
"batch_size": 32,
"local_epochs": 2,
"local_epochs": 10,
}
return config

Expand Down
6 changes: 3 additions & 3 deletions fed_kit_client/ios/Runner/DataLoader.swift
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,14 @@ private func prepareMLBatchProvider(
group.addTask {
let splits = line.split(separator: ",")
let imageMultiArr = try! MLMultiArray(shape: shapeData, dataType: .float32)
let outputMultiArr = try! MLMultiArray(shape: shapeTarget, dataType: .int32)
let outputMultiArr = try! MLMultiArray(shape: shapeTarget, dataType: .double)
for i in 0 ..< lengthEntry {
imageMultiArr[i] = (Float(String(splits[i + 1]))! / normalization) as NSNumber
}
if Int(String(splits[0]))! == 1 {
outputMultiArr[0] = 1
outputMultiArr[0] = 1.0
} else {
outputMultiArr[0] = 0
outputMultiArr[0] = 0.0
}
let imageValue = MLFeatureValue(multiArray: imageMultiArr)
let outputValue = MLFeatureValue(multiArray: outputMultiArr)
Expand Down
4 changes: 2 additions & 2 deletions fed_kit_client/ios/Runner/Utils/Flwr/MLClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,9 @@ public class MLClient {
}
let actl = batch.features(at: index).featureValue(for: modelProto.target)!
let prediction = try pred.multiArrayValue!.toArray(type: Float.self)
let actual = try actl.multiArrayValue!.toArray(type: Int32.self)
let actual = try actl.multiArrayValue!.toArray(type: Double.self)
totalLoss += meanSquareErrors(prediction, actual)
if actual.argmax() == prediction.argmax() {
if (actual[0] == 0 && prediction[0] < 0.5) || (actual[0] != 0 && prediction[0] >= 0.5) {
nCorrect += 1
}
}
Expand Down
2 changes: 1 addition & 1 deletion fed_kit_client/ios/Runner/Utils/MLHelper.swift
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ struct Layer {
}

/// Assuming the two have the same length.
func meanSquareErrors(_ a: [Float], _ b: [Int32]) -> Float {
func meanSquareErrors(_ a: [Float], _ b: [Double]) -> Float {
var sum: Float = 0.0
for (index, value) in a.enumerated() {
let diff = value - Float(b[index])
Expand Down

0 comments on commit afd7f49

Please sign in to comment.