-
Notifications
You must be signed in to change notification settings - Fork 215
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
initial commit for fedavg and fedprox
Signed-off-by: Mansi Sharma <[email protected]>
- Loading branch information
Mansi Sharma
committed
Dec 8, 2023
1 parent
d4108fb
commit 1da47a0
Showing
8 changed files
with
949 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
795 changes: 795 additions & 0 deletions
795
openfl-tutorials/experimental/Workflow_Interface_104_MNIST_with_fedprox.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
Empty file.
23 changes: 23 additions & 0 deletions
23
openfl/experimental/interface/aggregation_functions/fedcurv_weighted_average.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
"""openfl.experimental.interface.aggregation_functions.fedcurv_weighted_average package.""" | ||
|
||
import numpy as np | ||
|
||
from .interface import AggregationFunction | ||
from openfl.interface.aggregation_functions.weighted_average import weighted_average | ||
|
||
class FedcurvWeightedAverage(WeightedAverage): | ||
"""Fedcurv Weighted average aggregation.""" | ||
def __init__(self, **kwargs) -> None: | ||
super().__init__(**kwargs) | ||
|
||
def aggregate_models(self, model_weights, collaborator_weights, state_keys): | ||
"""Compute weighted average.""" | ||
return weighted_average(model_weights, collaborator_weights) | ||
|
||
def __call__(self, tensors_list, weights_list, state_dict_keys_list) -> np.ndarray: | ||
final_weights_list = [] | ||
for key,val in dict_.items(): | ||
if (key[-2:] == '_u' or key[-2:] == '_v' or key[-2:] == '_w'): | ||
final_weights_list.append(np.sum()) | ||
continue | ||
final_weights_list.append(np.average()) |
28 changes: 28 additions & 0 deletions
28
openfl/experimental/interface/aggregation_functions/fedprox.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
"""openfl.experimental.interface.aggregation_functions.fedprox package.""" | ||
|
||
import numpy as np | ||
|
||
from .weighted_average import weighted_average, WeightedAverage | ||
|
||
|
||
class FedProx(WeightedAverage): | ||
"""Weighted average aggregation.""" | ||
def __init__(self, **kwargs) -> None: | ||
super().__init__(**kwargs) | ||
|
||
def __repr__(self) -> str: | ||
"""Compute a string representation of the function.""" | ||
rep = "FedProx" | ||
return rep | ||
|
||
def aggregate_metrics(self, metrics, train_weights, test_weights): | ||
"""Weighted average of loss and accuracy metrics""" | ||
agg_model_loss_list, agg_model_accuracy_list = metrics | ||
aggregated_model_training_loss = weighted_average(agg_model_loss_list, train_weights) | ||
aggregated_model_test_accuracy = weighted_average(agg_model_accuracy_list, test_weights) | ||
|
||
return (aggregated_model_training_loss, aggregated_model_test_accuracy) | ||
|
||
|
||
|
||
|
19 changes: 19 additions & 0 deletions
19
openfl/experimental/interface/aggregation_functions/interface.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Copyright (C) 2020-2023 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
"""Aggregation function interface module.""" | ||
from abc import abstractmethod | ||
|
||
from openfl.utilities import SingletonABCMeta | ||
|
||
class AggregationFunction(metaclass=SingletonABCMeta): | ||
"""Abstract base class for specifying aggregation functions.""" | ||
|
||
@abstractmethod | ||
def aggregate_models(self, **kwargs): | ||
"""Aggregate training results using algorithms""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def aggregate_metrics(self, **kwargs): | ||
"""Aggregate evaluation metrics""" | ||
raise NotImplementedError |
34 changes: 34 additions & 0 deletions
34
openfl/experimental/interface/aggregation_functions/weighted_average.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
"""openfl.experimental.interface.aggregation_functions.weighted_average package.""" | ||
|
||
import numpy as np | ||
|
||
from .interface import AggregationFunction | ||
|
||
|
||
def weighted_average(tensors, weights): | ||
"""Compute weighted average""" | ||
return np.average(tensors, weights=weights, axis=0) | ||
|
||
|
||
|
||
class WeightedAverage(AggregationFunction): | ||
"""Weighted average aggregation.""" | ||
|
||
def __init__(self, **kwargs) -> None: | ||
super().__init__(**kwargs) | ||
|
||
def __repr__(self) -> str: | ||
"""Compute a string representation of the function.""" | ||
rep = "FedAvg" | ||
return rep | ||
|
||
|
||
def aggregate_models(self, model_weights, collaborator_weights): | ||
"""Compute weighted average.""" | ||
return weighted_average(model_weights, collaborator_weights) | ||
|
||
|
||
def aggregate_metrics(self, **kwargs): | ||
"""Aggregate loss""" | ||
pass | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters