Skip to content

Commit

Permalink
Adding fedcurv with example
Browse files Browse the repository at this point in the history
Signed-off-by: Mansi Sharma <[email protected]>
  • Loading branch information
Mansi Sharma committed Dec 18, 2023
1 parent 1da47a0 commit 1739720
Show file tree
Hide file tree
Showing 8 changed files with 535 additions and 400 deletions.
28 changes: 17 additions & 11 deletions openfl-tutorials/experimental/Workflow_Interface_101_MNIST.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"id": "14821d97",
"metadata": {},
"source": [
"# Workflow Interface 101: Quickstart\n",
"# Workflow Interface 101: MNIST with FedAvg aggregation algorithm\n",
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/intel/openfl/blob/develop/openfl-tutorials/experimental/Workflow_Interface_101_MNIST.ipynb)"
]
},
Expand Down Expand Up @@ -181,8 +181,9 @@
"from openfl.experimental.runtime import LocalRuntime\n",
"from openfl.experimental.placement import aggregator, collaborator\n",
"\n",
"#Import plugin adapter and FedAvg aggregation algorithm\n",
"from openfl.plugins.frameworks_adapters.pytorch_adapter import FrameworkAdapterPlugin as fa\n",
"from openfl.experimental.interface.aggregation_functions.weighted_average import WeightedAverage\n"
"from openfl.experimental.interface.aggregation_functions.fedavg import FedAvg\n"
]
},
{
Expand Down Expand Up @@ -221,6 +222,7 @@
" self.optimizer = optim.SGD(self.model.parameters(), lr=learning_rate,\n",
" momentum=momentum)\n",
" self.rounds = rounds\n",
" self.agg_func = FedAvg()\n",
"\n",
" @aggregator\n",
" def start(self):\n",
Expand Down Expand Up @@ -269,26 +271,30 @@
"\n",
" @aggregator\n",
" def join(self,inputs):\n",
" self.average_loss = sum(input.loss for input in inputs)/len(inputs)\n",
" self.aggregated_model_accuracy = sum(input.agg_validation_score for input in inputs)/len(inputs)\n",
" self.local_model_accuracy = sum(input.local_validation_score for input in inputs)/len(inputs)\n",
" loss_list, aggregated_model_accuracy_list, local_model_accuracy_list = [], [], []\n",
" for input in inputs:\n",
" loss_list.append(input.loss)\n",
" aggregated_model_accuracy_list.append(input.agg_validation_score)\n",
" local_model_accuracy_list.append(input.local_validation_score)\n",
" self.average_loss, self.aggregated_model_accuracy, self.local_model_accuracy= self.agg_func.aggregate_metrics(\n",
" [loss_list, aggregated_model_accuracy_list, local_model_accuracy_list])\n",
" \n",
" print(f'Average aggregated model validation values = {self.aggregated_model_accuracy}')\n",
" print(f'Average training loss = {self.average_loss}')\n",
" print(f'Average local model validation values = {self.local_model_accuracy}')\n",
"\n",
" train_tensors, train_weights=[],[]\n",
" agg_func = WeightedAverage()\n",
" model_weights, collaborator_weights=[],[]\n",
" for input in inputs:\n",
" train_weights.append(input.train_dataset_length/len(mnist_train.data))\n",
" collaborator_weights.append(input.train_dataset_length/len(mnist_train.data))\n",
" keys_list, inner_tensors_list = [], []\n",
" for k,v in (fa.get_tensor_dict(input.model, input.optimizer)).items():\n",
" if k == '__opt_state_needed':\n",
" continue\n",
" else:\n",
" inner_tensors_list.append(v)\n",
" keys_list.append(k)\n",
" train_tensors.append(inner_tensors_list)\n",
" avg_tensors = agg_func.aggregate_models(train_tensors, train_weights)\n",
" model_weights.append(inner_tensors_list)\n",
" avg_tensors = self.agg_func.aggregate_models(model_weights, collaborator_weights)\n",
" \n",
" state_dict = dict(zip(keys_list, avg_tensors))\n",
" # Add back __opt_state_needed key\n",
Expand All @@ -313,7 +319,7 @@
"source": [
"You'll notice in the `FederatedFlow` definition above that there were certain attributes that the flow was not initialized with, namely the `train_loader` and `test_loader` for each of the collaborators. These are **private_attributes** that are exposed only throught he runtime. Each participant has it's own set of private attributes: a dictionary where the key is the attribute name, and the value is the object that will be made accessible through that participant's task. \n",
"\n",
"Below, we segment shards of the MNIST dataset for **four collaborators**: Portland, Seattle, Chandler, and Portland. Each has their own slice of the dataset that's accessible via the `train_loader` or `test_loader` attribute. Note that the private attributes are flexible, and you can choose to pass in a completely different type of object to any of the collaborators or aggregator (with an arbitrary name). These private attributes will always be filtered out of the current state when transfering from collaborator to aggregator, or vice versa. "
"Below, we segment shards of the MNIST dataset for **four collaborators**: Portland, Seattle, Chandler, and Bangalore. Each has their own slice of the dataset that's accessible via the `train_loader` or `test_loader` attribute. Note that the private attributes are flexible, and you can choose to pass in a completely different type of object to any of the collaborators or aggregator (with an arbitrary name). These private attributes will always be filtered out of the current state when transfering from collaborator to aggregator, or vice versa. "
]
},
{
Expand Down

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""openfl.experimental.interface.aggregation_functions.weighted_average package."""

import numpy as np
from typing import List

from .interface import AggregationFunction

Expand All @@ -10,25 +11,30 @@ def weighted_average(tensors, weights):
return np.average(tensors, weights=weights, axis=0)



class WeightedAverage(AggregationFunction):
"""Weighted average aggregation."""
class FedAvg(AggregationFunction):
"""Federated average aggregation.
FedAvg paper: https://arxiv.org/pdf/1602.05629.pdf
"""

def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)

def __repr__(self) -> str:
"""Compute a string representation of the function."""
rep = "FedAvg"
rep = "FedAvg Aggregation Function"
return rep


def aggregate_models(self, model_weights, collaborator_weights):
"""Compute weighted average."""
def aggregate_models(self, model_weights, collaborator_weights) -> np.ndarray:
"""Compute fed avg across models."""
return weighted_average(model_weights, collaborator_weights)


def aggregate_metrics(self, **kwargs):
"""Aggregate loss"""
pass
def aggregate_metrics(self, metrics) -> List[np.ndarray]:
"""Aggregate metrics like loss and accuracy"""
agg_metrics = []
for metric in metrics:
agg_metrics.append(weighted_average(metric, weights=None))
return agg_metrics

29 changes: 29 additions & 0 deletions openfl/experimental/interface/aggregation_functions/fedcurv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""openfl.experimental.interface.aggregation_functions.fedcurv_weighted_average package."""

import numpy as np
from typing import Tuple

from .fedavg import FedAvg, weighted_average

class FedCurvAgg(FedAvg):
"""Fedcurv Weighted average aggregation.
Applies weighted average aggregation to all tensors
except Fisher matrices variables (u_t, v_t).
These variables are summed without weights.
FedCurv paper: https://arxiv.org/pdf/1910.07796.pdf
"""

def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)

