-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
53 lines (33 loc) · 1.25 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
import numpy as np
from fedzoo.fedavg import FedAvg
from fedzoo.fedprox import FedProx
from model.markedpp import MultivariateExponentialHawkes
import torch.optim as optim
if torch.cuda.is_available():
device = torch.device('cuda')
print("GPU Available, Using Device:", device)
else:
device = torch.device('cpu')
print("GPU Unavailable, Using Device:", device)
device = torch.device('cpu')
def main():
T = np.array([0., 50.])
data_dim = 2
n_class = 11
alphas = np.random.uniform(low=0.0, high=1.0, size=(n_class, n_class))
beta = np.random.uniform(low=0.0, high=1.0, size=(n_class))
model = MultivariateExponentialHawkes(T=T, mu=0.01*np.ones(n_class), alphas=alphas, beta=beta,
data_dim=data_dim, device=device)
optimizer = optim.Adadelta
optimizer_args = {'lr':1e0}
fed = FedAvg(
dataset='Outbreak', model=model, optimizer=optimizer, optimizer_args=optimizer_args,
)
fed.epoch()
# fed = FedProx(
# dataset='Outbreak', model=model, optimizer=optimizer, optimizer_args=optimizer_args,
# )
# fed.epoch()
if __name__ == '__main__':
main()