Skip to content
This repository has been archived by the owner on Sep 26, 2020. It is now read-only.

Add some missing layers commonly used for image processing #99

Merged
merged 7 commits into from
Nov 6, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified libraries/kotlintest-core-jvm-4.0.2631-SNAPSHOT.jar
Binary file not shown.
3 changes: 3 additions & 0 deletions test-util/test-util.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ dependencies {
}
})

// Needed by kotlintest
api(group = "com.github.wumpz", name = "diffutils", version = "2.2")

api(
group = "org.junit.jupiter",
name = "junit-jupiter",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,18 @@ class DefaultLayerToCode : LayerToCode, KoinComponent {
)
}

is Layer.AveragePooling2D -> makeLayerCode(
"tf.keras.layers.AvgPool2D",
listOf(),
listOf(
"pool_size" to layer.poolSize,
"strides" to layer.strides,
"padding" to layer.padding.value,
"data_format" to layer.dataFormat?.value,
"name" to layer.name
)
).right()

is Layer.Dense -> makeLayerCode(
"tf.keras.layers.Dense",
listOf(),
Expand Down Expand Up @@ -106,6 +118,24 @@ class DefaultLayerToCode : LayerToCode, KoinComponent {
)
).right()

is Layer.GlobalAveragePooling2D -> makeLayerCode(
"tf.keras.layers.GlobalAveragePooling2D",
listOf(),
listOf(
"data_format" to layer.dataFormat?.value,
"name" to layer.name
)
).right()

is Layer.GlobalMaxPooling2D -> makeLayerCode(
"tf.keras.layers.GlobalMaxPooling2D",
listOf(),
listOf(
"data_format" to layer.dataFormat?.value,
"name" to layer.name
)
).right()

is Layer.MaxPooling2D -> makeLayerCode(
"tf.keras.layers.MaxPooling2D",
listOf(),
Expand All @@ -118,6 +148,26 @@ class DefaultLayerToCode : LayerToCode, KoinComponent {
)
).right()

is Layer.SpatialDropout2D -> makeLayerCode(
"tf.keras.layers.SpatialDropout2D",
listOf(layer.rate.toString()),
listOf(
"data_format" to layer.dataFormat?.value,
"name" to layer.name
)
).right()

is Layer.UpSampling2D -> makeLayerCode(
"tf.keras.layers.UpSampling2D",
listOf(),
listOf(
"size" to layer.size,
"data_format" to layer.dataFormat?.value,
"interpolation" to layer.interpolation.value,
"name" to layer.name
)
).right()

// TODO: Remove this
else -> "Cannot construct an unknown layer: $layer".left()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@ import edu.wpi.axon.tfdata.layer.Activation
import edu.wpi.axon.tfdata.layer.Constraint
import edu.wpi.axon.tfdata.layer.DataFormat
import edu.wpi.axon.tfdata.layer.Initializer
import edu.wpi.axon.tfdata.layer.Interpolation
import edu.wpi.axon.tfdata.layer.Layer
import edu.wpi.axon.tfdata.layer.PoolingPadding
import edu.wpi.axon.tfdata.layer.Regularizer
import edu.wpi.axon.tfdata.layer.Layer
import io.kotlintest.shouldBe
import io.mockk.every
import io.mockk.mockk
Expand Down Expand Up @@ -183,6 +184,57 @@ internal class DefaultLayerToCodeTest : KoinTestFixture() {
),
("""tf.keras.layers.Flatten(data_format="channels_first", name="name")""").right(),
null
),
Arguments.of(
Layer.AveragePooling2D(
"name",
None,
Right(Tuple2(2, 2)),
Left(3),
PoolingPadding.Valid,
DataFormat.ChannelsLast
),
Right("""tf.keras.layers.AvgPool2D(pool_size=(2, 2), strides=3, padding="valid", data_format="channels_last", name="name")"""),
null
),
Arguments.of(
Layer.GlobalMaxPooling2D(
"name",
None,
DataFormat.ChannelsFirst
),
Right("""tf.keras.layers.GlobalMaxPooling2D(data_format="channels_first", name="name")"""),
null
),
Arguments.of(
Layer.SpatialDropout2D(
"name",
None,
0.2,
null
),
Right("""tf.keras.layers.SpatialDropout2D(0.2, data_format=None, name="name")"""),
null
),
Arguments.of(
Layer.UpSampling2D(
"name",
None,
Right(Tuple2(2, 2)),
null,
Interpolation.Nearest
),
Right("""tf.keras.layers.UpSampling2D(size=(2, 2), data_format=None, interpolation="nearest", name="name")"""),
null
),
Arguments.of(
Layer.GlobalAveragePooling2D(
"name",
None,
DataFormat.ChannelsLast
),
Right("""tf.keras.layers.GlobalAveragePooling2D(data_format="channels_last", name="name")"""),
null
)
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package edu.wpi.axon.tfdata.layer

