Skip to content

Commit

Permalink
Implement test_mode to run model on single buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
ange1a-j14 committed Aug 20, 2024
1 parent 2e2d8b2 commit 0f6a47f
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 24 deletions.
44 changes: 34 additions & 10 deletions decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def __init__(self, model_name, model_hparams, optimizer_name,
self.model = self.create_model(model_name, model_hparams)
# Create loss module
self.loss_function = nn.MSELoss()

self.step = misc_hparams["step"]

torch.set_float32_matmul_precision('medium')

Expand Down Expand Up @@ -87,15 +89,37 @@ def validation_step(self, batch, batch_idx):
return loss

def test_step(self, batch, batch_idx):
x, y = batch
preds = self.model(x)
loss = self.loss_function(preds, y)
acc = (preds == y).float().mean()
self.log("test_acc", acc, on_step=False, on_epoch=True)
self.log("test_loss", loss, prog_bar=True)
x_tot, y_tot = batch # [batch_size, 1, buffer_size], [batch_size, 1, num_groups]
num_groups = y.shape[2]
group_size = x.shape[2] - (num_groups - 1) * self.step
avg_loss = 0
avg_acc = 0
for i in range(num_groups):
start_idx = i*self.step
x = x_tot[:, :, start_idx:start_idx+group_size]
y = y_tot[:, :, i]
preds = self.model(x)
avg_loss += self.loss_function(preds, y)
avg_acc += (preds == y).float().mean()
avg_loss /= num_groups
avg_acc /= num_groups
self.log("test_acc", avg_acc, on_step=False, on_epoch=True)
self.log("test_loss", avg_loss, prog_bar=True)
return loss

def predict_step(self, batch, batch_idx, dataloader_idx=0):
x, y = batch
y_hat = self.model(x)
return y_hat, y
def predict_step(self, batch, batch_idx, test_mode=False, dataloader_idx=0):
x_tot, y_tot = batch
if test_mode:
# x_tot, y_tot: [batch_size, 1, buffer_size], [batch_size, 1, num_groups]
num_groups = y_tot.shape[2]
group_size = x_tot.shape[2] - (num_groups - 1) * self.step
y_hat = []
for i in range(num_groups):
start_idx = i*self.step
x = x_tot[:, :, start_idx:start_idx+group_size] # [batch_size, 1, group_size]
y = y_tot[:, :, i] # [batch_size, 1, 1]
y_hat.append(model(x).flatten())
y_hat = torch.squeeze(torch.transpose(torch.stack(y_hat), dim0=0, dim1=1), dim=1) # [batch_size, 1, num_groups]
else:
y_hat = self.model(x_tot)
return y_hat, y_tot
37 changes: 23 additions & 14 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

# Define a custom Dataset class
class VelocityDataset(Dataset):
def __init__(self, h5_file, step, group_size=256):
def __init__(self, test_mode, h5_file, step, group_size=256):
self.h5_file = h5_file
self.step = step
self.group_size = group_size
Expand All @@ -27,6 +27,7 @@ def __init__(self, h5_file, step, group_size=256):
self.length = len(f['Time (s)']) * num_groups # num shots
print(self.h5_file)
self.opened_flag = False
self.test_mode = test_mode

def open_hdf5(self, rolling=True, step=256, group_size=256, ch_in = 1):
"""Set up inputs and targets. For each shot, buffer is split into groups of sequences.
Expand All @@ -49,22 +50,30 @@ def open_hdf5(self, rolling=True, step=256, group_size=256, ch_in = 1):
self.file = h5py.File(self.h5_file, 'r')
pds = torch.Tensor(np.array(self.file['PD (V)'])) # [num_shots, buffer_size]
vels = torch.Tensor(np.array(self.file['Speaker (Microns/s)'])) # [num_shots, buffer_size]

if rolling:
# ROLLING INPUT INDICES
num_groups = (pds.shape[1] - group_size) // step + 1
start_idxs = torch.arange(num_groups) * step # starting indices for each group
idxs = torch.arange(group_size)[:, None] + start_idxs
idxs = torch.transpose(idxs, dim0=0, dim1=1) # indices in shape [num_groups, group_size]
self.inputs = pds[:, idxs].reshape(-1, group_size) # [num_shots * num_groups, group_size]
grouped_vels = vels[:, idxs].reshape(-1, group_size) # [num_shots * num_groups, group_size]
self.targets = torch.unsqueeze(torch.mean(grouped_vels, dim=1), dim=1) # [num_shots * num_groups, 1]
if self.test_mode:
self.inputs = pds # [num_shots, buffer_size]
grouped_vels = vels[:, idxs] # [num_shots, num_groups, group_size]
self.targets = torch.mean(grouped_vels, dim=2) # [num_shots, num_groups]
else:
self.inputs = pds[:, idxs].reshape(-1, group_size) # [num_shots * num_groups, group_size]
grouped_vels = vels[:, idxs].reshape(-1, group_size) # [num_shots * num_groups, group_size]
self.targets = torch.unsqueeze(torch.mean(grouped_vels, dim=1), dim=1) # [num_shots * num_groups, 1]
else:
# STEP INPUT
self.inputs = torch.cat(torch.split(pds, group_size, dim=1), dim=0) # [num_shots * num_groups, group_size]
grouped_vels = torch.cat(torch.split(vels, group_size, dim=1), dim=0)
self.targets = torch.unsqueeze(torch.mean(grouped_vels, dim=1), dim=1) # [num_shots * num_groups, 1]

if self.test_mode:
assert False, 'test_mode not implemented for step input. use rolling step=256'
else:
self.inputs = torch.cat(torch.split(pds, group_size, dim=1), dim=0) # [num_shots * num_groups, group_size]
grouped_vels = torch.cat(torch.split(vels, group_size, dim=1), dim=0)
self.targets = torch.unsqueeze(torch.mean(grouped_vels, dim=1), dim=1) # [num_shots * num_groups, 1]

if ch_in == 1:
self.inputs = torch.unsqueeze(self.inputs, dim=1)
self.targets = torch.unsqueeze(self.targets, dim=1)
Expand Down Expand Up @@ -121,10 +130,10 @@ def __init__(self, training_h5, validation_h5, testing_h5, step=256,
self.checkpoint_dir = "./checkpoints"
print('TrainingRunner initialized', datetime.datetime.now())

def get_custom_dataloader(self, h5_file, batch_size=128, shuffle=True,
def get_custom_dataloader(self, test_mode, h5_file, batch_size=128, shuffle=True,
velocity_only=True):
# if velocity_only:
dataset = VelocityDataset(h5_file, self.step)
dataset = VelocityDataset(test_mode, h5_file, self.step)
print("dataset initialized")
# We can use DataLoader to get batches of data
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
Expand All @@ -135,9 +144,9 @@ def get_custom_dataloader(self, h5_file, batch_size=128, shuffle=True,

def set_dataloaders(self, batch_size=128):
self.batch_size = batch_size
self.train_loader = self.get_custom_dataloader(self.training_h5, batch_size=self.batch_size)
self.valid_loader = self.get_custom_dataloader(self.validation_h5, batch_size=self.batch_size, shuffle=False)
self.test_loader = self.get_custom_dataloader(self.testing_h5, batch_size=self.batch_size, shuffle=False)
self.train_loader = self.get_custom_dataloader(False, self.training_h5, batch_size=self.batch_size)
self.valid_loader = self.get_custom_dataloader(False, self.validation_h5, batch_size=self.batch_size, shuffle=False)
self.test_loader = self.get_custom_dataloader(True, self.testing_h5, batch_size=self.batch_size, shuffle=False)

def train_model(self, model_name, save_name=None, **kwargs):
"""Train model.
Expand Down

0 comments on commit 0f6a47f

Please sign in to comment.