Skip to content

Commit

Permalink
Merge pull request #65 from ANTsX/ShivaWmhPvs
Browse files Browse the repository at this point in the history
ENH:  Shiva PVS and WMH segmentation.
  • Loading branch information
ntustison authored Aug 20, 2024
2 parents 3da603e + aaed7be commit 199c8be
Show file tree
Hide file tree
Showing 8 changed files with 478 additions and 37 deletions.
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ export(createResNetWithSpatialTransformerNetworkModel3D)
export(createResUnetModel2D)
export(createResUnetModel3D)
export(createRmnetGenerator)
export(createShivaUnetModel3D)
export(createSimpleClassificationWithSpatialTransformerNetworkModel2D)
export(createSimpleClassificationWithSpatialTransformerNetworkModel3D)
export(createSimpleFullyConvolutionalNeuralNetworkModel3D)
Expand Down Expand Up @@ -196,6 +197,7 @@ export(regressionMatchImage)
export(sampleFromCategoricalDistribution)
export(sampleFromOutput)
export(shivaPvsSegmentation)
export(shivaWmhSegmentation)
export(simulateBiasField)
export(splitMixtureParameters)
export(sysuMediaWmhSegmentation)
Expand Down
138 changes: 136 additions & 2 deletions R/createCustomUnetModel.R
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,140 @@ createHippMapp3rUnetModel3D <- function( inputImageSize,
return( unetModel )
}

#' Implementation of the "shiva" u-net architecture for PVS and WMH
#' segmentation
#'
#' Publications:
#'
#' * PVS: https://pubmed.ncbi.nlm.nih.gov/34262443/
#' * WMH: https://pubmed.ncbi.nlm.nih.gov/38050769/
#'
#' with respective GitHub repositories:
#'
#' * PVS: https://github.com/pboutinaud/SHIVA_PVS
#' * WMH: https://github.com/pboutinaud/SHIVA_WMH
#'
#' @param numberOfModalities Specifies number of channels in the
#' architecture.
#' @return a u-net keras model
#' @author Tustison NJ
#' @examples
#' \dontrun{
#'
#' model <- createShivaUnetModel3D()
#'
#' }
#' @import keras
#' @export
createShivaUnetModel3D <- function( numberOfModalities = 1 )
{
K <- tensorflow::tf$keras$backend

getPadShape <- function( targetLayer, referenceLayer )
{
padShape <- list()

delta <- K$int_shape( targetLayer )[[2]] - K$int_shape( referenceLayer )[[2]]
if( delta %% 2 != 0 )
{
padShape[[1]] <- c( as.integer( delta / 2 ), as.integer( delta / 2 ) + 1L )
} else {
padShape[[1]] <- c( as.integer( delta / 2 ), as.integer( delta / 2 ) )
}

delta <- K$int_shape( targetLayer )[[3]] - K$int_shape( referenceLayer )[[3]]
if( delta %% 2 != 0 )
{
padShape[[2]] <- c( as.integer( delta / 2 ), as.integer( delta / 2 ) + 1L )
} else {
padShape[[2]] <- c( as.integer( delta / 2 ), as.integer( delta / 2 ) )
}

delta <- K$int_shape( targetLayer )[[4]] - K$int_shape( referenceLayer )[[4]]
if( delta %% 2 != 0 )
{
padShape[[3]] <- c( as.integer( delta / 2 ), as.integer( delta / 2 ) + 1L )
} else {
padShape[[3]] <- c( as.integer( delta / 2 ), as.integer( delta / 2 ) )
}
if( all( padShape[[1]] == c( 0, 0 ) ) && all( padShape[[2]] == c( 0, 0 ) ) && all( padShape[[3]] == c( 0, 0 ) ) )
{
return( NULL )
} else {
return( padShape )
}
}
inputImageSize <- c( 160, 214, 176, numberOfModalities )
numberOfFilters <- c( 10, 18, 32, 58, 104, 187, 337 )

inputs <- layer_input( shape = inputImageSize )

# encoding layers

encodingLayers <- list()

outputs <- inputs
for( i in seq.int( length( numberOfFilters ) ) )
{
outputs <- outputs %>% layer_conv_3d( numberOfFilters[i], kernel_size = 3L, padding = 'same', use_bias = FALSE )
outputs <- outputs %>% layer_batch_normalization()
outputs <- outputs %>% layer_activation( "swish" )

outputs <- outputs %>% layer_conv_3d( numberOfFilters[i], kernel_size = 3L, padding = 'same', use_bias = FALSE )
outputs <- outputs %>% layer_batch_normalization()
outputs <- outputs %>% layer_activation( "swish" )

encodingLayers[[i]] <- outputs
outputs <- outputs %>% layer_max_pooling_3d( pool_size = 2L )
dropoutRate <- 0.05
if( i > 1 )
{
dropoutRate <- 0.5
}
outputs <- outputs %>% layer_spatial_dropout_3d( rate = dropoutRate )
}

# decoding layers

for( i in seq.int( from = length( encodingLayers ), to = 1, by = -1 ) )
{
upsampleLayer <- outputs %>% layer_upsampling_3d( size = 2L )
padShape <- getPadShape( encodingLayers[[i]], upsampleLayer )
if( i > 1 && ! is.null( padShape ) )
{
zeroLayer <- upsampleLayer %>% layer_zero_padding_3d( padding = padShape )
outputs <- layer_concatenate( list( zeroLayer, encodingLayers[[i]] ), axis = -1L, trainable = TRUE )
} else {
outputs <- layer_concatenate( list( upsampleLayer, encodingLayers[[i]] ), axis = -1L, trainable = TRUE )
}

outputs <- outputs %>% layer_conv_3d( K$int_shape( outputs )[[5]], kernel_size = 3L, padding = 'same', use_bias = FALSE )
outputs <- outputs %>% layer_batch_normalization()
outputs <- outputs %>% layer_activation( "swish" )

outputs <- outputs %>% layer_conv_3d( numberOfFilters[i], kernel_size = 3L, padding = 'same', use_bias = FALSE )
outputs <- outputs %>% layer_batch_normalization()
outputs <- outputs %>% layer_activation( "swish" )
outputs <- outputs %>% layer_spatial_dropout_3d( rate = 0.5 )
}

# final

outputs <- outputs %>% layer_conv_3d( 10, kernel_size = 3L, padding = 'same', use_bias = FALSE )
outputs <- outputs %>% layer_batch_normalization()
outputs <- outputs %>% layer_activation( "swish" )

outputs <- outputs %>% layer_conv_3d( 10, kernel_size = 3L, padding = 'same', use_bias = FALSE )
outputs <- outputs %>% layer_batch_normalization()
outputs <- outputs %>% layer_activation( "swish" )

outputs <- outputs %>% layer_conv_3d( 1, kernel_size = 1L, activation = "sigmoid", padding = 'same' )

unetModel <- keras_model( inputs = inputs, outputs = outputs )

return( unetModel )
}

