Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PMData instead of MNIST #27

Merged
merged 4 commits into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import org.eu.fedcampus.fed_kit_train.helpers.classifierAccuracy
import org.eu.fedcampus.fed_kit_train.helpers.maxSquaredErrorLoss
import java.io.File

fun sampleSpec() = SampleSpec<Float3DArray, FloatArray>(
fun sampleSpec() = SampleSpec<FloatArray, FloatArray>(
{ it.toTypedArray() },
{ it.toTypedArray() },
{ Array(it) { FloatArray(1) } },
Expand All @@ -36,38 +36,30 @@ private suspend fun processSet(dataSetDir: String, call: suspend (Int, String) -
}

suspend fun loadData(
dataDir: String, flowerClient: FlowerClient<Float3DArray, FloatArray>, partitionId: Int
dataDir: String, flowerClient: FlowerClient<FloatArray, FloatArray>, partitionId: Int
) {
processSet("$dataDir/MNIST_train.csv") { index, line ->
if (index / 6000 + 1 != partitionId) {
return@processSet
}
if (index % 1000 == 999) {
Log.i(TAG, "Prepared $index training data points.")
}
// proecess training set
Log.i(TAG, "loading pmdata")
processSet("$dataDir/p${partitionId.toString().padStart(2,'0')}_train.csv") { index, line ->
addSample(flowerClient, line, true)
}
processSet("$dataDir/MNIST_test.csv") { index, line ->
if (index % 1000 == 999) {
Log.i(TAG, "Prepared $index test data points.")
}
// process test set
processSet("$dataDir/test.csv") { index, line ->
addSample(flowerClient, line, false)
}
}

private fun addSample(
flowerClient: FlowerClient<Float3DArray, FloatArray>, line: String, isTraining: Boolean
flowerClient: FlowerClient<FloatArray, FloatArray>, line: String, isTraining: Boolean
) {

val splits = line.split(",")
val feature = Array(IMAGE_SIZE) { Array(IMAGE_SIZE) { FloatArray(1) } }
val label = FloatArray(1)
for (i in 0 until LENGTH_ENTRY) {
feature[i / IMAGE_SIZE][i % IMAGE_SIZE][0] = splits[i + 1].toFloat() / NORMALIZATION
}
if (splits.first().toInt() == 1) {
label[0] = 1f
val label = floatArrayOf(splits[splits.size-1].toFloat())
val featureArray = FloatArray(splits.size-2)
for (i in featureArray.indices){
featureArray[i] = splits[i+1].toFloat()
}
flowerClient.addSample(feature, label, isTraining)
flowerClient.addSample(featureArray, label, isTraining)
}

private const val TAG = "MNIST Data Loader"
Expand Down
93 changes: 93 additions & 0 deletions convert_model/pmdata_eg/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# from .. import tf
import tensorflow as tf

k = tf.keras


def build_model():
model = k.Sequential()
# activation=hp.Choice("activation", ["relu", "tanh", 'linear'])
# Tune the number of layers.
for i in range(2):
model.add(
k.layers.Dense(
# Tune number of units separately.
units=512,
activation="relu",
## TODO: hard code the feature
input_dim=7,
)
)

model.add(k.layers.Dense(1))
# learning_rate = hp.Choice('learning_rate', values=[1e-2, 1e-3, 1e-4])
learning_rate = 0.0013826
# optimizer=hp.Choice(name="optimizer", values=["rmsprop", "adam", "adadelta"])
model.compile(
# optimizer=optimizer,
optimizer=k.optimizers.Adam(learning_rate=learning_rate),
loss="mean_squared_error",
metrics=["mean_absolute_error"],
)
return model


if __name__ == "__main__":
## test code
import pandas as pd

df = pd.read_csv(r"merged_features_efficiency_FINAL2_modified_Beilong.csv")
data = df[
[
"steps",
"calories",
"distance_km",
"very_active_minutes",
"moderately_active_minutes",
"lightly_active_minutes",
"sedentary_minutes",
"efficiency",
]
]
# one-hot encode the values in the gender column with pd.get_dummies.
# data = pd.get_dummies(data, columns=["gender"], prefix="", prefix_sep="")
# shuffle the data. I noticed that suffling the data increased the loss compared with non-shuffled data
data = data.sample(frac=1).reset_index(drop=True)

# f_columns=['age','height','steps','calories','very_active_minutes','moderately_active_minutes','lightly_active_minutes','sedentary_minutes','female','male']
f_columns = [
"steps",
"distance_km",
"calories",
"very_active_minutes",
"moderately_active_minutes",
"lightly_active_minutes",
"sedentary_minutes",
]
n = len(data)
num_train_samples = int(n * 0.7)
num_val_samples = int(n * 0.9) - int(n * 0.7)
num_test_samples = n - int(n * 0.9)

train_data = data[0 : int(n * 0.7)]
val_data = data[int(n * 0.7) : int(n * 0.9)]
test_data = data[int(n * 0.9) :]

train_features = train_data[f_columns]
val_features = val_data[f_columns]
test_features = test_data[f_columns]

train_target = train_data["efficiency"]
val_target = val_data["efficiency"]
test_target = test_data["efficiency"]

model = build_model()
hist = model.fit(
train_features,
train_target,
epochs=300,
callbacks=k.callbacks.EarlyStopping(monitor="val_loss", patience=3),
verbose=1,
validation_data=(val_features, val_target),
)
pass
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.nio.ByteBuffer

class MainActivity : FlutterActivity() {
val scope = MainScope()
lateinit var flowerClient: FlowerClient<Float3DArray, FloatArray>
lateinit var flowerClient: FlowerClient<FloatArray, FloatArray>
var events: EventSink? = null

override fun configureFlutterEngine(flutterEngine: FlutterEngine) {
Expand Down
11 changes: 5 additions & 6 deletions fed_kit_client/lib/main.dart
Original file line number Diff line number Diff line change
Expand Up @@ -238,15 +238,14 @@ class _MyAppState extends State<MyApp> {
Future<int> deviceId() async => (await AppSetId().getIdentifier()).hashCode;

Future<String> downloadDataSet() async {
const url =
'https://github.com/FedCampus/FedKit/files/12707552/MNIST_data.zip';
final tempDir = '${(await getApplicationDocumentsDirectory()).path}/MNIST/';
final zipFile = File('$tempDir/MNIST_data.zip');
const url = 'https://github.com/FedCampus/FedKit/files/12744347/pmdata.zip';
final tempDir = '${(await getApplicationDocumentsDirectory()).path}/pmdata/';
final zipFile = File('$tempDir/pmdata.zip');
if (!await zipFile.exists()) {
await dio.download(url, zipFile.path);
}
final trainFile = File('$tempDir/MNIST_train.csv');
final testFile = File('$tempDir/MNIST_test.csv');
final trainFile = File('$tempDir/p01_train.csv');
final testFile = File('$tempDir/test.csv');
if (await trainFile.exists() && await testFile.exists()) {
return tempDir;
}
Expand Down