Skip to content

Commit 29e7a5c

Browse files
committed
Run model with no error
1 parent 43e5a20 commit 29e7a5c

File tree

4 files changed

+92
-59
lines changed

4 files changed

+92
-59
lines changed

decoder.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import lightning as L
66
from models import CNN
77

8+
model_dict = {"CNN": CNN}
9+
810
class VelocityDecoder(L.LightningModule):
911
def __init__(self, model_name, model_hparams, optimizer_name,
1012
optimizer_hparams, misc_hparams):
@@ -64,7 +66,6 @@ def training_step(self, batch, batch_idx):
6466
# training_step defines the train loop.
6567
# it is independent of forward
6668
x, y = batch
67-
x = x.view(-1, x.size(1)**2)
6869
preds = self.model(x)
6970
loss = self.loss_function(preds, y)
7071
acc = (preds == y).float().mean()
@@ -75,7 +76,9 @@ def training_step(self, batch, batch_idx):
7576
def validation_step(self, batch, batch_idx):
7677
# validation_step defines the validation loop.
7778
x, y = batch
78-
x = x.view(-1, x.size(1)**2)
79+
print(type(x))
80+
print(f"x size {x.size()}")
81+
print(f"y size {y.size()}")
7982
preds = self.model(x)
8083
loss = self.loss_function(preds, y)
8184
acc = (preds == y).float().mean()
@@ -85,7 +88,6 @@ def validation_step(self, batch, batch_idx):
8588

8689
def test_step(self, batch, batch_idx):
8790
x, y = batch
88-
x = x.view(-1, x.size(1)**2)
8991
preds = self.model(x)
9092
loss = self.loss_function(preds, y)
9193
acc = (preds == y).float().mean()
@@ -95,6 +97,5 @@ def test_step(self, batch, batch_idx):
9597

9698
def predict_step(self, batch, batch_idx, dataloader_idx=0):
9799
x, y = batch
98-
x = x.view(-1, x.size(1)**2)
99100
y_hat = self.model(x)
100101
return y_hat, y

main.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import numpy as np
2+
import sys
3+
import torch
4+
5+
import train
6+
import decoder
7+
import models
8+
9+
if __name__ == '__main__':
10+
np.random.seed(0x5EED+3)
11+
if len(sys.argv) == 1:
12+
"""Run functions in this scratch area.
13+
"""
14+
valid_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\valid_data.h5py'
15+
train_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\training_data.h5py'
16+
test_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\test_data.h5py'
17+
18+
runner = train.TrainingRunner(train_file, valid_file, test_file)
19+
runner.scan_hyperparams()
20+
21+
else:
22+
print("Error: Unsupported number of command-line arguments")

models.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,22 @@ def __init__(self, input_size, output_size):
1111
nn.MaxPool1d(2), # Lout = 125, given L = 250
1212
nn.Conv1d(16, 32, kernel_size=7), # Lout = 119, given L = 125
1313
nn.MaxPool1d(2), # Lout = 59, given L = 119
14-
nn.Conv1d(32, 64, kernel_size=7) # Lout = 53, given L = 59
14+
nn.Conv1d(32, 64, kernel_size=7), # Lout = 53, given L = 59
1515
nn.MaxPool1d(2), # Lout = 26, given L = 53
1616
nn.Dropout(0.1),
1717
nn.Conv1d(64, 64, kernel_size=7), # Lout = 20, given L = 26
1818
nn.MaxPool1d(2) # Lout = 10, given L = 20
1919
)
2020
self.fc_layers = nn.Sequential(
21-
nn.Linear(640, 16),
21+
nn.Linear(10, 16),
22+
nn.ReLU(),
2223
nn.Linear(16, 1)
2324
)
2425

2526
def forward(self, x):
2627
out = self.conv_layers(x)
27-
out = self.view(640)
28+
print(f"out size after conv: {out.size()}") # expect [128, 64, 10]
2829
out = self.fc_layers(out)
30+
print(f"out size after fc: {out.size()}") # expect [128, 64, 1]
2931
return out
3032