/**
* Values for the `interpolation` parameter for sampling-type layers.
*/
enum class Interpolation(val value: String) {
Nearest("nearest"), Bilinear("bilinear")
}
71 changes: 71 additions & 0 deletions tf-data/src/main/kotlin/edu/wpi/axon/tfdata/layer/Layer.kt
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,17 @@ sealed class Layer {
override val inputs: Option<Set<String>>
) : Layer()

/**
* A layer that contains an entire model inside it.
*
* @param model The model that acts as this layer.
*/
data class ModelLayer(
override val name: String,
override val inputs: Option<Set<String>>,
val model: Model
) : Layer()

/**
* A layer that accepts input data and has no parameters.
*
Expand Down Expand Up @@ -141,6 +152,18 @@ sealed class Layer {
val virtualBatchSize: Int? = null
) : Layer()

/**
* https://www.tensorflow.org/versions/r1.14/api_docs/python/tf/keras/layers/AveragePooling2D
*/
data class AveragePooling2D(
override val name: String,
override val inputs: Option<Set<String>>,
val poolSize: Either<Int, Tuple2<Int, Int>> = Right(Tuple2(2, 2)),
val strides: Either<Int, Tuple2<Int, Int>>? = null,
val padding: PoolingPadding = PoolingPadding.Valid,
val dataFormat: DataFormat? = null
) : Layer()

/**
* https://www.tensorflow.org/versions/r1.14/api_docs/python/tf/keras/layers/Conv2D
*/
Expand Down Expand Up @@ -197,6 +220,24 @@ sealed class Layer {
val dataFormat: DataFormat? = null
) : Layer()

/**
* https://www.tensorflow.org/versions/r1.14/api_docs/python/tf/keras/layers/GlobalAveragePooling2D
*/
data class GlobalAveragePooling2D(
override val name: String,
override val inputs: Option<Set<String>>,
val dataFormat: DataFormat?
) : Layer()

/**
* https://www.tensorflow.org/versions/r1.14/api_docs/python/tf/keras/layers/GlobalMaxPool2D
*/
data class GlobalMaxPooling2D(
override val name: String,
override val inputs: Option<Set<String>>,
val dataFormat: DataFormat? = null
) : Layer()

/**
* https://www.tensorflow.org/versions/r1.14/api_docs/python/tf/keras/layers/MaxPool2D
*/
Expand All @@ -208,4 +249,34 @@ sealed class Layer {
val padding: PoolingPadding = PoolingPadding.Valid,
val dataFormat: DataFormat? = null
) : Layer()

/**
* https://www.tensorflow.org/versions/r1.14/api_docs/python/tf/keras/layers/SpatialDropout2D
*/
data class SpatialDropout2D(
override val name: String,
override val inputs: Option<Set<String>>,
val rate: Double,
val dataFormat: DataFormat? = null
) : Layer() {

init {
require(rate in 0.0..1.0) {
"rate ($rate) was outside the allowed range of [0, 1]."
}
}
}