def __repr__(self) -> str:
"""Compute a string representation of the function."""
rep = "FedCurv Aggregation Function"
return rep

def aggregate_models(self, model_weights, fisher_matrix_model_weights, collaborator_weights) -> Tuple[np.ndarray, np.ndarray]:
"""Compute Fedcurv weighted average."""
# For Fisher variables, compute sum and for non Fisher elements, compute average
return (np.sum(fisher_matrix_model_weights, axis=0), weighted_average(model_weights, collaborator_weights))

This file was deleted.

32 changes: 17 additions & 15 deletions openfl/experimental/interface/aggregation_functions/fedprox.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,30 @@
"""openfl.experimental.interface.aggregation_functions.fedprox package."""

import numpy as np
from typing import List

from .fedavg import FedAvg, weighted_average

from .weighted_average import weighted_average, WeightedAverage

class FedProxAgg(FedAvg):
"""FedProx aggregation.
A representation of FedAvg with inclusion of the
proximal term for heterogeneous data distributions
FedProx paper: https://arxiv.org/pdf/1812.06127.pdf
"""

class FedProx(WeightedAverage):
"""Weighted average aggregation."""
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)

def __repr__(self) -> str:
"""Compute a string representation of the function."""
rep = "FedProx"
rep = "FedProx Aggregation Function"
return rep

