Skip to content

Commit 86640fc

Browse files
committed
minimal working example of 2-channel interferometer training - additional rewriting of the VelocityDataset
1 parent 4fa561f commit 86640fc

File tree

5 files changed

+87
-48
lines changed

5 files changed

+87
-48
lines changed

decoder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@ def __init__(self, model_name, model_hparams, optimizer_name,
2525
self.save_hyperparameters()
2626
# Create model
2727
self.model = self.create_model(model_name, model_hparams)
28-
print(summary(self.model, input_size=(misc_hparams['batch_size'], 1, 256)))
28+
print(summary(self.model, input_size=(misc_hparams['batch_size'],
29+
model_hparams['in_channels'],
30+
256)))
2931
# Create loss module
3032
self.loss_function = nn.MSELoss()
3133

interferometers.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@ def write_pretraining_data(num_shots, num_channels, file_path):
9696
_, signal, _, velocity = interferometer.get_buffer()
9797
signal = np.expand_dims(signal, axis=-1)
9898
velocity = np.expand_dims(velocity, axis=-1)
99+
# Want to end up with these shapes in h5 file:
100+
# signal: (num_shots, buffer_size, 1)
101+
# velocity: (num_shots, buffer_size, 1)
99102
entries = {"signal": signal, "velocity": velocity}
100103
util.write_data(file_path, entries)
101104
elif num_channels == 2:
@@ -107,8 +110,12 @@ def write_pretraining_data(num_shots, num_channels, file_path):
107110
_, signal2, _, _ = interferometer2.get_buffer()
108111
signal = np.stack((signal1, signal2), axis=-1)
109112
velocity = np.expand_dims(velocity, axis=-1)
110-
print(signal.shape)
113+
# Want to end up with these shapes in h5 file:
114+
# signal: (num_shots, buffer_size, num_channels)
115+
# velocity: (num_shots, buffer_size, 1)
111116
entries = {"signal": signal, "velocity": velocity}
117+
print(signal.shape)
118+
print(velocity.shape)
112119
util.write_data(file_path, entries)
113120

114121

main.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,19 @@
1313
if len(sys.argv) == 1:
1414
"""Run functions in this scratch area.
1515
"""
16-
valid_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\valid_max10kHz_30to1kHz_2kshots_dec=256_randampl.h5py'
17-
train_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\training_max10kHz_30to1kHz_10kshots_dec=256_randampl.h5py'
18-
test_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\test_max10kHz_30to1kHz_2kshots_dec=256_randampl.h5py'
16+
# valid_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\valid_max10kHz_30to1kHz_2kshots_dec=256_randampl.h5py'
17+
# train_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\training_max10kHz_30to1kHz_10kshots_dec=256_randampl.h5py'
18+
# test_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\test_max10kHz_30to1kHz_2kshots_dec=256_randampl.h5py'
19+
20+
train_file = "/Users/nolanpeard/Desktop/test.h5"
21+
valid_file = "/Users/nolanpeard/Desktop/test.h5"
22+
test_file = "/Users/nolanpeard/Desktop/test.h5"
1923

2024
print('begin main', datetime.datetime.now())
21-
step_list = [256, 128, 64, 32] # step sizes for rolling input
25+
step_list = [256]#, 128, 64, 32] # step sizes for rolling input
2226
for step in step_list:
23-
runner = train.TrainingRunner(train_file, valid_file, test_file, step)
27+
runner = train.TrainingRunner(train_file, valid_file, test_file,
28+
step)
2429
runner.scan_hyperparams()
2530

2631
else:

models.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77

88

99
class CNN(nn.Module):
10-
def __init__(self, input_size, output_size, ch_in=1, activation='LeakyReLU'):
10+
def __init__(self, input_size, output_size, in_channels=1, activation='LeakyReLU'):
1111
super(CNN, self).__init__()
12-
self.ch_in = ch_in
12+
self.in_channels = in_channels
13+
print("self.in_channels = ", self.in_channels)
1314
self.conv_layers = nn.Sequential(
14-
nn.Conv1d(ch_in, 16, kernel_size=7), # Lout = 250, given L = 256
15+
nn.Conv1d(in_channels, 16, kernel_size=7), # Lout = 250, given L = 256
1516
act_fn_by_name[activation],
1617
nn.MaxPool1d(2), # Lout = 125, given L = 250
1718
nn.Conv1d(16, 32, kernel_size=7), # Lout = 119, given L = 125
@@ -26,16 +27,18 @@ def __init__(self, input_size, output_size, ch_in=1, activation='LeakyReLU'):
2627
nn.MaxPool1d(2) # Lout = 10, given L = 20
2728
)
2829
self.fc_layers = nn.Sequential(
29-
nn.Linear(640, 16),
30+
#nn.Linear(640, 16),
31+
nn.Linear(320, 16),
3032
nn.ReLU(),
3133
nn.Linear(16, output_size)
3234
)
3335

3436
def forward(self, x):
37+
print("x.shape", x.shape)
3538
out = self.conv_layers(x)
3639
# print(f"post conv out size: {out.size()}") # [128, 64, 10]
37-
out = out.view(out.size(0), self.ch_in, -1)
38-
# print(f"post conv out reshaped size: {out.size()}") # confirmed [128, 1, 640]
40+
out = out.view(out.size(0), self.in_channels, -1)
41+
print(f"post conv out reshaped size: {out.size()}") # confirmed [128, 1, 640]
3942
out = self.fc_layers(out) # expect out [128, 1, 1]
4043
# print(f"post fc out size: {out.size()}") # confirmed: [128, 1, 1]
4144
return out

train.py

Lines changed: 57 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,21 @@ def __init__(self, test_mode, h5_file, step, group_size=256):
2626
self.h5_file = h5_file
2727
self.step = step
2828
self.group_size = group_size
29-
with h5py.File(self.h5_file, 'r') as f:
30-
num_groups = (f['Time (s)'].shape[1] - group_size) // step + 1
31-
if test_mode:
32-
self.length = len(f['Time (s)']) # in test_mode, length of dataset = num shots
33-
else:
34-
self.length = len(f['Time (s)']) * num_groups
29+
self.length = self.get_length(h5_file, step, group_size, test_mode)
3530
print(self.h5_file)
3631
self.opened_flag = False
3732
self.test_mode = test_mode
33+
34+
def get_length(self, h5_file, step, group_size, test_mode):
35+
with h5py.File(self.h5_file, 'r') as f:
36+
num_groups = (f['signal'].shape[1] - group_size) // step + 1
37+
if test_mode:
38+
length = len(f['signal'])
39+
# in test_mode, length of dataset = num shots
40+
else:
41+
length = len(f['signal']) * num_groups
42+
43+
return length
3844

3945
def open_hdf5(self, rolling=True, step=256, group_size=256):
4046
"""Set up inputs and targets. For each shot, buffer is split into groups of sequences.
@@ -58,44 +64,60 @@ def open_hdf5(self, rolling=True, step=256, group_size=256):
5864
signal = torch.Tensor(np.array(self.file['signal']))
5965
# [num_shots, buffer_size, num_channels]
6066
velocity = torch.Tensor(np.array(self.file['velocity']))
61-
# [num_shots, buffer_size]
67+
# [num_shots, buffer_size, 1]
6268

6369
num_channels = signal.shape[-1]
64-
if num_channels == 1:
65-
signal = torch.squeeze(signal, dim=-1)
66-
else:
67-
raise ValueError('num_channels must be 1')
68-
pass
70+
velocity = velocity.squeeze(dim=-1)
6971

7072
if rolling:
7173
# ROLLING INPUT INDICES
7274
num_groups = (signal.shape[1] - group_size) // step + 1
73-
start_idxs = torch.arange(num_groups) * step # starting indices for each group
75+
start_idxs = torch.arange(num_groups) * step
76+
# starting indices for each group
7477
idxs = torch.arange(group_size)[:, None] + start_idxs
75-
idxs = torch.transpose(idxs, dim0=0, dim1=1) # indices in shape [num_groups, group_size]
78+
idxs = torch.transpose(idxs, dim0=0, dim1=1)
79+
# indices in shape [num_groups, group_size]
7680
if self.test_mode:
77-
self.inputs = signal # [num_shots, buffer_size]
78-
grouped_vels = velocity[:, idxs] # [num_shots, num_groups, group_size]
79-
self.targets = torch.mean(grouped_vels, dim=2) # [num_shots, num_groups]
81+
self.inputs = signal # [num_shots, buffer_size, num_channels]
82+
grouped_vels = velocity[:, idxs]
83+
# [num_shots, num_groups, group_size]
84+
self.targets = torch.mean(grouped_vels, dim=2)
85+
# [num_shots, num_groups]
8086
else:
81-
self.inputs = signal[:, idxs].reshape(-1, group_size) # [num_shots * num_groups, group_size]
82-
grouped_vels = velocity[:, idxs].reshape(-1, group_size) # [num_shots * num_groups, group_size]
83-
self.targets = torch.unsqueeze(torch.mean(grouped_vels, dim=1), dim=1) # [num_shots * num_groups, 1]
87+
self.inputs = signal[:, idxs, :].reshape(-1, group_size,
88+
num_channels)
89+
# [num_shots * num_groups, group_size, num_channels]
90+
grouped_vels = velocity[:, idxs].reshape(-1, group_size)
91+
# [num_shots * num_groups, group_size]
92+
self.targets = torch.unsqueeze(torch.mean(grouped_vels, dim=1),
93+
dim=1)
94+
# [num_shots * num_groups, 1]
8495
else:
8596
# STEP INPUT
8697
if self.test_mode:
87-
assert False, 'test_mode not implemented for step input. use rolling step=256'
98+
raise NotImplementedError("test_mode not implemented for step "
99+
"input. use rolling step=256")
88100
else:
101+
self.inputs = torch.cat(torch.split(signal, group_size,
102+
dim=1), dim=0)
103+
# [num_shots * num_groups, group_size, num_channels]
104+
grouped_vels = torch.cat(torch.split(velocity, group_size,
105+
dim=1), dim=0)
89106
# [num_shots * num_groups, group_size]
90-
self.inputs = torch.cat(torch.split(signal, group_size, dim=1), dim=0)
91-
grouped_vels = torch.cat(torch.split(velocity, group_size, dim=1), dim=0)
92-
self.targets = torch.unsqueeze(torch.mean(grouped_vels, dim=1), dim=1) # [num_shots * num_groups, 1]
107+
self.targets = torch.unsqueeze(torch.mean(grouped_vels,
108+
dim=1), dim=1)
109+
# [num_shots * num_groups, 1]
93110

94111
if num_channels == 1:
95-
self.inputs = torch.unsqueeze(self.inputs, dim=1)
96-
self.targets = torch.unsqueeze(self.targets, dim=1)
112+
# self.inputs = torch.unsqueeze(self.inputs, dim=1)
113+
# self.targets = torch.unsqueeze(self.targets, dim=1)
114+
self.inputs = torch.reshape(self.inputs, (-1, 1, group_size))
115+
self.targets = torch.reshape(self.targets, (-1, 1, 1))
97116
else:
98-
assert False, 'ch > 1 not implemented'
117+
self.inputs = torch.reshape(self.inputs, (-1, num_channels, group_size))
118+
self.targets = torch.reshape(self.targets, (-1, 1, 1))
119+
print(self.inputs.shape)
120+
print(self.targets.shape)
99121

100122
# total number of group_size length sequences = num_shots * num_groups
101123
# print("open_hdf5 input size", self.inputs.size()) # [self.length, 256]
@@ -148,14 +170,13 @@ def __init__(self, training_h5, validation_h5, testing_h5, step=256,
148170
self.checkpoint_dir = "./checkpoints"
149171
print('TrainingRunner initialized', datetime.datetime.now())
150172

151-
def get_custom_dataloader(self, test_mode, h5_file, batch_size=128, shuffle=True,
152-
velocity_only=True):
153-
# if velocity_only:
173+
def get_custom_dataloader(self, test_mode, h5_file, batch_size=128, shuffle=True):
174+
154175
dataset = VelocityDataset(test_mode, h5_file, self.step)
155176
print("dataset initialized")
156177
# We can use DataLoader to get batches of data
157178
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
158-
num_workers=16, persistent_workers=True,
179+
num_workers=1, persistent_workers=True,
159180
pin_memory=True)
160181
print("dataloader initialized")
161182
return dataloader
@@ -197,8 +218,8 @@ def train_model(self, model_name, save_name=None, **kwargs):
197218
# Create a PyTorch Lightning trainer with the generation callback
198219
trainer = L.Trainer(
199220
default_root_dir=os.path.join(self.checkpoint_dir, save_name),
200-
accelerator="gpu",
201-
devices=[0],
221+
accelerator="cpu",
222+
#devices=[0],
202223
max_epochs=800,
203224
callbacks=[early_stop_callback, checkpoint_callback],
204225
check_val_every_n_epoch=5,
@@ -229,10 +250,11 @@ def scan_hyperparams(self):
229250
lr_list = [1e-3, 1e-4] # [1e-3, 1e-4, 1e-5]
230251
act_list = ['LeakyReLU'] # , 'ReLU']
231252
optim_list = ['Adam'] # , 'SGD']
232-
for lr, activation, optim in product(lr_list, act_list, optim_list): # , 1e-2, 3e-2]:
253+
for lr, activation, optim in product(lr_list, act_list, optim_list):
233254
model_config = {"input_size": self.input_size,
234255
"output_size": self.output_size,
235-
"activation": activation}
256+
"activation": activation,
257+
"in_channels": 2}
236258
optimizer_config = {"lr": lr}
237259
# "momentum": 0.9,}
238260
misc_config = {"batch_size": self.batch_size, "step": self.step}

0 commit comments

Comments
 (0)