Skip to content

Commit 2961230

Browse files
example with simple neural ode
1 parent 8277e69 commit 2961230

File tree

1 file changed

+143
-0
lines changed

1 file changed

+143
-0
lines changed

examples/toy_2d_neuralode.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
import matplotlib.pyplot as plt
2+
import os
3+
4+
import torch
5+
from torch import optim
6+
7+
8+
import flowcon.flows as flows
9+
import flowcon.distributions as distributions
10+
import flowcon.transforms as transforms
11+
import flowcon.nn as nn
12+
13+
from flowcon.nn.neural_odes import ODEnet, ODEfunc
14+
15+
from flowcon.datasets import load_plane_dataset, InfiniteLoader
16+
17+
if torch.cuda.is_available():
18+
device = "cuda"
19+
else:
20+
device = "cpu"
21+
22+
CONTINUE_TRAINING = False
23+
24+
os.makedirs("figures", exist_ok=True)
25+
os.makedirs("models", exist_ok=True)
26+
27+
base_dist = distributions.StandardNormal(shape=[2])
28+
MB_SIZE = 500
29+
selected_data = "two_spirals"
30+
num_layers = 1
31+
32+
num_iter = {"eight_gaussians": 3_000,
33+
"diamond": 50_000,
34+
"crescent": 3_000,
35+
"four_circles": 3_000,
36+
"two_circles": 3_000
37+
}.get(selected_data, 10_000)
38+
39+
40+
def main():
41+
42+
# create data
43+
train_dataset = load_plane_dataset(selected_data, int(1e7))
44+
train_loader = InfiniteLoader(
45+
dataset=train_dataset,
46+
batch_size=MB_SIZE,
47+
shuffle=True,
48+
drop_last=True,
49+
num_epochs=None
50+
)
51+
test_loader = InfiniteLoader(
52+
dataset=train_dataset,
53+
batch_size=10_000,
54+
shuffle=True,
55+
drop_last=True,
56+
num_epochs=None
57+
)
58+
59+
transform = build_transform(num_layers=num_layers)
60+
flow = flows.Flow(transform, base_dist).to(device)
61+
62+
63+
optimizer = optim.Adam(flow.parameters(), lr=1e-3, weight_decay=1e-5)
64+
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9995)
65+
try:
66+
for i in range(num_iter):
67+
x = next(train_loader).to(device)
68+
69+
optimizer.zero_grad()
70+
loss = -flow.log_prob(inputs=x).mean()
71+
if (i % 50) == 0:
72+
print(f"{i:04}: {loss=:.3f}")
73+
loss.backward()
74+
optimizer.step()
75+
scheduler.step()
76+
if (i + 1) % 250 == 0:
77+
with torch.no_grad():
78+
flow.eval()
79+
x = next(test_loader).to(device)
80+
test_loss = -flow.log_prob(inputs=x).mean()
81+
print(f"{i:04}: {test_loss=:.3f} {scheduler.get_last_lr()}")
82+
83+
plot_model(flow)
84+
flow.train()
85+
86+
except KeyboardInterrupt:
87+
pass
88+
plot_model(flow)
89+
90+
91+
def build_transform(n_features=2, num_layers=5) -> transforms.Transform:
92+
ode_net = ODEnet((32, 32, 32), input_shape=(n_features,),
93+
layer_type="hyper")
94+
transform_list = []
95+
transform_list.append(transforms.neuralode.SimpleCNF(ode_net,
96+
divergence_fn="brute_force"))
97+
transform = transforms.CompositeTransform(transform_list)
98+
return transform
99+
100+
101+
def count_parameters(model):
102+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
103+
104+
105+
def plot_model(flow):
106+
flow.eval()
107+
nsamples = 200
108+
fig, axs = plt.subplots(1, 2, figsize=(20, 10))
109+
axs = axs.flatten()
110+
x_min, x_max, y_min, y_max = dict(
111+
two_spirals=[-4, 4, -4, 4],
112+
checkerboard=[-4, 4, -4, 4]
113+
).get(selected_data, [-4, 4, -4, 4])
114+
# x_min = torch.floor(x.min(0)[0][0]) - 1e-1
115+
# y_min = torch.floor(x.min(0)[0][1]) - 1e-1
116+
# x_max = torch.ceil(x.max(0)[0][0]) + 1e-1
117+
# y_max = torch.ceil(x.max(0)[0][1]) + 1e-1
118+
xline = torch.linspace(x_min, x_max, nsamples).to(device)
119+
yline = torch.linspace(y_min, y_max, nsamples).to(device)
120+
xgrid, ygrid = torch.meshgrid(xline, yline, indexing='xy')
121+
xyinput = torch.cat([xgrid.reshape(-1, 1), ygrid.reshape(-1, 1)], dim=1)
122+
# flow.train()
123+
# flow.eval()
124+
with torch.no_grad():
125+
zgrid = flow.log_prob(xyinput).reshape(nsamples, nsamples).exp()
126+
127+
#samples = flow.sample(num_samples=1_000)
128+
# plt.contourf(xgrid.detach().cpu().numpy(), ygrid.detach().cpu().numpy(), zgrid.detach().cpu().numpy())
129+
axs[0].imshow(zgrid.detach().cpu().numpy(), origin='lower', extent=[-4, 4, -4, 4],
130+
cmap="inferno")
131+
axs[0].axis('off')
132+
axs[1].axis('off')
133+
axs[0].set_aspect('equal', 'box')
134+
axs[1].set_aspect('equal', 'box')
135+
#axs[1].scatter(*samples.detach().cpu().numpy().T, marker='+', alpha=0.1, color="black")
136+
plt.tight_layout()
137+
# plt.title('iteration {}'.format(i + 1))
138+
plt.savefig(f"figures/{selected_data}.png")
139+
plt.close()
140+
print("Saved plot to figures/{}.png".format(selected_data))
141+
142+
if __name__ == "__main__":
143+
main()

0 commit comments

Comments
 (0)