-
Notifications
You must be signed in to change notification settings - Fork 0
/
fedbase.py
125 lines (94 loc) · 4.58 KB
/
fedbase.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import numpy as np
# from src.dataset import FMNISTDataset
import json
client_list = [
'US-AK', 'US-AL', 'US-AR', 'US-AZ', 'US-CA', 'US-CO', 'US-CT',
'US-DC', 'US-DE', 'US-FL', 'US-GA', 'US-HI', 'US-IA', 'US-ID',
'US-IL', 'US-IN', 'US-KS', 'US-KY', 'US-LA', 'US-MA', 'US-MD',
'US-ME', 'US-MI', 'US-MN', 'US-MO', 'US-MS', 'US-MT', 'US-NC',
'US-ND', 'US-NE', 'US-NH', 'US-NJ', 'US-NM', 'US-NV', 'US-NY',
'US-OH', 'US-OK', 'US-OR', 'US-PA', 'US-RI', 'US-SC', 'US-SD',
'US-TN', 'US-TX', 'US-UT', 'US-VA', 'US-VT', 'US-WA', 'US-WI',
'US-WV', 'US-WY', 'US-X'
]
symptom_list = [
'symptom:Anxiety', 'symptom:Asthma',
# 'symptom:Anosmia',
'symptom:Alcoholism',
'symptom:Common cold', 'symptom:Cough', 'symptom:Depression', 'symptom:Fatigue',
'symptom:Fever', 'symptom:Headache', 'symptom:Nausea', 'symptom:Shortness of breath'
]
class FedBase(object):
def __init__(self, dataset=None):
self.dataset = dataset
def sampling(self):
if self.dataset is None:
raise Exception('Please Specify the Dataset')
# input format = data[region][symptom][year]
# output format = [year, max_seq_len, [event,symptom]]
if self.dataset == 'Outbreak':
with open('./data/event.json') as f:
all_data = json.load(f)
train_year = [str(i) for i in range(2017,2021)]
test_year = '2021'
client_train_data = []
max_len = 0
for client in client_list:
train_data = []
for year in train_year:
year_data = None
for i,symptom in enumerate(symptom_list):
temp_sym_data = np.array(all_data[client][symptom][year]).reshape(-1,1)
temp_sym_data = temp_sym_data % 50 + 1
symptom_code = np.ones_like(temp_sym_data) * i
temp_sym_data = np.concatenate((temp_sym_data, symptom_code), axis=-1)
try:
year_data = np.concatenate((year_data, temp_sym_data), axis=0)
except:
year_data = temp_sym_data
year_data = year_data[np.argsort(year_data[:,0])]
max_len = max([max_len, len(year_data)])
train_data.append(year_data)
# for i in range(len(train_data)):
# try:
# train_data[i] = np.concatenate((train_data[i], np.zeros([max_len-len(train_data[i]), 2])), axis=0)
# except:
# continue
client_train_data.append(train_data)
for train_data in client_train_data:
for i in range(len(train_data)):
try:
train_data[i] = np.concatenate((train_data[i], np.zeros([max_len-len(train_data[i]), 2])), axis=0)
except:
continue
client_test_data = []
max_len = 0
for client in client_list:
test_data = []
year_data = None
for i,symptom in enumerate(symptom_list):
temp_sym_data = np.array(all_data[client][symptom][test_year]).reshape(-1,1)
temp_sym_data = temp_sym_data % 50 + 1
symptom_code = np.ones_like(temp_sym_data) * i
temp_sym_data = np.concatenate((temp_sym_data, symptom_code), axis=-1)
try:
year_data = np.concatenate((year_data, temp_sym_data), axis=0)
except:
year_data = temp_sym_data
year_data = year_data[np.argsort(year_data[:,0])]
max_len = max([max_len, len(year_data)])
test_data.append(year_data)
client_test_data.append(test_data)
for test_data in client_test_data:
for i in range(len(test_data)):
try:
test_data[i] = np.concatenate((test_data[i], np.zeros([max_len-len(test_data[i]), 2])), axis=0)
except:
continue
return client_train_data, client_test_data
def train(self):
raise NotImplementedError
def val(self):
raise NotImplementedError
def epoch(self):
raise NotImplementedError