def aggregate_metrics(self, metrics, train_weights, test_weights):
def aggregate_metrics(self, metrics, weights) -> List[np.ndarray]:
"""Weighted average of loss and accuracy metrics"""
agg_model_loss_list, agg_model_accuracy_list = metrics
aggregated_model_training_loss = weighted_average(agg_model_loss_list, train_weights)
aggregated_model_test_accuracy = weighted_average(agg_model_accuracy_list, test_weights)

return (aggregated_model_training_loss, aggregated_model_test_accuracy)




agg_metrics = []
for idx,metric in enumerate(metrics):
agg_metrics.append(weighted_average(metric, weights=weights[idx]))
return agg_metrics
41 changes: 11 additions & 30 deletions openfl/utilities/fedcurv/torch/fedcurv.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,28 +21,6 @@ def register_buffer(module: torch.nn.Module, name: str, value: torch.Tensor):
mod.register_buffer(name, value)


def get_buffer(module, target):
"""Get module buffer.
Remove after pinning to a version
where https://github.com/pytorch/pytorch/pull/61429 is included.
Use module.get_buffer() instead.
"""
module_path, _, buffer_name = target.rpartition('.')

mod: torch.nn.Module = module.get_submodule(module_path)

if not hasattr(mod, buffer_name):
raise AttributeError(f'{mod._get_name()} has no attribute `{buffer_name}`')

buffer: torch.Tensor = getattr(mod, buffer_name)

if buffer_name not in mod._buffers:
raise AttributeError('`' + buffer_name + '` is not a buffer')

return buffer


class FedCurv:
"""Federated Curvature class.
Expand Down Expand Up @@ -80,7 +58,7 @@ def _register_fisher_parameters(self, model):
def _update_params(self, model):
self._params = deepcopy({n: p for n, p in model.named_parameters() if p.requires_grad})

def _diag_fisher(self, model, data_loader, device):
def _diag_fisher(self, model, data_loader, device='cpu', loss_fn='nll'):
precision_matrices = {}
for n, p in self._params.items():
p.data.zero_()
Expand All @@ -93,7 +71,10 @@ def _diag_fisher(self, model, data_loader, device):
sample = sample.to(device)
target = target.to(device)
output = model(sample)
loss = F.nll_loss(F.log_softmax(output, dim=1), target)
if loss_fn == 'cross_entropy':
loss = F.cross_entropy(output, target)
else:
loss = F.nll_loss(F.log_softmax(output, dim=1), target)
loss.backward()

for n, p in model.named_parameters():
Expand All @@ -102,7 +83,7 @@ def _diag_fisher(self, model, data_loader, device):

return precision_matrices

def get_penalty(self, model):
def get_penalty(self, model, device='cpu'):
"""Calculate the penalty term for the loss function.
Args:
Expand All @@ -117,11 +98,11 @@ def get_penalty(self, model):
for name, param in model.named_parameters():
if param.requires_grad:
u_global, v_global, w_global = (
get_buffer(model, target).detach()
model.get_buffer(target).detach().to(device)
for target in (f'{name}_u', f'{name}_v', f'{name}_w')
)
u_local, v_local, w_local = (
getattr(self, name).detach()
getattr(self, name).detach().to(device)
for name in (f'{name}_u', f'{name}_v', f'{name}_w')
)
u = u_global - u_local
Expand All @@ -140,7 +121,7 @@ def on_train_begin(self, model):
"""
self._update_params(model)

def on_train_end(self, model: torch.nn.Module, data_loader, device):
def on_train_end(self, model: torch.nn.Module, data_loader, device='cpu', loss_fn='nll'):
"""Post-train steps.
Args:
Expand All @@ -149,7 +130,7 @@ def on_train_end(self, model: torch.nn.Module, data_loader, device):
device(str): Model device.
loss_fn(Callable): Train loss function.
"""
precision_matrices = self._diag_fisher(model, data_loader, device)
precision_matrices = self._diag_fisher(model, data_loader, device, loss_fn)
for n, m in precision_matrices.items():
u = m.data.to(device)
v = m.data * model.get_parameter(n)
Expand All @@ -161,4 +142,4 @@ def on_train_end(self, model: torch.nn.Module, data_loader, device):
register_buffer(model, f'{n}_w', w.clone().detach())
setattr(self, f'{n}_u', u.clone().detach())
setattr(self, f'{n}_v', v.clone().detach())
setattr(self, f'{n}_w', w.clone().detach())
setattr(self, f'{n}_w', w.clone().detach())

0 comments on commit 1739720

Please sign in to comment.