Skip to content

Commit

Permalink
Add docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
ViiSkor committed Oct 31, 2023
1 parent dada0bb commit 7fd72ad
Show file tree
Hide file tree
Showing 11 changed files with 311 additions and 13 deletions.
11 changes: 1 addition & 10 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,4 @@ jobs:
working-directory: .

- name: Run Pylint
run: pylint core/*

- name: Run Pytest
run: pytest tests/test_model.py

- name: Upload test results
uses: actions/upload-artifact@v2
with:
name: test-results
path: test-reports
run: pylint core/*
2 changes: 1 addition & 1 deletion config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ train:
min_lr: 0.00000001
cooldown: 0
EarlyStopping:
patience: 10
patience: 20
checkpoint:
save_best_only: True
save_weights_only: True
36 changes: 36 additions & 0 deletions core/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,42 @@


def get_callbacks(model_name: str, config: dict) -> list:
"""
Get a list of Keras callback objects for model training.
This function creates a list of callback objects to be used during the training of a Keras model.
Parameters:
- model_name (str): Name of the model, used to generate checkpoint file names.
- config (dict): A dictionary containing configuration settings for the callbacks.
Returns:
- list: A list of Keras callback objects.
Configuration Settings:
- 'checkpoint' (dict): Configuration settings for ModelCheckpoint.
- 'reduceLROnPlat' (dict): Configuration settings for ReduceLROnPlateau.
- 'EarlyStopping' (dict): Configuration settings for EarlyStopping.
Example Configuration:
config = {
'checkpoint': {
'save_best_only': True,
'save_weights_only': True,
'mode': 'min'
},
'reduceLROnPlat': {
'factor': 0.5,
'patience': 5,
'min_lr': 1e-6
},
'EarlyStopping': {
'patience': 10,
'min_delta': 0.0001
}
}
"""

weight_path=f'{model_name}_weights.best.hdf5'
checkpoint = ModelCheckpoint(
weight_path, monitor='val_loss', verbose=1, mode='min', **config['checkpoint']
Expand Down
43 changes: 43 additions & 0 deletions core/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,20 @@


def create_full_res_model(model, img_scaling: tuple[int, int]):
"""
Create a full-resolution model by adding scaling operations to the output of an existing model.
This function takes an existing Keras model and, if img_scaling is not None, creates a new model that scales the
input image, applies the provided model, and then scales the output back to full resolution.
Parameters:
- model: The Keras model to which scaling operations will be applied.
- img_scaling (Tuple[int, int]): Tuple specifying the scaling factors for height and width. Use None for no scaling.
Returns:
- models.Model: A Keras model with scaling operations applied if img_scaling is not None, or the original model.
"""

if img_scaling is not None:
fullres_model = models.Sequential()
fullres_model.add(layers.AvgPool2D(img_scaling, input_shape=(None, None, 3)))
Expand All @@ -19,13 +33,42 @@ def create_full_res_model(model, img_scaling: tuple[int, int]):


def raw_prediction(model, c_img_name : str, path: str) -> np.array:
"""
Perform raw image segmentation using the provided model.
This function reads an image from a specified path, preprocesses it, and uses the model to perform
image segmentation.
Parameters:
- model: The Keras model used for image segmentation.
- c_img_name (str): Name of the image file to be processed.
- path (str): The directory path where the image is located.
Returns:
- np.array: The raw segmentation result.
- np.array: The input image.
"""

c_img = imread(os.path.join(path, c_img_name))
c_img = np.expand_dims(c_img, 0) / 255.0
cur_seg = model.predict(c_img)[0]
return cur_seg, c_img[0]


def smooth(cur_seg: np.array) -> np.array:
"""
Apply morphological operations to smooth the segmentation mask.
This function applies binary opening to the provided segmentation mask.
Parameters:
- cur_seg (np.array): The input binary segmentation mask.
Returns:
- np.array: The smoothed binary segmentation mask.
"""

return binary_opening(cur_seg > 0.99, np.expand_dims(disk(2), -1))


Expand Down
14 changes: 14 additions & 0 deletions core/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,20 @@


def dice_p_bce(y_true, y_pred):
"""
Compute the Dice coefficient-based loss with an added binary cross-entropy term.
This loss function combines the Dice coefficient loss and binary cross-entropy loss, providing a weighted sum
of both terms to optimize a model for binary image segmentation tasks.
Parameters:
- y_true (tf.Tensor): The true binary ground truth mask.
- y_pred (tf.Tensor): The predicted binary segmentation mask.
Returns:
- tf.Tensor: The combined loss value, which is a weighted sum of binary cross-entropy and negative Dice coefficient.
"""

y_true = tf.cast(y_true, dtype=tf.float32)
y_pred = tf.cast(y_pred, dtype=tf.float32)

Expand Down
29 changes: 29 additions & 0 deletions core/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,21 @@


def dice_coef(y_true, y_pred, smooth=1):
"""
Compute the Dice coefficient for binary image segmentation.
The Dice coefficient is a measure of the similarity between the true binary mask and the predicted binary mask.
It quantifies the overlap between the two masks.
Parameters:
- y_true (tf.Tensor): The true binary ground truth mask.
- y_pred (tf.Tensor): The predicted binary segmentation mask.
- smooth (float): A smoothing factor to prevent division by zero. Default is 1.
Returns:
- tf.Tensor: The computed Dice coefficient.
"""

y_true = tf.cast(y_true, dtype=tf.float32)
y_pred = tf.cast(y_pred, dtype=tf.float32)

Expand All @@ -12,6 +27,20 @@ def dice_coef(y_true, y_pred, smooth=1):


def POD(y_true, y_pred):
"""
Compute the Probability of Detection (POD) for binary classification.
The POD measures the ability of a model to correctly detect positive cases.
It is commonly used in binary classification tasks.
Parameters:
- y_true (tf.Tensor): The true binary ground truth mask.
- y_pred (tf.Tensor): The predicted binary segmentation mask.
Returns:
- tf.Tensor: The computed Probability of Detection (POD).
"""

y_true = tf.cast(y_true, dtype=tf.float32)
y_pred = tf.cast(y_pred, dtype=tf.float32)

Expand Down
28 changes: 28 additions & 0 deletions core/model/UNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,34 @@


def init_model(config: dict, input_shape: tuple=(256, 256, 3)) -> models.Model:
"""
Initialize and build a U-Net-style image segmentation model.
This function constructs a U-Net architecture for image segmentation based on the provided configuration.
Parameters:
- config (dict): A dictionary containing model configuration settings.
- input_shape (tuple): Tuple specifying the input image shape, default is (256, 256, 3).
Returns:
- models.Model: A Keras Model representing the U-Net-style image segmentation model.
Configuration Settings:
- 'n_filters' (int): Number of initial filters in the model.
- 'upsample_mode' (str): Upsampling mode, either 'DECONV' for deconvolution or 'SIMPLE' for simple upsampling.
- 'net_scaling' (int, None): Scale factor for downscaling and subsequent upscaling of the network.
If None, no scaling is applied.
- 'gaussian_noise' (float): Standard deviation of Gaussian noise applied to the input.
Example Configuration:
config = {
'n_filters': 64,
'upsample_mode': 'DECONV',
'net_scaling': 2,
'gaussian_noise': 0.01
}
"""

n_filters = config['n_filters']

if config['upsample_mode'] == 'DECONV':
Expand Down
50 changes: 50 additions & 0 deletions core/model/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,19 @@ def upsample_conv(
strides: tuple[int, int],
padding: Union['valid', 'same']
) -> tf.Tensor:
"""
Create an Conv2DTranspose layer.
Parameters:
- filters (int): Number of output filters.
- kernel_size (Tuple[int, int]): Size of the convolutional kernel.
- strides (Tuple[int, int]): Stride values for the convolution operation.
- padding (str): Padding mode, either 'valid' or 'same'.
Returns:
- tf.Tensor: Output tensor after applying the upsample convolutional layer.
"""

return layers.Conv2DTranspose(filters, kernel_size, strides=strides, padding=padding)


Expand All @@ -19,10 +32,34 @@ def upsample_simple(
strides: tuple[int, int],
padding: Union['valid', 'same']
) -> tf.Tensor:
"""
Create a UpSampling2D layer.
Parameters:
- filters (int): Number of output filters.
- kernel_size (Tuple[int, int]): Size of the upsampling kernel.
- strides (Tuple[int, int]): Stride values for the upsampling operation.
- padding (str): Padding mode, either 'valid' or 'same'.
Returns:
- tf.Tensor: Output tensor after applying the simple upsampling layer.
"""

return layers.UpSampling2D(strides)


def encoder_block(prev_layer_inputs: tf.Tensor, n_filters: int) -> Tuple[tf.Tensor, tf.Tensor]:
"""
Create an encoder block consisting of convolutional layers and pooling.
Parameters:
- prev_layer_inputs (tf.Tensor): Input tensor to the encoder block.
- n_filters (int): Number of filters for the convolutional layers.
Returns:
- Tuple[tf.Tensor, tf.Tensor]: A tuple containing the pooled output tensor and the skip connection tensor.
"""

conv = layers.Conv2D(n_filters, (3, 3), activation='relu', padding='same')(prev_layer_inputs)
conv = layers.BatchNormalization()(conv)
skip_connection = layers.Conv2D(n_filters, (3, 3), activation='relu', padding='same')(conv)
Expand All @@ -34,6 +71,19 @@ def encoder_block(prev_layer_inputs: tf.Tensor, n_filters: int) -> Tuple[tf.Tens
def decoder_block(
prev_layer_input: tf.Tensor, skip_layer_input: tf.Tensor, n_filters: int, upsample: Callable
) -> tf.Tensor:
"""
Create a decoder block consisting of upsampling and convolutional layers.
Parameters:
- prev_layer_input (tf.Tensor): Input tensor to the decoder block.
- skip_layer_input (tf.Tensor): Skip connection tensor from the encoder block.
- n_filters (int): Number of filters for the convolutional layers.
- upsample (Callable): The upsampling function to use.
Returns:
- tf.Tensor: Output tensor after applying the decoder block.
"""

up = upsample(n_filters, (2, 2), strides=(2, 2), padding='same')(prev_layer_input)
up = layers.concatenate([up, skip_layer_input])
up = layers.Conv2D(n_filters, (3, 3), activation='relu', padding='same')(up)
Expand Down
28 changes: 28 additions & 0 deletions core/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,34 @@


def train(model, img_dir, config, img_scaling, callbacks, train_df, valid_df, transform):
"""
Train a semantic segmentation model using the specified configuration and data.
This function trains a semantic segmentation model using the provided Keras model, data directories,
configuration settings, and data generators. It compiles the model with a custom loss function, metrics,
and the AdamW optimizer.
Parameters:
- model: The Keras model to be trained for semantic segmentation.
- img_dir (str): The directory path where image data is located.
- config (dict): A dictionary containing configuration settings for training.
- img_scaling (tuple): A tuple specifying the scaling factors for height and width of input images.
- callbacks (list): A list of Keras callback objects for monitoring and controlling the training process.
- train_df (DataFrame): A Pandas DataFrame containing training dataset information.
- valid_df (DataFrame): A Pandas DataFrame containing validation dataset information.
- transform: A function or callable for data augmentation.
Returns:
- model: The trained Keras model.
- loss_history: The training history, including loss and metric values.
Configuration Settings:
- 'batch_size' (int): Batch size for training and validation.
- 'learning_rate' (float): Learning rate for the AdamW optimizer.
- 'do_augmentation' (bool): Whether to apply data augmentation during training.
- 'epochs' (int): Number of training epochs.
"""

batch_size = config['batch_size']
model.compile(optimizer=AdamW(learning_rate=config['learning_rate']), loss=dice_p_bce,
metrics=['binary_accuracy', dice_coef, POD], run_eagerly=True)
Expand Down
Loading

0 comments on commit 7fd72ad

Please sign in to comment.