diff --git a/decoder.py b/decoder.py index 7dfac85..91fd559 100644 --- a/decoder.py +++ b/decoder.py @@ -70,7 +70,7 @@ def training_step(self, batch, batch_idx): loss = self.loss_function(preds, y) acc = (preds == y).float().mean() self.log("train_acc", acc, on_step=False, on_epoch=True) - self.log("train_loss", loss, on_step=True, prog_bar=True) + self.log("train_loss", loss, on_epoch=True, prog_bar=True) return loss def validation_step(self, batch, batch_idx): diff --git a/train.py b/train.py index a6dd57c..1403b98 100644 --- a/train.py +++ b/train.py @@ -18,24 +18,31 @@ # Define a custom Dataset class class VelocityDataset(Dataset): - def __init__(self, h5_file, step): + def __init__(self, h5_file, step, group_size=256): self.h5_file = h5_file self.step = step + self.group_size = group_size + with h5py.File(self.h5_file, 'r') as f: + num_groups = (f['Time (s)'].shape[1] - group_size) // step + 1 + self.length = len(f['Time (s)']) * num_groups # num shots print(self.h5_file) self.opened_flag = False - def open_hdf5(self, rolling=True, step, group_size=256): + def open_hdf5(self, rolling=True, step=256, group_size=256, ch_in = 1): """Set up inputs and targets. For each shot, buffer is split into groups of sequences. Inputs include grouped photodiode trace of 'group_size', spaced interval 'step' apart for each buffer. Targets include average velocity of each group. - Input shape is [num_shots * num_groups, group_size] and target shape is [num_shots * num_groups, 1], - where num_groups = (buffer_len - group_size)/step + 1, given that buffer_len - group_size is a multiple of step. + Input shape is [num_shots * num_groups, ch_in, group_size]. Target shape is [num_shots * num_groups, ch_in, 1], + where num_groups = (buffer_len - group_size)/step + 1, given that buffer_len - group_size is a multiple of step. + Input and target shape made to fit (N, C_in, L_in) in PyTorch Conv1d doc. If the given 'group_size' and 'step' do not satisfy the above requirement, the data will not be cleanly grouped. Args: - step (int): Size of step between group starts. buffer_len - grou_size = 0 (mod step). + rolling (bool, optional): Whether to use rolling input. Defaults to True. + step (int, optional): Size of step between group starts. buffer_len - grou_size = 0 (mod step). Defaults to 256. group_size (int, optional): Size of each group. buffer_len - group_size = 0 (mod step). Defaults to 256. + ch_in (int, optional): Number of input channels for model. Defaults to 1. """ # solves issue where hdf5 file opened in __init__ prevents multiple # workers: https://github.com/pytorch/pytorch/issues/11929 @@ -58,11 +65,18 @@ def open_hdf5(self, rolling=True, step, group_size=256): grouped_vels = torch.cat(torch.split(vels, group_size, dim=1), dim=0) self.targets = torch.unsqueeze(torch.mean(grouped_vels, dim=1), dim=1) # [num_shots * num_groups, 1] - self.length = self.inputs.shape[0] # total number of group_size length sequences = num_shots * num_groups + if ch_in == 1: + self.inputs = torch.unsqueeze(self.inputs, dim=1) + self.targets = torch.unsqueeze(self.targets, dim=1) + else: + assert False, 'ch > 1 not implemented' + + # total number of group_size length sequences = num_shots * num_groups # print(self.inputs.size()) # [10k*64, 256] # print(self.targets.size()) # [10k*64, 1] def __len__(self): + # print("__len__:", self.length) return self.length def __getitem__(self, idx): @@ -94,8 +108,8 @@ def __init__(self, training_h5, validation_h5, testing_h5, step=256, print("dataloaders set:", datetime.datetime.now()) input_ref = next(iter(self.train_loader)) # print("loaded next(iter", datetime.datetime.now()) - self.input_size = input_ref[0].shape[1] # group_size - self.output_size = input_ref[1].shape[1] # 1 + self.input_size = input_ref[0].shape[2] # group_size + self.output_size = input_ref[1].shape[2] # 1 print(f"input ref {len(input_ref)} , {input_ref[0].size()}") print(f"output ref {len(input_ref)} , {input_ref[1].size()}") print(f"train.py input_size {self.input_size}") @@ -155,9 +169,9 @@ def train_model(self, model_name, save_name=None, **kwargs): default_root_dir=os.path.join(self.checkpoint_dir, save_name), accelerator="gpu", devices=[0], - max_epochs=2000, + max_epochs=800, callbacks=[early_stop_callback, checkpoint_callback], - check_val_every_n_epoch=20, + check_val_every_n_epoch=5, logger=logger )