Skip to content

Commit

Permalink
initial commit for fedavg and fedprox
Browse files Browse the repository at this point in the history
Signed-off-by: Mansi Sharma <[email protected]>
  • Loading branch information
Mansi Sharma committed Dec 8, 2023
1 parent d4108fb commit 1da47a0
Show file tree
Hide file tree
Showing 8 changed files with 949 additions and 21 deletions.
70 changes: 49 additions & 21 deletions openfl-tutorials/experimental/Workflow_Interface_101_MNIST.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@
"metadata": {},
"outputs": [],
"source": [
"# !pip install torch\n",
"# !pip install torchvision\n",
"\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim\n",
Expand All @@ -103,14 +106,14 @@
"torch.backends.cudnn.enabled = False\n",
"torch.manual_seed(random_seed)\n",
"\n",
"mnist_train = torchvision.datasets.MNIST('files/', train=True, download=True,\n",
"mnist_train = torchvision.datasets.MNIST('files/', train=True, download=False,\n",
" transform=torchvision.transforms.Compose([\n",
" torchvision.transforms.ToTensor(),\n",
" torchvision.transforms.Normalize(\n",
" (0.1307,), (0.3081,))\n",
" ]))\n",
"\n",
"mnist_test = torchvision.datasets.MNIST('files/', train=False, download=True,\n",
"mnist_test = torchvision.datasets.MNIST('files/', train=False, download=False,\n",
" transform=torchvision.transforms.Compose([\n",
" torchvision.transforms.ToTensor(),\n",
" torchvision.transforms.Normalize(\n",
Expand Down Expand Up @@ -178,17 +181,8 @@
"from openfl.experimental.runtime import LocalRuntime\n",
"from openfl.experimental.placement import aggregator, collaborator\n",
"\n",
"\n",
"def FedAvg(models, weights=None):\n",
" new_model = models[0]\n",
" state_dicts = [model.state_dict() for model in models]\n",
" state_dict = new_model.state_dict()\n",
" for key in models[1].state_dict():\n",
" state_dict[key] = torch.from_numpy(np.average([state[key].numpy() for state in state_dicts],\n",
" axis=0, \n",
" weights=weights))\n",
" new_model.load_state_dict(state_dict)\n",
" return new_model"
"from openfl.plugins.frameworks_adapters.pytorch_adapter import FrameworkAdapterPlugin as fa\n",
"from openfl.experimental.interface.aggregation_functions.weighted_average import WeightedAverage\n"
]
},
{
Expand Down Expand Up @@ -244,7 +238,9 @@
" self.next(self.train)\n",
"\n",
" @collaborator\n",
" def train(self):\n",
" def train(self): \n",
" self.train_dataset_length = len(self.train_loader.dataset)\n",
" \n",
" self.model.train()\n",
" self.optimizer = optim.SGD(self.model.parameters(), lr=learning_rate,\n",
" momentum=momentum)\n",
Expand Down Expand Up @@ -279,8 +275,26 @@
" print(f'Average aggregated model validation values = {self.aggregated_model_accuracy}')\n",
" print(f'Average training loss = {self.average_loss}')\n",
" print(f'Average local model validation values = {self.local_model_accuracy}')\n",
" self.model = FedAvg([input.model for input in inputs])\n",
" self.optimizer = [input.optimizer for input in inputs][0]\n",
"\n",
" train_tensors, train_weights=[],[]\n",
" agg_func = WeightedAverage()\n",
" for input in inputs:\n",
" train_weights.append(input.train_dataset_length/len(mnist_train.data))\n",
" keys_list, inner_tensors_list = [], []\n",
" for k,v in (fa.get_tensor_dict(input.model, input.optimizer)).items():\n",
" if k == '__opt_state_needed':\n",
" continue\n",
" else:\n",
" inner_tensors_list.append(v)\n",
" keys_list.append(k)\n",
" train_tensors.append(inner_tensors_list)\n",
" avg_tensors = agg_func.aggregate_models(train_tensors, train_weights)\n",
" \n",
" state_dict = dict(zip(keys_list, avg_tensors))\n",
" # Add back __opt_state_needed key\n",
" state_dict['__opt_state_needed'] = 'true'\n",
" fa.set_tensor_dict(self.model, state_dict, self.optimizer)\n",
" \n",
" self.current_round += 1\n",
" if self.current_round < self.rounds:\n",
" self.next(self.aggregated_model_validation, foreach='collaborators', exclude=['private'])\n",
Expand Down Expand Up @@ -314,8 +328,11 @@
"aggregator.private_attributes = {}\n",
"\n",
"# Setup collaborators with private attributes\n",
"collaborator_names = ['Portland', 'Seattle', 'Chandler','Bangalore']\n",
"collaborator_names = ['Portland', 'Seattle', 'Chandler', 'Bangalore']\n",
"collaborators = [Collaborator(name=name) for name in collaborator_names]\n",
"# Keep a list of collaborator weights. The weights are decided by the number of samples for each collaborator\n",
"collaborators_weights_dict = {}\n",
"\n",
"for idx, collaborator in enumerate(collaborators):\n",
" local_train = deepcopy(mnist_train)\n",
" local_test = deepcopy(mnist_test)\n",
Expand All @@ -327,6 +344,15 @@
" 'train_loader': torch.utils.data.DataLoader(local_train,batch_size=batch_size_train, shuffle=True),\n",
" 'test_loader': torch.utils.data.DataLoader(local_test,batch_size=batch_size_train, shuffle=True)\n",
" }\n",
" collaborators_weights_dict[collaborator] = len(local_train.data)\n",
"\n",
"for col in collaborators_weights_dict:\n",
" collaborators_weights_dict[col] /= len(mnist_train.data)\n",
"\n",
"if len(collaborators_weights_dict) != 0:\n",
" assert np.abs(1.0 - sum(collaborators_weights_dict.values())) < 0.01, (\n",
" f'Collaborator weights do not sum to 1.0: {collaborators_weights_dict}'\n",
" )\n",
"\n",
"local_runtime = LocalRuntime(aggregator=aggregator, collaborators=collaborators, backend='single_process')\n",
"print(f'Local runtime collaborators = {local_runtime.collaborators}')"
Expand All @@ -344,7 +370,9 @@
"cell_type": "code",
"execution_count": null,
"id": "16937a65",
"metadata": {},
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"model = None\n",
Expand Down Expand Up @@ -640,9 +668,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "workflow-interface-py38",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "workflow-interface-py38"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -654,7 +682,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
"version": "3.8.0"
}
},
"nbformat": 4,
Expand Down

