Description
System information
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow.js): true
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Arch Linux
- TensorFlow.js installed from (npm or script link): NPM
- TensorFlow.js version (use command below): 4.17.0
Describe the current behavior
When building custom layers, it is often useful to use "standard" layer types like tf.layers.dense
and tf.layers.LSTM
, from inside of that layer. However, layers added in this way have 2 major problems:
- Their trainable parameters are not reported by
model.summary()
. - Their weights are not exported with
model.save()
.
This is problematic for obvious reasons. The alternative is to use the this.addWeight()
API; however, weights added in this way also have problems:
- It is wasteful and time-consuming, re-implementing layer types that already exist in the standard API.
- Weights added via
this.addWeight()
cannot use string activations, likemish
andswish
.
If there is already a supported way to integrate the weights from a standard layer like tf.layers.dense
, from within a custom model - the method is not clear, from any of the documentation I've seen.
Describe the expected behavior
I would expect weights used by the computational graph to be included in the model.summary()
's "trainable parameters" report. But, they are not.
___________________________________________________________________________________________________________________
Layer (type) Input Shape Output shape Param # Receives inputs
===================================================================================================================
inp-t0B (InputLayer) [[null,null]] [null,null] 0
___________________________________________________________________________________________________________________
emb-gza (SharedEmbedding) [[null,null]],[[null,null,2 multiple 5091328 inp-t0B[0][0]
mlp-adG[0][0]
___________________________________________________________________________________________________________________
enc-RC2 (SinusoidalPositio [[null,null,256]] [null,null,256] 0 emb-gza[0][0]
___________________________________________________________________________________________________________________
attn-FBz (SelfAttention) [[null,null,256]] [null,null,256] 0 enc-RC2[0][0]
___________________________________________________________________________________________________________________
mlp-3kL (MultiLayerPercept [[null,null,256]] [null,null,256] 0 attn-FBz[0][0]
___________________________________________________________________________________________________________________
attn-VZK (SelfAttention) [[null,null,256]] [null,null,256] 0 mlp-3kL[0][0]
___________________________________________________________________________________________________________________
mlp-Jfy (MultiLayerPercept [[null,null,256]] [null,null,256] 0 attn-VZK[0][0]
___________________________________________________________________________________________________________________
attn-j0b (SelfAttention) [[null,null,256]] [null,null,256] 0 mlp-Jfy[0][0]
___________________________________________________________________________________________________________________
mlp-oyK (MultiLayerPercept [[null,null,256]] [null,null,256] 0 attn-j0b[0][0]
___________________________________________________________________________________________________________________
attn-L1y (SelfAttention) [[null,null,256]] [null,null,256] 0 mlp-oyK[0][0]
___________________________________________________________________________________________________________________
mlp-9r1 (MultiLayerPercept [[null,null,256]] [null,null,256] 0 attn-L1y[0][0]
___________________________________________________________________________________________________________________
attn-Yha (SelfAttention) [[null,null,256]] [null,null,256] 0 mlp-9r1[0][0]
___________________________________________________________________________________________________________________
mlp-GV8 (MultiLayerPercept [[null,null,256]] [null,null,256] 0 attn-Yha[0][0]
___________________________________________________________________________________________________________________
attn-R5D (SelfAttention) [[null,null,256]] [null,null,256] 0 mlp-GV8[0][0]
___________________________________________________________________________________________________________________
mlp-adG (MultiLayerPercept [[null,null,256]] [null,null,256] 0 attn-R5D[0][0]
===================================================================================================================
Total params: 5091328
Trainable params: 5091328
Non-trainable params: 0
Standalone code to reproduce the issue
Add the following custom layer to any model, then call model.compile()
, then model.summary()
. You will see that it reports 0 trainable parameters:
class MultiLayerPerceptron extends tf.layers.Layer {
constructor(config) {
super({ ...config })
this.units = config?.units || 256
this.innerDim = config?.innerDim || 1024
}
build(inputShape) {
this.inProj = tf.layers.dense({
units: this.innerDim,
inputDim: this.units,
activation: 'relu'
})
this.outProj = tf.layers.dense({
units: this.units,
inputDim: inputShape,
activation: 'linear'
})
}
call(inputs, kwargs, training = false) {
return tf.tidy(() => {
inputs = Array.isArray(inputs) ? inputs[0] : inputs
return this.outProj.apply(this.inProj.apply(inputs))
})
}
computeOutputShape(inputShape) {
return inputShape
}
getConfig() {
return {
...super.getClass(),
units: this.units,
innerDim: this.innerDim
}
}
static get className() {
return 'MultiLayerPerceptron'
}
}
tf.serialization.registerClass(MultiLayerPerceptron)
Other info / logs
If there is a supported way to add the trainable parameters from tf.layers.dense()
to my custom layer, please let me know!