Skip to content

Commit

Permalink
Merge pull request #346 from SMILELab-FL/feature-datavisualize-siqi
Browse files Browse the repository at this point in the history
add feddata visualization code and readme tutorial
  • Loading branch information
AgentDS authored Dec 9, 2023
2 parents 0ecd419 + f6960bb commit 49cf058
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 0 deletions.
38 changes: 38 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,44 @@ Non-iid partition used in [[1]](#1). Data example for 4 clients could be shown a
</tbody>
</table>

### Partition Visualization

For data distribution visualization in data partition, we provide `fedlab.utils.dataset.functional.feddata_scatterplot()` for users' convenience.

Visualization for synthetic partition code below:
```python
import numpy as np
from matplotlib import pyplot as plt
from fedlab.utils.dataset.functional import feddata_scatterplot

sample_num = 15
class_num = 4
clients_num = 3
num_per_client = int(sample_num/clients_num)
labels = np.random.randint(class_num, size=sample_num) # generate 15 labels, each label is 0 to 3
rand_per = np.random.permutation(sample_num)
# partition synthetic data into 3 clients
data_indices = {0: rand_per[0:num_per_client],
1: rand_per[num_per_client:num_per_client*2],
2: rand_per[num_per_client*2:num_per_client*3]}
title = 'Data Distribution over Clients for Each Class'
fig = feddata_scatterplot(labels.tolist(),
data_indices,
clients_num,
class_num,
figsize=(6, 4),
max_size=200,
title=title)
plt.show(fig)
fig.savefig(f'imgs/feddata-scatterplot-vis.png')
```
<p align="center"><img src="./tutorials/Datasets-DataPartitioner-tutorials/imgs/feddata-scatterplot-vis.png" height="300"></p>


Visualization result for CIFAR-10 Dirichlet Non-IID with $\alpha=0.6$ on 5 clients:
<p align="center"><img src="./tutorials/Datasets-DataPartitioner-tutorials/imgs/train_vis-noniid-labeldir.png" height="300"></p>


## Performance & Insights

We provide the performance report of several reproduced federated learning algorithms to illustrate the correctness of FedLab in simulation. Furthermore, we describe several insights FedLab could provide for federated learning research. Without loss of generality, this section's experiments are conducted on partitioned MNIST datasets. The conclusions and observations in this section should still be valid in other data sets and scenarios.
Expand Down
86 changes: 86 additions & 0 deletions fedlab/utils/dataset/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

import numpy as np
import pandas as pd
import seaborn as sns
import pandas as pd
from matplotlib import pyplot as plt
import warnings
from collections import Counter

Expand Down Expand Up @@ -599,3 +602,86 @@ def partition_report(targets, data_indices, class_num=None, verbose=True, file=N
print(stats_df)

return stats_df


def feddata_scatterplot(
targets,
client_dict,
num_clients,
num_classes,
figsize=(6, 4),
max_size=200,
title=None,
):
"""Visualize the data distribution for each client and class in federated setting.
Args:
targets (_type_): List of labels, with each entry as integer number.
client_dict (_type_): Dictionary contains sample index list for each client, ``{ client_id: indices}``
num_clients (_type_): Number of total clients
num_classes (_type_): Number of total classes
figsize (tuple, optional): Figure size for scatter plot. Defaults to (6, 4).
max_size (int, optional): Max scatter marker size. Defaults to 200.
title (str, optional): Title for scatter plot. Defaults to None.
Returns:
Figure: matplotlib figure object
Examples:
First generate data partition:
>>> sample_num = 15
>>> class_num = 4
>>> clients_num = 3
>>> num_per_client = int(sample_num/clients_num)
>>> labels = np.random.randint(class_num, size=sample_num) # generate 15 labels, each label is 0 to 3
>>> rand_per = np.random.permutation(sample_num)
>>> # partition synthetic data into 3 clients
>>> data_indices = {0: rand_per[0:num_per_client],
... 1: rand_per[num_per_client:num_per_client*2],
... 2: rand_per[num_per_client*2:num_per_client*3]}
Now generate visualization for this data distribution:
>>> title = 'Data Distribution over Clients for Each Class'
>>> fig = feddata_scatterplot(labels.tolist(),
... data_indices,
... clients_num,
... class_num,
... figsize=(6, 4),
... max_size=200,
... title=title)
>>> plt.show(fig) # Show the plot
>>> fig.savefig(f'feddata-scatterplot-vis.png') # Save the plot
"""
palette = sns.color_palette("Set2", num_classes)
report_df = partition_report(
targets, client_dict, class_num=num_classes, verbose=True
)
sample_stats = report_df.values[:, 1 : 1 + num_classes]
min_max_ratio = np.min(sample_stats) / np.max(sample_stats)
data_tuples = []
for cid in range(num_clients):
for k in range(num_classes):
data_tuples.append((cid, k, sample_stats[cid, k] / np.max(sample_stats)))

df = pd.DataFrame(data_tuples, columns=["Client", "Class", "Samples"])
plt.figure(figsize=figsize)
scatter = sns.scatterplot(
data=df,
x="Client",
y="Class",
size="Samples",
hue="Class",
palette=palette,
legend=False,
sizes=(max_size * min_max_ratio, max_size),
)

# Customize the axes and layout
plt.xticks(range(num_clients), [f"Client {cid+1}" for cid in range(num_clients)])
plt.yticks(range(num_classes), [f"Class {k+1}" for k in range(num_classes)])
plt.xlabel("Clients")
plt.ylabel("Classes")
plt.title(title)
return plt.gcf()
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 49cf058

Please sign in to comment.