1
+ import matplotlib .pyplot as plt
2
+ import os
3
+
4
+ import torch
5
+ from torch import optim
6
+
7
+
8
+ import flowcon .flows as flows
9
+ import flowcon .distributions as distributions
10
+ import flowcon .transforms as transforms
11
+ import flowcon .nn as nn
12
+
13
+ from flowcon .nn .neural_odes import ODEnet , ODEfunc
14
+
15
+ from flowcon .datasets import load_plane_dataset , InfiniteLoader
16
+
17
+ if torch .cuda .is_available ():
18
+ device = "cuda"
19
+ else :
20
+ device = "cpu"
21
+
22
+ CONTINUE_TRAINING = False
23
+
24
+ os .makedirs ("figures" , exist_ok = True )
25
+ os .makedirs ("models" , exist_ok = True )
26
+
27
+ base_dist = distributions .StandardNormal (shape = [2 ])
28
+ MB_SIZE = 500
29
+ selected_data = "two_spirals"
30
+ num_layers = 1
31
+
32
+ num_iter = {"eight_gaussians" : 3_000 ,
33
+ "diamond" : 50_000 ,
34
+ "crescent" : 3_000 ,
35
+ "four_circles" : 3_000 ,
36
+ "two_circles" : 3_000
37
+ }.get (selected_data , 10_000 )
38
+
39
+
40
+ def main ():
41
+
42
+ # create data
43
+ train_dataset = load_plane_dataset (selected_data , int (1e7 ))
44
+ train_loader = InfiniteLoader (
45
+ dataset = train_dataset ,
46
+ batch_size = MB_SIZE ,
47
+ shuffle = True ,
48
+ drop_last = True ,
49
+ num_epochs = None
50
+ )
51
+ test_loader = InfiniteLoader (
52
+ dataset = train_dataset ,
53
+ batch_size = 10_000 ,
54
+ shuffle = True ,
55
+ drop_last = True ,
56
+ num_epochs = None
57
+ )
58
+
59
+ transform = build_transform (num_layers = num_layers )
60
+ flow = flows .Flow (transform , base_dist ).to (device )
61
+
62
+
63
+ optimizer = optim .Adam (flow .parameters (), lr = 1e-3 , weight_decay = 1e-5 )
64
+ scheduler = torch .optim .lr_scheduler .ExponentialLR (optimizer , gamma = 0.9995 )
65
+ try :
66
+ for i in range (num_iter ):
67
+ x = next (train_loader ).to (device )
68
+
69
+ optimizer .zero_grad ()
70
+ loss = - flow .log_prob (inputs = x ).mean ()
71
+ if (i % 50 ) == 0 :
72
+ print (f"{ i :04} : { loss = :.3f} " )
73
+ loss .backward ()
74
+ optimizer .step ()
75
+ scheduler .step ()
76
+ if (i + 1 ) % 250 == 0 :
77
+ with torch .no_grad ():
78
+ flow .eval ()
79
+ x = next (test_loader ).to (device )
80
+ test_loss = - flow .log_prob (inputs = x ).mean ()
81
+ print (f"{ i :04} : { test_loss = :.3f} { scheduler .get_last_lr ()} " )
82
+
83
+ plot_model (flow )
84
+ flow .train ()
85
+
86
+ except KeyboardInterrupt :
87
+ pass
88
+ plot_model (flow )
89
+
90
+
91
+ def build_transform (n_features = 2 , num_layers = 5 ) -> transforms .Transform :
92
+ ode_net = ODEnet ((32 , 32 , 32 ), input_shape = (n_features ,),
93
+ layer_type = "hyper" )
94
+ transform_list = []
95
+ transform_list .append (transforms .neuralode .SimpleCNF (ode_net ,
96
+ divergence_fn = "brute_force" ))
97
+ transform = transforms .CompositeTransform (transform_list )
98
+ return transform
99
+
100
+
101
+ def count_parameters (model ):
102
+ return sum (p .numel () for p in model .parameters () if p .requires_grad )
103
+
104
+
105
+ def plot_model (flow ):
106
+ flow .eval ()
107
+ nsamples = 200
108
+ fig , axs = plt .subplots (1 , 2 , figsize = (20 , 10 ))
109
+ axs = axs .flatten ()
110
+ x_min , x_max , y_min , y_max = dict (
111
+ two_spirals = [- 4 , 4 , - 4 , 4 ],
112
+ checkerboard = [- 4 , 4 , - 4 , 4 ]
113
+ ).get (selected_data , [- 4 , 4 , - 4 , 4 ])
114
+ # x_min = torch.floor(x.min(0)[0][0]) - 1e-1
115
+ # y_min = torch.floor(x.min(0)[0][1]) - 1e-1
116
+ # x_max = torch.ceil(x.max(0)[0][0]) + 1e-1
117
+ # y_max = torch.ceil(x.max(0)[0][1]) + 1e-1
118
+ xline = torch .linspace (x_min , x_max , nsamples ).to (device )
119
+ yline = torch .linspace (y_min , y_max , nsamples ).to (device )
120
+ xgrid , ygrid = torch .meshgrid (xline , yline , indexing = 'xy' )
121
+ xyinput = torch .cat ([xgrid .reshape (- 1 , 1 ), ygrid .reshape (- 1 , 1 )], dim = 1 )
122
+ # flow.train()
123
+ # flow.eval()
124
+ with torch .no_grad ():
125
+ zgrid = flow .log_prob (xyinput ).reshape (nsamples , nsamples ).exp ()
126
+
127
+ #samples = flow.sample(num_samples=1_000)
128
+ # plt.contourf(xgrid.detach().cpu().numpy(), ygrid.detach().cpu().numpy(), zgrid.detach().cpu().numpy())
129
+ axs [0 ].imshow (zgrid .detach ().cpu ().numpy (), origin = 'lower' , extent = [- 4 , 4 , - 4 , 4 ],
130
+ cmap = "inferno" )
131
+ axs [0 ].axis ('off' )
132
+ axs [1 ].axis ('off' )
133
+ axs [0 ].set_aspect ('equal' , 'box' )
134
+ axs [1 ].set_aspect ('equal' , 'box' )
135
+ #axs[1].scatter(*samples.detach().cpu().numpy().T, marker='+', alpha=0.1, color="black")
136
+ plt .tight_layout ()
137
+ # plt.title('iteration {}'.format(i + 1))
138
+ plt .savefig (f"figures/{ selected_data } .png" )
139
+ plt .close ()
140
+ print ("Saved plot to figures/{}.png" .format (selected_data ))
141
+
142
+ if __name__ == "__main__" :
143
+ main ()
0 commit comments