Skip to content

Commit

Permalink
minibatch DeepGP training
Browse files Browse the repository at this point in the history
  • Loading branch information
karibbov committed Sep 13, 2023
1 parent 167f58e commit 11df8cb
Showing 1 changed file with 65 additions and 31 deletions.
96 changes: 65 additions & 31 deletions src/neps/optimizers/bayesian_optimization/models/deepGP.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def __init__(
self.device = (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
# self.device = torch.device("cpu")
self.device = torch.device("cpu")

# Save the NN args, necessary for preprocessing
self.cnn_kernel_size = neural_network_args.get("cnn_kernel_size", 3)
Expand Down Expand Up @@ -335,6 +335,7 @@ def _fit(
normalize_y: bool = False,
normalize_budget: bool = True,
n_epochs: int = 1000,
batch_size: int = 64,
optimizer_args: dict | None = None,
early_stopping: bool = True,
patience: int = 10,
Expand All @@ -357,6 +358,7 @@ def _fit(
self.learning_curves,
self.y_train,
n_epochs=n_epochs,
batch_size=batch_size,
optimizer_args=optimizer_args,
early_stopping=early_stopping,
patience=patience,
Expand All @@ -369,6 +371,7 @@ def __train_model(
learning_curves: torch.Tensor,
y_train: torch.Tensor,
n_epochs: int = 1000,
batch_size: int = 64,
optimizer_args: dict | None = None,
early_stopping: bool = True,
patience: int = 10,
Expand All @@ -389,56 +392,87 @@ def __train_model(
)

count_down = patience
min_loss_val = np.inf
min_avg_loss_val = np.inf
average_loss: float = 0.0

for epoch_nr in range(0, n_epochs):
if early_stopping and count_down == 0:
self.logger.info(
f"Epoch: {epoch_nr - 1} surrogate training stops due to early "
f"stopping with the patience: {patience} and "
f"the minimum loss of {min_loss_val} and "
f"the final loss of {loss_value}"
f"the minimum average loss of {min_avg_loss_val} and "
f"the final average loss of {average_loss}"
)
break

nr_examples_batch = x_train.size(dim=0)
# if only one example in the batch, skip the batch.
# Otherwise, the code will fail because of batchnorm
if nr_examples_batch == 1:
continue
n_examples_batch = x_train.size(dim=0)

# get a random permutation for mini-batches
permutation = torch.randperm(n_examples_batch)

# optimize over mini-batches
total_scaled_loss = 0.0
for batch_idx, start_index in enumerate(
range(0, n_examples_batch, batch_size)
):
end_index = start_index + batch_size
if end_index > n_examples_batch:
end_index = n_examples_batch
indices = permutation[start_index:end_index]
batch_x, batch_budget, batch_lc, batch_y = (
x_train[indices],
train_budgets[indices],
learning_curves[indices],
y_train[indices],
)

# Zero backprop gradients
self.optimizer.zero_grad()
minibatch_size = end_index - start_index
# if only one example in the batch, skip the batch.
# Otherwise, the code will fail because of batchnorm
if minibatch_size <= 1:
continue

# Zero backprop gradients
self.optimizer.zero_grad()

projected_x = self.nn(batch_x, batch_budget, batch_lc)
self.model.set_train_data(projected_x, batch_y, strict=False)
output = self.model(projected_x)

# try:
# Calc loss and backprop derivatives
loss = -self.mll(output, self.model.train_targets)
episodic_loss_value: float = loss.detach().to("cpu").item()
# weighted sum over losses in the batch
total_scaled_loss = (
total_scaled_loss + episodic_loss_value * minibatch_size
)

projected_x = self.nn(x_train, train_budgets, learning_curves)
self.model.set_train_data(projected_x, y_train, strict=False)
output = self.model(projected_x)
mse = gpytorch.metrics.mean_squared_error(
output, self.model.train_targets
)
self.logger.debug(
f"Epoch {epoch_nr} Batch {batch_idx} - MSE {mse:.5f}, "
f"Loss: {episodic_loss_value:.3f}, "
f"lengthscale: {self.model.covar_module.base_kernel.lengthscale.item():.3f}, "
f"noise: {self.model.likelihood.noise.item():.3f}, "
)

# try:
# Calc loss and backprop derivatives
loss = -self.mll(output, self.model.train_targets)
loss_value: float = loss.detach().to("cpu").item()
loss.backward()
self.optimizer.step()

if loss_value < min_loss_val:
min_loss_val = loss_value
# Get average weighted loss over every batch
average_loss = total_scaled_loss / n_examples_batch
if average_loss < min_avg_loss_val:
min_avg_loss_val = average_loss
count_down = patience
elif early_stopping:
self.logger.debug(
f"No improvement over the minimum loss value of {min_loss_val} "
f"No improvement over the minimum loss value of {min_avg_loss_val} "
f"for the past {patience - count_down} epochs "
f"the training will stop in {count_down} epochs"
)
count_down -= 1

mse = gpytorch.metrics.mean_squared_error(output, self.model.train_targets)
self.logger.debug(
f"Epoch {epoch_nr} - MSE {mse:.5f}, "
f"Loss: {loss_value:.3f}, "
f"lengthscale: {self.model.covar_module.base_kernel.lengthscale.item():.3f}, "
f"noise: {self.model.likelihood.noise.item():.3f}, "
)
loss.backward()
self.optimizer.step()
# except Exception as training_error:
# self.logger.error(
# f'The following error happened while training: {training_error}')
Expand Down

0 comments on commit 11df8cb

Please sign in to comment.