From a07c5dab9849e83026fa5fbc2110daec0a862950 Mon Sep 17 00:00:00 2001 From: Angela Yueran Jia Date: Sat, 19 Oct 2024 15:37:27 -0700 Subject: [PATCH] Improve readability of prediction plotting code --- decoder.py | 4 +- model_analysis.py | 164 ++++++++++++++++++++++++---------------------- train.py | 1 + 3 files changed, 90 insertions(+), 79 deletions(-) diff --git a/decoder.py b/decoder.py index 91d39da..af43f9e 100644 --- a/decoder.py +++ b/decoder.py @@ -24,7 +24,7 @@ def __init__(self, model_name, model_hparams, optimizer_name, self.save_hyperparameters() # Create model self.model = self.create_model(model_name, model_hparams) - print(summary(self.model, input_size=(misc_hparams['batch_size'], 1, 256))) + # print(summary(self.model, input_size=(misc_hparams['batch_size'], 1, 256))) # Create loss module self.loss_function = nn.MSELoss() @@ -56,7 +56,7 @@ def configure_optimizers(self): # We will reduce the learning rate by 0.1 after 100 and 150 epochs scheduler = optim.lr_scheduler.MultiStepLR(optimizer, - milestones=[100, 150, 200], + milestones=[50, 100, 150], #changed from [100, 150, 200] gamma=0.1) return [optimizer], [scheduler] diff --git a/model_analysis.py b/model_analysis.py index 2d2a29b..be1bf08 100644 --- a/model_analysis.py +++ b/model_analysis.py @@ -13,87 +13,97 @@ N = 16384 SMPL_RATE_DEC1 = 125e6 decimation = 256 -smpl_rate = SMPL_RATE_DEC1/decimation -time_data = np.linspace(0, N-1, num=N) / smpl_rate +smpl_rate = SMPL_RATE_DEC1 / decimation +time_data = np.linspace(0, N - 1, num=N) / smpl_rate -valid_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\valid_30to1kHz_2kshots_dec=256_randampl.h5py' -train_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\training_30to1kHz_10kshots_dec=256_randampl.h5py' -test_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\test_30to1kHz_2kshots_dec=256_randampl.h5py' +valid_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\val_1to1kHz_invertspectra_trigdelay8192_sleep100ms_2kx1shots_dec=256_8192_randampl.h5py' +test_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\test_1to1kHz_invertspectra_trigdelay8192_sleep100ms_2kx1shots_dec=256_8192_randampl.h5py' +train_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\train_1to1kHz_invertspectra_trigdelay8192_sleep100ms_10kx1shots_dec=256_8192_randampl.h5py' -model_tag = "fu7qnzar" -step = 128 +model_tag = "z4dscrhq" +step = 64 batch_size = 128 runner = train.TrainingRunner(train_file, valid_file, test_file, step=step) model, result = runner.load_model(model_tag) -train_mode = True - -if train_mode: - iter_loader = iter(runner.train_loader) -else: - iter_loader = iter(runner.test_loader) -# for i in range(9): -# next(iter_loader) -batch_val = next(iter_loader) -batches = [batch_val] - -for batch in batches: - ## Plot results - inputs = batch[0] - targets = batch[1] - - print(inputs.shape) - print(targets.shape) - outputs_val, _ = model.predict_step(batch, 1, test_mode=True) - - outputs_val = torch.squeeze(outputs_val).cpu().detach().numpy() - inputs_val = torch.squeeze(inputs).cpu().detach().numpy() - targets_val = torch.squeeze(targets).cpu().detach().numpy() - - targets_squeezed_val = np.squeeze(targets_val) - print("targets shape", targets_squeezed_val.shape) - outputs_squeezed_val = np.squeeze(outputs_val) - print("preds shape", outputs_squeezed_val.shape) - - fig, ax = plt.subplots(2) - fig.set_size_inches(8, 6) - - if train_mode: - inputs_flattened = inputs_val.flatten() - ax[0].plot(inputs_flattened) - ax[0].set_title('Training PD Trace (scrambled buffer)', fontsize=15) - ax[0].set_ylabel('V', fontsize=15) - ax[0].set_xticks([]) - ax[0].tick_params(axis='y', which='major', labelsize=13) - # ax[0].set_xlabel('Time (s)') - ax[1].plot(targets_squeezed_val, marker='.', label='Target') - ax[1].set_title('Velocity', fontsize=15) - ax[1].set_ylabel('um/s', fontsize=15) - ax[1].set_xlabel('Batch Idx', fontsize=15) - - ax[1].plot(outputs_squeezed_val, marker='.', label='Pred') - ax[1].legend(prop={'size': 12}) - ax[1].tick_params(axis='both', which='major', labelsize=13) - else: - idx = 0 - ax[0].plot(time_data, inputs_val[idx]) - ax[0].set_title('Test PD Trace (contiguous buffer)', fontsize=15) - ax[0].set_ylabel('V', fontsize=15) - ax[0].set_xticks([]) - ax[0].tick_params(axis='y', which='major', labelsize=13) - # ax[0].set_xlabel('Time (s)') - num_groups = outputs_squeezed_val.shape[1] # 127 - start_idxs = torch.arange(num_groups) * step - - ax[1].plot(time_data[start_idxs], targets_squeezed_val[idx], marker='.', label='Target') - ax[1].set_title('Velocity', fontsize=15) - ax[1].set_ylabel('um/s', fontsize=15) - ax[1].set_xlabel('Time (s)', fontsize=15) - - ax[1].plot(time_data[start_idxs], outputs_squeezed_val[idx], marker='.', label='Pred') - ax[1].legend(prop={'size': 12}) - ax[1].tick_params(axis='both', which='major', labelsize=13) - - fig.tight_layout() - plt.show() \ No newline at end of file +modes = ['train', 'valid', 'test'] + +for mode in modes: + if mode == 'train': + iter_loader = iter(runner.train_loader) + elif mode == 'valid': + iter_loader = iter(runner.valid_loader_testmode) + else: # test mode + iter_loader = iter(runner.test_loader) + for i in range(2): + next(iter_loader) + batch_val = next(iter_loader) # batch 3 + batches = [batch_val] + next(iter_loader) + batches.append(next(iter_loader)) # batch 5 + next(iter_loader) + batches.append(next(iter_loader)) # batch 7 + + for batch in batches: + ## Plot results + inputs = batch[0] + targets = batch[1] + + print(inputs.shape) + print(targets.shape) + outputs_val, _ = model.predict_step(batch, 1, test_mode=True) + + outputs_val = torch.squeeze(outputs_val).cpu().detach().numpy() + inputs_val = torch.squeeze(inputs).cpu().detach().numpy() + targets_val = torch.squeeze(targets).cpu().detach().numpy() + + targets_squeezed_val = np.squeeze(targets_val) + print("targets shape", targets_squeezed_val.shape) + outputs_squeezed_val = np.squeeze(outputs_val) + print("preds shape", outputs_squeezed_val.shape) + + fig, ax = plt.subplots(2) + fig.set_size_inches(8, 6) + + if mode == 'train': + inputs_flattened = inputs_val.flatten() + ax[0].plot(inputs_flattened) + ax[0].set_title('Training PD Trace (scrambled buffer)', fontsize=15) + ax[0].set_ylabel('V', fontsize=15) + ax[0].set_xticks([]) + ax[0].tick_params(axis='y', which='major', labelsize=13) + # ax[0].set_xlabel('Time (s)') + ax[1].plot(targets_squeezed_val, marker='.', label='Target') + ax[1].set_title('Velocity', fontsize=15) + ax[1].set_ylabel('um/s', fontsize=15) + ax[1].set_xlabel('Batch Idx', fontsize=15) + + ax[1].plot(outputs_squeezed_val, marker='.', label='Pred') + ax[1].legend(prop={'size': 12}) + ax[1].tick_params(axis='both', which='major', labelsize=13) + else: + idx = 0 # use first in batch of 128 shots + ax[0].plot(time_data, inputs_val[idx]) + if mode == 'valid': + ax[0].set_title('Validation PD Trace (contiguous buffer)', fontsize=15) + elif mode == 'test': + ax[0].set_title('Test PD Trace (contiguous buffer)', fontsize=15) + ax[0].set_ylabel('V', fontsize=15) + ax[0].set_xticks([]) + ax[0].tick_params(axis='y', which='major', labelsize=13) + # ax[0].set_xlabel('Time (s)') + num_groups = outputs_squeezed_val.shape[1] # 127 + start_idxs = torch.arange(num_groups) * step + + ax[1].plot(time_data[start_idxs], targets_squeezed_val[idx], marker='.', label='Target') + ax[1].set_title('Velocity', fontsize=15) + ax[1].set_ylabel('um/s', fontsize=15) + ax[1].set_xlabel('Time (s)', fontsize=15) + + ax[1].plot(time_data[start_idxs], outputs_squeezed_val[idx], marker='.', label='Pred') + ax[1].legend(prop={'size': 12}) + ax[1].tick_params(axis='both', which='major', labelsize=13) + + fig.tight_layout() + plt.show() \ No newline at end of file diff --git a/train.py b/train.py index 409f050..3bd3805 100644 --- a/train.py +++ b/train.py @@ -149,6 +149,7 @@ def set_dataloaders(self, batch_size=128): self.batch_size = batch_size self.train_loader = self.get_custom_dataloader(False, self.training_h5, batch_size=self.batch_size) self.valid_loader = self.get_custom_dataloader(False, self.validation_h5, batch_size=self.batch_size, shuffle=False) + self.valid_loader_testmode = self.get_custom_dataloader(True, self.validation_h5, batch_size=self.batch_size, shuffle=False) self.test_loader = self.get_custom_dataloader(True, self.testing_h5, batch_size=self.batch_size, shuffle=False) def train_model(self, model_name, save_name=None, **kwargs):