Skip to content

Commit 2d28639

Browse files
committed
Simplify model and use public solver
1 parent 51a7d51 commit 2d28639

1 file changed

Lines changed: 14 additions & 17 deletions

File tree

examples/mmd_ae.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
from torchvision.utils import make_grid, save_image
1212

1313
from dwave.plugins.torch.models.boltzmann_machine import GraphRestrictedBoltzmannMachine as GRBM
14-
from dwave.plugins.torch.nn import (ConvolutionNetwork, FullyConnectedNetwork, LinearBlock,
15-
MaximumMeanDiscrepancy, RadialBasis, StraightThroughTanh,
16-
rands_like, zephyr_subgraph)
14+
from dwave.plugins.torch.nn.modules import (ConvolutionNetwork, FullyConnectedNetwork,
15+
MaximumMeanDiscrepancy, RadialBasis,
16+
StraightThroughTanh, rands_like, zephyr_subgraph)
1717
from dwave.system import DWaveSampler
1818

1919

@@ -32,7 +32,6 @@ def __init__(self, shape, n_bits):
3232
nn.Flatten(),
3333
FullyConnectedNetwork(chidden*h*w, n_bits, depth_fcnn, False, dropout),
3434
)
35-
self.mixer = LinearBlock(n_bits, n_bits, False, dropout)
3635
self.binarizer = StraightThroughTanh()
3736
self.decoder = nn.Sequential(
3837
FullyConnectedNetwork(n_bits, chidden*h*w, depth_fcnn, False, dropout),
@@ -41,24 +40,21 @@ def __init__(self, shape, n_bits):
4140
)
4241

4342
def decode(self, q):
44-
z = self.mixer(q)
45-
xhat = self.decoder(z)
46-
return z, xhat
43+
xhat = self.decoder(q)
44+
return xhat
4745

4846
def forward(self, x):
4947
spins = self.binarizer(self.encoder(x))
50-
z, xhat = self.decode(spins)
51-
return spins, z, xhat
48+
xhat = self.decode(spins)
49+
return spins, xhat
5250

5351

5452
def collect_stats(model, grbm, x, q, compute_mmd):
55-
s, z, xhat = model(x)
56-
zgen, xgen = model.decode(q)
53+
s, xhat = model(x)
5754
stats = {
5855
"quasi": grbm.quasi_objective(s.detach(), q),
5956
"bce": nn.functional.binary_cross_entropy_with_logits(xhat, x),
6057
"mmd": compute_mmd(s, q),
61-
"mmd2": compute_mmd(z, zgen),
6258
}
6359
return stats
6460

@@ -79,10 +75,10 @@ def round_graph_down(graph, group_size):
7975

8076

8177
def run(*, num_steps):
82-
sampler = DWaveSampler(solver="Advantage2_system1.7")
78+
sampler = DWaveSampler(solver="Advantage2_system1.8")
8379
sample_params = dict(num_reads=500, annealing_time=0.5, answer_mode="raw", auto_scale=False)
8480
h_range, j_range = sampler.properties["h_range"], sampler.properties["j_range"]
85-
outdir = "output/mmd_ae/"
81+
outdir = "output/example_mmd_ae/"
8682
makedirs(outdir, exist_ok=True)
8783

8884
device = "cuda"
@@ -117,7 +113,7 @@ def run(*, num_steps):
117113
# Train autoencoder
118114
stats = collect_stats(model, grbm, x, q, compute_mmd)
119115
opt_ae.zero_grad()
120-
(stats["bce"] + stats["mmd"] + stats["mmd2"]).backward()
116+
(stats["bce"] + stats["mmd"]).backward()
121117
opt_ae.step()
122118

123119
# Train GRBM
@@ -131,9 +127,10 @@ def run(*, num_steps):
131127
if step % 10 == 0:
132128
with torch.no_grad():
133129
grbm.eval()
134-
xgen = model.decode(q[:100])[-1]
135-
xuni = model.decode(rands_like(q[:100]))[-1]
130+
xgen = model.decode(q[:100])
131+
xuni = model.decode(rands_like(q[:100]))
136132
xhat = model(x[:100])[-1]
133+
save_image(make_grid(x[:100], 10, pad_value=1), outdir + "x.png")
137134
save_image(make_grid(xgen.sigmoid(), 10, pad_value=1), outdir + "xgen.png")
138135
save_image(make_grid(xhat.sigmoid(), 10, pad_value=1), outdir + "xhat.png")
139136
save_image(make_grid(xuni.sigmoid(), 10, pad_value=1), outdir + "xuni.png")

0 commit comments

Comments
 (0)