Skip to content

Commit f9cb6dc

Browse files
committed
Bug fixes and moved epochs and patience parameters to top
1 parent 1294994 commit f9cb6dc

File tree

1 file changed

+41
-32
lines changed

1 file changed

+41
-32
lines changed

train_decoders.py

Lines changed: 41 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -8,36 +8,42 @@
88
os.environ["CUDA_VISIBLE_DEVICES"]="0" #specify GPU to use
99
from run_nn_models import run_nn_models
1010
from transfer_learn_nn import transfer_learn_nn
11-
from model_utils import unseen_modality_test, diff_specs
11+
from model_utils import unseen_modality_test, diff_specs, ntrain_combine_df, frac_combine_df
1212
from transfer_learn_nn_eeg import transfer_learn_nn_eeg
1313

1414
t_start = time.time()
1515
##################USER-DEFINED PARAMETERS##################
16-
data_lp = '.../' # data load path
17-
1816
# Where data will be saved: rootpath + dataset + '/'
1917
rootpath = '.../'
2018
dataset = 'move_rest_ecog'
2119

20+
# Data load paths
21+
ecog_lp = rootpath + 'ecog_dataset/' # data load path
22+
ecog_roi_proj_lp = ecog_lp+'proj_mat/' #
23+
2224
### Tailored decoder params (within participant) ###
2325
n_folds_tail = 3 # number of folds (per participant)
24-
spec_meas_tail = ['power'] # 'power', 'power_log', 'relative_power', 'phase', 'freqslide'
26+
spec_meas_tail = ['power', 'power_log', 'relative_power', 'phase', 'freqslide']
2527
hyps_tail = {'F1' : 20, 'dropoutRate' : 0.693, 'kernLength' : 64,
2628
'kernLength_sep' : 56, 'dropoutType' : 'SpatialDropout2D',
2729
'D' : 2, 'n_estimators' : 240, 'max_depth' : 9}
2830
hyps_tail['F2'] = hyps_tail['F1'] * hyps_tail['D'] # F2 = F1 * D
31+
epochs_tail = 300
32+
patience_tail = 30
2933

3034
### Same modality decoder params (across participants) ###
3135
n_folds_same = 36 # number of total folds
32-
spec_meas_same = ['power'] # 'power', 'power_log', 'relative_power', 'phase', 'freqslide'
36+
spec_meas_same = ['power', 'power_log', 'relative_power', 'phase', 'freqslide']
3337
hyps_same = {'F1' : 19, 'dropoutRate' : 0.342, 'kernLength' : 24,
3438
'kernLength_sep' : 88, 'dropoutType' : 'Dropout',
3539
'D' : 2, 'n_estimators' : 240, 'max_depth' : 6}
3640
hyps_same['F2'] = hyps_same['F1'] * hyps_same['D'] # F2 = F1 * D
41+
epochs_same = 300
42+
patience_same = 20
3743

3844
### Unseen modality testing params (across participants) ###
39-
eeg_lp = '.../' # path to EEG xarray data
40-
eeg_roi_proj_lp = '.../' # path to EEG projection matrix
45+
eeg_lp = rootpath + 'eeg_dataset/' # path to EEG xarray data
46+
eeg_roi_proj_lp = eeg_lp+'proj_mat/' # path to EEG projection matrix
4147

4248
### Fine-tune same modality decoders ###
4349
model_type_finetune = 'eegnet_hilb' # NN model type to fine-tune (must be either 'eegnet_hilb' or 'eegnet')
@@ -52,8 +58,7 @@
5258
sp_finetune = [rootpath + dataset + '/tf_all_per/',
5359
rootpath + dataset + '/tf_per_1dconv/',
5460
rootpath + dataset + '/tf_depth_per/',
55-
rootpath + dataset + '/tf_sep_per/',
56-
rootpath + dataset + '/tf_single_sub/'] # where to save output (should match layers_to_finetune)
61+
rootpath + dataset + '/tf_sep_per/'] # where to save output (should match layers_to_finetune)
5762

5863
# How much train/val data to use, either by number of trials or percentage of available data
5964
use_per_vals = True #if True, use percentage values (otherwise, use number of trials)
@@ -67,21 +72,20 @@
6772
n_val_parts = 1 # number of validation participants to use
6873
##################USER-DEFINED PARAMETERS##################
6974