#' Implementation of the "HyperMapp3r" U-net architecture
#'
#' Creates a keras model implementation of the u-net architecture
Expand Down Expand Up @@ -480,7 +614,7 @@ createSysuMediaUnetModel2D <- function( inputImageSize, anatomy = c( "wmh", "cla
{
getCropShape <- function( targetLayer, referenceLayer )
{
K <- keras::backend()
K <- tensorflow::tf$keras$backend

cropShape <- list()

Expand Down Expand Up @@ -592,7 +726,7 @@ createSysuMediaUnetModel3D <- function( inputImageSize,
{
getCropShape <- function( targetLayer, referenceLayer )
{
K <- keras::backend()
K <- tensorflow::tf$keras$backend

cropShape <- list()

Expand Down
44 changes: 32 additions & 12 deletions R/getPretrainedNetwork.R
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,16 @@ getPretrainedNetwork <- function(
"pvs_shiva_t1_flair_2",
"pvs_shiva_t1_flair_3",
"pvs_shiva_t1_flair_4",
"wmh_shiva_flair_0",
"wmh_shiva_flair_1",
"wmh_shiva_flair_2",
"wmh_shiva_flair_3",
"wmh_shiva_flair_4",
"wmh_shiva_t1_flair_0",
"wmh_shiva_t1_flair_1",
"wmh_shiva_t1_flair_2",
"wmh_shiva_t1_flair_3",
"wmh_shiva_t1_flair_4",
"protonLungMri",
"protonLobes",
"pulmonaryArteryWeights",
Expand All @@ -145,7 +155,7 @@ getPretrainedNetwork <- function(
"wholeHeadInpaintingFLAIR",
"wholeHeadInpaintingPatchBasedT1",
"wholeHeadInpaintingPatchBasedFLAIR",
"wholeTumorSegmentationT2Flair",
"wholeTumorSegmentationT2Flair",
"wholeLungMaskFromVentilation" ),
targetFileName, antsxnetCacheDirectory = NULL )
{
Expand Down Expand Up @@ -250,17 +260,27 @@ getPretrainedNetwork <- function(
mouseT2wBrainParcellation3DNick = "https://figshare.com/ndownloader/files/44714944",
mouseT2wBrainParcellation3DTct = "https://figshare.com/ndownloader/files/47214538",
mouseSTPTBrainParcellation3DJay = "https://figshare.com/ndownloader/files/46710592",
pvs_shiva_t1_0 = "https://figshare.com/ndownloader/files/48363799",
pvs_shiva_t1_1 = "https://figshare.com/ndownloader/files/48363832",
pvs_shiva_t1_2 = "https://figshare.com/ndownloader/files/48363814",
pvs_shiva_t1_3 = "https://figshare.com/ndownloader/files/48363790",
pvs_shiva_t1_4 = "https://figshare.com/ndownloader/files/48363829",
pvs_shiva_t1_5 = "https://figshare.com/ndownloader/files/48363823",
pvs_shiva_t1_flair_0 = "https://figshare.com/ndownloader/files/48363784",
pvs_shiva_t1_flair_1 = "https://figshare.com/ndownloader/files/48363820",
pvs_shiva_t1_flair_2 = "https://figshare.com/ndownloader/files/48363796",
pvs_shiva_t1_flair_3 = "https://figshare.com/ndownloader/files/48363793",
pvs_shiva_t1_flair_4 = "https://figshare.com/ndownloader/files/48363826",
pvs_shiva_t1_0 = "https://figshare.com/ndownloader/files/48660169",
pvs_shiva_t1_1 = "https://figshare.com/ndownloader/files/48660193",
pvs_shiva_t1_2 = "https://figshare.com/ndownloader/files/48660199",
pvs_shiva_t1_3 = "https://figshare.com/ndownloader/files/48660178",
pvs_shiva_t1_4 = "https://figshare.com/ndownloader/files/48660172",
pvs_shiva_t1_5 = "https://figshare.com/ndownloader/files/48660187",
pvs_shiva_t1_flair_0 = "https://figshare.com/ndownloader/files/48660181",
pvs_shiva_t1_flair_1 = "https://figshare.com/ndownloader/files/48660175",
pvs_shiva_t1_flair_2 = "https://figshare.com/ndownloader/files/48660184",
pvs_shiva_t1_flair_3 = "https://figshare.com/ndownloader/files/48660190",
pvs_shiva_t1_flair_4 = "https://figshare.com/ndownloader/files/48660196",
wmh_shiva_flair_0 = "https://figshare.com/ndownloader/files/48660487",
wmh_shiva_flair_1 = "https://figshare.com/ndownloader/files/48660496",
wmh_shiva_flair_2 = "https://figshare.com/ndownloader/files/48660493",
wmh_shiva_flair_3 = "https://figshare.com/ndownloader/files/48660490",
wmh_shiva_flair_4 = "https://figshare.com/ndownloader/files/48660511",
wmh_shiva_t1_flair_0 = "https://figshare.com/ndownloader/files/48660529",
wmh_shiva_t1_flair_1 = "https://figshare.com/ndownloader/files/48660547",
wmh_shiva_t1_flair_2 = "https://figshare.com/ndownloader/files/48660499",
wmh_shiva_t1_flair_3 = "https://figshare.com/ndownloader/files/48660550",
wmh_shiva_t1_flair_4 = "https://figshare.com/ndownloader/files/48660544",
protonLungMri = "https://ndownloader.figshare.com/files/13606799",
protonLobes = "https://figshare.com/ndownloader/files/30678455",
pulmonaryAirwayWeights = "https://figshare.com/ndownloader/files/45187168",
Expand Down
Loading

0 comments on commit 199c8be

Please sign in to comment.