@@ -364,10 +364,10 @@ def initialize_grid_from_another_model(self, model, x):
364
364
365
365
def forward (self , x , singularity_avoiding = False , y_th = 10. ):
366
366
367
- assert x .shape [1 ] == self .width_in [0 ]
368
-
369
367
x = x [:,self .input_id .long ()]
370
368
369
+ assert x .shape [1 ] == self .width_in [0 ]
370
+
371
371
# cache data
372
372
self .cache_data = x
373
373
@@ -775,9 +775,9 @@ def score2alpha(score):
775
775
n = self .width_in [0 ]
776
776
for i in range (n ):
777
777
if isinstance (in_vars [i ], sympy .Expr ):
778
- plt .gcf ().get_axes ()[0 ].text (1 / (2 * (n )) + i / (n ), - 0.1 , f'${ latex (in_vars [i ])} $' , fontsize = 40 * scale * varscale , horizontalalignment = 'center' , verticalalignment = 'center' )
778
+ plt .gcf ().get_axes ()[0 ].text (1 / (2 * (n )) + i / (n ), - 0.1 , f'${ latex (in_vars [self . input_id [ i ] ])} $' , fontsize = 40 * scale * varscale , horizontalalignment = 'center' , verticalalignment = 'center' )
779
779
else :
780
- plt .gcf ().get_axes ()[0 ].text (1 / (2 * (n )) + i / (n ), - 0.1 , in_vars [i ], fontsize = 40 * scale * varscale , horizontalalignment = 'center' , verticalalignment = 'center' )
780
+ plt .gcf ().get_axes ()[0 ].text (1 / (2 * (n )) + i / (n ), - 0.1 , in_vars [self . input_id [ i ] ], fontsize = 40 * scale * varscale , horizontalalignment = 'center' , verticalalignment = 'center' )
781
781
782
782
783
783
@@ -873,7 +873,7 @@ def get_params(self):
873
873
874
874
875
875
def fit (self , dataset , opt = "LBFGS" , steps = 100 , log = 1 , lamb = 0. , lamb_l1 = 1. , lamb_entropy = 2. , lamb_coef = 0. , lamb_coefdiff = 0. , update_grid = True , grid_update_num = 10 , loss_fn = None , lr = 1. ,start_grid_update_step = - 1 , stop_grid_update_step = 50 , batch = - 1 ,
876
- metrics = None , save_fig = False , in_vars = None , out_vars = None , beta = 3 , save_fig_freq = 1 , img_folder = './video' , singularity_avoiding = False , y_th = 1000. , reg_metric = 'edge_backward ' , display_metrics = None ):
876
+ metrics = None , save_fig = False , in_vars = None , out_vars = None , beta = 3 , save_fig_freq = 1 , img_folder = './video' , singularity_avoiding = False , y_th = 1000. , reg_metric = 'edge_forward_n ' , display_metrics = None ):
877
877
878
878
if lamb > 0. and not self .save_act :
879
879
print ('setting lamb=0. If you want to set lamb > 0, set self.save_act=True' )
@@ -937,7 +937,7 @@ def closure():
937
937
938
938
if _ == steps - 1 and old_save_act :
939
939
#self.save_act = True
940
- self .recover_save_act_in_fit ()
940
+ self .recover_save_act_in_fit (old_save_act )
941
941
942
942
train_id = np .random .choice (dataset ['train_input' ].shape [0 ], batch_size , replace = False )
943
943
test_id = np .random .choice (dataset ['test_input' ].shape [0 ], batch_size_test , replace = False )
0 commit comments