train.py

Lines changed: 60 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,18 @@ class VelocityDataset(Dataset):
1919
def __init__(self, h5_file):
2020
self.h5_file = h5_file
2121
with h5py.File(self.h5_file, 'r') as f:
22-
self.length = len(f['time_data']) # num shots
22+
self.length = len(f['Time (s)']) # num shots
2323

24-
def open_hdf5(self, group_size=64, num_groups=256):
24+
def open_hdf5(self, num_groups=64, group_size=256):
2525
# solves issue where hdf5 file opened in __init__ prevents multiple
2626
# workers: https://github.com/pytorch/pytorch/issues/11929
2727
self.file = h5py.File(self.h5_file, 'r')
28-
self.inputs = self.file['PD (V)'][:, ::group_size] # take num_groups evenly spaced points, [num_shots, num_groups]
28+
grouped_pd = np.array(np.hsplit(self.file['PD (V)'], num_groups)) # [num_groups, num_shots, group_size]
29+
self.inputs = np.transpose(grouped_pd, [1, 0, 2]) # [num_shots, num_groups, group_size]
2930
grouped_velocities = np.array(np.hsplit(self.file['Speaker (Microns/s)'], num_groups)) # [num_groups, num_shots, group_size]
3031
grouped_velocities = np.transpose(grouped_velocities, [1, 0, 2]) # [num_shots, num_groups, group_size]
31-
self.targets = np.average(grouped_velocities, axis=3) # store average velocity per group per shot: [num_shots, num_groups]
32+
grouped_velocities = np.average(grouped_velocities, axis=2) # store average velocity per group per shot: [num_shots, num_groups]
33+
self.targets = np.expand_dims(grouped_velocities, axis=2) # [num_shots, num_groups, 1]
3234

3335
def __len__(self):
3436
return self.length
@@ -40,7 +42,7 @@ def __getitem__(self, idx):
4042

4143
class TrainingRunner:
4244
def __init__(self, training_h5, validation_h5, testing_h5,
43-
velocity_only=False):
45+
velocity_only=False, num_groups=64):
4446
self.training_h5 = training_h5
4547
self.validation_h5 = validation_h5
4648
self.testing_h5 = testing_h5
@@ -50,16 +52,22 @@ def __init__(self, training_h5, validation_h5, testing_h5,
5052
self.set_dataloaders()
5153

5254
# dimensions
53-
self.input_size = next(iter(self.train_loader))[0].size(-1) ** 2
54-
self.output_size = next(iter(self.train_loader))[1].size(-1)
55+
input_ref = next(iter(self.train_loader))
56+
output_ref = next(iter(self.train_loader))
57+
self.input_size = num_groups #input_ref[0].size(-1) #** 2
58+
self.output_size = num_groups # output_ref[1].size(-1)
59+
print(f"input ref {len(input_ref)} , {input_ref[0].size()}")
60+
print(f"output ref {len(output_ref)} , {output_ref[1].size()}")
61+
print(f"train.py input_size {self.input_size}")
62+
print(f"train.py output_size {self.output_size}")
5563

5664
# directories
5765
self.checkpoint_dir = "./checkpoints"
5866

5967
def get_custom_dataloader(self, h5_file, batch_size=128, shuffle=True,
6068
velocity_only=True):
61-
if velocity_only:
62-
dataset = VelocityDataset(h5_file)
69+
# if velocity_only:
70+
dataset = VelocityDataset(h5_file)
6371

6472
# We can use DataLoader to get batches of data
6573
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
@@ -109,7 +117,7 @@ def train_model(self, model_name, save_name=None, **kwargs):
109117
devices=[0],
110118
max_epochs=180,
111119
callbacks=[early_stop_callback, checkpoint_callback],
112-
check_val_every_n_epoch=10,
120+
check_val_every_n_epoch=1, #10,
113121
logger=logger
114122
)
115123

