-
Notifications
You must be signed in to change notification settings - Fork 216
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Mansi Sharma <[email protected]>
- Loading branch information
Mansi Sharma
committed
Dec 18, 2023
1 parent
1da47a0
commit 1739720
Showing
8 changed files
with
535 additions
and
400 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
327 changes: 15 additions & 312 deletions
327
openfl-tutorials/experimental/Workflow_Interface_104_MNIST_with_fedprox.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
431 changes: 431 additions & 0 deletions
431
openfl-tutorials/experimental/Workflow_Interface_105_MNIST_with_fedcurv.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
29 changes: 29 additions & 0 deletions
29
openfl/experimental/interface/aggregation_functions/fedcurv.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
"""openfl.experimental.interface.aggregation_functions.fedcurv_weighted_average package.""" | ||
|
||
import numpy as np | ||
from typing import Tuple | ||
|
||
from .fedavg import FedAvg, weighted_average | ||
|
||
class FedCurvAgg(FedAvg): | ||
"""Fedcurv Weighted average aggregation. | ||
Applies weighted average aggregation to all tensors | ||
except Fisher matrices variables (u_t, v_t). | ||
These variables are summed without weights. | ||
FedCurv paper: https://arxiv.org/pdf/1910.07796.pdf | ||
""" | ||
|
||
def __init__(self, **kwargs) -> None: | ||
super().__init__(**kwargs) | ||
|
||
def __repr__(self) -> str: | ||
"""Compute a string representation of the function.""" | ||
rep = "FedCurv Aggregation Function" | ||
return rep | ||
|
||
def aggregate_models(self, model_weights, fisher_matrix_model_weights, collaborator_weights) -> Tuple[np.ndarray, np.ndarray]: | ||
"""Compute Fedcurv weighted average.""" | ||
# For Fisher variables, compute sum and for non Fisher elements, compute average | ||
return (np.sum(fisher_matrix_model_weights, axis=0), weighted_average(model_weights, collaborator_weights)) |
23 changes: 0 additions & 23 deletions
23
openfl/experimental/interface/aggregation_functions/fedcurv_weighted_average.py
This file was deleted.
Oops, something went wrong.
32 changes: 17 additions & 15 deletions
32
openfl/experimental/interface/aggregation_functions/fedprox.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,28 +1,30 @@ | ||
"""openfl.experimental.interface.aggregation_functions.fedprox package.""" | ||
|
||
import numpy as np | ||
from typing import List | ||
|
||
from .fedavg import FedAvg, weighted_average | ||
|
||
from .weighted_average import weighted_average, WeightedAverage | ||
|
||
class FedProxAgg(FedAvg): | ||
"""FedProx aggregation. | ||
A representation of FedAvg with inclusion of the | ||
proximal term for heterogeneous data distributions | ||
FedProx paper: https://arxiv.org/pdf/1812.06127.pdf | ||
""" | ||
|
||
class FedProx(WeightedAverage): | ||
"""Weighted average aggregation.""" | ||
def __init__(self, **kwargs) -> None: | ||
super().__init__(**kwargs) | ||
|
||
def __repr__(self) -> str: | ||
"""Compute a string representation of the function.""" | ||
rep = "FedProx" | ||
rep = "FedProx Aggregation Function" | ||
return rep | ||
|
||
def aggregate_metrics(self, metrics, train_weights, test_weights): | ||
def aggregate_metrics(self, metrics, weights) -> List[np.ndarray]: | ||
"""Weighted average of loss and accuracy metrics""" | ||
agg_model_loss_list, agg_model_accuracy_list = metrics | ||
aggregated_model_training_loss = weighted_average(agg_model_loss_list, train_weights) | ||
aggregated_model_test_accuracy = weighted_average(agg_model_accuracy_list, test_weights) | ||
|
||
return (aggregated_model_training_loss, aggregated_model_test_accuracy) | ||
|
||
|
||
|
||
|
||
agg_metrics = [] | ||
for idx,metric in enumerate(metrics): | ||
agg_metrics.append(weighted_average(metric, weights=weights[idx])) | ||
return agg_metrics |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters