Skip to content

Commit

Permalink
compared double and single channel models. further work on prediction…
Browse files Browse the repository at this point in the history
… plotting.
  • Loading branch information
npeard committed Sep 25, 2024
1 parent df0843a commit 932b727
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 19 deletions.
8 changes: 5 additions & 3 deletions decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,14 @@ def predict_step(self, batch, batch_idx, test_mode=True, dataloader_idx=0):
if test_mode:
# x_tot, y_tot: [batch_size, 1, buffer_size], [batch_size, 1, num_groups]
print(x_tot.shape, y_tot.shape)
num_groups = y_tot.shape[2]
group_size = x_tot.shape[2] - (num_groups - 1) * self.step
num_groups = y_tot.shape[1]
group_size = x_tot.shape[1] - (num_groups - 1) * self.step
y_hat = []
for i in range(num_groups):
start_idx = i * self.step
x = x_tot[:, :, start_idx:start_idx + group_size] # [batch_size, 1, group_size]
x = x_tot[:, start_idx:start_idx + group_size, :]
x = x.reshape(x.shape[0], x.shape[2], x.shape[1])
# [batch_size, group_size, num_channels]
y_hat.append(self.model(x).flatten())
y_hat = torch.unsqueeze(torch.transpose(torch.stack(y_hat), dim0=0, dim1=1),
dim=1) # [batch_size, 1, num_groups]
Expand Down
9 changes: 5 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
# train_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\training_max10kHz_30to1kHz_10kshots_dec=256_randampl.h5py'
# test_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\test_max10kHz_30to1kHz_2kshots_dec=256_randampl.h5py'

train_file = "/Users/nolanpeard/Desktop/SMI_sim/train_single.h5"
valid_file = "/Users/nolanpeard/Desktop/SMI_sim/valid_single.h5"
test_file = "/Users/nolanpeard/Desktop/SMI_sim/test_single.h5"
train_file = "/Users/nolanpeard/Desktop/SMI_sim/train_double.h5"
valid_file = "/Users/nolanpeard/Desktop/SMI_sim/valid_double.h5"
test_file = "/Users/nolanpeard/Desktop/SMI_sim/test_double.h5"

print('begin main', datetime.datetime.now())
# step_list = [256]#, 128, 64, 32] # step sizes for rolling input
Expand All @@ -30,7 +30,8 @@

runner = train.TrainingRunner(train_file, valid_file, test_file,
step=256)
runner.plot_predictions(model_name="CNN", model_id="tdwhpu2l")
#runner.plot_predictions(model_name="CNN", model_id="tdwhpu2l")
runner.plot_predictions(model_name="CNN", model_id="e8vpuie1")

else:
print("Error: Unsupported number of command-line arguments")
27 changes: 15 additions & 12 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,15 +109,16 @@ def open_hdf5(self, rolling=True, step=256, group_size=256):
dim=1), dim=1)
# [num_shots * num_groups, 1]

if num_channels == 1:
# self.inputs = torch.unsqueeze(self.inputs, dim=1)
# self.targets = torch.unsqueeze(self.targets, dim=1)
print(self.inputs.shape, self.targets.shape)
self.inputs = torch.reshape(self.inputs, (-1, 1, group_size))
self.targets = torch.reshape(self.targets, (-1, 1, 1))
else:
self.inputs = torch.reshape(self.inputs, (-1, num_channels, group_size))
self.targets = torch.reshape(self.targets, (-1, 1, 1))
if not self.test_mode:
if num_channels == 1:
# self.inputs = torch.unsqueeze(self.inputs, dim=1)
# self.targets = torch.unsqueeze(self.targets, dim=1)
print(self.inputs.shape, self.targets.shape)
self.inputs = torch.reshape(self.inputs, (-1, 1, group_size))
self.targets = torch.reshape(self.targets, (-1, 1, 1))
else:
self.inputs = torch.reshape(self.inputs, (-1, num_channels, group_size))
self.targets = torch.reshape(self.targets, (-1, 1, 1))

# total number of group_size length sequences = num_shots * num_groups
# print("open_hdf5 input size", self.inputs.size()) # [self.length, 256]
Expand Down Expand Up @@ -296,6 +297,7 @@ def plot_predictions(self, model_name="CNN", model_id="i52c3rlz"):
y = trainer.predict(model, dataloaders=self.test_loader)

print(y[0][0].numpy().shape)
print(y[0][1].numpy().shape)
print(y[0][2].numpy().shape)
# y[batch_idx][return_idx], return_idx 0...3: 0: Predictions, 1:
# Targets, 2: inputs, 3: encoded
Expand All @@ -304,12 +306,13 @@ def plot_predictions(self, model_name="CNN", model_id="i52c3rlz"):
for i in range(len(y[0][0].numpy()[:, 0])):
fig = plt.figure(figsize=(5, 10))
ax1, ax2 = fig.subplots(2, 1)
for channel in range(y[0][2].numpy().shape[1]):
ax1.plot(y[0][2].numpy()[i, channel, :], label="Input"+str(channel))
for channel in range(y[0][2].numpy().shape[2]):
ax1.plot(y[0][2].numpy()[i, :, channel], label="Input"+str(
channel))
ax1.set_title("Inputs")
ax2.legend()

ax2.plot(y[0][1].numpy()[i, 0, :], label="Targets")
ax2.plot(y[0][1].numpy()[i, :], label="Targets")
ax2.plot(y[0][0].numpy()[i, 0, :], label="Predictions")
ax2.set_title("MSE Loss: " + str(nn.MSELoss(reduction='sum')
(y[0][0][i, :], y[0][1][i, :]).item()))
Expand Down

0 comments on commit 932b727

Please sign in to comment.