Skip to content

Commit

Permalink
Merge pull request #334 from SMILELab-FL/fix-evaluate_demo-siqi
Browse files Browse the repository at this point in the history
Fix evaluate demo siqi
  • Loading branch information
AgentDS authored Sep 19, 2023
2 parents 17b6668 + 28bab82 commit a62845e
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 47 deletions.
17 changes: 16 additions & 1 deletion fedlab/contrib/algorithm/basic_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
# limitations under the License.

import torch
from torch import nn
import random
from copy import deepcopy

from typing import List
from ...utils import Logger, Aggregators, SerializationTool
from ...utils.functional import evaluate
from ...core.server.handler import ServerHandler
from ..client_sampler.base_sampler import FedSampler
from ..client_sampler.uniform_sampler import RandomSampler
Expand All @@ -36,7 +38,7 @@ class SyncServerHandler(ServerHandler):
Args:
model (torch.nn.Module): model trained by federated learning.
global_round (int): stop condition. Shut down FL system when global round is reached.
num_clients (int): number of clients in FL. Default: 0 (initialized external).
num_clients (int): number of clients in FL. Default: 0 (initialized external).
sample_ratio (float): the result of ``sample_ratio * num_clients`` is the number of clients for every FL round.
cuda (bool): use GPUs or not. Default: ``False``.
device (str, optional): assign model/data to the given GPUs. E.g., 'device:0' or 'device:0,1'. Defaults to None. If device is None and cuda is True, FedLab will set the gpu with the largest memory as default.
Expand Down Expand Up @@ -149,6 +151,19 @@ def load(self, payload: List[torch.Tensor]) -> bool:
else:
return False

def setup_dataset(self, dataset) -> None:
self.dataset = dataset

def evaluate(self):
self._model.eval()
test_loader = self.dataset.get_dataloader(type="test", batch_size=128)
loss_, acc_ = evaluate(self._model, nn.CrossEntropyLoss(), test_loader)
self._LOGGER.info(
f"Round [{self.round - 1}/{self.global_round}] test performance on server: \t Loss: {loss_:.5f} \t Acc: {100*acc_:.3f}%"
)

return loss_, acc_


class AsyncServerHandler(ServerHandler):
"""Asynchronous Parameter Server Handler
Expand Down
60 changes: 43 additions & 17 deletions fedlab/contrib/dataset/pathological_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,17 @@
class PathologicalMNIST(FedDataset):
"""The partition stratigy in FedAvg. See http://proceedings.mlr.press/v54/mcmahan17a?ref=https://githubhelp.com
Args:
root (str): Path to download raw dataset.
path (str): Path to save partitioned subdataset.
num_clients (int): Number of clients.
shards (int, optional): Sort the dataset by the label, and uniformly partition them into shards. Then
download (bool, optional): Download. Defaults to True.
"""
def __init__(self, root, path, num_clients=100, shards=200, download=True, preprocess=False) -> None:
Args:
root (str): Path to download raw dataset.
path (str): Path to save partitioned subdataset.
num_clients (int): Number of clients.
shards (int, optional): Sort the dataset by the label, and uniformly partition them into shards. Then
download (bool, optional): Download. Defaults to True.
"""

def __init__(
self, root, path, num_clients=100, shards=200, download=True, preprocess=False
) -> None:
self.root = os.path.expanduser(root)
self.path = path
self.num_clients = num_clients
Expand All @@ -48,15 +51,19 @@ def preprocess(self, download=True):

if os.path.exists(self.path) is not True:
os.mkdir(self.path)

if os.path.exists(os.path.join(self.path, "train")) is not True:
os.mkdir(os.path.join(self.path, "train"))
os.mkdir(os.path.join(self.path, "var"))
os.mkdir(os.path.join(self.path, "test"))

# train
mnist = torchvision.datasets.MNIST(self.root, train=True, download=self.download,
transform=transforms.ToTensor())
mnist = torchvision.datasets.MNIST(
self.root,
train=True,
download=self.download,
transform=transforms.ToTensor(),
)
data_indices = noniid_slicing(mnist, self.num_clients, self.shards)

samples, labels = [], []
Expand All @@ -70,9 +77,25 @@ def preprocess(self, download=True):
data.append(x)
label.append(y)
dataset = BaseDataset(data, label)
torch.save(dataset, os.path.join(self.path, "train", "data{}.pkl".format(id)))

def get_dataset(self, id, type="train"):
torch.save(
dataset, os.path.join(self.path, "train", "data{}.pkl".format(id))
)

# test
mnist_test = torchvision.datasets.MNIST(
self.root,
train=False,
download=self.download,
transform=transforms.ToTensor(),
)
test_samples, test_labels = [], []
for x, y in mnist_test:
test_samples.append(x)
test_labels.append(y)
test_dataset = BaseDataset(test_samples, test_labels)
torch.save(test_dataset, os.path.join(self.path, "test", "test.pkl"))

