Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…to ml
  • Loading branch information
ange1a-j14 committed Dec 11, 2024
2 parents c4f9781 + a07c5da commit ad0782a
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 84 deletions.
4 changes: 2 additions & 2 deletions decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(self, model_name, model_hparams, optimizer_name,
self.save_hyperparameters()
# Create model
self.model = self.create_model(model_name, model_hparams)
print(summary(self.model, input_size=(misc_hparams['batch_size'], 1, 256)))
# print(summary(self.model, input_size=(misc_hparams['batch_size'], 1, 256)))
# Create loss module
self.loss_function = nn.MSELoss()

Expand Down Expand Up @@ -56,7 +56,7 @@ def configure_optimizers(self):

# We will reduce the learning rate by 0.1 after 100 and 150 epochs
scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
milestones=[100, 150, 200],
milestones=[50, 100, 150], #changed from [100, 150, 200]
gamma=0.1)
return [optimizer], [scheduler]

Expand Down
164 changes: 87 additions & 77 deletions model_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,87 +13,97 @@
N = 16384
SMPL_RATE_DEC1 = 125e6
decimation = 256
smpl_rate = SMPL_RATE_DEC1/decimation
time_data = np.linspace(0, N-1, num=N) / smpl_rate
smpl_rate = SMPL_RATE_DEC1 / decimation
time_data = np.linspace(0, N - 1, num=N) / smpl_rate

valid_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\valid_30to1kHz_2kshots_dec=256_randampl.h5py'
train_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\training_30to1kHz_10kshots_dec=256_randampl.h5py'
test_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\test_30to1kHz_2kshots_dec=256_randampl.h5py'
valid_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\val_1to1kHz_invertspectra_trigdelay8192_sleep100ms_2kx1shots_dec=256_8192_randampl.h5py'
test_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\test_1to1kHz_invertspectra_trigdelay8192_sleep100ms_2kx1shots_dec=256_8192_randampl.h5py'
train_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\train_1to1kHz_invertspectra_trigdelay8192_sleep100ms_10kx1shots_dec=256_8192_randampl.h5py'

model_tag = "fu7qnzar"
step = 128
model_tag = "z4dscrhq"
step = 64
batch_size = 128

runner = train.TrainingRunner(train_file, valid_file, test_file, step=step)
model, result = runner.load_model(model_tag)

train_mode = True

if train_mode:
iter_loader = iter(runner.train_loader)
else:
iter_loader = iter(runner.test_loader)
# for i in range(9):
# next(iter_loader)
batch_val = next(iter_loader)
batches = [batch_val]

for batch in batches:
## Plot results
inputs = batch[0]
targets = batch[1]

print(inputs.shape)
print(targets.shape)
outputs_val, _ = model.predict_step(batch, 1, test_mode=True)

outputs_val = torch.squeeze(outputs_val).cpu().detach().numpy()
inputs_val = torch.squeeze(inputs).cpu().detach().numpy()
targets_val = torch.squeeze(targets).cpu().detach().numpy()

targets_squeezed_val = np.squeeze(targets_val)
print("targets shape", targets_squeezed_val.shape)
outputs_squeezed_val = np.squeeze(outputs_val)
print("preds shape", outputs_squeezed_val.shape)

fig, ax = plt.subplots(2)
fig.set_size_inches(8, 6)

if train_mode:
inputs_flattened = inputs_val.flatten()
ax[0].plot(inputs_flattened)
ax[0].set_title('Training PD Trace (scrambled buffer)', fontsize=15)
ax[0].set_ylabel('V', fontsize=15)
ax[0].set_xticks([])
ax[0].tick_params(axis='y', which='major', labelsize=13)
# ax[0].set_xlabel('Time (s)')
ax[1].plot(targets_squeezed_val, marker='.', label='Target')
ax[1].set_title('Velocity', fontsize=15)
ax[1].set_ylabel('um/s', fontsize=15)
ax[1].set_xlabel('Batch Idx', fontsize=15)

ax[1].plot(outputs_squeezed_val, marker='.', label='Pred')
ax[1].legend(prop={'size': 12})
ax[1].tick_params(axis='both', which='major', labelsize=13)
else:
idx = 0
ax[0].plot(time_data, inputs_val[idx])
ax[0].set_title('Test PD Trace (contiguous buffer)', fontsize=15)
ax[0].set_ylabel('V', fontsize=15)
ax[0].set_xticks([])
ax[0].tick_params(axis='y', which='major', labelsize=13)
# ax[0].set_xlabel('Time (s)')
num_groups = outputs_squeezed_val.shape[1] # 127
start_idxs = torch.arange(num_groups) * step

ax[1].plot(time_data[start_idxs], targets_squeezed_val[idx], marker='.', label='Target')
ax[1].set_title('Velocity', fontsize=15)
ax[1].set_ylabel('um/s', fontsize=15)
ax[1].set_xlabel('Time (s)', fontsize=15)

ax[1].plot(time_data[start_idxs], outputs_squeezed_val[idx], marker='.', label='Pred')
ax[1].legend(prop={'size': 12})
ax[1].tick_params(axis='both', which='major', labelsize=13)

fig.tight_layout()
plt.show()
modes = ['train', 'valid', 'test']

