From 2d029cf41837e0a52d1a5b7630210fc8f23360fc Mon Sep 17 00:00:00 2001 From: TarekAbouChakra Date: Mon, 7 Aug 2023 08:14:42 +0200 Subject: [PATCH 1/6] tensorboard functionality --- .../convenience/neps_tblogger_tutorial.py | 424 ++++++++++++++ .../experimental/tensorboard_eval.py | 387 ------------- pyproject.toml | 4 + src/metahyper/api.py | 12 + src/neps/api.py | 7 + src/neps/plot/tensorboard_eval.py | 537 ++++++++++++++++++ 6 files changed, 984 insertions(+), 387 deletions(-) create mode 100644 neps_examples/convenience/neps_tblogger_tutorial.py delete mode 100644 neps_examples/experimental/tensorboard_eval.py create mode 100644 src/neps/plot/tensorboard_eval.py diff --git a/neps_examples/convenience/neps_tblogger_tutorial.py b/neps_examples/convenience/neps_tblogger_tutorial.py new file mode 100644 index 00000000..24fbc108 --- /dev/null +++ b/neps_examples/convenience/neps_tblogger_tutorial.py @@ -0,0 +1,424 @@ +""" +NePS tblogger With TensorBoard +==================================== +This tutorial demonstrates how to use TensorBoard plugin with NePS tblogger class +to detect performance data of the different model configurations during training. + + +Setup +----- +To install ``torchvision`` and ``tensorboard`` use the following command: + +.. code-block:: + + pip install torchvision + +""" +import argparse +import logging +import os +import random +import shutil +import time +from typing import Tuple + +import numpy as np +import torch +import torch.nn as nn +import torchvision +from torch.optim import lr_scheduler +from torch.utils.data.dataloader import DataLoader +from torch.utils.data.sampler import SubsetRandomSampler +from torchvision.transforms import transforms + +import neps +from neps.plot.tensorboard_eval import tblogger + +""" +Steps: + +#1 Define the seeds for reproducibility. +#2 Prepare the input data. +#3 Design the model. +#4 Design the pipeline search spaces. +#5 Design the run pipeline function. +#6 Use neps.run the run the entire search using your specified searcher. + +""" + +############################################################# +# Definig the seeds for reproducibility + + +def set_seed(seed=123): + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + + +############################################################# +# Prepare the input data. For this tutorial we use the MNIST dataset. + + +def MNIST( + batch_size: int = 32, n_train: int = 8192, n_valid: int = 1024 +) -> Tuple[DataLoader, DataLoader, DataLoader]: + train_dataset = torchvision.datasets.MNIST( + root="./data", train=True, transform=transforms.ToTensor(), download=True + ) + test_dataset = torchvision.datasets.MNIST( + root="./data", train=False, transform=transforms.ToTensor(), download=True + ) + + train_sampler = SubsetRandomSampler(range(n_train)) + valid_sampler = SubsetRandomSampler(range(n_train, n_train + n_valid)) + train_dataloader = DataLoader( + dataset=train_dataset, batch_size=batch_size, shuffle=False, sampler=train_sampler + ) + val_dataloader = DataLoader( + dataset=train_dataset, batch_size=batch_size, shuffle=False, sampler=valid_sampler + ) + test_dataloader = DataLoader( + dataset=test_dataset, batch_size=batch_size, shuffle=False + ) + + return train_dataloader, val_dataloader, test_dataloader + + +############################################################# +# Design small MLP model to be able to represent the input data. + + +class MLP(nn.Module): + def __init__(self) -> None: + super().__init__() + self.relu = nn.ReLU() + self.linear1 = nn.Linear(in_features=784, out_features=392) + self.linear2 = nn.Linear(in_features=392, out_features=196) + self.linear3 = nn.Linear(in_features=196, out_features=10) + + def forward(self, x): + x = x.view(x.size(0), -1) + x = self.relu(self.linear1(x)) + x = self.relu(self.linear2(x)) + x = self.linear3(x) + + return x + + +############################################################# +# Define the training step and return the validation error and misclassified images. + + +def loss_ev(model: nn.Module, data_loader: DataLoader) -> float: + model.eval() + correct = 0 + total = 0 + with torch.no_grad(): + for x, y in data_loader: + output = model(x) + _, predicted = torch.max(output.data, 1) + correct += (predicted == y).sum().item() + total += y.size(0) + + accuracy = correct / total + return 1 - accuracy + + +def training(model, optimizer, criterion, train_loader, validation_loader): + """ + Function that trains the model for one epoch and evaluates the model on the validation set. Used by the searcher. + + Args: + model (nn.Module): Model to be trained. + optimizer (torch.nn.optim): Optimizer used to train the weights (depends on the pipeline space). + criterion (nn.modules.loss) : Loss function to use. + train_loader (torch.utils.Dataloader): Data loader containing the training data. + validation_loader (torch.utils.Dataloader): Data loader containing the validation data. + + Returns: + (float) validation error for the epoch. + """ + incorrect_images = [] + model.train() + for x, y in train_loader: + optimizer.zero_grad() + output = model(x) + loss = criterion(output, y) + loss.backward() + optimizer.step() + + predicted_labels = torch.argmax(output, dim=1) + incorrect_mask = predicted_labels != y + incorrect_images.append(x[incorrect_mask]) + + validation_loss = loss_ev(model, validation_loader) + + if len(incorrect_images) > 0: + incorrect_images = torch.cat(incorrect_images, dim=0) + + return validation_loss, incorrect_images + + +############################################################# +# Design the pipeline search spaces. + + +# For BO: +def pipeline_space_BO() -> dict: + pipeline = dict( + lr=neps.FloatParameter(lower=1e-5, upper=1e-1, log=True), + optim=neps.CategoricalParameter(choices=["Adam", "SGD"]), + weight_decay=neps.FloatParameter(lower=1e-4, upper=1e-1, log=True), + ) + + return pipeline + + +# For Hyperband +def pipeline_space_Hyperband() -> dict: + pipeline = dict( + lr=neps.FloatParameter(lower=1e-5, upper=1e-1, log=True), + optim=neps.CategoricalParameter(choices=["Adam", "SGD"]), + weight_decay=neps.FloatParameter(lower=1e-4, upper=1e-1, log=True), + epochs=neps.IntegerParameter(lower=1, upper=9, is_fidelity=True), + ) + + return pipeline + + +############################################################# +# Implement the pipeline run search. + + +# For BO: +def run_pipeline_BO(lr, optim, weight_decay): + model = MLP() + + if optim == "Adam": + optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) + elif optim == "SGD": + optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay) + + max_epochs = 9 + + train_loader, validation_loader, test_loader = MNIST( + batch_size=64, n_train=4096, n_valid=512 + ) + + scheduler = lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.75) + + criterion = nn.CrossEntropyLoss() + losses = [] + + tblogger.disable(False) + + for i in range(max_epochs): + loss, miss_img = training( + optimizer=optimizer, + model=model, + criterion=criterion, + train_loader=train_loader, + validation_loader=validation_loader, + ) + losses.append(loss) + + tblogger.log( + loss=loss, + current_epoch=i, + data={ + "lr_decay": tblogger.scalar_logging(value=scheduler.get_last_lr()[0]), + "miss_img": tblogger.image_logging(img_tensor=miss_img, counter=2), + "layer_gradient": tblogger.layer_gradient_logging(model=model), + }, + ) + + scheduler.step() + + print(f" Epoch {i + 1} / {max_epochs} Val Error: {loss} ") + + train_accuracy = loss_ev(model, train_loader) + test_accuracy = loss_ev(model, test_loader) + + return { + "loss": loss, + "info_dict": { + "train_accuracy": train_accuracy, + "test_accuracy": test_accuracy, + "val_errors": losses, + "cost": max_epochs, + }, + } + + +# For Hyperband +def run_pipeline_Hyperband(pipeline_directory, previous_pipeline_directory, **configs): + model = MLP() + checkpoint_name = "checkpoint.pth" + start_epoch = 0 + + train_loader, validation_loader, test_loader = MNIST( + batch_size=32, n_train=4096, n_valid=512 + ) + + # define loss + criterion = nn.CrossEntropyLoss() + + # Define the optimizer + if configs["optim"] == "Adam": + optimizer = torch.optim.Adam( + model.parameters(), lr=configs["lr"], weight_decay=configs["weight_decay"] + ) + elif configs["optim"] == "SGD": + optimizer = torch.optim.SGD( + model.parameters(), lr=configs["lr"], weight_decay=configs["weight_decay"] + ) + + scheduler = lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.75) + + # We make use of checkpointing to resume training models on higher fidelities + if previous_pipeline_directory is not None: + # Read in state of the model after the previous fidelity rung + checkpoint = torch.load(previous_pipeline_directory / checkpoint_name) + model.load_state_dict(checkpoint["model_state_dict"]) + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + epochs_previously_spent = checkpoint["epoch"] + else: + epochs_previously_spent = 0 + + start_epoch += epochs_previously_spent + + losses = list() + + tblogger.disable(False) + + epochs = configs["epochs"] + + for epoch in range(start_epoch, epochs): + # Call the training function, get the validation errors and append them to val errors + loss, miss_img = training( + model, optimizer, criterion, train_loader, validation_loader + ) + losses.append(loss) + + tblogger.log( + loss=loss, + current_epoch=epoch, + hparam_accuracy_mode=True, + data={ + "lr_decay": tblogger.scalar_logging(value=scheduler.get_last_lr()[0]), + "miss_img": tblogger.image_logging(img_tensor=miss_img, counter=2), + "layer_gradient": tblogger.layer_gradient_logging(model=model), + }, + ) + + scheduler.step() + + print(f" Epoch {epoch + 1} / {epochs} Val Error: {loss} ") + + train_accuracy = loss_ev(model, train_loader) + test_accuracy = loss_ev(model, test_loader) + + torch.save( + { + "epoch": epochs, + "model_state_dict": model.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + }, + pipeline_directory / checkpoint_name, + ) + + return { + "loss": loss, + "info_dict": { + "train_accuracy": train_accuracy, + "test_accuracy": test_accuracy, + "val_errors": losses, + "cost": epochs - epochs_previously_spent, + }, + "cost": epochs - epochs_previously_spent, + } + + +############################################################# +""" +Defining the main with argument parsing to use either BO or Hyperband and specifying their +respective properties +""" + +if __name__ == "__main__": + argParser = argparse.ArgumentParser() + argParser.add_argument( + "--searcher", + type=str, + choices=["bayesian_optimization", "hyperband"], + default="bayesian_optimization", + help="Searcher type used", + ) + argParser.add_argument( + "--max_cost_total", type=int, default=30, help="Max cost used for Hyperband" + ) + argParser.add_argument( + "--max_evaluations_total", type=int, default=10, help="Max evaluation used for BO" + ) + args = argParser.parse_args() + + if args.searcher == "hyperband": + start_time = time.time() + set_seed(112) + logging.basicConfig(level=logging.INFO) + if os.path.exists("results/hyperband"): + shutil.rmtree("results/hyperband") + neps.run( + run_pipeline=run_pipeline_Hyperband, + pipeline_space=pipeline_space_Hyperband(), + root_directory="hyperband", + max_cost_total=args.max_cost_total, + searcher="hyperband", + ) + + """ + To check live plots during this command run, please open a new terminal with the directory of this saved project and run + + tensorboard --logdir hyperband + """ + + end_time = time.time() # Record the end time + execution_time = end_time - start_time + print(f"Execution time: {execution_time} seconds") + + elif args.searcher == "bayesian_optimization": + start_time = time.time() + set_seed(112) + logging.basicConfig(level=logging.INFO) + if os.path.exists("results/bayesian_optimization"): + shutil.rmtree("results/bayesian_optimization") + neps.run( + run_pipeline=run_pipeline_BO, + pipeline_space=pipeline_space_BO(), + root_directory="bayesian_optimization", + max_evaluations_total=args.max_evaluations_total, + searcher="bayesian_optimization", + ) + + """ + To check live plots during this command run, please open a new terminal with the directory of this saved project and run + + tensorboard --logdir bayesian_optimization + """ + + end_time = time.time() # Record the end time + execution_time = end_time - start_time + print(f"Execution time: {execution_time} seconds") + + """ + When running this code without any arguments, it will by default run bayesian optimization with 10 max evaluations + of 9 epochs each: + + python neps_tblogger_tutorial.py + + + If you wish to do this run with hyperband searcher with default max cost total of 30. Please run this command on the terminal: + + python neps_tblogger_tutorial.py --searcher hyperband + """ diff --git a/neps_examples/experimental/tensorboard_eval.py b/neps_examples/experimental/tensorboard_eval.py deleted file mode 100644 index d5667eac..00000000 --- a/neps_examples/experimental/tensorboard_eval.py +++ /dev/null @@ -1,387 +0,0 @@ -import math -import random -import warnings -from typing import Union - -import numpy as np -import torch -import torch.nn as nn -from torch.utils.tensorboard import SummaryWriter -from torch.utils.tensorboard.summary import hparams - - -# Inherit from class and change to fit purpose: -class SummaryWriter_(SummaryWriter): - """ - This function before the update used to create another subfolder inside the logdir and then create further 'tfevent' - which makes everything else uneasy to differentiate and hence this gives the same result with a much easier way and logs - everything on the same 'tfevent' as for other functions. - In addition, a change in the metric dictiornay was made for the cause of making the printed 'Loss' or 'Accuracy' display on the - Summary file - """ - - def add_hparams(self, hparam_dict, metric_dict, global_step): - if not isinstance(hparam_dict, dict) or not isinstance(metric_dict, dict): - raise TypeError("hparam_dict and metric_dict should be dictionary.") - updated_metric = {} - for key, value in metric_dict.items(): - updated_key = "Summary" + "/" + key - updated_metric[updated_key] = value - exp, ssi, sei = hparams(hparam_dict, updated_metric) - - self.file_writer.add_summary(exp) - self.file_writer.add_summary(ssi) - self.file_writer.add_summary(sei) - for k, v in updated_metric.items(): - self.add_scalar(tag=k, scalar_value=v, global_step=global_step) - - -class tensorboard_evaluations: - def __init__(self, log_dir: str = "/logs") -> None: - self._log_dir = log_dir - - self._best_incum_track = np.inf - self._step_update = 1 - - self._toggle_epoch_max_reached = False - - self._config_track = 1 - - self._fidelity_search_count = 0 - self._fidelity_counter = 0 - self._fidelity_bool = False - self._fidelity_was_bool = False - - self._config_dict: dict[str, dict[str, Union[list[str], float, int]]] = {} - self._config_track_last = 1 - self._prev_config_list: list[str] = [] - - self._writer_config = [] - self._writer_summary = SummaryWriter_(log_dir=self._log_dir + "/summary") - self._writer_config.append( - SummaryWriter_( - log_dir=self._log_dir + "/configs" + "/config_" + str(self._config_track) - ) - ) - - def _make_grid(self, images: torch.tensor, nrow: int, padding: int = 2): - batch_size, num_channels, height, width = images.size() - x_mapping = min(nrow, batch_size) - y_mapping = int(math.ceil(float(batch_size) / x_mapping)) - height, width = height + 2, width + 2 - - grid = torch.zeros( - (num_channels, height * y_mapping + padding, width * x_mapping + padding) - ) - - k = 0 - for y in range(y_mapping): - for x in range(x_mapping): - if k >= batch_size: - break - image = images[k] - grid[ - :, - y * height + padding : y * height + padding + height - padding, - x * width + padding : x * width + padding + width - padding, - ] = image - k += 1 - - return grid - - def _incumbent(self, **incum_data) -> None: - """ - A function used to mainly display out the incumbent trajectory based on the step update which is after finishing every computation. - In other words, epochs == max_epochs - """ - loss = incum_data["loss"] - if loss < self._best_incum_track: - self._best_incum_track = loss - self._writer_summary.add_scalar( - tag="Summary" + "/Incumbent_Graph", - scalar_value=self._best_incum_track, - global_step=self._step_update, - ) - self._step_update += 1 - - def _track_config(self, **config_data) -> None: - config_list = config_data["config_list"] - loss = float(config_data["loss"]) - - for config_dict in self._config_dict.values(): - if self._prev_config_list != config_list: - if config_dict["config_list"] == config_list: - if self._fidelity_search_count == 0: - self._config_track_last = self._config_track - self._fidelity_was_bool = True - self._fidelity_bool = True - self._fidelity_search_count += 1 - loss_prev = self._config_dict["config_" + str(self._config_track)][ - "loss" - ] - self._incumbent(loss=loss_prev) - config = config_dict["config"] - if isinstance(config, (int, float)): - self._config_track = int(config) - - if not self._fidelity_bool: - if len(self._prev_config_list) > 0: - if self._prev_config_list != config_list: - self._fidelity_search_count = 0 - if self._fidelity_was_bool: - loss_prev = self._config_dict[ - "config_" + str(self._config_track) - ]["loss"] - self._incumbent(loss=loss_prev) - self._config_track = self._config_track_last + 1 - self._fidelity_counter += 1 - self._config_dict.clear() - self._fidelity_was_bool = False - else: - loss_prev = self._config_dict[ - "config_" + str(self._config_track) - ]["loss"] - self._incumbent(loss=loss_prev) - self._config_track += 1 - self._writer_config.append( - SummaryWriter_( - log_dir=self._log_dir - + "/configs" - + "/config_" - + str(self._config_track) - ) - ) - else: - self._fidelity_bool = False - self._toggle_epoch_max_reached = False - - self._config_dict["config_" + str(self._config_track)] = { - "config_list": config_list, - "loss": float(loss), - "config": self._config_track, - } - - self._prev_config_list = config_list - - def write_scalar_configs( - self, config_list: list, current_epoch: int, loss: float, scalar: float, tag: str - ) -> None: - """ - Writes any scalar to the specific corresponding config, EX: Learning_rate decay tracking, Accuracy... - - Arguments: - conifg_list: a list (The configurations sved as a list in run_pipline and passed here as an argument) - current_epoch: an integer (The currecnt epoch running at the time) - loss: a float (The loss at the specific run, important for hypeband) - scalar: a float (The scalar value to be visualized) - tag: a string (The tag of the scalar EX: tag = 'Learning_Rate') - """ - if tag == "loss": - scalar = loss - - if loss is None or current_epoch is None or config_list is None: - raise ValueError( - "Loss, epochs, and max_epochs cannot be None. Please provide a valid value." - ) - - self._track_config(config_list=config_list, loss=loss) - - self._writer_config[self._config_track - 1].add_scalar( - tag="Config" + str(self._config_track) + "/" + tag, - scalar_value=scalar, - global_step=current_epoch, - ) - - def write_scalar_fidelity( - self, config_list: list, current_epoch: int, loss: float, Accuracy: bool = False - ) -> None: - """ - This function will take the each fidelity and show the accuracy or the loss during HPO search for each fidelity. - - Arguments: - conifg_list: a list (The configurations sved as a list in run_pipline and passed here as an argument) - current_epoch: an integer (The currecnt epoch running at the time) - loss: a float (The loss at the specific run, important for hypeband) - Accuracy: a bool (If true it will change the loss to accuracy % and display the results. - If false it will remain displaying with respect to the loss) - """ - if loss is None or current_epoch is None or config_list is None: - raise ValueError( - "Loss, epochs, and max_epochs cannot be None. Please provide a valid value." - ) - - self._track_config(config_list=config_list, loss=loss) - - if Accuracy: - acc = (1 - loss) * 100 - scalar_value = acc - else: - scalar_value = loss - - self._writer_config[self._config_track - 1].add_scalar( - tag="Summary" + "/Fidelity_" + str(self._fidelity_counter), - scalar_value=scalar_value, - global_step=current_epoch, - ) - - def write_histogram( - self, config_list: list, current_epoch: int, loss: float, model: nn.Module - ) -> None: - """ - By logging histograms for all parameters, you can gain insights into the distribution of different - parameter types and identify potential issues or patterns in their values. This comprehensive analysis - can help you better understand your model's behavior during training. - - Ex: Weights where their histograms do not show a change in shape from the first epoch up until the last prove to - mean that the training is not done properly and hence weights are not updated in the rythm they should - - Arguments: - conifg_list: a list (The configurations sved as a list in run_pipline and passed here as an argument) - current_epoch: an integer (The currecnt epoch running at the time) - loss: a float (The loss at the specific run, important for hypeband) - model: a nn.Module (The model which we want to analyze) - """ - if loss is None or current_epoch is None or config_list is None: - raise ValueError( - "Loss, epochs, and max_epochs cannot be None. Please provide a valid value." - ) - - self._track_config(config_list=config_list, loss=loss) - - for _, param in model.named_parameters(): - self._writer_config[self._config_track - 1].add_histogram( - "Config" + str(self._config_track), - param.clone().cpu().data.numpy(), - current_epoch, - ) - - def write_image( - self, - config_list: list, - max_epochs: int, - current_epoch: int, - loss: float, - image_input: torch.Tensor, - num_images: int = 10, - random_images: bool = False, - resize_images: np.array = None, - ignore_warning: bool = True, - ) -> None: - """ - The user is free on how they want to tackle image visualization on tensorboard, they specify the numebr of images - they want to show and if the images should be taken randomly or not. - - Arguments: - conifg_list: a list (The configurations sved as a list in run_pipline and passed here as an argument) - max_epochs: an integer (Maximum epoch that can be reached at that specific run) - current_epoch: an integer (The currecnt epoch running at the time) - loss: a float (The loss at the specific run) - image_imput: a Tensor (The input image in batch, shape: 12x3x28x28 'BxCxWxH') - num_images: an integer (The number of images ot be displayed for each config on tensorboard) - random_images: a bool (True is the images should be sampled randomly, False otherwise) - resize_images: an array (Resizing an the images to make them fit and be clearly visible on the grid) - ignore_warning: a bool (At the moment a warning is appearing, bug will be fixed later) - - Example code of displaying wrongly classified images: - - 1- In the trianing for loop: - predicted_labels = torch.argmax(output_of_model_after_input, dim=1) - misclassification_mask = predicted_labels != y_actual_labels - misclassified_images.append(x[misclassification_mask]) - - 2- Before the return, outside the training loop: - if len(misclassified_images) > 0: - misclassified_images = torch.cat(misclassified_images, dim=0) - - 3- Returning the misclassified images - return ..., misclassified_images - - Then use these misclassified_images as the image_input of this function - """ - if loss is None or current_epoch is None or config_list is None: - raise ValueError( - "Loss, epochs, and max_epochs cannot be None. Please provide a valid value." - ) - - self._track_config(config_list=config_list, loss=loss) - - if resize_images is None: - resize_images = [56, 56] - - if ignore_warning is True: - warnings.filterwarnings("ignore", category=DeprecationWarning) - - if current_epoch == max_epochs - 1: - if num_images > len(image_input): - num_images = len(image_input) - - if random_images is False: - subset_images = image_input[:num_images] - else: - random_indices = random.sample(range(len(image_input)), num_images) - subset_images = image_input[random_indices] - - resized_images = torch.nn.functional.interpolate( - subset_images, - size=(resize_images[0], resize_images[1]), - mode="bilinear", - align_corners=False, - ) - - nrow = int(resized_images.size(0) ** 0.75) - img_grid = self._make_grid(resized_images, nrow=nrow) - - self._writer_config[self._config_track - 1].add_image( - tag="IMG_config " + str(self._config_track), - img_tensor=img_grid, - global_step=self._config_track, - ) - - def write_hparam( - self, - config_list: list, - current_epoch: int, - loss: float, - Accuracy: bool = False, - **pipeline_space, - ) -> None: - """ - '.add_hparam' is a function in TensorBoard that allows you to log hyperparameters associated with your training run. - It takes a dictionary of hyperparameter names and values and associates them with the current run, making it easy to - compare and analyze different hyperparameter configurations. - - Arguments: - conifg_list: a list (The configurations sved as a list in run_pipline and passed here as an argument) - current_epoch: an integer (The currecnt epoch running at the time) - loss: a float (The loss at the specific run) - Accuracy: a bool (If true it will change the loss to accuracy % and display the results. - If false it will remain displaying with respect to the loss) - pipeline_space: The name of the hyperparameters in addition to their kwargs to be searched on. - """ - if loss is None or current_epoch is None or config_list is None: - raise ValueError( - "Loss, epochs, and max_epochs cannot be None. Please provide a valid value." - ) - - self._track_config(config_list=config_list, loss=loss) - - if Accuracy: - str_name = "Accuracy" - str_value = (1 - loss) * 100 - else: - str_name = "Loss" - str_value = loss - - values = {str_name: str_value} - - self._writer_config[self._config_track - 1].add_hparams( - pipeline_space, values, current_epoch - ) - - def close_writers(self) -> None: - """ - Closing the writers created after finishing all the tensorboard visualizations - """ - self._writer_summary.close() - for _, writer in enumerate(self._writer_config): - writer.close() diff --git a/pyproject.toml b/pyproject.toml index fa819ba0..fba3332d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,6 +70,10 @@ more-itertools = "^9.0.0" portalocker = "^2.6.0" seaborn = "^0.12.1" pyyaml = "^6.0" +tensorboard = [ + {version = "^2.11", python = "<3.8"}, + {version = "^2.13", python = ">=3.8"} +] [tool.poetry.group.dev.dependencies] pre-commit = "^2.10" diff --git a/src/metahyper/api.py b/src/metahyper/api.py index 1babcdec..0da60626 100644 --- a/src/metahyper/api.py +++ b/src/metahyper/api.py @@ -12,6 +12,8 @@ from pathlib import Path from typing import Any +from neps.plot.tensorboard_eval import tblogger + from ._locker import Locker from .utils import YamlSerializer, find_files, non_empty_file @@ -391,6 +393,16 @@ def run( pipeline_directory, previous_pipeline_directory, ) = _sample_config(optimization_dir, sampler, serializer, logger) + # Take the config data in case tensorboard is to be used. + if tblogger.logger_init_bool or tblogger.logger_bool: + tblogger.config_track_init_api( + config_id=config_id, + config=config, + config_working_directory=pipeline_directory, + config_previous_directory=previous_pipeline_directory, + optim_path=optimization_dir, + ) + tblogger.logger_init_bool = False config_lock_file = pipeline_directory / ".config_lock" config_lock_file.touch(exist_ok=True) diff --git a/src/neps/api.py b/src/neps/api.py index 721919d2..6eb97bd9 100644 --- a/src/neps/api.py +++ b/src/neps/api.py @@ -15,6 +15,7 @@ from metahyper import instance_from_map from .optimizers import BaseOptimizer, SearcherMapping +from .plot.tensorboard_eval import tblogger from .search_spaces.parameter import Parameter from .search_spaces.search_space import SearchSpace, pipeline_space_from_configspace from .utils.result_utils import get_loss @@ -82,8 +83,14 @@ def write_loss_and_config(file_handle, loss_, config_id_, config_): f"Finished evaluating config {config_id}" f" -- new best with loss {float(loss) :.3f}" ) + if tblogger.logger_bool: + tblogger.tracking_incumbent_api(best_loss=loss) + else: logger.info(f"Finished evaluating config {config_id}") + # Track the incumbent from the best loss + if tblogger.logger_bool: + tblogger.tracking_incumbent_api(best_loss=best_loss) return _post_evaluation_hook diff --git a/src/neps/plot/tensorboard_eval.py b/src/neps/plot/tensorboard_eval.py new file mode 100644 index 00000000..18dca4c3 --- /dev/null +++ b/src/neps/plot/tensorboard_eval.py @@ -0,0 +1,537 @@ +import math +import os +import random +import warnings +from typing import List, Optional, Union + +import torch +import torch.nn as nn +from torch.utils.tensorboard import SummaryWriter +from torch.utils.tensorboard.summary import hparams + + +class SummaryWriter_(SummaryWriter): + """ + This class inherits from the base SummaryWriter class and provides modifications to improve the logging. + It simplifies the logging structure and ensures consistent tag formatting for metrics. + + Changes Made: + - Avoids creating unnecessary subfolders in the log directory. + - Ensures all logs are stored in the same 'tfevent' directory for better organization. + - Updates metric keys to have a consistent 'Summary/' prefix for clarity. + - Improves the display of 'Loss' or 'Accuracy' on the Summary file. + + Methods: + - add_hparams: Overrides the base method to log hyperparameters and metrics with better formatting. + """ + + def add_hparams(self, hparam_dict, metric_dict, global_step): + if not isinstance(hparam_dict, dict) or not isinstance(metric_dict, dict): + raise TypeError("hparam_dict and metric_dict should be dictionary.") + updated_metric = {} + for key, value in metric_dict.items(): + updated_key = "Summary" + "/" + key + updated_metric[updated_key] = value + exp, ssi, sei = hparams(hparam_dict, updated_metric) + + self.file_writer.add_summary(exp) + self.file_writer.add_summary(ssi) + self.file_writer.add_summary(sei) + for k, v in updated_metric.items(): + self.add_scalar(tag=k, scalar_value=v, global_step=global_step) + + +class tblogger: + config = None + config_id: Optional[int] = None + config_working_directory = None + config_previous_directory = None + optim_path = None + + config_value_fid: Optional[str] = None + fidelity_mode: bool = False + + logger_init_bool: bool = True + logger_bool: bool = False + + image_logger: bool = False + image_value: Optional[torch.tensor] = None + image_name: Optional[str] = None + epoch_value: Optional[int] = None + + disable_logging: bool = False + + loss: Optional[float] = None + current_epoch: int + scalar_accuracy_mode: bool = False + hparam_accuracy_mode: bool = False + + config_writer: Optional[SummaryWriter_] = None + summary_writer: Optional[SummaryWriter_] = None + + logging_mode: list = [] + + @staticmethod + def config_track_init_api( + config_id, config, config_working_directory, config_previous_directory, optim_path + ): + """ + Track the Configuration space data from the way it is done on neps metahyper '_sample_config' to keep insinc with + config ids and directories NePS is operating on. + """ + + tblogger.config = config + tblogger.config_id = config_id + tblogger.config_working_directory = config_working_directory + tblogger.config_previous_directory = config_previous_directory + tblogger.optim_path = optim_path + + @staticmethod + def _initialize_writers(): + if not tblogger.config_writer: + optim_config_path = tblogger.optim_path / "results" + if tblogger.config_previous_directory is not None: + tblogger.fidelity_mode = True + while not tblogger.config_writer: + if os.path.exists(tblogger.config_previous_directory / "tbevents"): + find_previous_config_id = ( + tblogger.config_working_directory / "previous_config.id" + ) + if os.path.exists(find_previous_config_id): + with open(find_previous_config_id) as file: + contents = file.read() + tblogger.config_value_fid = contents + tblogger.config_writer = SummaryWriter_( + tblogger.config_previous_directory / "tbevents" + ) + else: + find_previous_config_path = ( + tblogger.config_previous_directory / "previous_config.id" + ) + if os.path.exists(find_previous_config_path): + with open(find_previous_config_path) as file: + contents = file.read() + tblogger.config_value_fid = contents + tblogger.config_working_directory = ( + tblogger.config_previous_directory + ) + tblogger.config_previous_directory = ( + optim_config_path / f"config_{contents}" + ) + else: + tblogger.fidelity_mode = False + tblogger.config_writer = SummaryWriter_( + tblogger.config_working_directory / "tbevents" + ) + + @staticmethod + def _make_grid(images: torch.tensor, nrow: int, padding: int = 2): + """ + Create a grid of images from a batch of images. + + Args: + images (torch.Tensor): The input batch of images with shape (batch_size, num_channels, height, width). + nrow (int): The number rows on the grid. + padding (int, optional): The padding between images in the grid. Default is 2. + + Returns: + torch.Tensor: A grid of images with shape (num_channels, total_height, total_width), + where total_height and total_width depend on the number of images and the grid settings. + """ + batch_size, num_channels, height, width = images.size() + x_mapping = min(nrow, batch_size) + y_mapping = int(math.ceil(float(batch_size) / x_mapping)) + height, width = height + 2, width + 2 + + grid = torch.zeros( + (num_channels, height * y_mapping + padding, width * x_mapping + padding) + ) + + k = 0 + for y in range(y_mapping): + for x in range(x_mapping): + if k >= batch_size: + break + image = images[k] + grid[ + :, + y * height + padding : y * height + padding + height - padding, + x * width + padding : x * width + padding + width - padding, + ] = image + k += 1 + + return grid + + @staticmethod + def scalar_logging(value: float) -> list: + """ + Prepare a scalar value for logging. + + Args: + value (float): The scalar value to be logged. + + Returns: + list: A list containing the logging mode and the value for logging. + The list format is [logging_mode, value]. + """ + logging_mode = "scalar" + return [logging_mode, value] + + @staticmethod + def image_logging( + img_tensor: torch.Tensor, + counter: int, + resize_images: Optional[List[Optional[int]]] = None, + ignore_warning: bool = True, + random_images: bool = True, + num_images: int = 20, + ) -> List[Union[str, torch.Tensor, int, bool, List[Optional[int]]]]: + """ + Prepare an image tensor for logging. + + Args: + img_tensor (torch.Tensor): The image tensor to be logged. + counter (int): A counter value for teh frequency of image logging (ex: counter 2 means for every + 2 epochs a new set of images are logged). + resize_images (list of int): A list of integers representing the image sizes + after resizing or None if no resizing required. + Default is None. + ignore_warning (bool, optional): Whether to ignore any warning during logging. Default is True. + random_images (bool, optional): Whether the images are selected randomly. Default is True. + num_images (int, optional): The number of images to log. Default is 20. + + Returns: + list: A list containing the logging mode and all the necessary parameters for image logging. + The list format is [logging_mode, img_tensor, counter, repetitive, resize_images, + ignore_warning, random_images, num_images]. + """ + logging_mode = "image" + return [ + logging_mode, + img_tensor, + counter, + resize_images, + ignore_warning, + random_images, + num_images, + ] + + @staticmethod + def layer_gradient_logging(model: nn.Module): + """ + Prepare a model for logging layer gradients. + + Args: + model (nn.Module): The PyTorch model for which layer gradients will be logged. + + Returns: + list: A list containing the logging mode and the model for layer gradient logging. + The list format is [logging_mode, model]. + """ + logging_mode = "gradient_mean" + return [logging_mode, model] + + @staticmethod + def _file_arrange(): + # TODO: Have only one tfevent file in the respective folders instead of multiple (especially in the summary folder) + pass + + @staticmethod + def _write_scalar_config(tag: str, value: Union[float, int]): + """ + Write scalar values to the TensorBoard log. + + Args: + tag (str): The tag for the scalar value. + value (float or int): The scalar value to be logged. Default is None. + + Note: + If the tag is 'Loss' and scalar_accuracy_mode is True, the tag will be changed to 'Accuracy', + and the value will be transformed accordingly. + + The function relies on the initialize_config_writer to ensure the TensorBoard writer is initialized at + the correct directory. + + It also depends on the following global variables: + - tblogger.scalar_accuracy_mode (bool) + - tblogger.fidelity_mode (bool) + - tblogger.config_writer (SummaryWriter_) + + The function will log the scalar value under different tags based on fidelity mode and other configurations. + """ + tblogger._initialize_writers() + + if tag == "Loss": + if tblogger.scalar_accuracy_mode: + tag = "Accuracy" + value = (1 - value) * 100 + if tblogger.config_writer is not None: + if tblogger.fidelity_mode: + tblogger.config_writer.add_scalar( + tag="Config_" + str(tblogger.config_value_fid) + "/" + tag, + scalar_value=value, + global_step=tblogger.current_epoch, + ) + else: + tblogger.config_writer.add_scalar( + tag="Config_" + str(tblogger.config_id) + "/" + tag, + scalar_value=value, + global_step=tblogger.current_epoch, + ) + + @staticmethod + def _write_image_config( + tag: str, + image: torch.tensor, + counter: int, + resize_images: Optional[List[Optional[int]]] = None, + ignore_warning: bool = True, + random_images: bool = True, + num_images: int = 20, + ): + """ + Write images to the TensorBoard log. + + Args: + tag (str): The tag for the images. + image (torch.Tensor): The image tensor to be logged. + counter (int): A counter value associated with the images. + resize_images (list of int): A list of integers representing the image sizes + after resizing or None if no resizing required. + Default is None. + ignore_warning (bool, optional): Whether to ignore any warning during logging. Default is True. + random_images (bool, optional): Whether the images are selected randomly. Default is True. + num_images (int, optional): The number of images to log. Default is 20. + + Note: + The function relies on the initialize_config_writer to ensure the TensorBoard writer is initialized at + the correct directory. + + It also depends on the following global variables: + - tblogger.current_epoch (int) + - tblogger.fidelity_mode (bool) + - tblogger.config_writer (SummaryWriter_) + - tblogger.config_value_fid (int or None) + - tblogger.config_id (int) + + The function will log a subset of images to TensorBoard based on the given configurations. + """ + tblogger._initialize_writers() + + if resize_images is None: + resize_images = [32, 32] + + if ignore_warning is True: + warnings.filterwarnings("ignore", category=DeprecationWarning) + + if tblogger.current_epoch % counter == 0: + if num_images > len(image): + num_images = len(image) + + if random_images is False: + subset_images = image[:num_images] + else: + random_indices = random.sample(range(len(image)), num_images) + subset_images = image[random_indices] + + resized_images = torch.nn.functional.interpolate( + subset_images, + size=(resize_images[0], resize_images[1]), + mode="bilinear", + align_corners=False, + ) + + nrow = int(resized_images.size(0) ** 0.75) + img_grid = tblogger._make_grid(resized_images, nrow=nrow) + if tblogger.config_writer is not None: + if tblogger.fidelity_mode: + tblogger.config_writer.add_image( + tag="Config_" + str(tblogger.config_value_fid) + "/" + tag, + img_tensor=img_grid, + global_step=tblogger.current_epoch, + ) + else: + tblogger.config_writer.add_image( + tag="Config_" + str(tblogger.config_id) + "/" + tag, + img_tensor=img_grid, + global_step=tblogger.current_epoch, + ) + + @staticmethod + def _write_hparam_config(): + """ + Write hyperparameter configurations to the TensorBoard log, inspired by the 'hparam' original function of tensorboard. + + Note: + The function relies on the initialize_config_writer to ensure the TensorBoard writer is initialized at + the correct directory. + + It also depends on the following global variables: + - tblogger.hparam_accuracy_mode (bool) + - tblogger.loss (float) + - tblogger.config_writer (SummaryWriter_) + - tblogger.config (dict) + - tblogger.current_epoch (int) + + The function will log hyperparameter configurations along with a metric value (either accuracy or loss) + to TensorBoard based on the given configurations. + """ + tblogger._initialize_writers() + + if tblogger.hparam_accuracy_mode: + str_name = "Accuracy" + str_value = (1 - tblogger.loss) * 100 + else: + str_name = "Loss" + str_value = tblogger.loss + + values = {str_name: str_value} + if tblogger.config_writer is not None: + tblogger.config_writer.add_hparams( + hparam_dict=tblogger.config, + metric_dict=values, + global_step=tblogger.current_epoch, + ) + + @staticmethod + def tracking_incumbent_api(best_loss): + """ + Track the incumbent (best) loss and log it in the TensorBoard summary. + + Args: + best_loss (float): The best loss value to be tracked, according to the _post_hook_function of NePS. + + Note: + The function relies on the following global variables: + - tblogger.config_writer (SummaryWriter_) + - tblogger.optim_path (str) + - tblogger.incum_tracker (int) + - tblogger.incum_val (float) + - tblogger.summary_writer (SummaryWriter_) + + The function logs the incumbent loss in a TensorBoard summary with a graph. + It increments the incumbent tracker based on occurrences of "Config ID" in the 'all_losses_and_configs.txt' file. + """ + if tblogger.config_writer: + tblogger.config_writer.close() + tblogger.config_writer = None + + file_path = str(tblogger.optim_path) + "/all_losses_and_configs.txt" + tblogger.incum_tracker = 0 + with open(file_path) as f: + for line in f: + tblogger.incum_tracker += line.count("Config ID") + + tblogger.incum_val = float(best_loss) + + logdir = str(tblogger.optim_path) + "/summary" + + if tblogger.summary_writer is None: + tblogger.summary_writer = SummaryWriter_(logdir) + + tblogger.summary_writer.add_scalar( + tag="Summary" + "/Incumbent_graph", + scalar_value=tblogger.incum_val, + global_step=tblogger.incum_tracker, + ) + + tblogger.summary_writer.flush() + tblogger.summary_writer.close() + + @staticmethod + def disable(disable_logger: bool = True): + """ + The function allows for enabling or disabling the logger functionality + throughout the program execution by updating the value of 'tblogger.disable_logging'. + When the logger is disabled, it will not perform any logging operations. + + Args: + disable_logger (bool, optional): A boolean flag to control the logger. + If True (default), the logger will be disabled. + If False, the logger will be enabled. + + Example: + # Disable the logger + tblogger.disable() + + # Enable the logger + tblogger.disable(False) + """ + tblogger.disable_logging = disable_logger + + @staticmethod + def log( + loss: float, + current_epoch: int, + writer_scalar: bool = True, + writer_hparam: bool = True, + scalar_accuracy_mode: bool = False, + hparam_accuracy_mode: bool = False, + data: Optional[dict] = None, + ): + """ + Log experiment data to the logger, including scalar values, hyperparameters, images, and layer gradients. + + Args: + loss (float): The current loss value in training. + current_epoch (int): The current epoch of the experiment. + writer_scalar (bool, optional): Whether to write the loss or accuracy for the + configs during training. Default is True. + writer_hparam (bool, optional): Whether to write hyperparameters logging + of the configs during training. Default is True. + scalar_accuracy_mode (bool, optional): If True, interpret the 'loss' as 'accuracy' and transform it's + value accordingliy. Default is False. + hparam_accuracy_mode (bool, optional): If True, interpret the 'loss' as 'accuracy' and transform it's + value accordingliy. Default is False. + data (dict, optional): Additional experiment data to be logged. It should be in the format: + { + 'tag1': tblogger.scalar_logging(value=value1), + 'tag2': tblogger.image_logging(img_tensor=img, counter=2), + 'tag3': tblogger.layer_gradient_logging(model=model), + } + Default is None. + + """ + tblogger.current_epoch = current_epoch + tblogger.loss = loss + tblogger.scalar_accuracy_mode = scalar_accuracy_mode + tblogger.hparam_accuracy_mode = hparam_accuracy_mode + + if not tblogger.disable_logging: + tblogger.logger_bool = True + + if writer_scalar: + tblogger._write_scalar_config(tag="Loss", value=loss) + + if writer_hparam: + tblogger._write_hparam_config() + + if data is not None: + for key in data: + if data[key][0] == "scalar": + tblogger._write_scalar_config(tag=str(key), value=data[key][1]) + + elif data[key][0] == "image": + tblogger._write_image_config( + tag=str(key), + image=data[key][1], + counter=data[key][2], + resize_images=data[key][3], + ignore_warning=data[key][4], + random_images=data[key][5], + num_images=data[key][6], + ) + + elif data[key][0] == "gradient_mean": + for i, layer in enumerate(data[key][1].children()): + layer_gradients = [param.grad for param in layer.parameters()] + if layer_gradients: + mean_gradient = torch.mean( + torch.cat([grad.view(-1) for grad in layer_gradients]) + ) + tblogger._write_scalar_config( + tag=f"{key}_gradient_{i}", value=mean_gradient.item() + ) + + else: + tblogger.logger_bool = False From 7a83abb03451f11a62bdbc47ad63cf3132725f76 Mon Sep 17 00:00:00 2001 From: TarekAbouChakra Date: Thu, 17 Aug 2023 09:21:07 +0200 Subject: [PATCH 2/6] tensorboard example and class minor changes --- .../convenience/neps_tblogger_tutorial.py | 296 +++++++----------- src/metahyper/api.py | 3 + src/neps/plot/tensorboard_eval.py | 55 ++-- 3 files changed, 143 insertions(+), 211 deletions(-) diff --git a/neps_examples/convenience/neps_tblogger_tutorial.py b/neps_examples/convenience/neps_tblogger_tutorial.py index 24fbc108..cbb10618 100644 --- a/neps_examples/convenience/neps_tblogger_tutorial.py +++ b/neps_examples/convenience/neps_tblogger_tutorial.py @@ -1,19 +1,47 @@ """ NePS tblogger With TensorBoard ==================================== -This tutorial demonstrates how to use TensorBoard plugin with NePS tblogger class -to detect performance data of the different model configurations during training. +1- Introduction +--------------- +Welcome to the NePS tblogger with TensorBoard tutorial! This guide will walk you +through the process of using the NePS tblogger class to effectively monitor and +analyze performance data from various model configurations during training. -Setup ------ -To install ``torchvision`` and ``tensorboard`` use the following command: +Assuming you already have an experience in NePS the main reason of creating this tutorial is to showcase the +power of visualization using tblogger. if you wish to directly reach that part, check the lines +between 244-264 or search for 'Start Tensorboard Logging' -.. code-block:: +2- Learning Objectives +---------------------- + +By completing this tutorial, you will: + +- Understand the role of NePS tblogger and its importance in HPO and NAS. +- Learn how to define search spaces within NePS to explore different model configurations. +- Build a comprehensive run pipeline to train and evaluate models. +- Utilize TensorBoard to visualize and compare performance metrics of different model configurations. + +3- Setup +-------- + +Before we dive in, make sure you have the necessary dependencies installed. If you haven't already, +install the ``NePS`` package using the following command: + +```bash + + pip install neural-pipeline-search + +Additionally, please note that NePS does not include ``torchvision`` as a dependency. +You can install it with the following command: + +```bash pip install torchvision +These dependencies will ensure you have everything you need to follow along with this tutorial successfully. """ + import argparse import logging import os @@ -35,7 +63,7 @@ from neps.plot.tensorboard_eval import tblogger """ -Steps: +Steps for a successful training pipeline: #1 Define the seeds for reproducibility. #2 Prepare the input data. @@ -44,10 +72,12 @@ #5 Design the run pipeline function. #6 Use neps.run the run the entire search using your specified searcher. +Each step will be covered in detail thourghout the code + """ ############################################################# -# Definig the seeds for reproducibility +# 1 Definig the seeds for reproducibility def set_seed(seed=123): @@ -57,7 +87,7 @@ def set_seed(seed=123): ############################################################# -# Prepare the input data. For this tutorial we use the MNIST dataset. +# 2 Prepare the input data. For this tutorial we use the MNIST dataset. def MNIST( @@ -86,7 +116,7 @@ def MNIST( ############################################################# -# Design small MLP model to be able to represent the input data. +# 3 Design small MLP model to be able to represent the input data. class MLP(nn.Module): @@ -98,6 +128,7 @@ def __init__(self) -> None: self.linear3 = nn.Linear(in_features=196, out_features=10) def forward(self, x): + # Flattening the grayscaled image from 1x28x28 (CxWxH) to 784. x = x.view(x.size(0), -1) x = self.relu(self.linear1(x)) x = self.relu(self.linear2(x)) @@ -107,7 +138,7 @@ def forward(self, x): ############################################################# -# Define the training step and return the validation error and misclassified images. +# 4 Define the training step and return the validation error and misclassified images. def loss_ev(model: nn.Module, data_loader: DataLoader) -> float: @@ -127,7 +158,7 @@ def loss_ev(model: nn.Module, data_loader: DataLoader) -> float: def training(model, optimizer, criterion, train_loader, validation_loader): """ - Function that trains the model for one epoch and evaluates the model on the validation set. Used by the searcher. + Function that trains the model for one epoch and evaluates the model on the validation set. Args: model (nn.Module): Model to be trained. @@ -161,10 +192,9 @@ def training(model, optimizer, criterion, train_loader, validation_loader): ############################################################# -# Design the pipeline search spaces. +# 5 Design the pipeline search spaces. -# For BO: def pipeline_space_BO() -> dict: pipeline = dict( lr=neps.FloatParameter(lower=1e-5, upper=1e-1, log=True), @@ -175,23 +205,10 @@ def pipeline_space_BO() -> dict: return pipeline -# For Hyperband -def pipeline_space_Hyperband() -> dict: - pipeline = dict( - lr=neps.FloatParameter(lower=1e-5, upper=1e-1, log=True), - optim=neps.CategoricalParameter(choices=["Adam", "SGD"]), - weight_decay=neps.FloatParameter(lower=1e-4, upper=1e-1, log=True), - epochs=neps.IntegerParameter(lower=1, upper=9, is_fidelity=True), - ) - - return pipeline - - ############################################################# -# Implement the pipeline run search. +# 6 Implement the pipeline run search. -# For BO: def run_pipeline_BO(lr, optim, weight_decay): model = MLP() @@ -211,8 +228,6 @@ def run_pipeline_BO(lr, optim, weight_decay): criterion = nn.CrossEntropyLoss() losses = [] - tblogger.disable(False) - for i in range(max_epochs): loss, miss_img = training( optimizer=optimizer, @@ -223,16 +238,38 @@ def run_pipeline_BO(lr, optim, weight_decay): ) losses.append(loss) + # Gathering the gradient mean in each layer to display some of them in tensorboard + mean_gradient = [] + for layer in model.children(): + layer_gradients = [param.grad for param in layer.parameters()] + if layer_gradients: + mean_gradient.append( + torch.mean(torch.cat([grad.view(-1) for grad in layer_gradients])) + ) + + ###################### Start Tensorboard Logging ###################### + + # tblogger for neps config loggings. This line will result in the following: + + # 1 Incumbent of the configs (best performance regardless of fiedlity budget if the searcher was fidelity depenedent). + # 2 Loss curves of each of the configsat each epochs. + # 3 lr_decay curve at each epoch. + # 4 miss_img which represents the wrongly classified images by the model according the the counter. + # 5 first two layer_gradients computed above and passed as scalar configs. + tblogger.log( loss=loss, current_epoch=i, data={ "lr_decay": tblogger.scalar_logging(value=scheduler.get_last_lr()[0]), "miss_img": tblogger.image_logging(img_tensor=miss_img, counter=2), - "layer_gradient": tblogger.layer_gradient_logging(model=model), + "layer_gradient1": tblogger.scalar_logging(value=mean_gradient[0]), + "layer_gradient2": tblogger.scalar_logging(value=mean_gradient[1]), }, ) + ###################### End Tensorboard Logging ###################### + scheduler.step() print(f" Epoch {i + 1} / {max_epochs} Val Error: {loss} ") @@ -251,174 +288,81 @@ def run_pipeline_BO(lr, optim, weight_decay): } -# For Hyperband -def run_pipeline_Hyperband(pipeline_directory, previous_pipeline_directory, **configs): - model = MLP() - checkpoint_name = "checkpoint.pth" - start_epoch = 0 +############################################################# +# 6 Running neps with BO as our main searcher, saving the results in a defined directory. - train_loader, validation_loader, test_loader = MNIST( - batch_size=32, n_train=4096, n_valid=512 +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--max_evaluations_total", + type=int, + default=10, + help="Number of different configs to train", ) + args = parser.parse_args() - # define loss - criterion = nn.CrossEntropyLoss() - - # Define the optimizer - if configs["optim"] == "Adam": - optimizer = torch.optim.Adam( - model.parameters(), lr=configs["lr"], weight_decay=configs["weight_decay"] - ) - elif configs["optim"] == "SGD": - optimizer = torch.optim.SGD( - model.parameters(), lr=configs["lr"], weight_decay=configs["weight_decay"] - ) - - scheduler = lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.75) - - # We make use of checkpointing to resume training models on higher fidelities - if previous_pipeline_directory is not None: - # Read in state of the model after the previous fidelity rung - checkpoint = torch.load(previous_pipeline_directory / checkpoint_name) - model.load_state_dict(checkpoint["model_state_dict"]) - optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) - epochs_previously_spent = checkpoint["epoch"] - else: - epochs_previously_spent = 0 + """ + When running this code without any arguments, it will by default run bayesian optimization with 10 max evaluations + of 9 epochs each: - start_epoch += epochs_previously_spent + ```bash: - losses = list() + python neps_tblogger_tutorial.py + """ - tblogger.disable(False) + start_time = time.time() - epochs = configs["epochs"] + set_seed(112) + logging.basicConfig(level=logging.INFO) - for epoch in range(start_epoch, epochs): - # Call the training function, get the validation errors and append them to val errors - loss, miss_img = training( - model, optimizer, criterion, train_loader, validation_loader - ) - losses.append(loss) + if os.path.exists("results/bayesian_optimization"): + shutil.rmtree("results/bayesian_optimization") - tblogger.log( - loss=loss, - current_epoch=epoch, - hparam_accuracy_mode=True, - data={ - "lr_decay": tblogger.scalar_logging(value=scheduler.get_last_lr()[0]), - "miss_img": tblogger.image_logging(img_tensor=miss_img, counter=2), - "layer_gradient": tblogger.layer_gradient_logging(model=model), - }, - ) + """ + For showcasing purposes. After completing the first run, one can uncomment the line below + and continue the search via: - scheduler.step() + ```bash: - print(f" Epoch {epoch + 1} / {epochs} Val Error: {loss} ") + python neps_tblogger_tutorial.py --max_evaluations_total 15 - train_accuracy = loss_ev(model, train_loader) - test_accuracy = loss_ev(model, test_loader) + This would result in continuing the search for 5 new different configurations in addition + to disabling the logging, hence tblogger can always be disabled using the line below. - torch.save( - { - "epoch": epochs, - "model_state_dict": model.state_dict(), - "optimizer_state_dict": optimizer.state_dict(), - }, - pipeline_directory / checkpoint_name, - ) + ```code: - return { - "loss": loss, - "info_dict": { - "train_accuracy": train_accuracy, - "test_accuracy": test_accuracy, - "val_errors": losses, - "cost": epochs - epochs_previously_spent, - }, - "cost": epochs - epochs_previously_spent, - } + tblogger.disable() + """ -############################################################# -""" -Defining the main with argument parsing to use either BO or Hyperband and specifying their -respective properties -""" + # tblogger.disable() -if __name__ == "__main__": - argParser = argparse.ArgumentParser() - argParser.add_argument( - "--searcher", - type=str, - choices=["bayesian_optimization", "hyperband"], - default="bayesian_optimization", - help="Searcher type used", - ) - argParser.add_argument( - "--max_cost_total", type=int, default=30, help="Max cost used for Hyperband" - ) - argParser.add_argument( - "--max_evaluations_total", type=int, default=10, help="Max evaluation used for BO" + neps.run( + run_pipeline=run_pipeline_BO, + pipeline_space=pipeline_space_BO(), + root_directory="bayesian_optimization", + max_evaluations_total=args.max_evaluations_total, + searcher="bayesian_optimization", ) - args = argParser.parse_args() - - if args.searcher == "hyperband": - start_time = time.time() - set_seed(112) - logging.basicConfig(level=logging.INFO) - if os.path.exists("results/hyperband"): - shutil.rmtree("results/hyperband") - neps.run( - run_pipeline=run_pipeline_Hyperband, - pipeline_space=pipeline_space_Hyperband(), - root_directory="hyperband", - max_cost_total=args.max_cost_total, - searcher="hyperband", - ) - - """ - To check live plots during this command run, please open a new terminal with the directory of this saved project and run - - tensorboard --logdir hyperband - """ - - end_time = time.time() # Record the end time - execution_time = end_time - start_time - print(f"Execution time: {execution_time} seconds") - - elif args.searcher == "bayesian_optimization": - start_time = time.time() - set_seed(112) - logging.basicConfig(level=logging.INFO) - if os.path.exists("results/bayesian_optimization"): - shutil.rmtree("results/bayesian_optimization") - neps.run( - run_pipeline=run_pipeline_BO, - pipeline_space=pipeline_space_BO(), - root_directory="bayesian_optimization", - max_evaluations_total=args.max_evaluations_total, - searcher="bayesian_optimization", - ) - """ - To check live plots during this command run, please open a new terminal with the directory of this saved project and run - - tensorboard --logdir bayesian_optimization - """ + """ + To check live plots during this search, please open a new terminal and make sure to be at the same level directory + of your project and run this commant on the file created by neps search algorithm. - end_time = time.time() # Record the end time - execution_time = end_time - start_time - print(f"Execution time: {execution_time} seconds") + ```bash: - """ - When running this code without any arguments, it will by default run bayesian optimization with 10 max evaluations - of 9 epochs each: + tensorboard --logdir bayesian_optimization - python neps_tblogger_tutorial.py + To be able to check the visualization of tensorboard make sure to follow the local link provided. + ```bash: - If you wish to do this run with hyperband searcher with default max cost total of 30. Please run this command on the terminal: + http://localhost:6006/ - python neps_tblogger_tutorial.py --searcher hyperband + If nothing was visualized and you followed the tutorial exactly, there could have been an error in passing the correct + directory, please double check. Tensorboard will always run in the command line without checking if the directory exists. """ + + end_time = time.time() # Record the end time + execution_time = end_time - start_time + logging.info(f"Execution time: {execution_time} seconds") diff --git a/src/metahyper/api.py b/src/metahyper/api.py index 0da60626..4e410fcc 100644 --- a/src/metahyper/api.py +++ b/src/metahyper/api.py @@ -395,6 +395,9 @@ def run( ) = _sample_config(optimization_dir, sampler, serializer, logger) # Take the config data in case tensorboard is to be used. if tblogger.logger_init_bool or tblogger.logger_bool: + # A trick to enter the condition once if tblogger is not used and always + # if it is, necessary to log the first config. (need to save the first config + # then check if tblogger is used during training in the run_pipeline.) tblogger.config_track_init_api( config_id=config_id, config=config, diff --git a/src/neps/plot/tensorboard_eval.py b/src/neps/plot/tensorboard_eval.py index 18dca4c3..185bd1a2 100644 --- a/src/neps/plot/tensorboard_eval.py +++ b/src/neps/plot/tensorboard_eval.py @@ -5,7 +5,6 @@ from typing import List, Optional, Union import torch -import torch.nn as nn from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard.summary import hparams @@ -76,7 +75,7 @@ def config_track_init_api( config_id, config, config_working_directory, config_previous_directory, optim_path ): """ - Track the Configuration space data from the way it is done on neps metahyper '_sample_config' to keep insinc with + Track the Configuration space data from the way handled by neps metahyper '_sample_config' to keep in sync with config ids and directories NePS is operating on. """ @@ -89,15 +88,21 @@ def config_track_init_api( @staticmethod def _initialize_writers(): if not tblogger.config_writer: + # If the writer is still not assgined optim_config_path = tblogger.optim_path / "results" if tblogger.config_previous_directory is not None: + # If a previous directory is available (Now the search is done for higher fidelity but logging is + # saved on the previous directory) tblogger.fidelity_mode = True while not tblogger.config_writer: if os.path.exists(tblogger.config_previous_directory / "tbevents"): + # If the previous directory was actually the first fidelity, + # tbevents is the folder holding the logging event files "tfevent" find_previous_config_id = ( tblogger.config_working_directory / "previous_config.id" ) if os.path.exists(find_previous_config_id): + # Get the ID of the previous config to log on the new train data with open(find_previous_config_id) as file: contents = file.read() tblogger.config_value_fid = contents @@ -105,6 +110,9 @@ def _initialize_writers(): tblogger.config_previous_directory / "tbevents" ) else: + # If the directory does not have the writer created, + # find the previous config and keep on looping backward until locating + # the inital config holding the tfevent files find_previous_config_path = ( tblogger.config_previous_directory / "previous_config.id" ) @@ -119,6 +127,7 @@ def _initialize_writers(): optim_config_path / f"config_{contents}" ) else: + # If no fidelities are there, define the writer via the normal config_id tblogger.fidelity_mode = False tblogger.config_writer = SummaryWriter_( tblogger.config_working_directory / "tbevents" @@ -216,26 +225,6 @@ def image_logging( num_images, ] - @staticmethod - def layer_gradient_logging(model: nn.Module): - """ - Prepare a model for logging layer gradients. - - Args: - model (nn.Module): The PyTorch model for which layer gradients will be logged. - - Returns: - list: A list containing the logging mode and the model for layer gradient logging. - The list format is [logging_mode, model]. - """ - logging_mode = "gradient_mean" - return [logging_mode, model] - - @staticmethod - def _file_arrange(): - # TODO: Have only one tfevent file in the respective folders instead of multiple (especially in the summary folder) - pass - @staticmethod def _write_scalar_config(tag: str, value: Union[float, int]): """ @@ -325,7 +314,10 @@ def _write_image_config( warnings.filterwarnings("ignore", category=DeprecationWarning) if tblogger.current_epoch % counter == 0: + # Log every multiple of "counter" + if num_images > len(image): + # Be safe if the number of images is not as the len (as in the batch size) num_images = len(image) if random_images is False: @@ -340,7 +332,7 @@ def _write_image_config( mode="bilinear", align_corners=False, ) - + # Create the grid according to the number of images and log the grid to tensorboard. nrow = int(resized_images.size(0) ** 0.75) img_grid = tblogger._make_grid(resized_images, nrow=nrow) if tblogger.config_writer is not None: @@ -379,6 +371,7 @@ def _write_hparam_config(): tblogger._initialize_writers() if tblogger.hparam_accuracy_mode: + # Changes the loss to accuracy and logs in accuracy terms. str_name = "Accuracy" str_value = (1 - tblogger.loss) * 100 else: @@ -413,6 +406,7 @@ def tracking_incumbent_api(best_loss): It increments the incumbent tracker based on occurrences of "Config ID" in the 'all_losses_and_configs.txt' file. """ if tblogger.config_writer: + # Close all the previous config writers tblogger.config_writer.close() tblogger.config_writer = None @@ -420,6 +414,8 @@ def tracking_incumbent_api(best_loss): tblogger.incum_tracker = 0 with open(file_path) as f: for line in f: + # Count the amount of presence of "Config ID" because it correlates to the + # step size of how many configurations were completed. tblogger.incum_tracker += line.count("Config ID") tblogger.incum_val = float(best_loss) @@ -474,7 +470,7 @@ def log( Args: loss (float): The current loss value in training. - current_epoch (int): The current epoch of the experiment. + current_epoch (int): The current epoch of the experiment. Used as the global step. writer_scalar (bool, optional): Whether to write the loss or accuracy for the configs during training. Default is True. writer_hparam (bool, optional): Whether to write hyperparameters logging @@ -522,16 +518,5 @@ def log( num_images=data[key][6], ) - elif data[key][0] == "gradient_mean": - for i, layer in enumerate(data[key][1].children()): - layer_gradients = [param.grad for param in layer.parameters()] - if layer_gradients: - mean_gradient = torch.mean( - torch.cat([grad.view(-1) for grad in layer_gradients]) - ) - tblogger._write_scalar_config( - tag=f"{key}_gradient_{i}", value=mean_gradient.item() - ) - else: tblogger.logger_bool = False From 52b41fcb9183ebebc83ef4f725f683c0735bf57a Mon Sep 17 00:00:00 2001 From: TarekAbouChakra Date: Thu, 24 Aug 2023 14:23:54 +0200 Subject: [PATCH 3/6] Fixing issues in tblogger and example --- .../convenience/neps_tblogger_tutorial.py | 65 +-- src/metahyper/api.py | 16 +- src/neps/plot/tensorboard_eval.py | 425 ++++++++++-------- 3 files changed, 291 insertions(+), 215 deletions(-) diff --git a/neps_examples/convenience/neps_tblogger_tutorial.py b/neps_examples/convenience/neps_tblogger_tutorial.py index cbb10618..c806aece 100644 --- a/neps_examples/convenience/neps_tblogger_tutorial.py +++ b/neps_examples/convenience/neps_tblogger_tutorial.py @@ -4,11 +4,11 @@ 1- Introduction --------------- -Welcome to the NePS tblogger with TensorBoard tutorial! This guide will walk you +Welcome to the NePS tblogger with TensorBoard tutorial. This guide will walk you through the process of using the NePS tblogger class to effectively monitor and analyze performance data from various model configurations during training. -Assuming you already have an experience in NePS the main reason of creating this tutorial is to showcase the +Assuming you already have experience in NePS, the main reason of creating this tutorial is to showcase the power of visualization using tblogger. if you wish to directly reach that part, check the lines between 244-264 or search for 'Start Tensorboard Logging' @@ -29,15 +29,15 @@ install the ``NePS`` package using the following command: ```bash - - pip install neural-pipeline-search +pip install neural-pipeline-search +``` Additionally, please note that NePS does not include ``torchvision`` as a dependency. You can install it with the following command: ```bash - - pip install torchvision +pip install torchvision==0.14.1 +``` These dependencies will ensure you have everything you need to follow along with this tutorial successfully. """ @@ -53,6 +53,7 @@ import numpy as np import torch import torch.nn as nn +import torch.nn.functional as F import torchvision from torch.optim import lr_scheduler from torch.utils.data.dataloader import DataLoader @@ -122,7 +123,6 @@ def MNIST( class MLP(nn.Module): def __init__(self) -> None: super().__init__() - self.relu = nn.ReLU() self.linear1 = nn.Linear(in_features=784, out_features=392) self.linear2 = nn.Linear(in_features=392, out_features=196) self.linear3 = nn.Linear(in_features=196, out_features=10) @@ -130,8 +130,8 @@ def __init__(self) -> None: def forward(self, x): # Flattening the grayscaled image from 1x28x28 (CxWxH) to 784. x = x.view(x.size(0), -1) - x = self.relu(self.linear1(x)) - x = self.relu(self.linear2(x)) + x = F.relu(self.linear1(x)) + x = F.relu(self.linear2(x)) x = self.linear3(x) return x @@ -216,6 +216,10 @@ def run_pipeline_BO(lr, optim, weight_decay): optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) elif optim == "SGD": optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay) + else: + raise ValueError( + "Optimizer choices are defined differently in the pipeline_space" + ) max_epochs = 9 @@ -251,10 +255,10 @@ def run_pipeline_BO(lr, optim, weight_decay): # tblogger for neps config loggings. This line will result in the following: - # 1 Incumbent of the configs (best performance regardless of fiedlity budget if the searcher was fidelity depenedent). - # 2 Loss curves of each of the configsat each epochs. + # 1 Incumbent of the configs (best performance regardless of fidelity budget, if the searcher was fidelity dependent). + # 2 Loss curves of each of the configs at each epoch. # 3 lr_decay curve at each epoch. - # 4 miss_img which represents the wrongly classified images by the model according the the counter. + # 4 miss_img which represents the wrongly classified images by the model. # 5 first two layer_gradients computed above and passed as scalar configs. tblogger.log( @@ -262,7 +266,9 @@ def run_pipeline_BO(lr, optim, weight_decay): current_epoch=i, data={ "lr_decay": tblogger.scalar_logging(value=scheduler.get_last_lr()[0]), - "miss_img": tblogger.image_logging(img_tensor=miss_img, counter=2), + "miss_img": tblogger.image_logging( + img_tensor=miss_img, counter=2, seed=2 + ), "layer_gradient1": tblogger.scalar_logging(value=mean_gradient[0]), "layer_gradient2": tblogger.scalar_logging(value=mean_gradient[1]), }, @@ -305,9 +311,9 @@ def run_pipeline_BO(lr, optim, weight_decay): When running this code without any arguments, it will by default run bayesian optimization with 10 max evaluations of 9 epochs each: - ```bash: - - python neps_tblogger_tutorial.py + ```bash + python neps_tblogger_tutorial.py + ``` """ start_time = time.time() @@ -322,17 +328,16 @@ def run_pipeline_BO(lr, optim, weight_decay): For showcasing purposes. After completing the first run, one can uncomment the line below and continue the search via: - ```bash: - - python neps_tblogger_tutorial.py --max_evaluations_total 15 + ```bash: + python neps_tblogger_tutorial.py --max_evaluations_total 15 + ``` This would result in continuing the search for 5 new different configurations in addition to disabling the logging, hence tblogger can always be disabled using the line below. - ```code: - - tblogger.disable() - + ```python: + tblogger.disable() + ``` """ # tblogger.disable() @@ -347,17 +352,17 @@ def run_pipeline_BO(lr, optim, weight_decay): """ To check live plots during this search, please open a new terminal and make sure to be at the same level directory - of your project and run this commant on the file created by neps search algorithm. + of your project and run this command on the file created by neps search algorithm. - ```bash: - - tensorboard --logdir bayesian_optimization + ```bash: + tensorboard --logdir bayesian_optimization + ``` To be able to check the visualization of tensorboard make sure to follow the local link provided. - ```bash: - - http://localhost:6006/ + ```bash: + http://localhost:6006/ + ``` If nothing was visualized and you followed the tutorial exactly, there could have been an error in passing the correct directory, please double check. Tensorboard will always run in the command line without checking if the directory exists. diff --git a/src/metahyper/api.py b/src/metahyper/api.py index 4e410fcc..e6942bbe 100644 --- a/src/metahyper/api.py +++ b/src/metahyper/api.py @@ -393,11 +393,19 @@ def run( pipeline_directory, previous_pipeline_directory, ) = _sample_config(optimization_dir, sampler, serializer, logger) - # Take the config data in case tensorboard is to be used. if tblogger.logger_init_bool or tblogger.logger_bool: - # A trick to enter the condition once if tblogger is not used and always - # if it is, necessary to log the first config. (need to save the first config - # then check if tblogger is used during training in the run_pipeline.) + # The following code block handles configuration data for potential use with TensorBoard. + # If the TensorBoard logger has been initialized or is active, this block captures + # configuration details. During the first configuration sampling, this process captures + # the initial configuration regardless of whether TensorBoard is used. In the subsequent + # sampling rounds, if the `logger_bool` flag is True, indicating TensorBoard usage, + # the logger will continue to capture and track configurations. If `logger_bool` is False, + # and `logger_init_bool` is False as well, no more configurations will be captured. + + # The `run_pipeline` step occurs after this sampling, and at this initial stage, it's not possible to know + # whether TensorBoard will be used. Gathering initial configuration details at this point ensures their + # availability for later stages, even if TensorBoard is not employed and stops capturing if it actually + # is not employed. tblogger.config_track_init_api( config_id=config_id, config=config, diff --git a/src/neps/plot/tensorboard_eval.py b/src/neps/plot/tensorboard_eval.py index 185bd1a2..23e08100 100644 --- a/src/neps/plot/tensorboard_eval.py +++ b/src/neps/plot/tensorboard_eval.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import math import os -import random import warnings -from typing import List, Optional, Union +from pathlib import Path +import numpy as np import torch from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard.summary import hparams @@ -24,13 +26,10 @@ class SummaryWriter_(SummaryWriter): - add_hparams: Overrides the base method to log hyperparameters and metrics with better formatting. """ - def add_hparams(self, hparam_dict, metric_dict, global_step): + def add_hparams(self, hparam_dict: dict, metric_dict: dict, global_step: int) -> None: if not isinstance(hparam_dict, dict) or not isinstance(metric_dict, dict): raise TypeError("hparam_dict and metric_dict should be dictionary.") - updated_metric = {} - for key, value in metric_dict.items(): - updated_key = "Summary" + "/" + key - updated_metric[updated_key] = value + updated_metric = {f"Summary/{key}": val for key, val in metric_dict.items()} exp, ssi, sei = hparams(hparam_dict, updated_metric) self.file_writer.add_summary(exp) @@ -41,39 +40,43 @@ def add_hparams(self, hparam_dict, metric_dict, global_step): class tblogger: - config = None - config_id: Optional[int] = None - config_working_directory = None - config_previous_directory = None - optim_path = None - - config_value_fid: Optional[str] = None - fidelity_mode: bool = False + config_id: str | None = None + config: dict | None = None + config_working_directory: Path | None = None + optim_path: Path | None = None + config_previous_directory: Path | None = None logger_init_bool: bool = True - logger_bool: bool = False + """logger_init_bool is only used once to capture configuration data for the first ever configuration, + and then turned false for the entire run.""" - image_logger: bool = False - image_value: Optional[torch.tensor] = None - image_name: Optional[str] = None - epoch_value: Optional[int] = None + logger_bool: bool = False + """logger_bool is true only if tblogger.log is used by the user, hence this allows to always capturing + the configuration data for all configurations.""" disable_logging: bool = False + """disable_logging is a hard switch to disable the logging feature if it was turned true. + hence even when logger_bool is true it disables the logging process""" - loss: Optional[float] = None - current_epoch: int + loss: float | None = None + current_epoch: int | None = None scalar_accuracy_mode: bool = False hparam_accuracy_mode: bool = False - config_writer: Optional[SummaryWriter_] = None - summary_writer: Optional[SummaryWriter_] = None + incum_tracker: int | None = None + incum_val: float | None = None - logging_mode: list = [] + config_writer: SummaryWriter_ | None = None + summary_writer: SummaryWriter_ | None = None @staticmethod def config_track_init_api( - config_id, config, config_working_directory, config_previous_directory, optim_path - ): + config_id: str, + config: dict, + config_working_directory: Path, + optim_path: Path, + config_previous_directory: Path | None = None, + ) -> None: """ Track the Configuration space data from the way handled by neps metahyper '_sample_config' to keep in sync with config ids and directories NePS is operating on. @@ -82,65 +85,86 @@ def config_track_init_api( tblogger.config = config tblogger.config_id = config_id tblogger.config_working_directory = config_working_directory - tblogger.config_previous_directory = config_previous_directory tblogger.optim_path = optim_path + tblogger.config_previous_directory = config_previous_directory + + @staticmethod + def _is_initialized() -> bool: + # Returns 'True' if config_writer is already initialized. 'False' otherwise + return tblogger.config_writer is not None @staticmethod - def _initialize_writers(): - if not tblogger.config_writer: - # If the writer is still not assgined - optim_config_path = tblogger.optim_path / "results" + def _initialize_writers() -> None: + # This code runs only once per config, which is at the very beginning to assign that config + # a config_writer in the correct directory. + if tblogger.config_previous_directory is None: + # If no fidelities are there yet, define the writer via the normal config_id + tblogger.config_writer = SummaryWriter_( + tblogger.config_working_directory / "tbevents" + ) + return + while not tblogger._is_initialized(): + # While no writer has been assigned yet, knowing previous directory exists. + # TensorBoard requires tfevent files for the same plot to be located in the same directory for proper data appending. + # In this case we search for the first fidelity directory and store all tfevent files there + + prev_dir_id_from_init = ( + tblogger.config_working_directory / "previous_config.id" + ) if tblogger.config_previous_directory is not None: - # If a previous directory is available (Now the search is done for higher fidelity but logging is - # saved on the previous directory) - tblogger.fidelity_mode = True - while not tblogger.config_writer: - if os.path.exists(tblogger.config_previous_directory / "tbevents"): - # If the previous directory was actually the first fidelity, - # tbevents is the folder holding the logging event files "tfevent" - find_previous_config_id = ( - tblogger.config_working_directory / "previous_config.id" - ) - if os.path.exists(find_previous_config_id): - # Get the ID of the previous config to log on the new train data - with open(find_previous_config_id) as file: - contents = file.read() - tblogger.config_value_fid = contents - tblogger.config_writer = SummaryWriter_( - tblogger.config_previous_directory / "tbevents" - ) - else: - # If the directory does not have the writer created, - # find the previous config and keep on looping backward until locating - # the inital config holding the tfevent files - find_previous_config_path = ( - tblogger.config_previous_directory / "previous_config.id" - ) - if os.path.exists(find_previous_config_path): - with open(find_previous_config_path) as file: - contents = file.read() - tblogger.config_value_fid = contents - tblogger.config_working_directory = ( - tblogger.config_previous_directory - ) - tblogger.config_previous_directory = ( - optim_config_path / f"config_{contents}" - ) + get_tbevent_dir = tblogger.config_previous_directory / "tbevents" else: - # If no fidelities are there, define the writer via the normal config_id - tblogger.fidelity_mode = False + warnings.warn( + "There should be a previous config directory at this stage." + "Prone to failure" + ) + + # This should execute when having Config_x_1 and Config_x_0 + if os.path.exists(get_tbevent_dir): + # When tfevents directory is detected => we are at the first fidelity directory, create writer. + with open(prev_dir_id_from_init) as file: + contents = file.read() + tblogger.config_id = contents tblogger.config_writer = SummaryWriter_( - tblogger.config_working_directory / "tbevents" + tblogger.config_previous_directory / "tbevents" + ) + return + + # This should execute when having Config_x_y and Config_x_y where y > 0 + if tblogger.config_previous_directory is not None: + prev_dir_id_from_prev = ( + tblogger.config_previous_directory / "previous_config.id" + ) + else: + warnings.warn( + "There should be a previous config directory at this stage." + "Prone to failure" + ) + + if os.path.exists(prev_dir_id_from_prev): + # To get the new previous config directory + with open(prev_dir_id_from_prev) as file: + contents = file.read() + tblogger.config_id = contents + tblogger.config_working_directory = tblogger.config_previous_directory + tblogger.config_previous_directory = ( + tblogger.optim_path / "results" / f"config_{contents}" + ) + else: + # If we do not find tbevents, hence we passed through each and every + # directory for that config, raising error rather than staying in a 'while 1'. + raise FileNotFoundError( + "'tbevents' was not found in the directory of the initial fidelity." ) @staticmethod - def _make_grid(images: torch.tensor, nrow: int, padding: int = 2): + def _make_grid(images: torch.Tensor, nrow: int, padding: int = 2) -> torch.Tensor: """ Create a grid of images from a batch of images. Args: images (torch.Tensor): The input batch of images with shape (batch_size, num_channels, height, width). - nrow (int): The number rows on the grid. + nrow (int): The number of rows on the grid. padding (int, optional): The padding between images in the grid. Default is 2. Returns: @@ -172,7 +196,7 @@ def _make_grid(images: torch.tensor, nrow: int, padding: int = 2): return grid @staticmethod - def scalar_logging(value: float) -> list: + def scalar_logging(value: float) -> tuple: """ Prepare a scalar value for logging. @@ -180,53 +204,62 @@ def scalar_logging(value: float) -> list: value (float): The scalar value to be logged. Returns: - list: A list containing the logging mode and the value for logging. - The list format is [logging_mode, value]. + tuple: A tuple containing the logging mode and the value for logging. + The tuple format is (logging_mode, value). """ logging_mode = "scalar" - return [logging_mode, value] + return (logging_mode, value) @staticmethod def image_logging( img_tensor: torch.Tensor, - counter: int, - resize_images: Optional[List[Optional[int]]] = None, - ignore_warning: bool = True, + counter: int = 1, + resize_images: list[None | int] | None = None, random_images: bool = True, num_images: int = 20, - ) -> List[Union[str, torch.Tensor, int, bool, List[Optional[int]]]]: + seed: int | np.random.RandomState | None = None, + ) -> tuple[ + str, + torch.Tensor, + int, + list[None | int] | None, + bool, + int, + int | np.random.RandomState | None, + ]: """ Prepare an image tensor for logging. Args: img_tensor (torch.Tensor): The image tensor to be logged. - counter (int): A counter value for teh frequency of image logging (ex: counter 2 means for every + counter (int): A counter value for the frequency of image logging (ex: counter 2 means for every 2 epochs a new set of images are logged). resize_images (list of int): A list of integers representing the image sizes after resizing or None if no resizing required. Default is None. - ignore_warning (bool, optional): Whether to ignore any warning during logging. Default is True. random_images (bool, optional): Whether the images are selected randomly. Default is True. num_images (int, optional): The number of images to log. Default is 20. + seed (int or np.random.RandomState or None, optional): Seed value or RandomState instance to control + the randomness of image selection. Default is None. Returns: - list: A list containing the logging mode and all the necessary parameters for image logging. - The list format is [logging_mode, img_tensor, counter, repetitive, resize_images, - ignore_warning, random_images, num_images]. + tuple: A tuple containing the logging mode and all the necessary parameters for image logging. + The tuple format is (logging_mode, img_tensor, counter, repetitive, resize_images, + random_images, num_images, seed). """ logging_mode = "image" - return [ + return ( logging_mode, img_tensor, counter, resize_images, - ignore_warning, random_images, num_images, - ] + seed, + ) @staticmethod - def _write_scalar_config(tag: str, value: Union[float, int]): + def _write_scalar_config(tag: str, value: float | int) -> None: """ Write scalar values to the TensorBoard log. @@ -243,41 +276,41 @@ def _write_scalar_config(tag: str, value: Union[float, int]): It also depends on the following global variables: - tblogger.scalar_accuracy_mode (bool) - - tblogger.fidelity_mode (bool) - tblogger.config_writer (SummaryWriter_) + - tblogger.config_id (str) The function will log the scalar value under different tags based on fidelity mode and other configurations. """ - tblogger._initialize_writers() + if not tblogger._is_initialized(): + tblogger._initialize_writers() - if tag == "Loss": - if tblogger.scalar_accuracy_mode: - tag = "Accuracy" - value = (1 - value) * 100 + if tag == "Loss" and tblogger.scalar_accuracy_mode: + tag = "Accuracy" + value = (1 - value) * 100 + + # Just an extra safety measure if tblogger.config_writer is not None: - if tblogger.fidelity_mode: - tblogger.config_writer.add_scalar( - tag="Config_" + str(tblogger.config_value_fid) + "/" + tag, - scalar_value=value, - global_step=tblogger.current_epoch, - ) - else: - tblogger.config_writer.add_scalar( - tag="Config_" + str(tblogger.config_id) + "/" + tag, - scalar_value=value, - global_step=tblogger.current_epoch, - ) + tblogger.config_writer.add_scalar( + tag="Config_" + str(tblogger.config_id) + "/" + tag, + scalar_value=value, + global_step=tblogger.current_epoch, + ) + else: + raise ValueError( + "The 'config_writer' is None in _write_scalar_config. No loggings are performed. " + "An error occurred during the initialization process." + ) @staticmethod def _write_image_config( tag: str, - image: torch.tensor, - counter: int, - resize_images: Optional[List[Optional[int]]] = None, - ignore_warning: bool = True, + image: torch.Tensor, + counter: int = 1, + resize_images: list[None | int] | None = None, random_images: bool = True, num_images: int = 20, - ): + seed: int | np.random.RandomState | None = None, + ) -> None: """ Write images to the TensorBoard log. @@ -288,9 +321,10 @@ def _write_image_config( resize_images (list of int): A list of integers representing the image sizes after resizing or None if no resizing required. Default is None. - ignore_warning (bool, optional): Whether to ignore any warning during logging. Default is True. random_images (bool, optional): Whether the images are selected randomly. Default is True. num_images (int, optional): The number of images to log. Default is 20. + seed (int or np.random.RandomState or None, optional): Seed value or RandomState instance to control + the randomness of image selection. Default is None. Note: The function relies on the initialize_config_writer to ensure the TensorBoard writer is initialized at @@ -298,33 +332,34 @@ def _write_image_config( It also depends on the following global variables: - tblogger.current_epoch (int) - - tblogger.fidelity_mode (bool) - tblogger.config_writer (SummaryWriter_) - - tblogger.config_value_fid (int or None) - - tblogger.config_id (int) + - tblogger.config_id (str) The function will log a subset of images to TensorBoard based on the given configurations. """ - tblogger._initialize_writers() + if not tblogger._is_initialized(): + tblogger._initialize_writers() if resize_images is None: resize_images = [32, 32] - if ignore_warning is True: - warnings.filterwarnings("ignore", category=DeprecationWarning) - if tblogger.current_epoch % counter == 0: - # Log every multiple of "counter" + # Log every multiple of "counter" if num_images > len(image): - # Be safe if the number of images is not as the len (as in the batch size) + # If the number of images requested by the user + # is more than the ones available. num_images = len(image) if random_images is False: subset_images = image[:num_images] else: - random_indices = random.sample(range(len(image)), num_images) - subset_images = image[random_indices] + if not isinstance(seed, np.random.RandomState): + seed = np.random.RandomState(seed) + # We do not interfere with any randomness from the pipeline + num_total_images = len(image) + indices = seed.choice(num_total_images, num_images, replace=False) + subset_images = image[indices] resized_images = torch.nn.functional.interpolate( subset_images, @@ -335,22 +370,21 @@ def _write_image_config( # Create the grid according to the number of images and log the grid to tensorboard. nrow = int(resized_images.size(0) ** 0.75) img_grid = tblogger._make_grid(resized_images, nrow=nrow) + # Just an extra safety measure if tblogger.config_writer is not None: - if tblogger.fidelity_mode: - tblogger.config_writer.add_image( - tag="Config_" + str(tblogger.config_value_fid) + "/" + tag, - img_tensor=img_grid, - global_step=tblogger.current_epoch, - ) - else: - tblogger.config_writer.add_image( - tag="Config_" + str(tblogger.config_id) + "/" + tag, - img_tensor=img_grid, - global_step=tblogger.current_epoch, - ) + tblogger.config_writer.add_image( + tag="Config_" + str(tblogger.config_id) + "/" + tag, + img_tensor=img_grid, + global_step=tblogger.current_epoch, + ) + else: + raise ValueError( + "The 'config_writer' is None in _write_image_config. No loggings are performed. " + "An error occurred during the initialization process." + ) @staticmethod - def _write_hparam_config(): + def _write_hparam_config() -> None: """ Write hyperparameter configurations to the TensorBoard log, inspired by the 'hparam' original function of tensorboard. @@ -368,7 +402,8 @@ def _write_hparam_config(): The function will log hyperparameter configurations along with a metric value (either accuracy or loss) to TensorBoard based on the given configurations. """ - tblogger._initialize_writers() + if not tblogger._is_initialized(): + tblogger._initialize_writers() if tblogger.hparam_accuracy_mode: # Changes the loss to accuracy and logs in accuracy terms. @@ -379,15 +414,21 @@ def _write_hparam_config(): str_value = tblogger.loss values = {str_name: str_value} + # Just an extra safety measure if tblogger.config_writer is not None: tblogger.config_writer.add_hparams( hparam_dict=tblogger.config, metric_dict=values, global_step=tblogger.current_epoch, ) + else: + raise ValueError( + "The 'config_writer' is None in _write_hparam_config. No loggings are performed. " + "An error occurred during the initialization process." + ) @staticmethod - def tracking_incumbent_api(best_loss): + def tracking_incumbent_api(best_loss: float) -> None: """ Track the incumbent (best) loss and log it in the TensorBoard summary. @@ -406,17 +447,25 @@ def tracking_incumbent_api(best_loss): It increments the incumbent tracker based on occurrences of "Config ID" in the 'all_losses_and_configs.txt' file. """ if tblogger.config_writer: - # Close all the previous config writers + # Close and reset previous config writers to ensure proper handling of logging continuity. + + # This is important as writers can be reinitialized to the same directory for ongoing + # logging (across different fidelities). Closing writers after logging ensures consistency + # and avoids conflicts. tblogger.config_writer.close() tblogger.config_writer = None file_path = str(tblogger.optim_path) + "/all_losses_and_configs.txt" tblogger.incum_tracker = 0 - with open(file_path) as f: - for line in f: + if os.path.exists(file_path): + with open(file_path) as f: # Count the amount of presence of "Config ID" because it correlates to the # step size of how many configurations were completed. - tblogger.incum_tracker += line.count("Config ID") + tblogger.incum_tracker = sum(line.count("Config ID") for line in f) + else: + raise FileExistsError( + "all_losses_and_configs.txt does not exist in the optimization directory" + ) tblogger.incum_val = float(best_loss) @@ -426,34 +475,49 @@ def tracking_incumbent_api(best_loss): tblogger.summary_writer = SummaryWriter_(logdir) tblogger.summary_writer.add_scalar( - tag="Summary" + "/Incumbent_graph", + tag="Summary/Incumbent_graph", scalar_value=tblogger.incum_val, global_step=tblogger.incum_tracker, ) + # One challenge here is that the process of frequently closing existing writers and creating new ones + # leads to the creation of new 'tfevent' files (handled by TensorBoard). Although this outcome is unavoidable, + # it arises due to the necessity of allowing multiple writers to operate concurrently on the same directory. + # When multiple writers remain open simultaneously, conflicts can arise. This is the reason writers are + # flushed and closed after their use. (To make it work for parallelization) + tblogger.summary_writer.flush() tblogger.summary_writer.close() @staticmethod - def disable(disable_logger: bool = True): + def disable() -> None: """ - The function allows for enabling or disabling the logger functionality + The function allows for disabling the logger functionality throughout the program execution by updating the value of 'tblogger.disable_logging'. When the logger is disabled, it will not perform any logging operations. - Args: - disable_logger (bool, optional): A boolean flag to control the logger. - If True (default), the logger will be disabled. - If False, the logger will be enabled. + By default tblogger is enabled when used. If for any reason disabling is needed. This function does the job. Example: # Disable the logger tblogger.disable() + """ + tblogger.disable_logging = True + + @staticmethod + def enable() -> None: + """ + The function allows for enabling the logger functionality + throughout the program execution by updating the value of 'tblogger.disable_logging'. + When the logger is enabled, it will perform the logging operations. + By default this is enabled. Hence only needed when tblogger was once disabled. + + Example: # Enable the logger - tblogger.disable(False) + tblogger.enable() """ - tblogger.disable_logging = disable_logger + tblogger.disable_logging = False @staticmethod def log( @@ -463,8 +527,8 @@ def log( writer_hparam: bool = True, scalar_accuracy_mode: bool = False, hparam_accuracy_mode: bool = False, - data: Optional[dict] = None, - ): + data: dict | None = None, + ) -> None: """ Log experiment data to the logger, including scalar values, hyperparameters, images, and layer gradients. @@ -483,7 +547,6 @@ def log( { 'tag1': tblogger.scalar_logging(value=value1), 'tag2': tblogger.image_logging(img_tensor=img, counter=2), - 'tag3': tblogger.layer_gradient_logging(model=model), } Default is None. @@ -493,30 +556,30 @@ def log( tblogger.scalar_accuracy_mode = scalar_accuracy_mode tblogger.hparam_accuracy_mode = hparam_accuracy_mode - if not tblogger.disable_logging: - tblogger.logger_bool = True - - if writer_scalar: - tblogger._write_scalar_config(tag="Loss", value=loss) - - if writer_hparam: - tblogger._write_hparam_config() - - if data is not None: - for key in data: - if data[key][0] == "scalar": - tblogger._write_scalar_config(tag=str(key), value=data[key][1]) - - elif data[key][0] == "image": - tblogger._write_image_config( - tag=str(key), - image=data[key][1], - counter=data[key][2], - resize_images=data[key][3], - ignore_warning=data[key][4], - random_images=data[key][5], - num_images=data[key][6], - ) - - else: + if tblogger.disable_logging: tblogger.logger_bool = False + return + + tblogger.logger_bool = True + + if writer_scalar: + tblogger._write_scalar_config(tag="Loss", value=loss) + + if writer_hparam: + tblogger._write_hparam_config() + + if data is not None: + for key in data: + if data[key][0] == "scalar": + tblogger._write_scalar_config(tag=str(key), value=data[key][1]) + + elif data[key][0] == "image": + tblogger._write_image_config( + tag=str(key), + image=data[key][1], + counter=data[key][2], + resize_images=data[key][3], + random_images=data[key][4], + num_images=data[key][5], + seed=data[key][6], + ) From 7495eccddc56f845ac288645ab23ab8c313c1d28 Mon Sep 17 00:00:00 2001 From: TarekAbouChakra Date: Wed, 30 Aug 2023 09:43:36 +0200 Subject: [PATCH 4/6] tblogger changes in class and example --- .../convenience/neps_tblogger_tutorial.py | 12 +++- src/metahyper/api.py | 16 ++--- src/neps/plot/tensorboard_eval.py | 62 +++++++++---------- 3 files changed, 44 insertions(+), 46 deletions(-) diff --git a/neps_examples/convenience/neps_tblogger_tutorial.py b/neps_examples/convenience/neps_tblogger_tutorial.py index c806aece..1595c55d 100644 --- a/neps_examples/convenience/neps_tblogger_tutorial.py +++ b/neps_examples/convenience/neps_tblogger_tutorial.py @@ -325,7 +325,7 @@ def run_pipeline_BO(lr, optim, weight_decay): shutil.rmtree("results/bayesian_optimization") """ - For showcasing purposes. After completing the first run, one can uncomment the line below + For showcasing purposes. After completing the first run, one can uncomment line 351 and continue the search via: ```bash: @@ -338,8 +338,16 @@ def run_pipeline_BO(lr, optim, weight_decay): ```python: tblogger.disable() ``` + + Note that by default tblogger is enabled when used. However, there is also an enable toggle that can be used + ```python: + tblogger.enable() + ``` """ + # by defualt tblogger is enabled when used, one can also check the status using: + # tblogger.get_status() + # tblogger.disable() neps.run( @@ -352,7 +360,7 @@ def run_pipeline_BO(lr, optim, weight_decay): """ To check live plots during this search, please open a new terminal and make sure to be at the same level directory - of your project and run this command on the file created by neps search algorithm. + of your project and run this command on the file created by neps root_directory. ```bash: tensorboard --logdir bayesian_optimization diff --git a/src/metahyper/api.py b/src/metahyper/api.py index e6942bbe..a0b5e552 100644 --- a/src/metahyper/api.py +++ b/src/metahyper/api.py @@ -394,18 +394,10 @@ def run( previous_pipeline_directory, ) = _sample_config(optimization_dir, sampler, serializer, logger) if tblogger.logger_init_bool or tblogger.logger_bool: - # The following code block handles configuration data for potential use with TensorBoard. - # If the TensorBoard logger has been initialized or is active, this block captures - # configuration details. During the first configuration sampling, this process captures - # the initial configuration regardless of whether TensorBoard is used. In the subsequent - # sampling rounds, if the `logger_bool` flag is True, indicating TensorBoard usage, - # the logger will continue to capture and track configurations. If `logger_bool` is False, - # and `logger_init_bool` is False as well, no more configurations will be captured. - - # The `run_pipeline` step occurs after this sampling, and at this initial stage, it's not possible to know - # whether TensorBoard will be used. Gathering initial configuration details at this point ensures their - # availability for later stages, even if TensorBoard is not employed and stops capturing if it actually - # is not employed. + # This block manages configuration data, potentially for TensorBoard. + # Captures details during sampling; initial config always captured. + # In later rounds, captures if `logger_bool` is True; stops if False. + # Initial details gathered for `run_pipeline` pre-TensorBoard decision. tblogger.config_track_init_api( config_id=config_id, config=config, diff --git a/src/neps/plot/tensorboard_eval.py b/src/neps/plot/tensorboard_eval.py index 23e08100..47575aef 100644 --- a/src/neps/plot/tensorboard_eval.py +++ b/src/neps/plot/tensorboard_eval.py @@ -95,18 +95,16 @@ def _is_initialized() -> bool: @staticmethod def _initialize_writers() -> None: - # This code runs only once per config, which is at the very beginning to assign that config - # a config_writer in the correct directory. + # This code runs only once per config, to assign that config a config_writer. if tblogger.config_previous_directory is None: - # If no fidelities are there yet, define the writer via the normal config_id + # If no fidelities are there yet, define the writer via the config_id tblogger.config_writer = SummaryWriter_( tblogger.config_working_directory / "tbevents" ) return while not tblogger._is_initialized(): - # While no writer has been assigned yet, knowing previous directory exists. - # TensorBoard requires tfevent files for the same plot to be located in the same directory for proper data appending. - # In this case we search for the first fidelity directory and store all tfevent files there + # Ensure proper directory for TensorBoard data appending. + # Search for the first fidelity directory to store tfevent files. prev_dir_id_from_init = ( tblogger.config_working_directory / "previous_config.id" @@ -151,8 +149,8 @@ def _initialize_writers() -> None: tblogger.optim_path / "results" / f"config_{contents}" ) else: - # If we do not find tbevents, hence we passed through each and every - # directory for that config, raising error rather than staying in a 'while 1'. + # If no tbevents found after traversing all config directories, + # raise an error to prevent indefinite 'while 1' loop. raise FileNotFoundError( "'tbevents' was not found in the directory of the initial fidelity." ) @@ -233,10 +231,10 @@ def image_logging( Args: img_tensor (torch.Tensor): The image tensor to be logged. counter (int): A counter value for the frequency of image logging (ex: counter 2 means for every - 2 epochs a new set of images are logged). + 2 global steps a new set of images are logged). resize_images (list of int): A list of integers representing the image sizes - after resizing or None if no resizing required. - Default is None. + after resizing or None if no resizing required. + Default is None. random_images (bool, optional): Whether the images are selected randomly. Default is True. num_images (int, optional): The number of images to log. Default is 20. seed (int or np.random.RandomState or None, optional): Seed value or RandomState instance to control @@ -244,7 +242,7 @@ def image_logging( Returns: tuple: A tuple containing the logging mode and all the necessary parameters for image logging. - The tuple format is (logging_mode, img_tensor, counter, repetitive, resize_images, + The tuple format is (logging_mode, img_tensor, counter, resize_images, random_images, num_images, seed). """ logging_mode = "image" @@ -271,7 +269,7 @@ def _write_scalar_config(tag: str, value: float | int) -> None: If the tag is 'Loss' and scalar_accuracy_mode is True, the tag will be changed to 'Accuracy', and the value will be transformed accordingly. - The function relies on the initialize_config_writer to ensure the TensorBoard writer is initialized at + The function relies on the _initialize_writers to ensure the TensorBoard writer is initialized at the correct directory. It also depends on the following global variables: @@ -297,7 +295,7 @@ def _write_scalar_config(tag: str, value: float | int) -> None: ) else: raise ValueError( - "The 'config_writer' is None in _write_scalar_config. No loggings are performed. " + "The 'config_writer' is None in _write_scalar_config. No logging is performed." "An error occurred during the initialization process." ) @@ -327,7 +325,7 @@ def _write_image_config( the randomness of image selection. Default is None. Note: - The function relies on the initialize_config_writer to ensure the TensorBoard writer is initialized at + The function relies on the _initialize_writers to ensure the TensorBoard writer is initialized at the correct directory. It also depends on the following global variables: @@ -379,7 +377,7 @@ def _write_image_config( ) else: raise ValueError( - "The 'config_writer' is None in _write_image_config. No loggings are performed. " + "The 'config_writer' is None in _write_image_config. No logging is performed. " "An error occurred during the initialization process." ) @@ -389,7 +387,7 @@ def _write_hparam_config() -> None: Write hyperparameter configurations to the TensorBoard log, inspired by the 'hparam' original function of tensorboard. Note: - The function relies on the initialize_config_writer to ensure the TensorBoard writer is initialized at + The function relies on the _initialize_writers to ensure the TensorBoard writer is initialized at the correct directory. It also depends on the following global variables: @@ -423,7 +421,7 @@ def _write_hparam_config() -> None: ) else: raise ValueError( - "The 'config_writer' is None in _write_hparam_config. No loggings are performed. " + "The 'config_writer' is None in _write_hparam_config. No logging is performed. " "An error occurred during the initialization process." ) @@ -447,11 +445,8 @@ def tracking_incumbent_api(best_loss: float) -> None: It increments the incumbent tracker based on occurrences of "Config ID" in the 'all_losses_and_configs.txt' file. """ if tblogger.config_writer: - # Close and reset previous config writers to ensure proper handling of logging continuity. - - # This is important as writers can be reinitialized to the same directory for ongoing - # logging (across different fidelities). Closing writers after logging ensures consistency - # and avoids conflicts. + # Close and reset previous config writers for consistent logging. + # Prevent conflicts by reinitializing writers when logging ongoing. tblogger.config_writer.close() tblogger.config_writer = None @@ -459,8 +454,7 @@ def tracking_incumbent_api(best_loss: float) -> None: tblogger.incum_tracker = 0 if os.path.exists(file_path): with open(file_path) as f: - # Count the amount of presence of "Config ID" because it correlates to the - # step size of how many configurations were completed. + # Count "Config ID" occurrences to track completed configurations. tblogger.incum_tracker = sum(line.count("Config ID") for line in f) else: raise FileExistsError( @@ -480,11 +474,8 @@ def tracking_incumbent_api(best_loss: float) -> None: global_step=tblogger.incum_tracker, ) - # One challenge here is that the process of frequently closing existing writers and creating new ones - # leads to the creation of new 'tfevent' files (handled by TensorBoard). Although this outcome is unavoidable, - # it arises due to the necessity of allowing multiple writers to operate concurrently on the same directory. - # When multiple writers remain open simultaneously, conflicts can arise. This is the reason writers are - # flushed and closed after their use. (To make it work for parallelization) + # Frequent writer open/close creates new 'tfevent' files due to parallelization needs. + # Simultaneous open writers risk conflicts, so they're flushed and closed after use. tblogger.summary_writer.flush() tblogger.summary_writer.close() @@ -519,6 +510,13 @@ def enable() -> None: """ tblogger.disable_logging = False + @staticmethod + def get_status(): + """ + Returns the currect state of tblogger ie. whether the logger is enabled or not + """ + return not tblogger.disable_logging + @staticmethod def log( loss: float, @@ -530,7 +528,7 @@ def log( data: dict | None = None, ) -> None: """ - Log experiment data to the logger, including scalar values, hyperparameters, images, and layer gradients. + Log experiment data to the logger, including scalar values, hyperparameters, and images. Args: loss (float): The current loss value in training. @@ -546,7 +544,7 @@ def log( data (dict, optional): Additional experiment data to be logged. It should be in the format: { 'tag1': tblogger.scalar_logging(value=value1), - 'tag2': tblogger.image_logging(img_tensor=img, counter=2), + 'tag2': tblogger.image_logging(img_tensor=img, counter=2, seed=0), } Default is None. From af23e28c25a62071ef3183b0db42693799aa95dc Mon Sep 17 00:00:00 2001 From: TarekAbouChakra Date: Tue, 5 Sep 2023 20:34:29 +0200 Subject: [PATCH 5/6] fixing tblogger example --- .../convenience/neps_tblogger_tutorial.py | 178 ++++++++++-------- 1 file changed, 101 insertions(+), 77 deletions(-) diff --git a/neps_examples/convenience/neps_tblogger_tutorial.py b/neps_examples/convenience/neps_tblogger_tutorial.py index 1595c55d..49005970 100644 --- a/neps_examples/convenience/neps_tblogger_tutorial.py +++ b/neps_examples/convenience/neps_tblogger_tutorial.py @@ -1,45 +1,45 @@ """ NePS tblogger With TensorBoard -==================================== +============================== + 1- Introduction --------------- - Welcome to the NePS tblogger with TensorBoard tutorial. This guide will walk you -through the process of using the NePS tblogger class to effectively monitor and -analyze performance data from various model configurations during training. +through the process of using the NePS tblogger class to monitor performance +data for different hyperparameter configurations during optimization. -Assuming you already have experience in NePS, the main reason of creating this tutorial is to showcase the -power of visualization using tblogger. if you wish to directly reach that part, check the lines -between 244-264 or search for 'Start Tensorboard Logging' +Assuming you have experience with NePS, this tutorial aims to showcase the power +of visualization using tblogger. To go directly to that part, check lines 244-264 +or search for 'Start Tensorboard Logging'. 2- Learning Objectives ---------------------- - By completing this tutorial, you will: -- Understand the role of NePS tblogger and its importance in HPO and NAS. -- Learn how to define search spaces within NePS to explore different model configurations. +- Understand the role of NePS tblogger in HPO and NAS. +- Learn to define search spaces within NePS for different model configurations. - Build a comprehensive run pipeline to train and evaluate models. -- Utilize TensorBoard to visualize and compare performance metrics of different model configurations. +- Utilize TensorBoard to visualize and compare performance metrics of different + model configurations. 3- Setup -------- - -Before we dive in, make sure you have the necessary dependencies installed. If you haven't already, -install the ``NePS`` package using the following command: +Before we begin, ensure you have the necessary dependencies installed. To install +the 'NePS' package, use the following command: ```bash pip install neural-pipeline-search ``` -Additionally, please note that NePS does not include ``torchvision`` as a dependency. -You can install it with the following command: +Additionally, note that 'NePS' does not include 'torchvision' as a dependency. +You can install it with this command: ```bash pip install torchvision==0.14.1 ``` -These dependencies will ensure you have everything you need to follow along with this tutorial successfully. +These dependencies ensure you have everything you need for this tutorial. + """ import argparse @@ -94,20 +94,28 @@ def set_seed(seed=123): def MNIST( batch_size: int = 32, n_train: int = 8192, n_valid: int = 1024 ) -> Tuple[DataLoader, DataLoader, DataLoader]: + # Datasets downloading if required. train_dataset = torchvision.datasets.MNIST( - root="./data", train=True, transform=transforms.ToTensor(), download=True + root="./data", train=True, transform=transforms.ToTensor(), + download=True ) test_dataset = torchvision.datasets.MNIST( - root="./data", train=False, transform=transforms.ToTensor(), download=True + root="./data", train=False, transform=transforms.ToTensor(), + download=True ) + # Further sampling a validation dataset from the train dataset. train_sampler = SubsetRandomSampler(range(n_train)) valid_sampler = SubsetRandomSampler(range(n_train, n_train + n_valid)) + + # Creating the dataloaders. train_dataloader = DataLoader( - dataset=train_dataset, batch_size=batch_size, shuffle=False, sampler=train_sampler + dataset=train_dataset, batch_size=batch_size, shuffle=False, + sampler=train_sampler ) val_dataloader = DataLoader( - dataset=train_dataset, batch_size=batch_size, shuffle=False, sampler=valid_sampler + dataset=train_dataset, batch_size=batch_size, shuffle=False, + sampler=valid_sampler ) test_dataloader = DataLoader( dataset=test_dataset, batch_size=batch_size, shuffle=False @@ -138,7 +146,8 @@ def forward(self, x): ############################################################# -# 4 Define the training step and return the validation error and misclassified images. +# 4 Define the training step. Return the validation error and +# misclassified images. def loss_ev(model: nn.Module, data_loader: DataLoader) -> float: @@ -156,19 +165,27 @@ def loss_ev(model: nn.Module, data_loader: DataLoader) -> float: return 1 - accuracy -def training(model, optimizer, criterion, train_loader, validation_loader): +def training( + model: nn.Module, + optimizer: torch.optim, + criterion: nn.modules.loss, + train_loader: DataLoader, + validation_loader: DataLoader, + ) -> Tuple[float, torch.Tensor]: """ - Function that trains the model for one epoch and evaluates the model on the validation set. + Function that trains the model for one epoch and evaluates the model + on the validation set. Args: model (nn.Module): Model to be trained. - optimizer (torch.nn.optim): Optimizer used to train the weights (depends on the pipeline space). + optimizer (torch.optim): Optimizer used to train the weights. criterion (nn.modules.loss) : Loss function to use. - train_loader (torch.utils.Dataloader): Data loader containing the training data. - validation_loader (torch.utils.Dataloader): Data loader containing the validation data. + train_loader (Dataloader): Dataloader containing the training data. + validation_loader (Dataloader): Dataloader containing the validation data. Returns: - (float) validation error for the epoch. + Tuple[float, torch.Tensor]: A tuple containing the validation error (float) + and a tensor of misclassified images. """ incorrect_images = [] model.train() @@ -188,7 +205,7 @@ def training(model, optimizer, criterion, train_loader, validation_loader): if len(incorrect_images) > 0: incorrect_images = torch.cat(incorrect_images, dim=0) - return validation_loss, incorrect_images + return (validation_loss, incorrect_images) ############################################################# @@ -213,9 +230,13 @@ def run_pipeline_BO(lr, optim, weight_decay): model = MLP() if optim == "Adam": - optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) + optimizer = torch.optim.Adam( + model.parameters(), lr=lr, weight_decay=weight_decay + ) elif optim == "SGD": - optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay) + optimizer = torch.optim.SGD( + model.parameters(), lr=lr, weight_decay=weight_decay + ) else: raise ValueError( "Optimizer choices are defined differently in the pipeline_space" @@ -228,9 +249,7 @@ def run_pipeline_BO(lr, optim, weight_decay): ) scheduler = lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.75) - criterion = nn.CrossEntropyLoss() - losses = [] for i in range(max_epochs): loss, miss_img = training( @@ -240,9 +259,8 @@ def run_pipeline_BO(lr, optim, weight_decay): train_loader=train_loader, validation_loader=validation_loader, ) - losses.append(loss) - # Gathering the gradient mean in each layer to display some of them in tensorboard + # Gathering the gradient mean in each layer mean_gradient = [] for layer in model.children(): layer_gradients = [param.grad for param in layer.parameters()] @@ -253,13 +271,14 @@ def run_pipeline_BO(lr, optim, weight_decay): ###################### Start Tensorboard Logging ###################### - # tblogger for neps config loggings. This line will result in the following: + # This followinf line will result in: - # 1 Incumbent of the configs (best performance regardless of fidelity budget, if the searcher was fidelity dependent). + # 1 Incumbent trajectory (best performance regardless of the + # fidelity budget, if the searcher was fidelity dependent). # 2 Loss curves of each of the configs at each epoch. # 3 lr_decay curve at each epoch. - # 4 miss_img which represents the wrongly classified images by the model. - # 5 first two layer_gradients computed above and passed as scalar configs. + # 4 The wrongly classified images by the model. + # 5 first two layer_gradients passed as scalar configs. tblogger.log( loss=loss, @@ -288,16 +307,23 @@ def run_pipeline_BO(lr, optim, weight_decay): "info_dict": { "train_accuracy": train_accuracy, "test_accuracy": test_accuracy, - "val_errors": losses, "cost": max_epochs, }, } ############################################################# -# 6 Running neps with BO as our main searcher, saving the results in a defined directory. +# 6 Running neps with BO as the searcher. if __name__ == "__main__": + """ + When running this code without any arguments, it will by default + run bayesian optimization with 10 evaluations of 9 epochs each: + + ```bash + python neps_tblogger_tutorial.py + ``` + """ parser = argparse.ArgumentParser() parser.add_argument( "--max_evaluations_total", @@ -307,15 +333,6 @@ def run_pipeline_BO(lr, optim, weight_decay): ) args = parser.parse_args() - """ - When running this code without any arguments, it will by default run bayesian optimization with 10 max evaluations - of 9 epochs each: - - ```bash - python neps_tblogger_tutorial.py - ``` - """ - start_time = time.time() set_seed(112) @@ -324,28 +341,7 @@ def run_pipeline_BO(lr, optim, weight_decay): if os.path.exists("results/bayesian_optimization"): shutil.rmtree("results/bayesian_optimization") - """ - For showcasing purposes. After completing the first run, one can uncomment line 351 - and continue the search via: - - ```bash: - python neps_tblogger_tutorial.py --max_evaluations_total 15 - ``` - - This would result in continuing the search for 5 new different configurations in addition - to disabling the logging, hence tblogger can always be disabled using the line below. - - ```python: - tblogger.disable() - ``` - - Note that by default tblogger is enabled when used. However, there is also an enable toggle that can be used - ```python: - tblogger.enable() - ``` - """ - - # by defualt tblogger is enabled when used, one can also check the status using: + # Check the status of tblogger via: # tblogger.get_status() # tblogger.disable() @@ -356,26 +352,54 @@ def run_pipeline_BO(lr, optim, weight_decay): root_directory="bayesian_optimization", max_evaluations_total=args.max_evaluations_total, searcher="bayesian_optimization", + # By default, NePS runs 10 random configurations before sampling + # from the acquisition function. We will change this behavior with + # the following keyword argument. + initial_design_size = 5, ) """ - To check live plots during this search, please open a new terminal and make sure to be at the same level directory - of your project and run this command on the file created by neps root_directory. + To check live plots during this search, please open a new terminal + and make sure to be at the same level directory of your project and + run the following command on the file created by neps root_directory. ```bash: tensorboard --logdir bayesian_optimization ``` - To be able to check the visualization of tensorboard make sure to follow the local link provided. + To be able to check the visualization of tensorboard make sure to + follow the local link provided. ```bash: http://localhost:6006/ ``` - If nothing was visualized and you followed the tutorial exactly, there could have been an error in passing the correct - directory, please double check. Tensorboard will always run in the command line without checking if the directory exists. + If nothing was visualized and you followed the tutorial exactly, + there could have been an error in passing the correct directory, + please double check. Tensorboard will always run in the command + line without checking if the directory exists. """ end_time = time.time() # Record the end time execution_time = end_time - start_time logging.info(f"Execution time: {execution_time} seconds") + + + """ + For showcasing purposes. After completing the first run, one can + uncomment line 348 and continue the search via: + + ```bash: + python neps_tblogger_tutorial.py --max_evaluations_total 15 + ``` + + This would result in continuing the search for 5 different configurations + in addition to disabling tblogger. + + Note that by default tblogger is enabled when used. However, + one can also enable when needed via. + + ```python: + tblogger.enable() + ``` + """ \ No newline at end of file From 43ced019548b2146299fa5220a4b76ebc38d84ab Mon Sep 17 00:00:00 2001 From: TarekAbouChakra Date: Wed, 6 Sep 2023 09:04:05 +0200 Subject: [PATCH 6/6] tblogger cleanup --- .../convenience/neps_tblogger_tutorial.py | 109 +++++----- src/neps/plot/tensorboard_eval.py | 191 +++++++++--------- 2 files changed, 157 insertions(+), 143 deletions(-) diff --git a/neps_examples/convenience/neps_tblogger_tutorial.py b/neps_examples/convenience/neps_tblogger_tutorial.py index 49005970..da7d44e4 100644 --- a/neps_examples/convenience/neps_tblogger_tutorial.py +++ b/neps_examples/convenience/neps_tblogger_tutorial.py @@ -9,7 +9,7 @@ data for different hyperparameter configurations during optimization. Assuming you have experience with NePS, this tutorial aims to showcase the power -of visualization using tblogger. To go directly to that part, check lines 244-264 +of visualization using tblogger. To go directly to that part, check lines 244-264 or search for 'Start Tensorboard Logging'. 2- Learning Objectives @@ -31,7 +31,7 @@ pip install neural-pipeline-search ``` -Additionally, note that 'NePS' does not include 'torchvision' as a dependency. +Additionally, note that 'NePS' does not include 'torchvision' as a dependency. You can install it with this command: ```bash @@ -44,9 +44,7 @@ import argparse import logging -import os import random -import shutil import time from typing import Tuple @@ -94,7 +92,7 @@ def set_seed(seed=123): def MNIST( batch_size: int = 32, n_train: int = 8192, n_valid: int = 1024 ) -> Tuple[DataLoader, DataLoader, DataLoader]: - # Datasets downloading if required. + # Download MNIST training and test datasets if not already downloaded. train_dataset = torchvision.datasets.MNIST( root="./data", train=True, transform=transforms.ToTensor(), download=True @@ -104,11 +102,12 @@ def MNIST( download=True ) - # Further sampling a validation dataset from the train dataset. + # Create a random subset of the training dataset for validation. + # We also opted on reducing the dataset sizes for faster training. train_sampler = SubsetRandomSampler(range(n_train)) valid_sampler = SubsetRandomSampler(range(n_train, n_train + n_valid)) - # Creating the dataloaders. + # Create DataLoaders for training, validation, and test datasets. train_dataloader = DataLoader( dataset=train_dataset, batch_size=batch_size, shuffle=False, sampler=train_sampler @@ -146,49 +145,60 @@ def forward(self, x): ############################################################# -# 4 Define the training step. Return the validation error and +# 4 Define the training step. Return the validation error and # misclassified images. def loss_ev(model: nn.Module, data_loader: DataLoader) -> float: + # Set the model in evaluation mode (no gradient computation). model.eval() + correct = 0 total = 0 + + # Disable gradient computation for efficiency. with torch.no_grad(): for x, y in data_loader: output = model(x) + + # Get the predicted class for each input. _, predicted = torch.max(output.data, 1) + + # Update the correct and total counts. correct += (predicted == y).sum().item() total += y.size(0) + # Calculate the accuracy and return the error rate. accuracy = correct / total - return 1 - accuracy + error_rate = 1 - accuracy + return error_rate def training( - model: nn.Module, - optimizer: torch.optim, - criterion: nn.modules.loss, - train_loader: DataLoader, - validation_loader: DataLoader, - ) -> Tuple[float, torch.Tensor]: + model: nn.Module, + optimizer: torch.optim.Optimizer, + criterion: nn.Module, + train_loader: DataLoader, + validation_loader: DataLoader, +) -> Tuple[float, torch.Tensor]: """ - Function that trains the model for one epoch and evaluates the model + Function that trains the model for one epoch and evaluates the model on the validation set. Args: model (nn.Module): Model to be trained. - optimizer (torch.optim): Optimizer used to train the weights. - criterion (nn.modules.loss) : Loss function to use. - train_loader (Dataloader): Dataloader containing the training data. - validation_loader (Dataloader): Dataloader containing the validation data. + optimizer (torch.optim.Optimizer): Optimizer used to train the weights. + criterion (nn.Module) : Loss function to use. + train_loader (DataLoader): DataLoader containing the training data. + validation_loader (DataLoader): DataLoader containing the validation data. Returns: - Tuple[float, torch.Tensor]: A tuple containing the validation error (float) + Tuple[float, torch.Tensor]: A tuple containing the validation error (float) and a tensor of misclassified images. """ incorrect_images = [] model.train() + for x, y in train_loader: optimizer.zero_grad() output = model(x) @@ -200,8 +210,10 @@ def training( incorrect_mask = predicted_labels != y incorrect_images.append(x[incorrect_mask]) + # Calculate validation loss using the loss_ev function. validation_loss = loss_ev(model, validation_loader) + # Return the misclassified image by during model training. if len(incorrect_images) > 0: incorrect_images = torch.cat(incorrect_images, dim=0) @@ -227,6 +239,7 @@ def pipeline_space_BO() -> dict: def run_pipeline_BO(lr, optim, weight_decay): + # Create the network model. model = MLP() if optim == "Adam": @@ -244,6 +257,7 @@ def run_pipeline_BO(lr, optim, weight_decay): max_epochs = 9 + # Load the MNIST dataset for training, validation, and testing. train_loader, validation_loader, test_loader = MNIST( batch_size=64, n_train=4096, n_valid=512 ) @@ -285,9 +299,7 @@ def run_pipeline_BO(lr, optim, weight_decay): current_epoch=i, data={ "lr_decay": tblogger.scalar_logging(value=scheduler.get_last_lr()[0]), - "miss_img": tblogger.image_logging( - img_tensor=miss_img, counter=2, seed=2 - ), + "miss_img": tblogger.image_logging(image=miss_img, counter=2, seed=2), "layer_gradient1": tblogger.scalar_logging(value=mean_gradient[0]), "layer_gradient2": tblogger.scalar_logging(value=mean_gradient[1]), }, @@ -299,9 +311,11 @@ def run_pipeline_BO(lr, optim, weight_decay): print(f" Epoch {i + 1} / {max_epochs} Val Error: {loss} ") + # Calculate training and test accuracy. train_accuracy = loss_ev(model, train_loader) test_accuracy = loss_ev(model, test_loader) + # Return a dictionary with relevant metrics and information. return { "loss": loss, "info_dict": { @@ -317,7 +331,7 @@ def run_pipeline_BO(lr, optim, weight_decay): if __name__ == "__main__": """ - When running this code without any arguments, it will by default + When running this code without any arguments, it will by default run bayesian optimization with 10 evaluations of 9 epochs each: ```bash @@ -338,10 +352,7 @@ def run_pipeline_BO(lr, optim, weight_decay): set_seed(112) logging.basicConfig(level=logging.INFO) - if os.path.exists("results/bayesian_optimization"): - shutil.rmtree("results/bayesian_optimization") - - # Check the status of tblogger via: + # To check the status of tblogger: # tblogger.get_status() # tblogger.disable() @@ -352,54 +363,46 @@ def run_pipeline_BO(lr, optim, weight_decay): root_directory="bayesian_optimization", max_evaluations_total=args.max_evaluations_total, searcher="bayesian_optimization", - # By default, NePS runs 10 random configurations before sampling + # By default, NePS runs 10 random configurations before sampling # from the acquisition function. We will change this behavior with # the following keyword argument. - initial_design_size = 5, + initial_design_size=5, ) """ - To check live plots during this search, please open a new terminal - and make sure to be at the same level directory of your project and + To check live plots during this search, please open a new terminal + and make sure to be at the same level directory of your project and run the following command on the file created by neps root_directory. ```bash: tensorboard --logdir bayesian_optimization ``` - To be able to check the visualization of tensorboard make sure to + To be able to check the visualization of tensorboard make sure to follow the local link provided. - ```bash: http://localhost:6006/ - ``` - If nothing was visualized and you followed the tutorial exactly, - there could have been an error in passing the correct directory, - please double check. Tensorboard will always run in the command - line without checking if the directory exists. + Double-check the directory path you've provided; if you're not seeing + any visualizations and have followed the tutorial closely, there + might be an error in the directory specification. Remember that + TensorBoard runs in the command line without checking if the directory + actually exists. """ end_time = time.time() # Record the end time execution_time = end_time - start_time logging.info(f"Execution time: {execution_time} seconds") - """ - For showcasing purposes. After completing the first run, one can - uncomment line 348 and continue the search via: + After your first run, you can continue with more experiments by + uncommenting line 361 and running the following command in your terminal: ```bash: python neps_tblogger_tutorial.py --max_evaluations_total 15 ``` - This would result in continuing the search for 5 different configurations - in addition to disabling tblogger. - - Note that by default tblogger is enabled when used. However, - one can also enable when needed via. - - ```python: - tblogger.enable() - ``` - """ \ No newline at end of file + This adds five more configurations to your search and turns off tblogger. + By default, tblogger is on, but you can control it with `tblogger.enable()` + or `tblogger.disable()` in your code." + """ diff --git a/src/neps/plot/tensorboard_eval.py b/src/neps/plot/tensorboard_eval.py index 47575aef..774d50e8 100644 --- a/src/neps/plot/tensorboard_eval.py +++ b/src/neps/plot/tensorboard_eval.py @@ -13,17 +13,20 @@ class SummaryWriter_(SummaryWriter): """ - This class inherits from the base SummaryWriter class and provides modifications to improve the logging. - It simplifies the logging structure and ensures consistent tag formatting for metrics. + This class inherits from the base SummaryWriter class and provides + modifications to improve the logging. It simplifies the logging structure + and ensures consistent tag formatting for metrics. Changes Made: - Avoids creating unnecessary subfolders in the log directory. - - Ensures all logs are stored in the same 'tfevent' directory for better organization. + - Ensures all logs are stored in the same 'tfevent' directory for + better organization. - Updates metric keys to have a consistent 'Summary/' prefix for clarity. - Improves the display of 'Loss' or 'Accuracy' on the Summary file. Methods: - - add_hparams: Overrides the base method to log hyperparameters and metrics with better formatting. + - add_hparams: Overrides the base method to log hyperparameters and + metrics with better formatting. """ def add_hparams(self, hparam_dict: dict, metric_dict: dict, global_step: int) -> None: @@ -47,16 +50,16 @@ class tblogger: config_previous_directory: Path | None = None logger_init_bool: bool = True - """logger_init_bool is only used once to capture configuration data for the first ever configuration, - and then turned false for the entire run.""" + """logger_init_bool is only used once to capture configuration data + for the first configuration, and then turned false for the entire run.""" logger_bool: bool = False - """logger_bool is true only if tblogger.log is used by the user, hence this allows to always capturing - the configuration data for all configurations.""" + """logger_bool is true only if tblogger.log is used by the user, this + allows to always capturing the configuration data.""" disable_logging: bool = False - """disable_logging is a hard switch to disable the logging feature if it was turned true. - hence even when logger_bool is true it disables the logging process""" + """disable_logging is a hard switch to disable the logging feature + if it was turned true.""" loss: float | None = None current_epoch: int | None = None @@ -78,8 +81,9 @@ def config_track_init_api( config_previous_directory: Path | None = None, ) -> None: """ - Track the Configuration space data from the way handled by neps metahyper '_sample_config' to keep in sync with - config ids and directories NePS is operating on. + Track the Configuration space data from the way handled by neps metahyper + '_sample_config' to keep in sync with config ids and directories NePS is + operating on. """ tblogger.config = config @@ -119,7 +123,8 @@ def _initialize_writers() -> None: # This should execute when having Config_x_1 and Config_x_0 if os.path.exists(get_tbevent_dir): - # When tfevents directory is detected => we are at the first fidelity directory, create writer. + # When tfevents directory is detected => we are at the first + # fidelity directory, create writer. with open(prev_dir_id_from_init) as file: contents = file.read() tblogger.config_id = contents @@ -161,13 +166,15 @@ def _make_grid(images: torch.Tensor, nrow: int, padding: int = 2) -> torch.Tenso Create a grid of images from a batch of images. Args: - images (torch.Tensor): The input batch of images with shape (batch_size, num_channels, height, width). + images (torch.Tensor): The input batch of images with shape + (batch_size, num_channels, height, width). nrow (int): The number of rows on the grid. - padding (int, optional): The padding between images in the grid. Default is 2. + padding (int, optional): The padding between images in the grid. + Default is 2. Returns: - torch.Tensor: A grid of images with shape (num_channels, total_height, total_width), - where total_height and total_width depend on the number of images and the grid settings. + torch.Tensor: A grid of images with shape: + (num_channels, total_height, total_width) """ batch_size, num_channels, height, width = images.size() x_mapping = min(nrow, batch_size) @@ -210,7 +217,7 @@ def scalar_logging(value: float) -> tuple: @staticmethod def image_logging( - img_tensor: torch.Tensor, + image: torch.Tensor, counter: int = 1, resize_images: list[None | int] | None = None, random_images: bool = True, @@ -229,26 +236,26 @@ def image_logging( Prepare an image tensor for logging. Args: - img_tensor (torch.Tensor): The image tensor to be logged. - counter (int): A counter value for the frequency of image logging (ex: counter 2 means for every - 2 global steps a new set of images are logged). - resize_images (list of int): A list of integers representing the image sizes - after resizing or None if no resizing required. - Default is None. - random_images (bool, optional): Whether the images are selected randomly. Default is True. - num_images (int, optional): The number of images to log. Default is 20. - seed (int or np.random.RandomState or None, optional): Seed value or RandomState instance to control - the randomness of image selection. Default is None. + image (torch.Tensor): Image tensor to be logged. + counter (int): Counter value associated with the images. + resize_images (list of int, optional): List of integers for image + sizes after resizing (default: None). + random_images (bool, optional): Images are randomly selected + if True (default: True). + num_images (int, optional): Number of images to log (default: 20). + seed (int or np.random.RandomState or None, optional): Seed value + or RandomState instance to control randomness (default: None). Returns: - tuple: A tuple containing the logging mode and all the necessary parameters for image logging. - The tuple format is (logging_mode, img_tensor, counter, resize_images, - random_images, num_images, seed). + tuple: A tuple containing the logging mode and all the necessary + parameters for image logging. + Tuple: (logging_mode, img_tensor, counter, resize_images, + random_images, num_images, seed). """ logging_mode = "image" return ( logging_mode, - img_tensor, + image, counter, resize_images, random_images, @@ -266,18 +273,19 @@ def _write_scalar_config(tag: str, value: float | int) -> None: value (float or int): The scalar value to be logged. Default is None. Note: - If the tag is 'Loss' and scalar_accuracy_mode is True, the tag will be changed to 'Accuracy', - and the value will be transformed accordingly. + If the tag is 'Loss' and scalar_accuracy_mode is True, the tag will + be changed to 'Accuracy', and the value will be transformed accordingly. - The function relies on the _initialize_writers to ensure the TensorBoard writer is initialized at - the correct directory. + The function relies on the _initialize_writers to ensure the + TensorBoard writer is initialized at the correct directory. It also depends on the following global variables: - tblogger.scalar_accuracy_mode (bool) - tblogger.config_writer (SummaryWriter_) - tblogger.config_id (str) - The function will log the scalar value under different tags based on fidelity mode and other configurations. + The function will log the scalar value under different tags based + on fidelity mode and other configurations. """ if not tblogger._is_initialized(): tblogger._initialize_writers() @@ -295,7 +303,7 @@ def _write_scalar_config(tag: str, value: float | int) -> None: ) else: raise ValueError( - "The 'config_writer' is None in _write_scalar_config. No logging is performed." + "The 'config_writer' is None in _write_scalar_config." "An error occurred during the initialization process." ) @@ -313,27 +321,28 @@ def _write_image_config( Write images to the TensorBoard log. Args: - tag (str): The tag for the images. - image (torch.Tensor): The image tensor to be logged. - counter (int): A counter value associated with the images. - resize_images (list of int): A list of integers representing the image sizes - after resizing or None if no resizing required. - Default is None. - random_images (bool, optional): Whether the images are selected randomly. Default is True. - num_images (int, optional): The number of images to log. Default is 20. - seed (int or np.random.RandomState or None, optional): Seed value or RandomState instance to control - the randomness of image selection. Default is None. + tag (str): Tag for the images. + image (torch.Tensor): Image tensor to be logged. + counter (int): Counter value associated with the images. + resize_images (list of int, optional): List of integers for image + sizes after resizing (default: None). + random_images (bool, optional): Images are randomly selected + if True (default: True). + num_images (int, optional): Number of images to log (default: 20). + seed (int or np.random.RandomState or None, optional): Seed value + or RandomState instance to control randomness (default: None). Note: - The function relies on the _initialize_writers to ensure the TensorBoard writer is initialized at - the correct directory. + The function relies on the _initialize_writers to ensure the + TensorBoard writer is initialized at the correct directory. It also depends on the following global variables: - tblogger.current_epoch (int) - tblogger.config_writer (SummaryWriter_) - tblogger.config_id (str) - The function will log a subset of images to TensorBoard based on the given configurations. + The function will log a subset of images to TensorBoard based on + the given configurations. """ if not tblogger._is_initialized(): tblogger._initialize_writers() @@ -365,7 +374,8 @@ def _write_image_config( mode="bilinear", align_corners=False, ) - # Create the grid according to the number of images and log the grid to tensorboard. + # Create the grid according to the number of images and log + # the grid to tensorboard. nrow = int(resized_images.size(0) ** 0.75) img_grid = tblogger._make_grid(resized_images, nrow=nrow) # Just an extra safety measure @@ -377,18 +387,19 @@ def _write_image_config( ) else: raise ValueError( - "The 'config_writer' is None in _write_image_config. No logging is performed. " + "The 'config_writer' is None in _write_image_config." "An error occurred during the initialization process." ) @staticmethod def _write_hparam_config() -> None: """ - Write hyperparameter configurations to the TensorBoard log, inspired by the 'hparam' original function of tensorboard. + Write hyperparameter configurations to the TensorBoard log, inspired + by the 'hparam' original function of tensorboard. Note: - The function relies on the _initialize_writers to ensure the TensorBoard writer is initialized at - the correct directory. + The function relies on the _initialize_writers to ensure the + TensorBoard writer is initialized at the correct directory. It also depends on the following global variables: - tblogger.hparam_accuracy_mode (bool) @@ -397,8 +408,9 @@ def _write_hparam_config() -> None: - tblogger.config (dict) - tblogger.current_epoch (int) - The function will log hyperparameter configurations along with a metric value (either accuracy or loss) - to TensorBoard based on the given configurations. + The function will log hyperparameter configurations along + with a metric value (either accuracy or loss) to TensorBoard + based on the given configurations. """ if not tblogger._is_initialized(): tblogger._initialize_writers() @@ -421,7 +433,7 @@ def _write_hparam_config() -> None: ) else: raise ValueError( - "The 'config_writer' is None in _write_hparam_config. No logging is performed. " + "The 'config_writer' is None in _write_hparam_config." "An error occurred during the initialization process." ) @@ -431,7 +443,8 @@ def tracking_incumbent_api(best_loss: float) -> None: Track the incumbent (best) loss and log it in the TensorBoard summary. Args: - best_loss (float): The best loss value to be tracked, according to the _post_hook_function of NePS. + best_loss (float): The best loss value to be tracked, according + to the _post_hook_function of NePS. Note: The function relies on the following global variables: @@ -441,8 +454,10 @@ def tracking_incumbent_api(best_loss: float) -> None: - tblogger.incum_val (float) - tblogger.summary_writer (SummaryWriter_) - The function logs the incumbent loss in a TensorBoard summary with a graph. - It increments the incumbent tracker based on occurrences of "Config ID" in the 'all_losses_and_configs.txt' file. + The function logs the incumbent trajectory in TensorBoard. + + It increments the incumbent tracker based on occurrences of + "Config ID" in the 'all_losses_and_configs.txt' file. """ if tblogger.config_writer: # Close and reset previous config writers for consistent logging. @@ -474,8 +489,9 @@ def tracking_incumbent_api(best_loss: float) -> None: global_step=tblogger.incum_tracker, ) - # Frequent writer open/close creates new 'tfevent' files due to parallelization needs. - # Simultaneous open writers risk conflicts, so they're flushed and closed after use. + # Frequent writer open/close creates new 'tfevent' files due to + # parallelization needs. Simultaneous open writers risk conflicts, + # so they're flushed and closed after use. tblogger.summary_writer.flush() tblogger.summary_writer.close() @@ -483,11 +499,10 @@ def tracking_incumbent_api(best_loss: float) -> None: @staticmethod def disable() -> None: """ - The function allows for disabling the logger functionality - throughout the program execution by updating the value of 'tblogger.disable_logging'. - When the logger is disabled, it will not perform any logging operations. + The function allows for disabling the logger functionality. + When the logger is disabled, it will not perform logging operations. - By default tblogger is enabled when used. If for any reason disabling is needed. This function does the job. + By default tblogger is enabled when used. Example: # Disable the logger @@ -498,11 +513,10 @@ def disable() -> None: @staticmethod def enable() -> None: """ - The function allows for enabling the logger functionality - throughout the program execution by updating the value of 'tblogger.disable_logging'. + The function allows for enabling the logger functionality. When the logger is enabled, it will perform the logging operations. - By default this is enabled. Hence only needed when tblogger was once disabled. + By default this is enabled. Example: # Enable the logger @@ -513,7 +527,8 @@ def enable() -> None: @staticmethod def get_status(): """ - Returns the currect state of tblogger ie. whether the logger is enabled or not + Returns the currect state of tblogger ie. whether the logger is + enabled or not """ return not tblogger.disable_logging @@ -528,26 +543,22 @@ def log( data: dict | None = None, ) -> None: """ - Log experiment data to the logger, including scalar values, hyperparameters, and images. + Log experiment data to the logger, including scalar values, + hyperparameters, and images. Args: - loss (float): The current loss value in training. - current_epoch (int): The current epoch of the experiment. Used as the global step. - writer_scalar (bool, optional): Whether to write the loss or accuracy for the - configs during training. Default is True. - writer_hparam (bool, optional): Whether to write hyperparameters logging - of the configs during training. Default is True. - scalar_accuracy_mode (bool, optional): If True, interpret the 'loss' as 'accuracy' and transform it's - value accordingliy. Default is False. - hparam_accuracy_mode (bool, optional): If True, interpret the 'loss' as 'accuracy' and transform it's - value accordingliy. Default is False. - data (dict, optional): Additional experiment data to be logged. It should be in the format: - { - 'tag1': tblogger.scalar_logging(value=value1), - 'tag2': tblogger.image_logging(img_tensor=img, counter=2, seed=0), - } - Default is None. - + loss (float): Current loss value in training. + current_epoch (int): Current epoch of the experiment + (used as the global step). + writer_scalar (bool, optional): Displaying the loss or accuracy + curve on tensorboard (default: True) + writer_hparam (bool, optional): Write hyperparameters logging of + the configs (default: True). + scalar_accuracy_mode (bool, optional): Interpret 'loss' as 'accuracy' + and change value (default: False). + hparam_accuracy_mode (bool, optional): Interpret 'loss' as 'accuracy' + and change value (default: False). + data (dict, optional): Additional experiment data for logging. """ tblogger.current_epoch = current_epoch tblogger.loss = loss