1111from torchvision .utils import make_grid , save_image
1212
1313from dwave .plugins .torch .models .boltzmann_machine import GraphRestrictedBoltzmannMachine as GRBM
14- from dwave .plugins .torch .nn import (ConvolutionNetwork , FullyConnectedNetwork , LinearBlock ,
15- MaximumMeanDiscrepancy , RadialBasis , StraightThroughTanh ,
16- rands_like , zephyr_subgraph )
14+ from dwave .plugins .torch .nn . modules import (ConvolutionNetwork , FullyConnectedNetwork ,
15+ MaximumMeanDiscrepancy , RadialBasis ,
16+ StraightThroughTanh , rands_like , zephyr_subgraph )
1717from dwave .system import DWaveSampler
1818
1919
@@ -32,7 +32,6 @@ def __init__(self, shape, n_bits):
3232 nn .Flatten (),
3333 FullyConnectedNetwork (chidden * h * w , n_bits , depth_fcnn , False , dropout ),
3434 )
35- self .mixer = LinearBlock (n_bits , n_bits , False , dropout )
3635 self .binarizer = StraightThroughTanh ()
3736 self .decoder = nn .Sequential (
3837 FullyConnectedNetwork (n_bits , chidden * h * w , depth_fcnn , False , dropout ),
@@ -41,24 +40,21 @@ def __init__(self, shape, n_bits):
4140 )
4241
4342 def decode (self , q ):
44- z = self .mixer (q )
45- xhat = self .decoder (z )
46- return z , xhat
43+ xhat = self .decoder (q )
44+ return xhat
4745
4846 def forward (self , x ):
4947 spins = self .binarizer (self .encoder (x ))
50- z , xhat = self .decode (spins )
51- return spins , z , xhat
48+ xhat = self .decode (spins )
49+ return spins , xhat
5250
5351
5452def collect_stats (model , grbm , x , q , compute_mmd ):
55- s , z , xhat = model (x )
56- zgen , xgen = model .decode (q )
53+ s , xhat = model (x )
5754 stats = {
5855 "quasi" : grbm .quasi_objective (s .detach (), q ),
5956 "bce" : nn .functional .binary_cross_entropy_with_logits (xhat , x ),
6057 "mmd" : compute_mmd (s , q ),
61- "mmd2" : compute_mmd (z , zgen ),
6258 }
6359 return stats
6460
@@ -79,10 +75,10 @@ def round_graph_down(graph, group_size):
7975
8076
8177def run (* , num_steps ):
82- sampler = DWaveSampler (solver = "Advantage2_system1.7 " )
78+ sampler = DWaveSampler (solver = "Advantage2_system1.8 " )
8379 sample_params = dict (num_reads = 500 , annealing_time = 0.5 , answer_mode = "raw" , auto_scale = False )
8480 h_range , j_range = sampler .properties ["h_range" ], sampler .properties ["j_range" ]
85- outdir = "output/mmd_ae /"
81+ outdir = "output/example_mmd_ae /"
8682 makedirs (outdir , exist_ok = True )
8783
8884 device = "cuda"
@@ -117,7 +113,7 @@ def run(*, num_steps):
117113 # Train autoencoder
118114 stats = collect_stats (model , grbm , x , q , compute_mmd )
119115 opt_ae .zero_grad ()
120- (stats ["bce" ] + stats ["mmd" ] + stats [ "mmd2" ] ).backward ()
116+ (stats ["bce" ] + stats ["mmd" ]).backward ()
121117 opt_ae .step ()
122118
123119 # Train GRBM
@@ -131,9 +127,10 @@ def run(*, num_steps):
131127 if step % 10 == 0 :
132128 with torch .no_grad ():
133129 grbm .eval ()
134- xgen = model .decode (q [:100 ])[ - 1 ]
135- xuni = model .decode (rands_like (q [:100 ]))[ - 1 ]
130+ xgen = model .decode (q [:100 ])
131+ xuni = model .decode (rands_like (q [:100 ]))
136132 xhat = model (x [:100 ])[- 1 ]
133+ save_image (make_grid (x [:100 ], 10 , pad_value = 1 ), outdir + "x.png" )
137134 save_image (make_grid (xgen .sigmoid (), 10 , pad_value = 1 ), outdir + "xgen.png" )
138135 save_image (make_grid (xhat .sigmoid (), 10 , pad_value = 1 ), outdir + "xhat.png" )
139136 save_image (make_grid (xuni .sigmoid (), 10 , pad_value = 1 ), outdir + "xuni.png" )
0 commit comments