Skip to content

Commit

Permalink
Extract iOS training library (#29)
Browse files Browse the repository at this point in the history
* mv out common library code

* README for ios package

* fix ios ci

* migrate the config as xcode wish

* rm unused async&func

* invert control of loss&accuracy calculation
  • Loading branch information
SichangHe committed Oct 9, 2023
1 parent 22534e1 commit 953c154
Show file tree
Hide file tree
Showing 43 changed files with 44 additions and 95 deletions.
7 changes: 0 additions & 7 deletions .github/workflows/flutter-client-ios.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,6 @@ jobs:
- uses: actions/checkout@v3
with:
lfs: true
- name: Download MNIST data set
run: |
cd fed_kit_client/ios/Runner/Utils/
mkdir MNIST
cd MNIST
curl -LO https://github.com/FedCampus/FedKit/files/12543877/MNIST.zip
unzip MNIST.zip
- uses: subosito/flutter-action@v2
with:
cache: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public class MLClient {
}

func fit(epochs: Int? = nil, callback: ((Float) -> Void)? = nil) async throws {
let config = try await config()
let config = try config()
if epochs != nil {
config.parameters![MLParameterKey.epochs] = epochs
}
Expand Down Expand Up @@ -86,39 +86,33 @@ public class MLClient {
try saveModel(updateContext)
}

/// Currently, calculates Mean Square Error and category correctness.
func evaluate() async throws -> (Float, Float) {
let config = try await config()
func evaluate(
_ singleLossAccuracy: (MLFeatureProvider, MLFeatureProvider) throws -> (Float, Float)
) throws -> (Float, Float) {
let config = try config()
let model = try MLModel(contentsOf: compiledModelUrl)
let batch = dataLoader.testBatchProvider
let predictions = try model.predictions(fromBatch: batch)

var totalLoss: Float = 0.0
var nCorrect = 0
var totalLoss: Float = 0
var totalAccuracy: Float = 0
for index in 0 ..< predictions.count {
guard let pred =
predictions.features(at: index).featureValue(for: modelProto.output)
else {
throw MLClientErr.FeatureNoValue
}
let actl = batch.features(at: index).featureValue(for: modelProto.target)!
let prediction = try pred.multiArrayValue!.toArray(type: Float.self)
let actual = try actl.multiArrayValue!.toArray(type: Int32.self)
totalLoss += meanSquareErrors(prediction, actual)
if actual.argmax() == prediction.argmax() {
nCorrect += 1
}
let prediction = predictions.features(at: index)
let actual = batch.features(at: index)
let (loss, accuracy) = try singleLossAccuracy(actual, prediction)
totalLoss += loss
totalAccuracy += accuracy
}

let total = Float(predictions.count)
let loss = totalLoss / total
let accuracy = Float(nCorrect) / total
let accuracy = totalAccuracy / total
return (loss, accuracy)
}

/// Guarantee that the config returned has non-nil `parameters`.
/// Update `compiledModelUrl` to match new parameters.
private func config() async throws -> MLModelConfiguration {
private func config() throws -> MLModelConfiguration {
let config = MLModelConfiguration()
if config.parameters == nil {
config.parameters = [:]
Expand Down
3 changes: 3 additions & 0 deletions fed_kit_client/ios/FedKitTrain/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# FedKit Training Swift Package

To use this package, open the example client in Xcode and copy the files.
27 changes: 13 additions & 14 deletions fed_kit_client/ios/Runner.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
3F4AFE982A90977B004B55C8 /* SVM.pb.swift in Sources */ = {isa = PBXBuildFile; fileRef = 3F4AFE762A90977B004B55C8 /* SVM.pb.swift */; };
3F4AFE9E2A909E77004B55C8 /* SwiftProtobuf in Frameworks */ = {isa = PBXBuildFile; productRef = 3F4AFE9D2A909E77004B55C8 /* SwiftProtobuf */; };
3F8854242A90A53400B37487 /* MLClient.swift in Sources */ = {isa = PBXBuildFile; fileRef = 3F8854232A90A53400B37487 /* MLClient.swift */; };
3FCC6BCB2AD2424100A73E1F /* README.md in Resources */ = {isa = PBXBuildFile; fileRef = 3FCC6BCA2AD2424100A73E1F /* README.md */; };
74858FAF1ED2DC5600515810 /* AppDelegate.swift in Sources */ = {isa = PBXBuildFile; fileRef = 74858FAE1ED2DC5600515810 /* AppDelegate.swift */; };
97C146FC1CF9000F007C117D /* Main.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 97C146FA1CF9000F007C117D /* Main.storyboard */; };
97C146FE1CF9000F007C117D /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 97C146FD1CF9000F007C117D /* Assets.xcassets */; };
Expand Down Expand Up @@ -130,6 +131,7 @@
3F4AFE752A90977B004B55C8 /* LinkedModel.pb.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = LinkedModel.pb.swift; sourceTree = "<group>"; };
3F4AFE762A90977B004B55C8 /* SVM.pb.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = SVM.pb.swift; sourceTree = "<group>"; };
3F8854232A90A53400B37487 /* MLClient.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = MLClient.swift; sourceTree = "<group>"; };
3FCC6BCA2AD2424100A73E1F /* README.md */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = "<group>"; };
74858FAD1ED2DC5600515810 /* Runner-Bridging-Header.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = "Runner-Bridging-Header.h"; sourceTree = "<group>"; };
74858FAE1ED2DC5600515810 /* AppDelegate.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = AppDelegate.swift; sourceTree = "<group>"; };
7AFA3C8E1D35360C0083082E /* Release.xcconfig */ = {isa = PBXFileReference; lastKnownFileType = text.xcconfig; name = Release.xcconfig; path = Flutter/Release.xcconfig; sourceTree = "<group>"; };
Expand Down Expand Up @@ -190,15 +192,16 @@
path = RunnerTests;
sourceTree = "<group>";
};
3F4AFE4C2A9096D2004B55C8 /* Utils */ = {
3F4AFE4C2A9096D2004B55C8 /* FedKitTrain */ = {
isa = PBXGroup;
children = (
3FCC6BCA2AD2424100A73E1F /* README.md */,
3F8854232A90A53400B37487 /* MLClient.swift */,
3F3379462A91D7490084B7C8 /* MLHelper.swift */,
3F3379482A91E9C80084B7C8 /* Helpers.swift */,
3F8854222A90A52700B37487 /* Flwr */,
3F4AFE542A909769004B55C8 /* CoreMLProto */,
3F3379462A91D7490084B7C8 /* MLHelper.swift */,
);
path = Utils;
path = FedKitTrain;
sourceTree = "<group>";
};
3F4AFE542A909769004B55C8 /* CoreMLProto */ = {
Expand Down Expand Up @@ -242,14 +245,6 @@
path = CoreMLProto;
sourceTree = "<group>";
};
3F8854222A90A52700B37487 /* Flwr */ = {
isa = PBXGroup;
children = (
3F8854232A90A53400B37487 /* MLClient.swift */,
);
path = Flwr;
sourceTree = "<group>";
};
5603D2B3AEE34821ABBC862F /* Frameworks */ = {
isa = PBXGroup;
children = (
Expand All @@ -273,6 +268,7 @@
97C146E51CF9000F007C117D = {
isa = PBXGroup;
children = (
3F4AFE4C2A9096D2004B55C8 /* FedKitTrain */,
9740EEB11CF90186004384FC /* Flutter */,
97C146F01CF9000F007C117D /* Runner */,
97C146EF1CF9000F007C117D /* Products */,
Expand All @@ -294,7 +290,6 @@
97C146F01CF9000F007C117D /* Runner */ = {
isa = PBXGroup;
children = (
3F4AFE4C2A9096D2004B55C8 /* Utils */,
97C146FA1CF9000F007C117D /* Main.storyboard */,
97C146FD1CF9000F007C117D /* Assets.xcassets */,
97C146FF1CF9000F007C117D /* LaunchScreen.storyboard */,
Expand Down Expand Up @@ -363,7 +358,7 @@
97C146E61CF9000F007C117D /* Project object */ = {
isa = PBXProject;
attributes = {
LastUpgradeCheck = 1300;
LastUpgradeCheck = 1400;
ORGANIZATIONNAME = "";
TargetAttributes = {
331C8080294A63A400263BE5 = {
Expand Down Expand Up @@ -414,6 +409,7 @@
97C147011CF9000F007C117D /* LaunchScreen.storyboard in Resources */,
3F4AFE912A90977B004B55C8 /* README.rst in Resources */,
3F4AFE832A90977B004B55C8 /* LICENSE.txt in Resources */,
3FCC6BCB2AD2424100A73E1F /* README.md in Resources */,
3B3967161E833CAA004F5970 /* AppFrameworkInfo.plist in Resources */,
97C146FE1CF9000F007C117D /* Assets.xcassets in Resources */,
97C146FC1CF9000F007C117D /* Main.storyboard in Resources */,
Expand Down Expand Up @@ -624,6 +620,7 @@
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
CLANG_WARN_STRICT_PROTOTYPES = YES;
CLANG_WARN_SUSPICIOUS_MOVE = YES;
Expand Down Expand Up @@ -746,6 +743,7 @@
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
CLANG_WARN_STRICT_PROTOTYPES = YES;
CLANG_WARN_SUSPICIOUS_MOVE = YES;
Expand Down Expand Up @@ -801,6 +799,7 @@
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
CLANG_WARN_STRICT_PROTOTYPES = YES;
CLANG_WARN_SUSPICIOUS_MOVE = YES;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<Scheme
LastUpgradeVersion = "1300"
LastUpgradeVersion = "1400"
version = "1.3">
<BuildAction
parallelizeBuildables = "YES"
Expand Down
14 changes: 13 additions & 1 deletion fed_kit_client/ios/Runner/AppDelegate.swift
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,19 @@ let log = logger(String(describing: AppDelegate.self))

func evaluate(_ result: @escaping FlutterResult) {
runAsync(result) {
let (loss, accuracy) = try await self.mlClient!.evaluate()
let (loss, accuracy) = try self.mlClient!.evaluate { a, p in
guard let pred =
p.featureValue(for: self.mlClient!.modelProto.output)
else {
throw MLClientErr.FeatureNoValue
}
let actl = a.featureValue(for: self.mlClient!.modelProto.target)!
let prediction = try pred.multiArrayValue!.toArray(type: Float.self)
let actual = try actl.multiArrayValue!.toArray(type: Int32.self)
let loss = meanSquareErrors(prediction, actual)
let accuracy: Float = actual.argmax() == prediction.argmax() ? 1 : 0
return (loss, accuracy)
}
let lossAccuracy = [loss, accuracy]
return FlutterStandardTypedData(float32: Data(fromArray: lossAccuracy))
}
Expand Down
52 changes: 0 additions & 52 deletions fed_kit_client/ios/Runner/DataLoader.swift
Original file line number Diff line number Diff line change
Expand Up @@ -53,58 +53,6 @@ func testBatchProvider(
)
}

private func extract(_ set: String) throws -> URL {
let sourceURL = Bundle.main.url(forResource: dataset + set, withExtension: "csv.lzfse")!
let destinationURL = appDirectory.appendingPathComponent(dataset + set + ".csv")
return try extractFile(from: sourceURL, to: destinationURL)
}

/// Extract file
///
/// - parameter sourceURL:URL of source file
/// - parameter destinationFilename: Choosen destination filename
///
/// - returns: Temporary path of extracted file
private func extractFile(from sourceURL: URL, to destinationURL: URL) throws -> URL {
let fileManager = FileManager.default
if fileManager.fileExists(atPath: destinationURL.path) {
return destinationURL
}
var isDir: ObjCBool = true
if !fileManager.fileExists(atPath: appDirectory.path, isDirectory: &isDir) {
try fileManager.createDirectory(at: appDirectory, withIntermediateDirectories: true)
}
fileManager.createFile(atPath: destinationURL.path, contents: nil, attributes: nil)

let destinationFileHandle = try FileHandle(forWritingTo: destinationURL)
defer {
destinationFileHandle.closeFile()
}

let filter = try OutputFilter(.decompress, using: .lzfse) { data in
if let data = data {
destinationFileHandle.write(data)
}
}

let sourceFileHandle = try FileHandle(forReadingFrom: sourceURL)
defer {
sourceFileHandle.closeFile()
}
let bufferSize = 65536
while true {
let data = sourceFileHandle.readData(ofLength: bufferSize)

try filter.write(data)
if data.count < bufferSize {
break
}
}
try filter.finalize()

return destinationURL
}

private func prepareMLBatchProvider(
filePath: String,
inputName: String,
Expand Down

0 comments on commit 953c154

Please sign in to comment.