/**
* https://www.tensorflow.org/versions/r1.14/api_docs/python/tf/keras/layers/UpSampling2D
*
* Bug: TF does not export a value for [interpolation].
*/
data class UpSampling2D(
override val name: String,
override val inputs: Option<Set<String>>,
val size: Either<Int, Tuple2<Int, Int>> = Right(Tuple2(2, 2)),
val dataFormat: DataFormat? = null,
val interpolation: Interpolation = Interpolation.Nearest
) : Layer()
}
12 changes: 12 additions & 0 deletions tf-data/src/test/kotlin/edu/wpi/axon/tfdata/layer/LayerTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,18 @@ internal class LayerTest {

@Test
fun `dropout with invalid rate`() {
shouldThrow<IllegalArgumentException> { Layer.Dropout("", None, -0.1) }
shouldThrow<IllegalArgumentException> { Layer.Dropout("", None, 1.2) }
}

@Test
fun `spatialdropout2d with valid rate`() {
shouldNotThrow<IllegalArgumentException> { Layer.SpatialDropout2D("", None, 0.5) }
}

@Test
fun `spatialdropout2d with invalid rate`() {
shouldThrow<IllegalArgumentException> { Layer.SpatialDropout2D("", None, -0.1) }
shouldThrow<IllegalArgumentException> { Layer.SpatialDropout2D("", None, 1.2) }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import edu.wpi.axon.tfdata.layer.Activation
import edu.wpi.axon.tfdata.layer.Constraint
import edu.wpi.axon.tfdata.layer.DataFormat
import edu.wpi.axon.tfdata.layer.Initializer
import edu.wpi.axon.tfdata.layer.Interpolation
import edu.wpi.axon.tfdata.layer.Layer
import edu.wpi.axon.tfdata.layer.PoolingPadding
import edu.wpi.axon.tfdata.layer.Regularizer
Expand Down Expand Up @@ -129,6 +130,8 @@ class LoadLayersFromHDF5(
val json = data["config"] as JsonObject
val name = json["name"] as String
return when (className) {
"Sequential", "Model" -> Layer.ModelLayer(name, data.inboundNodes(), parseModel(data))

"InputLayer" -> Layer.InputLayer(
name,
(json["batch_input_shape"] as JsonArray<Int?>).toList().let {
Expand Down Expand Up @@ -166,8 +169,16 @@ class LoadLayersFromHDF5(
json["virtual_batch_size"] as Int?
)

"Conv2D"
-> Layer.Conv2D(
"AvgPool2D", "AveragePooling2D" -> Layer.AveragePooling2D(
name,
data.inboundNodes(),
json["pool_size"].tuple2OrInt(),
json["strides"].tuple2OrIntOrNull(),
json["padding"].poolingPadding(),
json["data_format"].dataFormatOrNull()
)

"Conv2D" -> Layer.Conv2D(
name,
data.inboundNodes(),
json["filters"] as Int,
Expand Down Expand Up @@ -208,6 +219,18 @@ class LoadLayersFromHDF5(
json["data_format"].dataFormatOrNull()
)

"GlobalAveragePooling2D", "GlobalAvgPool2D" -> Layer.GlobalAveragePooling2D(
name,
data.inboundNodes(),
json["data_format"].dataFormatOrNull()
)

"GlobalMaxPooling2D", "GlobalMaxPool2D" -> Layer.GlobalMaxPooling2D(
name,
data.inboundNodes(),
json["data_format"].dataFormatOrNull()
)

"MaxPool2D", "MaxPooling2D" -> Layer.MaxPooling2D(
name,
data.inboundNodes(),
Expand All @@ -217,6 +240,21 @@ class LoadLayersFromHDF5(
json["data_format"].dataFormatOrNull()
)

"SpatialDropout2D" -> Layer.SpatialDropout2D(
name,
data.inboundNodes(),
json["rate"].double(),
json["data_format"].dataFormatOrNull()
)

"UpSampling2D" -> Layer.UpSampling2D(
name,
data.inboundNodes(),
json["size"].tuple2OrInt(),
json["data_format"].dataFormatOrNull(),
json["interpolation"].interpolation()
)

else -> Layer.UnknownLayer(
name,
data.inboundNodes()
Expand Down Expand Up @@ -363,6 +401,13 @@ private fun Any?.dataFormatOrNull(): DataFormat? = when (this as? String) {
else -> throw IllegalArgumentException("Not convertible: $this")
}

private fun Any?.interpolation(): Interpolation = when (this as? String) {
// Null in versions < v1.15.0 (TF bug). Use nearest as the default
null, "nearest" -> Interpolation.Nearest
"bilinear" -> Interpolation.Bilinear
else -> throw IllegalArgumentException("Not convertible: $this")
}

private fun Any?.tuple2OrInt(): Either<Int, Tuple2<Int, Int>> = when {
this is Int -> Left(this)

Expand Down
Loading