Skip to content

Commit 544a9b6

Browse files
committed
fix recover_save_act
1 parent f320f72 commit 544a9b6

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+5263
-200
lines changed

kan/MultKAN.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -364,10 +364,10 @@ def initialize_grid_from_another_model(self, model, x):
364364

365365
def forward(self, x, singularity_avoiding=False, y_th=10.):
366366

367-
assert x.shape[1] == self.width_in[0]
368-
369367
x = x[:,self.input_id.long()]
370368

369+
assert x.shape[1] == self.width_in[0]
370+
371371
# cache data
372372
self.cache_data = x
373373

@@ -775,9 +775,9 @@ def score2alpha(score):
775775
n = self.width_in[0]
776776
for i in range(n):
777777
if isinstance(in_vars[i], sympy.Expr):
778-
plt.gcf().get_axes()[0].text(1 / (2 * (n)) + i / (n), -0.1, f'${latex(in_vars[i])}$', fontsize=40 * scale * varscale, horizontalalignment='center', verticalalignment='center')
778+
plt.gcf().get_axes()[0].text(1 / (2 * (n)) + i / (n), -0.1, f'${latex(in_vars[self.input_id[i]])}$', fontsize=40 * scale * varscale, horizontalalignment='center', verticalalignment='center')
779779
else:
780-
plt.gcf().get_axes()[0].text(1 / (2 * (n)) + i / (n), -0.1, in_vars[i], fontsize=40 * scale * varscale, horizontalalignment='center', verticalalignment='center')
780+
plt.gcf().get_axes()[0].text(1 / (2 * (n)) + i / (n), -0.1, in_vars[self.input_id[i]], fontsize=40 * scale * varscale, horizontalalignment='center', verticalalignment='center')
781781

782782

783783

@@ -873,7 +873,7 @@ def get_params(self):
873873

874874

875875
def fit(self, dataset, opt="LBFGS", steps=100, log=1, lamb=0., lamb_l1=1., lamb_entropy=2., lamb_coef=0., lamb_coefdiff=0., update_grid=True, grid_update_num=10, loss_fn=None, lr=1.,start_grid_update_step=-1, stop_grid_update_step=50, batch=-1,
876-
metrics=None, save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video', singularity_avoiding=False, y_th=1000., reg_metric='edge_backward', display_metrics=None):
876+
metrics=None, save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video', singularity_avoiding=False, y_th=1000., reg_metric='edge_forward_n', display_metrics=None):
877877

878878
if lamb > 0. and not self.save_act:
879879
print('setting lamb=0. If you want to set lamb > 0, set self.save_act=True')
@@ -937,7 +937,7 @@ def closure():
937937

938938
if _ == steps-1 and old_save_act:
939939
#self.save_act = True
940-
self.recover_save_act_in_fit()
940+
self.recover_save_act_in_fit(old_save_act)
941941

942942
train_id = np.random.choice(dataset['train_input'].shape[0], batch_size, replace=False)
943943
test_id = np.random.choice(dataset['test_input'].shape[0], batch_size_test, replace=False)

tutorials/.ipynb_checkpoints/Interp_3A_KAN_Compiler_PDE-checkpoint.ipynb

Lines changed: 55 additions & 31 deletions
Large diffs are not rendered by default.

tutorials/.ipynb_checkpoints/Interp_3_KAN_Compiler-checkpoint.ipynb

Lines changed: 17 additions & 13 deletions
Large diffs are not rendered by default.

tutorials/.ipynb_checkpoints/Interp_4_feature_attribution-checkpoint.ipynb

Lines changed: 300 additions & 17 deletions
Large diffs are not rendered by default.

tutorials/.ipynb_checkpoints/Interp_8_adding_auxillary_variables-checkpoint.ipynb

Lines changed: 34 additions & 26 deletions
Large diffs are not rendered by default.

tutorials/.ipynb_checkpoints/Interp_9_different_plotting_metrics-checkpoint.ipynb

Lines changed: 299 additions & 0 deletions
Large diffs are not rendered by default.

tutorials/Interp_3A_KAN_Compiler_PDE.ipynb

Lines changed: 55 additions & 31 deletions
Large diffs are not rendered by default.

tutorials/Interp_3_KAN_Compiler.ipynb

Lines changed: 75 additions & 25 deletions
Large diffs are not rendered by default.

tutorials/Interp_4_feature_attribution.ipynb

Lines changed: 351 additions & 9 deletions
Large diffs are not rendered by default.

tutorials/Interp_7_Building_in_structural_biases.ipynb

Lines changed: 15 additions & 15 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)