Skip to content

Commit

Permalink
Add scannable step size
Browse files Browse the repository at this point in the history
  • Loading branch information
ange1a-j14 committed Aug 13, 2024
1 parent fe34171 commit 5705a1e
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 22 deletions.
14 changes: 7 additions & 7 deletions decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def configure_optimizers(self):

# We will reduce the learning rate by 0.1 after 100 and 150 epochs
scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
milestones=[100, 150],
milestones=[1000, 1500, 1800],
gamma=0.1)
return [optimizer], [scheduler]

Expand All @@ -68,8 +68,8 @@ def training_step(self, batch, batch_idx):
x, y = batch
preds = self.model(x)
loss = self.loss_function(preds, y)
# acc = (preds == y).float().mean()
# self.log("train_acc", acc, on_step=False, on_epoch=True)
acc = (preds == y).float().mean()
self.log("train_acc", acc, on_step=False, on_epoch=True)
self.log("train_loss", loss, prog_bar=True)
return loss

Expand All @@ -81,17 +81,17 @@ def validation_step(self, batch, batch_idx):
# print(f"y size {y.size()}")
preds = self.model(x)
loss = self.loss_function(preds, y)
# acc = (preds == y).float().mean()
# self.log("val_acc", acc, on_step=False, on_epoch=True)
acc = (preds == y).float().mean()
self.log("val_acc", acc, on_step=False, on_epoch=True)
self.log("val_loss", loss, prog_bar=True)
return loss

def test_step(self, batch, batch_idx):
x, y = batch
preds = self.model(x)
loss = self.loss_function(preds, y)
# acc = (preds == y).float().mean()
# self.log("test_acc", acc, on_step=False, on_epoch=True)
acc = (preds == y).float().mean()
self.log("test_acc", acc, on_step=False, on_epoch=True)
self.log("test_loss", loss, prog_bar=True)
return loss

Expand Down
6 changes: 4 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
test_file = 'C:\\Users\\aj14\\Desktop\\SMI\\data\\test_30to1kHz_2kshots_dec=256_randampl.h5py'

print('begin main', datetime.datetime.now())
runner = train.TrainingRunner(train_file, valid_file, test_file)
runner.scan_hyperparams()
step_list = [256, 128, 64]
for step in step_list:
runner = train.TrainingRunner(train_file, valid_file, test_file, step)
runner.scan_hyperparams()

else:
print("Error: Unsupported number of command-line arguments")
2 changes: 1 addition & 1 deletion models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torch import nn
import torch

act_fn_by_name = {'Tanh': nn.Tanh(), 'LeakyReLU': nn.LeakyReLU(), 'ReLU': nn.ReLU()}
act_fn_by_name = {'LeakyReLU': nn.LeakyReLU(), 'ReLU': nn.ReLU()}

class CNN(nn.Module):
def __init__(self, input_size, output_size, activation='LeakyReLU'):
Expand Down
28 changes: 16 additions & 12 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@

# Define a custom Dataset class
class VelocityDataset(Dataset):
def __init__(self, h5_file):
def __init__(self, h5_file, step):
self.h5_file = h5_file
self.step = step
with h5py.File(self.h5_file, 'r') as f:
self.length = len(f['Time (s)']) # num shots
print(self.h5_file)
Expand Down Expand Up @@ -70,7 +71,7 @@ def __len__(self):
def __getitem__(self, idx):
# print("getitem entered")
if not self.opened_flag: #not hasattr(self, 'h5_file'):
self.open_hdf5()
self.open_hdf5(step=self.step)
self.opened_flag = True
# print("open_hdf5 finished")
return FloatTensor(self.inputs[idx]), FloatTensor(self.targets[idx])
Expand All @@ -83,12 +84,13 @@ def __getitem__(self, idx):
# return FloatTensor(self.inputs[indices]), FloatTensor(self.targets[indices])

class TrainingRunner:
def __init__(self, training_h5, validation_h5, testing_h5,
def __init__(self, training_h5, validation_h5, testing_h5, step=256,
velocity_only=True):
self.training_h5 = training_h5
self.validation_h5 = validation_h5
self.testing_h5 = testing_h5
self.velocity_only = velocity_only
self.step = step

# get dataloaders
self.set_dataloaders()
Expand All @@ -110,21 +112,20 @@ def __init__(self, training_h5, validation_h5, testing_h5,
def get_custom_dataloader(self, h5_file, batch_size=128, shuffle=True,
velocity_only=True):
# if velocity_only:
dataset = VelocityDataset(h5_file)
dataset = VelocityDataset(h5_file, self.step)
print("dataset initialized")
# We can use DataLoader to get batches of data
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
num_workers=16, persistent_workers=True,
num_workers=8, persistent_workers=True,
pin_memory=True)
print("dataloader initialized")
return dataloader

def set_dataloaders(self, batch_size=128):
self.batch_size = batch_size
self.train_loader = self.get_custom_dataloader(self.training_h5, velocity_only=self.velocity_only, batch_size=self.batch_size)
self.valid_loader = self.get_custom_dataloader(self.validation_h5, velocity_only=self.velocity_only, batch_size=self.batch_size, shuffle=False)
self.test_loader = self.get_custom_dataloader(self.testing_h5, velocity_only=self.velocity_only, batch_size=self.batch_size, shuffle=False)

self.train_loader = self.get_custom_dataloader(self.training_h5, batch_size=self.batch_size)
self.valid_loader = self.get_custom_dataloader(self.validation_h5, batch_size=self.batch_size, shuffle=False)
self.test_loader = self.get_custom_dataloader(self.testing_h5, batch_size=self.batch_size, shuffle=False)

def train_model(self, model_name, save_name=None, **kwargs):
"""Train model.
Expand Down Expand Up @@ -160,7 +161,7 @@ def train_model(self, model_name, save_name=None, **kwargs):
devices=[0],
max_epochs=2000,
callbacks=[early_stop_callback, checkpoint_callback],
check_val_every_n_epoch=10,
check_val_every_n_epoch=20,
logger=logger
)

Expand All @@ -185,13 +186,16 @@ def train_model(self, model_name, save_name=None, **kwargs):
return model, result

def scan_hyperparams(self):
for lr in [1e-4]:#[1e-3]: #, 1e-2, 3e-2]:
lr_list = [1e-3, 1e-4]
act_list = ['LeakyReLU', 'ReLU']
optim_list = ['Adam', 'SGD']
for lr, activation, step in product(lr_list, act_list, step_list): #, 1e-2, 3e-2]:

model_config = {"input_size": self.input_size,
"output_size": self.output_size}
optimizer_config = {"lr": lr}
#"momentum": 0.9,}
misc_config = {"batch_size": self.batch_size}
misc_config = {"batch_size": self.batch_size, "step": self.step}

self.train_model(model_name="CNN",
model_hparams=model_config,
Expand Down

0 comments on commit 5705a1e

Please sign in to comment.