def get_dataset(self, id=None, type="train"):
"""Load subdataset for client with client ID ``cid`` from local file.
Args:
Expand All @@ -82,10 +105,13 @@ def get_dataset(self, id, type="train"):
Returns:
Dataset
"""
dataset = torch.load(os.path.join(self.path, type, "data{}.pkl".format(id)))
if type == "train":
dataset = torch.load(os.path.join(self.path, type, "data{}.pkl".format(id)))
else:
dataset = torch.load(os.path.join(self.path, "test", "test.pkl"))
return dataset

def get_dataloader(self, id, batch_size=None, type="train"):
def get_dataloader(self, id=None, batch_size=None, type="train"):
"""Return dataload for client with client ID ``cid``.
Args:
Expand Down
3 changes: 2 additions & 1 deletion fedlab/core/standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .client.trainer import SerialClientTrainer
from .server.handler import ServerHandler


class StandalonePipeline(object):
def __init__(self, handler: ServerHandler, trainer: SerialClientTrainer):
"""Perform standalone simulation process.
Expand Down Expand Up @@ -48,4 +49,4 @@ def main(self):
# self.handler.evaluate()

def evaluate(self):
print("This is a example implementation. Please read the source code at fedlab.core.standalone.")
loss_, acc_ = self.handler.evaluate()
59 changes: 31 additions & 28 deletions fedlab/utils/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,14 @@ def reset(self):

def update(self, val, n=1):
self.val = val
self.sum += val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count


def evaluate(model, criterion, test_loader):
"""Evaluate classify task model accuracy.
Returns:
(loss.sum, acc.avg)
"""
Expand All @@ -62,17 +62,18 @@ def evaluate(model, criterion, test_loader):
acc_ = AverageMeter()
with torch.no_grad():
for inputs, labels in test_loader:
batch_size = len(labels)
inputs = inputs.to(gpu)
labels = labels.to(gpu)

outputs = model(inputs)
loss = criterion(outputs, labels)

_, predicted = torch.max(outputs, 1)
loss_.update(loss.item())
acc_.update(torch.sum(predicted.eq(labels)).item(), len(labels))
loss_.update(loss.item(), batch_size)
acc_.update(torch.sum(predicted.eq(labels)).item() / batch_size, batch_size)

return loss_.sum, acc_.avg
return loss_.avg, acc_.avg


def read_config_from_json(json_file: str, user_name: str):
Expand Down Expand Up @@ -114,8 +115,12 @@ def read_config_from_json(json_file: str, user_name: str):
with open(json_file) as f:
config = json.load(f)
config_info = config[user_name]
return (config_info["ip"], config_info["port"], config_info["world_size"],
config_info["rank"])
return (
config_info["ip"],
config_info["port"],
config_info["world_size"],
config_info["rank"],
)


def get_best_gpu():
Expand All @@ -126,7 +131,7 @@ def get_best_gpu():

if "CUDA_VISIBLE_DEVICES" in os.environ.keys() is not None:
cuda_devices = [
int(device) for device in os.environ["CUDA_VISIBLE_DEVICES"].split(',')
int(device) for device in os.environ["CUDA_VISIBLE_DEVICES"].split(",")
]
else:
cuda_devices = range(deviceCount)
Expand All @@ -142,11 +147,7 @@ def get_best_gpu():
return torch.device("cuda:%d" % (best_device_index))


def partition_report(targets,
data_indices,
class_num=None,
verbose=True,
file=None):
def partition_report(targets, data_indices, class_num=None, verbose=True, file=None):
"""Generate data partition report for clients in ``data_indices``.
Generate data partition report for each client according to ``data_indices``, including
Expand Down Expand Up @@ -203,12 +204,10 @@ def partition_report(targets,
if not class_num:
class_num = max(targets) + 1

sorted_cid = sorted(
data_indices.keys()) # sort client id in ascending order
sorted_cid = sorted(data_indices.keys()) # sort client id in ascending order

header_line = "Class frequencies:"
col_name = "client," + ','.join([f"class{i}"
for i in range(class_num)]) + ",Amount"
col_name = "client," + ",".join([f"class{i}" for i in range(class_num)]) + ",Amount"

if verbose:
print(header_line)
Expand All @@ -221,16 +220,21 @@ def partition_report(targets,
for client_id in sorted_cid:
indices = data_indices[client_id]
client_targets = targets[indices]
client_sample_num = len(
indices) # total number of samples of current client
client_target_cnt = Counter(
client_targets) # { cls1: num1, cls2: num2, ... }

report_line = f"Client {client_id:3d}," + \
','.join([
f"{client_target_cnt[cls] / client_sample_num:.3f}" if cls in client_target_cnt else "0.00"
for cls in range(class_num)]) + \
f",{client_sample_num}"
client_sample_num = len(indices) # total number of samples of current client
client_target_cnt = Counter(client_targets) # { cls1: num1, cls2: num2, ... }

report_line = (
f"Client {client_id:3d},"
+ ",".join(
[
f"{client_target_cnt[cls] / client_sample_num:.3f}"
if cls in client_target_cnt
else "0.00"
for cls in range(class_num)
]
)
+ f",{client_sample_num}"
)
if verbose:
print(report_line)
if file is not None:
Expand All @@ -240,4 +244,3 @@ def partition_report(targets,
fh = open(file, "w")
fh.write("\n".join(reports))
fh.close()

0 comments on commit a62845e

Please sign in to comment.