Large diffs are not rendered by default.

Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""openfl.experimental.interface.aggregation_functions.fedcurv_weighted_average package."""

import numpy as np

from .interface import AggregationFunction
from openfl.interface.aggregation_functions.weighted_average import weighted_average

class FedcurvWeightedAverage(WeightedAverage):
"""Fedcurv Weighted average aggregation."""
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)

def aggregate_models(self, model_weights, collaborator_weights, state_keys):
"""Compute weighted average."""
return weighted_average(model_weights, collaborator_weights)

def __call__(self, tensors_list, weights_list, state_dict_keys_list) -> np.ndarray:
final_weights_list = []
for key,val in dict_.items():
if (key[-2:] == '_u' or key[-2:] == '_v' or key[-2:] == '_w'):
final_weights_list.append(np.sum())
continue
final_weights_list.append(np.average())
28 changes: 28 additions & 0 deletions openfl/experimental/interface/aggregation_functions/fedprox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""openfl.experimental.interface.aggregation_functions.fedprox package."""

import numpy as np

from .weighted_average import weighted_average, WeightedAverage


class FedProx(WeightedAverage):
"""Weighted average aggregation."""
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)

def __repr__(self) -> str:
"""Compute a string representation of the function."""
rep = "FedProx"
return rep

def aggregate_metrics(self, metrics, train_weights, test_weights):
"""Weighted average of loss and accuracy metrics"""
agg_model_loss_list, agg_model_accuracy_list = metrics
aggregated_model_training_loss = weighted_average(agg_model_loss_list, train_weights)
aggregated_model_test_accuracy = weighted_average(agg_model_accuracy_list, test_weights)

return (aggregated_model_training_loss, aggregated_model_test_accuracy)




19 changes: 19 additions & 0 deletions openfl/experimental/interface/aggregation_functions/interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (C) 2020-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""Aggregation function interface module."""
from abc import abstractmethod

from openfl.utilities import SingletonABCMeta

class AggregationFunction(metaclass=SingletonABCMeta):
"""Abstract base class for specifying aggregation functions."""

@abstractmethod
def aggregate_models(self, **kwargs):
"""Aggregate training results using algorithms"""
raise NotImplementedError

@abstractmethod
def aggregate_metrics(self, **kwargs):
"""Aggregate evaluation metrics"""
raise NotImplementedError
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""openfl.experimental.interface.aggregation_functions.weighted_average package."""

import numpy as np

from .interface import AggregationFunction


def weighted_average(tensors, weights):
"""Compute weighted average"""
return np.average(tensors, weights=weights, axis=0)



class WeightedAverage(AggregationFunction):
"""Weighted average aggregation."""

def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)

def __repr__(self) -> str:
"""Compute a string representation of the function."""
rep = "FedAvg"
return rep


def aggregate_models(self, model_weights, collaborator_weights):
"""Compute weighted average."""
return weighted_average(model_weights, collaborator_weights)


def aggregate_metrics(self, **kwargs):
"""Aggregate loss"""
pass

1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def run(self):
'openfl.databases.utilities',
'openfl.experimental',
'openfl.experimental.interface',
'openfl.experimental.interface.aggregation_functions',
'openfl.experimental.placement',
'openfl.experimental.runtime',
'openfl.experimental.utilities',
Expand Down

0 comments on commit 1da47a0

Please sign in to comment.