70-
7175
#### Tailored decoder training ####
7276
for s,val in enumerate(spec_meas_tail):
7377
do_log = True if val == 'power_log' else False
7478
compute_val = 'power' if val == 'power_log' else val
7579
single_sp = rootpath + dataset + '/single_sbjs_' + val + '/'
7680
combined_sbjs = False
7781
if not os.path.exists(single_sp):
78-
os.mkdirs(single_sp)
82+
os.makedirs(single_sp)
7983
if s==0:
8084
models = ['eegnet_hilb','eegnet','rf','riemann'] # fit all decoder types
8185
else:
8286
models = ['eegnet_hilb'] # avoid fitting non-HTNet models again
83-
run_nn_models(single_sp, n_folds_tail, combined_sbjs, lp=data_lp, test_day = 'last', do_log=do_log,
84-
epochs=300, patience=30, models=models, compute_val=compute_val,
87+
run_nn_models(single_sp, n_folds_tail, combined_sbjs, ecog_lp, ecog_roi_proj_lp, test_day = 'last', do_log=do_log,
88+
epochs=epochs_tail, patience=patience_tail, models=models, compute_val=compute_val,
8589
F1 = hyps_tail['F1'], dropoutRate = hyps_tail['dropoutRate'], kernLength = hyps_tail['kernLength'],
8690
kernLength_sep = hyps_tail['kernLength_sep'], dropoutType = hyps_tail['dropoutType'],
8791
D = hyps_tail['D'], F2 = hyps_tail['F2'], n_estimators = hyps_tail['n_estimators'], max_depth = hyps_tail['max_depth'])
@@ -93,14 +97,14 @@
9397
compute_val = 'power' if val == 'power_log' else val
9498
multi_sp = rootpath + dataset + '/combined_sbjs_' + val + '/'
9599
if not os.path.exists(multi_sp):
96-
os.mkdirs(multi_sp)
100+
os.makedirs(multi_sp)
97101
combined_sbjs = True
98102
if s==0:
99103
models = ['eegnet_hilb','eegnet','rf','riemann'] # fit all decoder types
100104
else:
101105
models = ['eegnet_hilb'] # avoid fitting non-HTNet models again
102-
run_nn_models(multi_sp, n_folds_same, combined_sbjs, lp=data_lp, test_day = 'last', do_log=do_log,
103-
epochs=300, patience=20, models=models, compute_val=compute_val,
106+
run_nn_models(multi_sp, n_folds_same, combined_sbjs, ecog_lp, ecog_roi_proj_lp, test_day = 'last', do_log=do_log,
107+
epochs=epochs_same, patience=patience_same, models=models, compute_val=compute_val,
104108
F1 = hyps_same['F1'], dropoutRate = hyps_same['dropoutRate'], kernLength = hyps_same['kernLength'],
105109
kernLength_sep = hyps_same['kernLength_sep'], dropoutType = hyps_same['dropoutType'],
106110
D = hyps_same['D'], F2 = hyps_same['F2'], n_estimators = hyps_same['n_estimators'], max_depth = hyps_same['max_depth'])
@@ -132,23 +136,24 @@
132136
lp_finetune = rootpath + dataset + '/combined_sbjs_'+spec_meas+'/'
133137
if use_per_vals:
134138
for i in range(len(per_train_trials)):
135-
transfer_learn_nn(lp_finetune, sp_finetune[j], eeg_lp,
139+
transfer_learn_nn(lp_finetune, sp_finetune[j],
136140
model_type = model_type_finetune, layers_to_finetune = curr_layer,
137141
use_per_vals = use_per_vals, per_train_trials = per_train_trials[i],
138-
per_val_trials = per_val_trials[i],single_sub = single_sub, epochs=300, patience=20)
142+
per_val_trials = per_val_trials[i],single_sub = single_sub, epochs=epochs_same, patience=patience_same)
139143
else:
140144
for i in range(len(n_train_trials)):
141-
transfer_learn_nn(lp_finetune, sp_finetune[j], eeg_lp,
145+
transfer_learn_nn(lp_finetune, sp_finetune[j],
142146
model_type = model_type_finetune, layers_to_finetune = curr_layer,
143147
use_per_vals = use_per_vals, n_train_trials = n_train_trials[i],
144-
n_val_trials = n_val_trials[i], single_sub = single_sub, epochs=300, patience=20)
148+
n_val_trials = n_val_trials[i], single_sub = single_sub, epochs=epochs_same, patience=patience_same)
145149

146150
#### Unseen modality fine-tuning ####
147151
spec_meas = 'relative_power'
148152
for j,curr_layer in enumerate(layers_to_finetune):
153+
sp_finetune_eeg = sp_finetune[j][:-1]+'_eeg/'
149154
# Create save directory if does not exist already
150-
if not os.path.exists(sp_finetune[j]):
151-
os.makedirs(sp_finetune[j])
155+
if not os.path.exists(sp_finetune_eeg):
156+
os.makedirs(sp_finetune_eeg)
152157

153158
# Fine-tune with each amount of train/val data
154159
if curr_layer==layers_to_finetune[-1]:
@@ -159,32 +164,36 @@
159164
lp_finetune = rootpath + dataset + '/combined_sbjs_'+spec_meas+'/'
160165
if use_per_vals:
161166
for i in range(len(per_train_trials)):
162-
transfer_learn_nn_eeg(lp_finetune, sp_finetune[j][:-1]+'_eeg/',
167+
transfer_learn_nn_eeg(lp_finetune, sp_finetune_eeg, eeg_lp,
163168
model_type = model_type_finetune, layers_to_finetune = curr_layer,
164169
use_per_vals = use_per_vals, per_train_trials = per_train_trials[i],
165-
per_val_trials = per_val_trials[i],single_sub = single_sub, epochs=300, patience=20)
170+
per_val_trials = per_val_trials[i],single_sub = single_sub, epochs=epochs_same, patience=patience_same)
166171
else:
167172
for i in range(len(n_train_trials)):
168-
transfer_learn_nn_eeg(lp_finetune, sp_finetune[j][:-1]+'_eeg/',
173+
transfer_learn_nn_eeg(lp_finetune, sp_finetune_eeg, eeg_lp,
169174
model_type = model_type_finetune, layers_to_finetune = curr_layer,
170175
use_per_vals = use_per_vals, n_train_trials = n_train_trials[i],
171-
n_val_trials = n_val_trials[i], single_sub = single_sub, epochs=300, patience=20)
176+
n_val_trials = n_val_trials[i], single_sub = single_sub, epochs=epochs_same, patience=patience_same)
172177

173178

174179
#### Training same modality decoders with different numbers of training participants ####
175180
for i in range(max_train_parts):
176181
sp_curr = rootpath + dataset + '/combined_sbjs_ntra'+str(i+1)+'/'
177182
combined_sbjs = True
178183
if not os.path.exists(sp_curr):
179-
os.mkdirs(sp_curr)
180-
run_nn_models(sp_curr,n_folds_same,combined_sbjs,test_day = 'last', do_log=False,
181-
epochs=300, patience=20, models=['eegnet_hilb','eegnet','rf','riemann'], compute_val='power',
184+
os.makedirs(sp_curr)
185+
run_nn_models(sp_curr,n_folds_same,combined_sbjs,ecog_lp,ecog_roi_proj_lp,test_day = 'last', do_log=False,
186+
epochs=epochs_same, patience=patience_same, models=['eegnet_hilb','eegnet','rf','riemann'], compute_val='power',
182187
n_val = n_val_parts, n_train = i + 1, F1 = hyps_same['F1'], dropoutRate = hyps_same['dropoutRate'],
183188
kernLength = hyps_same['kernLength'], kernLength_sep = hyps_same['kernLength_sep'], dropoutType = hyps_same['dropoutType'],
184189
D = hyps_same['D'], F2 = hyps_same['F2'], n_estimators = hyps_same['n_estimators'], max_depth = hyps_same['max_depth'])
185-
190+
# Combine results into dataframes
191+
ntrain_combine_df(rootpath + dataset)
192+
frac_combine_df(rootpath + dataset, ecog_roi_proj_lp)
193+
194+
186195
#### Pre-compute difference spectrograms for ECoG and EEG datasets ####
187-
diff_specs(rootpath + dataset + '/combined_sbjs/', data_lp, ecog = True)
188-
diff_specs(rootpath + dataset + '/combined_sbjs/', eeg_lp, ecog = False)
196+
diff_specs(rootpath + dataset + '/combined_sbjs_power/', ecog_lp, ecog = True)
197+
diff_specs(rootpath + dataset + '/combined_sbjs_power/', eeg_lp, ecog = False)
189198

190199
print('Elapsed time: '+str(time.time() - t_start))

0 commit comments

Comments
 (0)