Skip to content

Commit

Permalink
Fix data shape to [num_shots*num_groups, in_ch, group_size]
Browse files Browse the repository at this point in the history
  • Loading branch information
ange1a-j14 committed Aug 14, 2024
1 parent ab9467f commit ad6c6f1
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 11 deletions.
2 changes: 1 addition & 1 deletion decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def training_step(self, batch, batch_idx):
loss = self.loss_function(preds, y)
acc = (preds == y).float().mean()
self.log("train_acc", acc, on_step=False, on_epoch=True)
self.log("train_loss", loss, on_step=True, prog_bar=True)
self.log("train_loss", loss, on_epoch=True, prog_bar=True)
return loss

def validation_step(self, batch, batch_idx):
Expand Down
34 changes: 24 additions & 10 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,31 @@

# Define a custom Dataset class
class VelocityDataset(Dataset):
def __init__(self, h5_file, step):
def __init__(self, h5_file, step, group_size=256):
self.h5_file = h5_file
self.step = step
self.group_size = group_size
with h5py.File(self.h5_file, 'r') as f:
num_groups = (f['Time (s)'].shape[1] - group_size) // step + 1
self.length = len(f['Time (s)']) * num_groups # num shots
print(self.h5_file)
self.opened_flag = False

def open_hdf5(self, rolling=True, step, group_size=256):
def open_hdf5(self, rolling=True, step=256, group_size=256, ch_in = 1):
"""Set up inputs and targets. For each shot, buffer is split into groups of sequences.
Inputs include grouped photodiode trace of 'group_size', spaced interval 'step' apart for each buffer.
Targets include average velocity of each group.
Input shape is [num_shots * num_groups, group_size] and target shape is [num_shots * num_groups, 1],
where num_groups = (buffer_len - group_size)/step + 1, given that buffer_len - group_size is a multiple of step.
Input shape is [num_shots * num_groups, ch_in, group_size]. Target shape is [num_shots * num_groups, ch_in, 1],
where num_groups = (buffer_len - group_size)/step + 1, given that buffer_len - group_size is a multiple of step.
Input and target shape made to fit (N, C_in, L_in) in PyTorch Conv1d doc.
If the given 'group_size' and 'step' do not satisfy the above requirement,
the data will not be cleanly grouped.
Args:
step (int): Size of step between group starts. buffer_len - grou_size = 0 (mod step).
rolling (bool, optional): Whether to use rolling input. Defaults to True.
step (int, optional): Size of step between group starts. buffer_len - grou_size = 0 (mod step). Defaults to 256.
group_size (int, optional): Size of each group. buffer_len - group_size = 0 (mod step). Defaults to 256.
ch_in (int, optional): Number of input channels for model. Defaults to 1.
"""
# solves issue where hdf5 file opened in __init__ prevents multiple
# workers: https://github.com/pytorch/pytorch/issues/11929
Expand All @@ -58,11 +65,18 @@ def open_hdf5(self, rolling=True, step, group_size=256):
grouped_vels = torch.cat(torch.split(vels, group_size, dim=1), dim=0)
self.targets = torch.unsqueeze(torch.mean(grouped_vels, dim=1), dim=1) # [num_shots * num_groups, 1]

self.length = self.inputs.shape[0] # total number of group_size length sequences = num_shots * num_groups
if ch_in == 1:
self.inputs = torch.unsqueeze(self.inputs, dim=1)
self.targets = torch.unsqueeze(self.targets, dim=1)
else:
assert False, 'ch > 1 not implemented'

# total number of group_size length sequences = num_shots * num_groups
# print(self.inputs.size()) # [10k*64, 256]
# print(self.targets.size()) # [10k*64, 1]

def __len__(self):
# print("__len__:", self.length)
return self.length

def __getitem__(self, idx):
Expand Down Expand Up @@ -94,8 +108,8 @@ def __init__(self, training_h5, validation_h5, testing_h5, step=256,
print("dataloaders set:", datetime.datetime.now())
input_ref = next(iter(self.train_loader))
# print("loaded next(iter", datetime.datetime.now())
self.input_size = input_ref[0].shape[1] # group_size
self.output_size = input_ref[1].shape[1] # 1
self.input_size = input_ref[0].shape[2] # group_size
self.output_size = input_ref[1].shape[2] # 1
print(f"input ref {len(input_ref)} , {input_ref[0].size()}")
print(f"output ref {len(input_ref)} , {input_ref[1].size()}")
print(f"train.py input_size {self.input_size}")
Expand Down Expand Up @@ -155,9 +169,9 @@ def train_model(self, model_name, save_name=None, **kwargs):
default_root_dir=os.path.join(self.checkpoint_dir, save_name),
accelerator="gpu",
devices=[0],
max_epochs=2000,
max_epochs=800,
callbacks=[early_stop_callback, checkpoint_callback],
check_val_every_n_epoch=20,
check_val_every_n_epoch=5,
logger=logger
)

Expand Down

0 comments on commit ad6c6f1

Please sign in to comment.