diff --git a/decoder.py b/decoder.py index 2e00297..14254ba 100644 --- a/decoder.py +++ b/decoder.py @@ -116,12 +116,14 @@ def predict_step(self, batch, batch_idx, test_mode=True, dataloader_idx=0): if test_mode: # x_tot, y_tot: [batch_size, 1, buffer_size], [batch_size, 1, num_groups] print(x_tot.shape, y_tot.shape) - num_groups = y_tot.shape[2] - group_size = x_tot.shape[2] - (num_groups - 1) * self.step + num_groups = y_tot.shape[1] + group_size = x_tot.shape[1] - (num_groups - 1) * self.step y_hat = [] for i in range(num_groups): start_idx = i * self.step - x = x_tot[:, :, start_idx:start_idx + group_size] # [batch_size, 1, group_size] + x = x_tot[:, start_idx:start_idx + group_size, :] + x = x.reshape(x.shape[0], x.shape[2], x.shape[1]) + # [batch_size, group_size, num_channels] y_hat.append(self.model(x).flatten()) y_hat = torch.unsqueeze(torch.transpose(torch.stack(y_hat), dim0=0, dim1=1), dim=1) # [batch_size, 1, num_groups] diff --git a/main.py b/main.py index 8532152..f6a91c9 100644 --- a/main.py +++ b/main.py @@ -17,9 +17,9 @@ # train_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\training_max10kHz_30to1kHz_10kshots_dec=256_randampl.h5py' # test_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\test_max10kHz_30to1kHz_2kshots_dec=256_randampl.h5py' - train_file = "/Users/nolanpeard/Desktop/SMI_sim/train_single.h5" - valid_file = "/Users/nolanpeard/Desktop/SMI_sim/valid_single.h5" - test_file = "/Users/nolanpeard/Desktop/SMI_sim/test_single.h5" + train_file = "/Users/nolanpeard/Desktop/SMI_sim/train_double.h5" + valid_file = "/Users/nolanpeard/Desktop/SMI_sim/valid_double.h5" + test_file = "/Users/nolanpeard/Desktop/SMI_sim/test_double.h5" print('begin main', datetime.datetime.now()) # step_list = [256]#, 128, 64, 32] # step sizes for rolling input @@ -30,7 +30,8 @@ runner = train.TrainingRunner(train_file, valid_file, test_file, step=256) - runner.plot_predictions(model_name="CNN", model_id="tdwhpu2l") + #runner.plot_predictions(model_name="CNN", model_id="tdwhpu2l") + runner.plot_predictions(model_name="CNN", model_id="e8vpuie1") else: print("Error: Unsupported number of command-line arguments") \ No newline at end of file diff --git a/train.py b/train.py index 5c3cedb..80e03c8 100644 --- a/train.py +++ b/train.py @@ -109,15 +109,16 @@ def open_hdf5(self, rolling=True, step=256, group_size=256): dim=1), dim=1) # [num_shots * num_groups, 1] - if num_channels == 1: - # self.inputs = torch.unsqueeze(self.inputs, dim=1) - # self.targets = torch.unsqueeze(self.targets, dim=1) - print(self.inputs.shape, self.targets.shape) - self.inputs = torch.reshape(self.inputs, (-1, 1, group_size)) - self.targets = torch.reshape(self.targets, (-1, 1, 1)) - else: - self.inputs = torch.reshape(self.inputs, (-1, num_channels, group_size)) - self.targets = torch.reshape(self.targets, (-1, 1, 1)) + if not self.test_mode: + if num_channels == 1: + # self.inputs = torch.unsqueeze(self.inputs, dim=1) + # self.targets = torch.unsqueeze(self.targets, dim=1) + print(self.inputs.shape, self.targets.shape) + self.inputs = torch.reshape(self.inputs, (-1, 1, group_size)) + self.targets = torch.reshape(self.targets, (-1, 1, 1)) + else: + self.inputs = torch.reshape(self.inputs, (-1, num_channels, group_size)) + self.targets = torch.reshape(self.targets, (-1, 1, 1)) # total number of group_size length sequences = num_shots * num_groups # print("open_hdf5 input size", self.inputs.size()) # [self.length, 256] @@ -296,6 +297,7 @@ def plot_predictions(self, model_name="CNN", model_id="i52c3rlz"): y = trainer.predict(model, dataloaders=self.test_loader) print(y[0][0].numpy().shape) + print(y[0][1].numpy().shape) print(y[0][2].numpy().shape) # y[batch_idx][return_idx], return_idx 0...3: 0: Predictions, 1: # Targets, 2: inputs, 3: encoded @@ -304,12 +306,13 @@ def plot_predictions(self, model_name="CNN", model_id="i52c3rlz"): for i in range(len(y[0][0].numpy()[:, 0])): fig = plt.figure(figsize=(5, 10)) ax1, ax2 = fig.subplots(2, 1) - for channel in range(y[0][2].numpy().shape[1]): - ax1.plot(y[0][2].numpy()[i, channel, :], label="Input"+str(channel)) + for channel in range(y[0][2].numpy().shape[2]): + ax1.plot(y[0][2].numpy()[i, :, channel], label="Input"+str( + channel)) ax1.set_title("Inputs") ax2.legend() - ax2.plot(y[0][1].numpy()[i, 0, :], label="Targets") + ax2.plot(y[0][1].numpy()[i, :], label="Targets") ax2.plot(y[0][0].numpy()[i, 0, :], label="Predictions") ax2.set_title("MSE Loss: " + str(nn.MSELoss(reduction='sum') (y[0][0][i, :], y[0][1][i, :]).item()))