@@ -133,45 +141,45 @@ def train_model(self, model_name, save_name=None, **kwargs):
133141

134142
return model, result
135143

136-
def scan_hyperparams(self):
137-
for lr in [1e-3, 1e-2, 3e-2]:
138-
139-
model_config = {"input_size": self.input_size,
140-
"output_size": self.output_size}
141-
optimizer_config = {"lr": lr}
142-
#"momentum": 0.9,}
143-
misc_config = {"batch_size": self.batch_size}
144-
145-
self.train_model(model_name="CNN",
146-
model_hparams=model_config,
147-
optimizer_name="Adam",
148-
optimizer_hparams=optimizer_config,
149-
misc_hparams=misc_config)
150-
151-
def load_model(self):
152-
Check whether pretrained model exists. If yes, load it and skip training
153-
pretrained_filename = os.path.join(self.checkpoint_dir, "SMI", "f63rieqp",
154-
"checkpoints", "*" + ".ckpt")
155-
print(pretrained_filename)
156-
if os.path.isfile(glob.glob(pretrained_filename)[0]):
157-
pretrained_filename = glob.glob(pretrained_filename)[0]
158-
print(
159-
f"Found pretrained model at {pretrained_filename}, loading...")
160-
# Automatically loads the model with the saved hyperparameters
161-
model = VelocityDecoder.load_from_checkpoint(pretrained_filename)
162-
163-
# Create a PyTorch Lightning trainer with the generation callback
164-
trainer = L.Trainer(
165-
accelerator="gpu",
166-
devices=[0]
167-
)
168-
169-
# Test best model on validation and test set
170-
val_result = trainer.test(model, dataloaders=self.valid_loader,
171-
verbose=False)
172-
test_result = trainer.test(model, dataloaders=self.test_loader,
173-
verbose=False)
174-
result = {"test": test_result[0]["test_acc"],
175-
"val": val_result[0]["test_acc"]}
176-
177-
return model, result
144+
def scan_hyperparams(self):
145+
for lr in [1e-3]:#, 1e-2, 3e-2]:
146+
147+
model_config = {"input_size": self.input_size,
148+
"output_size": self.output_size}
149+
optimizer_config = {"lr": lr}
150+
#"momentum": 0.9,}
151+
misc_config = {"batch_size": self.batch_size}
152+
153+
self.train_model(model_name="CNN",
154+
model_hparams=model_config,
155+
optimizer_name="Adam",
156+
optimizer_hparams=optimizer_config,
157+
misc_hparams=misc_config)
158+
159+
def load_model(self):
160+
# Check whether pretrained model exists. If yes, load it and skip training
161+
pretrained_filename = os.path.join(self.checkpoint_dir, "SMI", "f63rieqp",
162+
"checkpoints", "*" + ".ckpt")
163+
print(pretrained_filename)
164+
if os.path.isfile(glob.glob(pretrained_filename)[0]):
165+
pretrained_filename = glob.glob(pretrained_filename)[0]
166+
print(
167+
f"Found pretrained model at {pretrained_filename}, loading...")
168+
# Automatically loads the model with the saved hyperparameters
169+
model = VelocityDecoder.load_from_checkpoint(pretrained_filename)
170+
171+
# Create a PyTorch Lightning trainer with the generation callback
172+
trainer = L.Trainer(
173+
accelerator="gpu",
174+
devices=[0]
175+
)
176+
177+
# Test best model on validation and test set
178+
val_result = trainer.test(model, dataloaders=self.valid_loader,
179+
verbose=False)
180+
test_result = trainer.test(model, dataloaders=self.test_loader,
181+
verbose=False)
182+
result = {"test": test_result[0]["test_acc"],
183+
"val": val_result[0]["test_acc"]}
184+
185+
return model, result

0 commit comments

Comments
 (0)