for mode in modes:
if mode == 'train':
iter_loader = iter(runner.train_loader)
elif mode == 'valid':
iter_loader = iter(runner.valid_loader_testmode)
else: # test mode
iter_loader = iter(runner.test_loader)
for i in range(2):
next(iter_loader)
batch_val = next(iter_loader) # batch 3
batches = [batch_val]
next(iter_loader)
batches.append(next(iter_loader)) # batch 5
next(iter_loader)
batches.append(next(iter_loader)) # batch 7

for batch in batches:
## Plot results
inputs = batch[0]
targets = batch[1]

print(inputs.shape)
print(targets.shape)
outputs_val, _ = model.predict_step(batch, 1, test_mode=True)

outputs_val = torch.squeeze(outputs_val).cpu().detach().numpy()
inputs_val = torch.squeeze(inputs).cpu().detach().numpy()
targets_val = torch.squeeze(targets).cpu().detach().numpy()

targets_squeezed_val = np.squeeze(targets_val)
print("targets shape", targets_squeezed_val.shape)
outputs_squeezed_val = np.squeeze(outputs_val)
print("preds shape", outputs_squeezed_val.shape)

fig, ax = plt.subplots(2)
fig.set_size_inches(8, 6)

if mode == 'train':
inputs_flattened = inputs_val.flatten()
ax[0].plot(inputs_flattened)
ax[0].set_title('Training PD Trace (scrambled buffer)', fontsize=15)
ax[0].set_ylabel('V', fontsize=15)
ax[0].set_xticks([])
ax[0].tick_params(axis='y', which='major', labelsize=13)
# ax[0].set_xlabel('Time (s)')
ax[1].plot(targets_squeezed_val, marker='.', label='Target')
ax[1].set_title('Velocity', fontsize=15)
ax[1].set_ylabel('um/s', fontsize=15)
ax[1].set_xlabel('Batch Idx', fontsize=15)

ax[1].plot(outputs_squeezed_val, marker='.', label='Pred')
ax[1].legend(prop={'size': 12})
ax[1].tick_params(axis='both', which='major', labelsize=13)
else:
idx = 0 # use first in batch of 128 shots
ax[0].plot(time_data, inputs_val[idx])
if mode == 'valid':
ax[0].set_title('Validation PD Trace (contiguous buffer)', fontsize=15)
elif mode == 'test':
ax[0].set_title('Test PD Trace (contiguous buffer)', fontsize=15)
ax[0].set_ylabel('V', fontsize=15)
ax[0].set_xticks([])
ax[0].tick_params(axis='y', which='major', labelsize=13)
# ax[0].set_xlabel('Time (s)')
num_groups = outputs_squeezed_val.shape[1] # 127
start_idxs = torch.arange(num_groups) * step

ax[1].plot(time_data[start_idxs], targets_squeezed_val[idx], marker='.', label='Target')
ax[1].set_title('Velocity', fontsize=15)
ax[1].set_ylabel('um/s', fontsize=15)
ax[1].set_xlabel('Time (s)', fontsize=15)

ax[1].plot(time_data[start_idxs], outputs_squeezed_val[idx], marker='.', label='Pred')
ax[1].legend(prop={'size': 12})
ax[1].tick_params(axis='both', which='major', labelsize=13)

fig.tight_layout()
plt.show()
2 changes: 1 addition & 1 deletion models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self, input_size, output_size, ch_in=1, activation='LeakyReLU'):
)
self.fc_layers = nn.Sequential(
nn.Linear(640, 16),
nn.ReLU(),
act_fn_by_name[activation],
nn.Linear(16, output_size)
)

Expand Down
9 changes: 5 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ def __init__(self, test_mode, h5_file, step, group_size=256):
self.step = step
self.group_size = group_size
with h5py.File(self.h5_file, 'r') as f:
num_groups = (f['Time (s)'].shape[1] - group_size) // step + 1
num_groups = (f['PD (V)'].shape[1] - group_size) // step + 1
if test_mode:
self.length = len(f['Time (s)']) # in test_mode, length of dataset = num shots
self.length = len(f['PD (V)']) # in test_mode, length of dataset = num shots
else:
self.length = len(f['Time (s)']) * num_groups
self.length = len(f['PD (V)']) * num_groups
print(self.h5_file)
self.opened_flag = False
self.test_mode = test_mode
Expand Down Expand Up @@ -149,6 +149,7 @@ def set_dataloaders(self, batch_size=128):
self.batch_size = batch_size
self.train_loader = self.get_custom_dataloader(False, self.training_h5, batch_size=self.batch_size)
self.valid_loader = self.get_custom_dataloader(False, self.validation_h5, batch_size=self.batch_size, shuffle=False)
self.valid_loader_testmode = self.get_custom_dataloader(True, self.validation_h5, batch_size=self.batch_size, shuffle=False)
self.test_loader = self.get_custom_dataloader(True, self.testing_h5, batch_size=self.batch_size, shuffle=False)

def train_model(self, model_name, save_name=None, **kwargs):
Expand Down Expand Up @@ -210,7 +211,7 @@ def train_model(self, model_name, save_name=None, **kwargs):
return model, result

def scan_hyperparams(self):
lr_list = [1e-3, 1e-4] # [1e-3, 1e-4, 1e-5]
lr_list = [1e-3, 1e-4]# [1e-3, 1e-4, 1e-5] # [1e-3, 1e-4, 1e-5]
act_list = ['LeakyReLU'] #, 'ReLU']
optim_list = ['Adam'] #, 'SGD']
for lr, activation, optim in product(lr_list, act_list, optim_list): #, 1e-2, 3e-2]:
Expand Down

0 comments on commit ad0782a

Please sign in to comment.