@@ -26,15 +26,21 @@ def __init__(self, test_mode, h5_file, step, group_size=256):
26
26
self .h5_file = h5_file
27
27
self .step = step
28
28
self .group_size = group_size
29
- with h5py .File (self .h5_file , 'r' ) as f :
30
- num_groups = (f ['Time (s)' ].shape [1 ] - group_size ) // step + 1
31
- if test_mode :
32
- self .length = len (f ['Time (s)' ]) # in test_mode, length of dataset = num shots
33
- else :
34
- self .length = len (f ['Time (s)' ]) * num_groups
29
+ self .length = self .get_length (h5_file , step , group_size , test_mode )
35
30
print (self .h5_file )
36
31
self .opened_flag = False
37
32
self .test_mode = test_mode
33
+
34
+ def get_length (self , h5_file , step , group_size , test_mode ):
35
+ with h5py .File (self .h5_file , 'r' ) as f :
36
+ num_groups = (f ['signal' ].shape [1 ] - group_size ) // step + 1
37
+ if test_mode :
38
+ length = len (f ['signal' ])
39
+ # in test_mode, length of dataset = num shots
40
+ else :
41
+ length = len (f ['signal' ]) * num_groups
42
+
43
+ return length
38
44
39
45
def open_hdf5 (self , rolling = True , step = 256 , group_size = 256 ):
40
46
"""Set up inputs and targets. For each shot, buffer is split into groups of sequences.
@@ -58,44 +64,60 @@ def open_hdf5(self, rolling=True, step=256, group_size=256):
58
64
signal = torch .Tensor (np .array (self .file ['signal' ]))
59
65
# [num_shots, buffer_size, num_channels]
60
66
velocity = torch .Tensor (np .array (self .file ['velocity' ]))
61
- # [num_shots, buffer_size]
67
+ # [num_shots, buffer_size, 1 ]
62
68
63
69
num_channels = signal .shape [- 1 ]
64
- if num_channels == 1 :
65
- signal = torch .squeeze (signal , dim = - 1 )
66
- else :
67
- raise ValueError ('num_channels must be 1' )
68
- pass
70
+ velocity = velocity .squeeze (dim = - 1 )
69
71
70
72
if rolling :
71
73
# ROLLING INPUT INDICES
72
74
num_groups = (signal .shape [1 ] - group_size ) // step + 1
73
- start_idxs = torch .arange (num_groups ) * step # starting indices for each group
75
+ start_idxs = torch .arange (num_groups ) * step
76
+ # starting indices for each group
74
77
idxs = torch .arange (group_size )[:, None ] + start_idxs
75
- idxs = torch .transpose (idxs , dim0 = 0 , dim1 = 1 ) # indices in shape [num_groups, group_size]
78
+ idxs = torch .transpose (idxs , dim0 = 0 , dim1 = 1 )
79
+ # indices in shape [num_groups, group_size]
76
80
if self .test_mode :
77
- self .inputs = signal # [num_shots, buffer_size]
78
- grouped_vels = velocity [:, idxs ] # [num_shots, num_groups, group_size]
79
- self .targets = torch .mean (grouped_vels , dim = 2 ) # [num_shots, num_groups]
81
+ self .inputs = signal # [num_shots, buffer_size, num_channels]
82
+ grouped_vels = velocity [:, idxs ]
83
+ # [num_shots, num_groups, group_size]
84
+ self .targets = torch .mean (grouped_vels , dim = 2 )
85
+ # [num_shots, num_groups]
80
86
else :
81
- self .inputs = signal [:, idxs ].reshape (- 1 , group_size ) # [num_shots * num_groups, group_size]
82
- grouped_vels = velocity [:, idxs ].reshape (- 1 , group_size ) # [num_shots * num_groups, group_size]
83
- self .targets = torch .unsqueeze (torch .mean (grouped_vels , dim = 1 ), dim = 1 ) # [num_shots * num_groups, 1]
87
+ self .inputs = signal [:, idxs , :].reshape (- 1 , group_size ,
88
+ num_channels )
89
+ # [num_shots * num_groups, group_size, num_channels]
90
+ grouped_vels = velocity [:, idxs ].reshape (- 1 , group_size )
91
+ # [num_shots * num_groups, group_size]
92
+ self .targets = torch .unsqueeze (torch .mean (grouped_vels , dim = 1 ),
93
+ dim = 1 )
94
+ # [num_shots * num_groups, 1]
84
95
else :
85
96
# STEP INPUT
86
97
if self .test_mode :
87
- assert False , 'test_mode not implemented for step input. use rolling step=256'
98
+ raise NotImplementedError ("test_mode not implemented for step "
99
+ "input. use rolling step=256" )
88
100
else :
101
+ self .inputs = torch .cat (torch .split (signal , group_size ,
102
+ dim = 1 ), dim = 0 )
103
+ # [num_shots * num_groups, group_size, num_channels]
104
+ grouped_vels = torch .cat (torch .split (velocity , group_size ,
105
+ dim = 1 ), dim = 0 )
89
106
# [num_shots * num_groups, group_size]
90
- self .inputs = torch .cat (torch .split ( signal , group_size , dim = 1 ), dim = 0 )
91
- grouped_vels = torch . cat ( torch . split ( velocity , group_size , dim = 1 ), dim = 0 )
92
- self . targets = torch . unsqueeze ( torch . mean ( grouped_vels , dim = 1 ), dim = 1 ) # [num_shots * num_groups, 1]
107
+ self .targets = torch .unsqueeze (torch .mean ( grouped_vels ,
108
+ dim = 1 ), dim = 1 )
109
+ # [num_shots * num_groups, 1]
93
110
94
111
if num_channels == 1 :
95
- self .inputs = torch .unsqueeze (self .inputs , dim = 1 )
96
- self .targets = torch .unsqueeze (self .targets , dim = 1 )
112
+ # self.inputs = torch.unsqueeze(self.inputs, dim=1)
113
+ # self.targets = torch.unsqueeze(self.targets, dim=1)
114
+ self .inputs = torch .reshape (self .inputs , (- 1 , 1 , group_size ))
115
+ self .targets = torch .reshape (self .targets , (- 1 , 1 , 1 ))
97
116
else :
98
- assert False , 'ch > 1 not implemented'
117
+ self .inputs = torch .reshape (self .inputs , (- 1 , num_channels , group_size ))
118
+ self .targets = torch .reshape (self .targets , (- 1 , 1 , 1 ))
119
+ print (self .inputs .shape )
120
+ print (self .targets .shape )
99
121
100
122
# total number of group_size length sequences = num_shots * num_groups
101
123
# print("open_hdf5 input size", self.inputs.size()) # [self.length, 256]
@@ -148,14 +170,13 @@ def __init__(self, training_h5, validation_h5, testing_h5, step=256,
148
170
self .checkpoint_dir = "./checkpoints"
149
171
print ('TrainingRunner initialized' , datetime .datetime .now ())
150
172
151
- def get_custom_dataloader (self , test_mode , h5_file , batch_size = 128 , shuffle = True ,
152
- velocity_only = True ):
153
- # if velocity_only:
173
+ def get_custom_dataloader (self , test_mode , h5_file , batch_size = 128 , shuffle = True ):
174
+
154
175
dataset = VelocityDataset (test_mode , h5_file , self .step )
155
176
print ("dataset initialized" )
156
177
# We can use DataLoader to get batches of data
157
178
dataloader = DataLoader (dataset , batch_size = batch_size , shuffle = shuffle ,
158
- num_workers = 16 , persistent_workers = True ,
179
+ num_workers = 1 , persistent_workers = True ,
159
180
pin_memory = True )
160
181
print ("dataloader initialized" )
161
182
return dataloader
@@ -197,8 +218,8 @@ def train_model(self, model_name, save_name=None, **kwargs):
197
218
# Create a PyTorch Lightning trainer with the generation callback
198
219
trainer = L .Trainer (
199
220
default_root_dir = os .path .join (self .checkpoint_dir , save_name ),
200
- accelerator = "gpu " ,
201
- devices = [0 ],
221
+ accelerator = "cpu " ,
222
+ # devices=[0],
202
223
max_epochs = 800 ,
203
224
callbacks = [early_stop_callback , checkpoint_callback ],
204
225
check_val_every_n_epoch = 5 ,
@@ -229,10 +250,11 @@ def scan_hyperparams(self):
229
250
lr_list = [1e-3 , 1e-4 ] # [1e-3, 1e-4, 1e-5]
230
251
act_list = ['LeakyReLU' ] # , 'ReLU']
231
252
optim_list = ['Adam' ] # , 'SGD']
232
- for lr , activation , optim in product (lr_list , act_list , optim_list ): # , 1e-2, 3e-2]:
253
+ for lr , activation , optim in product (lr_list , act_list , optim_list ):
233
254
model_config = {"input_size" : self .input_size ,
234
255
"output_size" : self .output_size ,
235
- "activation" : activation }
256
+ "activation" : activation ,
257
+ "in_channels" : 2 }
236
258
optimizer_config = {"lr" : lr }
237
259
# "momentum": 0.9,}
238
260
misc_config = {"batch_size" : self .batch_size , "step" : self .step }
0 commit comments