diff --git a/src/layers/core/Lambda.js b/src/layers/core/Lambda.js new file mode 100644 index 00000000..08b7d2d2 --- /dev/null +++ b/src/layers/core/Lambda.js @@ -0,0 +1,47 @@ +import Layer from '../../Layer' + +/** + * Lambda layer class + * This layer requires you to re-implement lambda nodes in javascript, + * as we do not have a python runtime available. + */ +export default class Lambda extends Layer { + + /** + * Creates a Lambda layer + * + * @param {Object} [attrs] - layer config attributes + */ + constructor(attrs = {}) { + super(attrs); + this.layerClass = 'Lambda'; + + if(this.functions[attrs.name]) { + this._call = this.functions[attrs.name].bind(this); + } else { + console.log("Missing lambda, using No_op! Implement it by defining require(\"keras-js/layers/core/Lambda\").functions["+attrs.name+"]"); + this._call = x => ({output: x, inputShape: x.tensor.shape}); + } + + if(this.initializers[attrs.name]) { + this.initializers[attrs.name].bind(this)(); + } else { + console.log("No lambda initializer. Disable this warning by defining require(\"keras-js/layers/core/Lambda\").initializers["+attrs.name+"]"); + } + } + + /** + * Method for layer computational logic + * + * @param {Tensor} x + * @returns {Tensor} + */ + call(x) { + return Object.assign(this, this._call(x)).output; + } + +} +Lambda.prototype.functions = {}; +Lambda.prototype.initializers = {} + +exports.default = Lambda; diff --git a/src/layers/core/index.js b/src/layers/core/index.js index b7af847e..d0d5cb8e 100644 --- a/src/layers/core/index.js +++ b/src/layers/core/index.js @@ -5,6 +5,7 @@ import SpatialDropout1D from './SpatialDropout1D' import SpatialDropout2D from './SpatialDropout2D' import SpatialDropout3D from './SpatialDropout3D' import Flatten from './Flatten' +import Lambda from './Lambda' import Reshape from './Reshape' import Permute from './Permute' import RepeatVector from './RepeatVector' @@ -17,6 +18,7 @@ export { SpatialDropout2D, SpatialDropout3D, Flatten, + Lambda, Reshape, Permute, RepeatVector