From 347550b74000dfc25c39d234f76a750fde8f56dc Mon Sep 17 00:00:00 2001 From: Angela Yueran Jia Date: Sat, 12 Oct 2024 22:14:52 -0700 Subject: [PATCH] Customize actfn in fc layers --- models.py | 2 +- train.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/models.py b/models.py index a64a044..95a18ef 100644 --- a/models.py +++ b/models.py @@ -26,7 +26,7 @@ def __init__(self, input_size, output_size, ch_in=1, activation='LeakyReLU'): ) self.fc_layers = nn.Sequential( nn.Linear(640, 16), - nn.ReLU(), + act_fn_by_name[activation], nn.Linear(16, output_size) ) diff --git a/train.py b/train.py index 79d904b..409f050 100644 --- a/train.py +++ b/train.py @@ -23,11 +23,11 @@ def __init__(self, test_mode, h5_file, step, group_size=256): self.step = step self.group_size = group_size with h5py.File(self.h5_file, 'r') as f: - num_groups = (f['Time (s)'].shape[1] - group_size) // step + 1 + num_groups = (f['PD (V)'].shape[1] - group_size) // step + 1 if test_mode: - self.length = len(f['Time (s)']) # in test_mode, length of dataset = num shots + self.length = len(f['PD (V)']) # in test_mode, length of dataset = num shots else: - self.length = len(f['Time (s)']) * num_groups + self.length = len(f['PD (V)']) * num_groups print(self.h5_file) self.opened_flag = False self.test_mode = test_mode @@ -210,7 +210,7 @@ def train_model(self, model_name, save_name=None, **kwargs): return model, result def scan_hyperparams(self): - lr_list = [1e-3, 1e-4] # [1e-3, 1e-4, 1e-5] + lr_list = [1e-3, 1e-4]# [1e-3, 1e-4, 1e-5] # [1e-3, 1e-4, 1e-5] act_list = ['LeakyReLU'] #, 'ReLU'] optim_list = ['Adam'] #, 'SGD'] for lr, activation, optim in product(lr_list, act_list, optim_list): #, 1e-2, 3e-2]: