Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch dev #1396

Open
wants to merge 4 commits into
base: dev
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
torch dev for ring_cnn + 2p spatial
mannypaeza authored and manuelpaeza committed Jan 7, 2025
commit 0120624f46fad50f9038ff8d609593ad99ec79aa
26 changes: 13 additions & 13 deletions caiman/source_extraction/cnmf/online_cnmf.py
Original file line number Diff line number Diff line change
@@ -323,9 +323,9 @@ def _prepare_object(self, Yr, T, new_dims=None, idx_components=None):
if self.params.get('online', 'path_to_model') is None or self.params.get('online', 'sniper_mode') is False:
loaded_model = None
self.params.set('online', {'sniper_mode': False})
# self.tf_in = None
# self.tf_out = None
self.use_torch = None #fix
self.tf_in = None
self.tf_out = None
# self.use_torch = None
else:
try:
from keras.models import load_model
@@ -340,12 +340,12 @@ def _prepare_object(self, Yr, T, new_dims=None, idx_components=None):
# uses online model -> be careful
model_path = ".".join(path + ["keras"])
loaded_model = model_load(model_path)
self.use_torch = False
# self.use_torch = False
else:
model_path = '.'.join(path + ['pt'])
loaded_model = load_graph(model_path)
loaded_model = torch.load(model_file)
self.use_torch = True
# self.use_torch = True

self.loaded_model = loaded_model

@@ -547,8 +547,8 @@ def fit_next(self, t, frame_in, num_iters_hals=3):
sniper_mode=self.params.get('online', 'sniper_mode'),
use_peak_max=self.params.get('online', 'use_peak_max'),
mean_buff=self.estimates.mean_buff,
# tf_in=self.tf_in, tf_out=self.tf_out,
use_torch=self.use_torch,
tf_in=self.tf_in, tf_out=self.tf_out,
# use_torch=self.use_torch,
ssub_B=ssub_B, W=self.estimates.W if self.is1p else None,
b0=self.estimates.b0 if self.is1p else None,
corr_img=self.estimates.corr_img if use_corr else None,
@@ -2003,8 +2003,8 @@ def get_candidate_components(sv, dims, Yres_buf, min_num_trial=3, gSig=(5, 5),
patch_size=50, loaded_model=None, test_both=False,
thresh_CNN_noisy=0.5, use_peak_max=False,
thresh_std_peak_resid = 1, mean_buff=None,
# tf_in=None, tf_out=None):
use_torch=None):
tf_in=None, tf_out=None):
# use_torch=None):
"""
Extract new candidate components from the residual buffer and test them
using space correlation or the CNN classifier. The function runs the CNN
@@ -2146,8 +2146,8 @@ def update_num_components(t, sv, Ab, Cf, Yres_buf, Y_buf, rho_buf,
corr_img=None, first_moment=None, second_moment=None,
crosscorr=None, col_ind=None, row_ind=None, corr_img_mode=None,
max_img=None, downscale_matrix=None, upscale_matrix=None,
# tf_in=None, tf_out=None):
torch_in=None, torch_out=None):
tf_in=None, tf_out=None):
# torch_in=None, torch_out=None):
"""
Checks for new components in the residual buffer and incorporates them if they pass the acceptance tests
"""
@@ -2177,8 +2177,8 @@ def update_num_components(t, sv, Ab, Cf, Yres_buf, Y_buf, rho_buf,
sniper_mode=sniper_mode, rval_thr=rval_thr, patch_size=50,
loaded_model=loaded_model, thresh_CNN_noisy=thresh_CNN_noisy,
use_peak_max=use_peak_max, test_both=test_both, mean_buff=mean_buff,
# tf_in=tf_in, tf_out=tf_out)
torch_in=torch_in, torch_out=torch_out)
tf_in=tf_in, tf_out=tf_out)
#torch_in=torch_in, torch_out=torch_out)

ind_new_all = ijsig_all

4 changes: 2 additions & 2 deletions caiman/utils/nn_models.py
Original file line number Diff line number Diff line change
@@ -555,7 +555,7 @@ def fit_NL_model(model_NL, Y, patience=5, val_split=0.2, batch_size=32,
Y = np.expand_dims(Y, axis=-1)
run_logdir = get_run_logdir()
os.mkdir(run_logdir)
path_to_model = os.path.join(run_logdir, 'model.h5')
path_to_model = os.path.join(run_logdir, 'model.weights.h5')
chk = ModelCheckpoint(filepath=path_to_model,
verbose=0, save_best_only=True, save_weights_only=True)
es = EarlyStopping(monitor='val_loss', patience=patience,
@@ -566,7 +566,7 @@ def fit_NL_model(model_NL, Y, patience=5, val_split=0.2, batch_size=32,
history_NL = model_NL.fit(Y, Y, epochs=epochs, batch_size=batch_size,
shuffle=True, validation_split=val_split,
callbacks=callbacks)
model_NL.load_weights(os.path.join(run_logdir, 'model.h5'))
model_NL.load_weights(os.path.join(run_logdir, 'model.weights.h5'))
return model_NL, history_NL, path_to_model

def get_MCNN_model(Y, gSig=5, n_channels=8, lr=1e-4, pct=10, r_factor=1.5,