Skip to content

Commit

Permalink
Merge pull request #6 from matteo-rizzo/dev
Browse files Browse the repository at this point in the history
Minor clean up and improvements
  • Loading branch information
matteo-rizzo committed Jun 23, 2021
2 parents 8d769e8 + 26e964d commit b8f89fd
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 19 deletions.
6 changes: 3 additions & 3 deletions classes/data/ColorCheckerDataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from typing import Tuple

import numpy as np
import scipy.io
Expand All @@ -17,7 +18,7 @@ def __init__(self, train: bool = True, folds_num: int = 1):
self.__da = DataAugmenter()

path_to_folds = os.path.join("dataset", "folds.mat")
path_to_metadata = os.path.join("dataset", "color_checker_metadata.txt")
path_to_metadata = os.path.join("dataset", "metadata.txt")
self.__path_to_data = os.path.join("dataset", "preprocessed", "numpy_data")
self.__path_to_label = os.path.join("dataset", "preprocessed", "numpy_labels")

Expand All @@ -27,8 +28,7 @@ def __init__(self, train: bool = True, folds_num: int = 1):
metadata = open(path_to_metadata, 'r').readlines()
self.__fold_data = [metadata[i - 1] for i in img_idx]

def __getitem__(self, index: int) -> tuple:

def __getitem__(self, index: int) -> Tuple:
file_name = self.__fold_data[index].strip().split(' ')[1]
img = np.array(np.load(os.path.join(self.__path_to_data, file_name + '.npy')), dtype='float32')
illuminant = np.array(np.load(os.path.join(self.__path_to_label, file_name + '.npy')), dtype='float32')
Expand Down
9 changes: 2 additions & 7 deletions classes/fc4/ModelFC4.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Union
from typing import Union, Tuple

import matplotlib.pyplot as plt
import torchvision.transforms.functional as F
Expand All @@ -18,7 +18,7 @@ def __init__(self):
super().__init__()
self._network = FC4().to(self._device)

def predict(self, img: Tensor, return_steps: bool = False) -> Union[Tensor, tuple]:
def predict(self, img: Tensor, return_steps: bool = False) -> Union[Tensor, Tuple]:
"""
Performs inference on the input image using the FC4 method.
@param img: the image for which an illuminant colour has to be estimated
Expand All @@ -42,11 +42,6 @@ def optimize(self, img: Tensor, label: Tensor) -> float:
self._optimizer.step()
return loss.item()

def get_regularized_loss(self, pred: Tensor, label: Tensor, attention_mask: Tensor) -> Tensor:
angular = self.get_loss(pred, label)
sparsity = self.__bs_loss(attention_mask)
return angular + sparsity

def save_vis(self, model_output: dict, path_to_plot: str):
model_output = {k: v.clone().detach().to(self._device) for k, v in model_output.items()}

Expand Down
20 changes: 11 additions & 9 deletions dataset/img2npy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,26 @@
and utilizing MCCs as a visual cue, all images are masked with provided locations of MCC during training and testing
"""

PATH_TO_NUMPY_DATA = os.path.join("preprocessed", "numpy_data")
PATH_TO_NUMPY_LABELS = os.path.join("preprocessed", "numpy_labels")
PATH_TO_LINEAR_IMAGES = os.path.join("preprocessed", "linear_images")
PATH_TO_GT_CORRECTED = os.path.join("preprocessed", "gt_corrected")
PATH_TO_IMAGES = os.path.join("images")
PATH_TO_COORDINATES = os.path.join("coordinates")
PATH_TO_CC_METADATA = os.path.join("color_checker_metadata.txt")
PATH_TO_CC_METADATA = os.path.join("metadata.txt")

BASE_PATH = "preprocessed"
PATH_TO_NUMPY_DATA = os.path.join(BASE_PATH, "numpy_data")
PATH_TO_NUMPY_LABELS = os.path.join(BASE_PATH, "numpy_labels")
PATH_TO_LINEAR_IMAGES = os.path.join(BASE_PATH, "linear_images")
PATH_TO_GT_CORRECTED = os.path.join(BASE_PATH, "gt_corrected")


def main():
print("\n=================================================\n")
print("\t Masking MCC charts")
print("\n=================================================\n")
print("Paths: \n"
"\t - Numpy data generated at .......... : {} \n"
"\t - Numpy labels generated at ........ : {} \n"
"\t - Images fetched from .............. : {} \n"
"\t - Coordinates fetched from ......... : {} \n"
"\t - Numpy data generated at ..... : {} \n"
"\t - Numpy labels generated at ... : {} \n"
"\t - Images fetched from ......... : {} \n"
"\t - Coordinates fetched from .... : {} \n"
.format(PATH_TO_NUMPY_DATA, PATH_TO_NUMPY_LABELS, PATH_TO_IMAGES, PATH_TO_COORDINATES))

os.makedirs(PATH_TO_NUMPY_DATA, exist_ok=True)
Expand Down
File renamed without changes.

0 comments on commit b8f89fd

Please sign in to comment.