@@ -19,16 +19,18 @@ class VelocityDataset(Dataset):
19
19
def __init__ (self , h5_file ):
20
20
self .h5_file = h5_file
21
21
with h5py .File (self .h5_file , 'r' ) as f :
22
- self .length = len (f ['time_data ' ]) # num shots
22
+ self .length = len (f ['Time (s) ' ]) # num shots
23
23
24
- def open_hdf5 (self , group_size = 64 , num_groups = 256 ):
24
+ def open_hdf5 (self , num_groups = 64 , group_size = 256 ):
25
25
# solves issue where hdf5 file opened in __init__ prevents multiple
26
26
# workers: https://github.com/pytorch/pytorch/issues/11929
27
27
self .file = h5py .File (self .h5_file , 'r' )
28
- self .inputs = self .file ['PD (V)' ][:, ::group_size ] # take num_groups evenly spaced points, [num_shots, num_groups]
28
+ grouped_pd = np .array (np .hsplit (self .file ['PD (V)' ], num_groups )) # [num_groups, num_shots, group_size]
29
+ self .inputs = np .transpose (grouped_pd , [1 , 0 , 2 ]) # [num_shots, num_groups, group_size]
29
30
grouped_velocities = np .array (np .hsplit (self .file ['Speaker (Microns/s)' ], num_groups )) # [num_groups, num_shots, group_size]
30
31
grouped_velocities = np .transpose (grouped_velocities , [1 , 0 , 2 ]) # [num_shots, num_groups, group_size]
31
- self .targets = np .average (grouped_velocities , axis = 3 ) # store average velocity per group per shot: [num_shots, num_groups]
32
+ grouped_velocities = np .average (grouped_velocities , axis = 2 ) # store average velocity per group per shot: [num_shots, num_groups]
33
+ self .targets = np .expand_dims (grouped_velocities , axis = 2 ) # [num_shots, num_groups, 1]
32
34
33
35
def __len__ (self ):
34
36
return self .length
@@ -40,7 +42,7 @@ def __getitem__(self, idx):
40
42
41
43
class TrainingRunner :
42
44
def __init__ (self , training_h5 , validation_h5 , testing_h5 ,
43
- velocity_only = False ):
45
+ velocity_only = False , num_groups = 64 ):
44
46
self .training_h5 = training_h5
45
47
self .validation_h5 = validation_h5
46
48
self .testing_h5 = testing_h5
@@ -50,16 +52,22 @@ def __init__(self, training_h5, validation_h5, testing_h5,
50
52
self .set_dataloaders ()
51
53
52
54
# dimensions
53
- self .input_size = next (iter (self .train_loader ))[0 ].size (- 1 ) ** 2
54
- self .output_size = next (iter (self .train_loader ))[1 ].size (- 1 )
55
+ input_ref = next (iter (self .train_loader ))
56
+ output_ref = next (iter (self .train_loader ))
57
+ self .input_size = num_groups #input_ref[0].size(-1) #** 2
58
+ self .output_size = num_groups # output_ref[1].size(-1)
59
+ print (f"input ref { len (input_ref )} , { input_ref [0 ].size ()} " )
60
+ print (f"output ref { len (output_ref )} , { output_ref [1 ].size ()} " )
61
+ print (f"train.py input_size { self .input_size } " )
62
+ print (f"train.py output_size { self .output_size } " )
55
63
56
64
# directories
57
65
self .checkpoint_dir = "./checkpoints"
58
66
59
67
def get_custom_dataloader (self , h5_file , batch_size = 128 , shuffle = True ,
60
68
velocity_only = True ):
61
- if velocity_only :
62
- dataset = VelocityDataset (h5_file )
69
+ # if velocity_only:
70
+ dataset = VelocityDataset (h5_file )
63
71
64
72
# We can use DataLoader to get batches of data
65
73
dataloader = DataLoader (dataset , batch_size = batch_size , shuffle = shuffle ,
@@ -109,7 +117,7 @@ def train_model(self, model_name, save_name=None, **kwargs):
109
117
devices = [0 ],
110
118
max_epochs = 180 ,
111
119
callbacks = [early_stop_callback , checkpoint_callback ],
112
- check_val_every_n_epoch = 10 ,
120
+ check_val_every_n_epoch = 1 , # 10,
113
121
logger = logger
114
122
)
115
123
@@ -133,45 +141,45 @@ def train_model(self, model_name, save_name=None, **kwargs):
133
141
134
142
return model , result
135
143
136
- def scan_hyperparams (self ):
137
- for lr in [1e-3 , 1e-2 , 3e-2 ]:
138
-
139
- model_config = {"input_size" : self .input_size ,
140
- "output_size" : self .output_size }
141
- optimizer_config = {"lr" : lr }
142
- #"momentum": 0.9,}
143
- misc_config = {"batch_size" : self .batch_size }
144
-
145
- self .train_model (model_name = "CNN" ,
146
- model_hparams = model_config ,
147
- optimizer_name = "Adam" ,
148
- optimizer_hparams = optimizer_config ,
149
- misc_hparams = misc_config )
150
-
151
- def load_model (self ):
152
- Check whether pretrained model exists . If yes , load it and skip training
153
- pretrained_filename = os .path .join (self .checkpoint_dir , "SMI" , "f63rieqp" ,
154
- "checkpoints" , "*" + ".ckpt" )
155
- print (pretrained_filename )
156
- if os .path .isfile (glob .glob (pretrained_filename )[0 ]):
157
- pretrained_filename = glob .glob (pretrained_filename )[0 ]
158
- print (
159
- f"Found pretrained model at { pretrained_filename } , loading..." )
160
- # Automatically loads the model with the saved hyperparameters
161
- model = VelocityDecoder .load_from_checkpoint (pretrained_filename )
162
-
163
- # Create a PyTorch Lightning trainer with the generation callback
164
- trainer = L .Trainer (
165
- accelerator = "gpu" ,
166
- devices = [0 ]
167
- )
168
-
169
- # Test best model on validation and test set
170
- val_result = trainer .test (model , dataloaders = self .valid_loader ,
171
- verbose = False )
172
- test_result = trainer .test (model , dataloaders = self .test_loader ,
173
- verbose = False )
174
- result = {"test" : test_result [0 ]["test_acc" ],
175
- "val" : val_result [0 ]["test_acc" ]}
176
-
177
- return model , result
144
+ def scan_hyperparams (self ):
145
+ for lr in [1e-3 ]: # , 1e-2, 3e-2]:
146
+
147
+ model_config = {"input_size" : self .input_size ,
148
+ "output_size" : self .output_size }
149
+ optimizer_config = {"lr" : lr }
150
+ #"momentum": 0.9,}
151
+ misc_config = {"batch_size" : self .batch_size }
152
+
153
+ self .train_model (model_name = "CNN" ,
154
+ model_hparams = model_config ,
155
+ optimizer_name = "Adam" ,
156
+ optimizer_hparams = optimizer_config ,
157
+ misc_hparams = misc_config )
158
+
159
+ def load_model (self ):
160
+ # Check whether pretrained model exists. If yes, load it and skip training
161
+ pretrained_filename = os .path .join (self .checkpoint_dir , "SMI" , "f63rieqp" ,
162
+ "checkpoints" , "*" + ".ckpt" )
163
+ print (pretrained_filename )
164
+ if os .path .isfile (glob .glob (pretrained_filename )[0 ]):
165
+ pretrained_filename = glob .glob (pretrained_filename )[0 ]
166
+ print (
167
+ f"Found pretrained model at { pretrained_filename } , loading..." )
168
+ # Automatically loads the model with the saved hyperparameters
169
+ model = VelocityDecoder .load_from_checkpoint (pretrained_filename )
170
+
171
+ # Create a PyTorch Lightning trainer with the generation callback
172
+ trainer = L .Trainer (
173
+ accelerator = "gpu" ,
174
+ devices = [0 ]
175
+ )
176
+
177
+ # Test best model on validation and test set
178
+ val_result = trainer .test (model , dataloaders = self .valid_loader ,
179
+ verbose = False )
180
+ test_result = trainer .test (model , dataloaders = self .test_loader ,
181
+ verbose = False )
182
+ result = {"test" : test_result [0 ]["test_acc" ],
183
+ "val" : val_result [0 ]["test_acc" ]}
184
+
185
+ return model , result
0 commit comments