Skip to content
This repository has been archived by the owner on Nov 11, 2021. It is now read-only.

Commit

Permalink
feat(playground): remove regularization from link internal state + …
Browse files Browse the repository at this point in the history
…apply regularization from state, on the fly

inspired by this PR:
tensorflow#139
  • Loading branch information
Bamdad Sabbagh committed Oct 28, 2021
1 parent 894f04c commit 8627dec
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 19 deletions.
36 changes: 20 additions & 16 deletions src/nn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -166,23 +166,18 @@ export class Link {
accErrorDer = 0
/** Number of accumulated derivatives since the last update. */
numAccumulatedDers = 0
regularization: RegularizationFunction

/**
* Constructs a link in the neural network initialized with random weight.
*
* @param source The source node.
* @param dest The destination node.
* @param regularization The regularization function that computes the
* penalty for this weight. If null, there will be no regularization.
* @param initZero
*/
constructor (source: Node, dest: Node,
regularization: RegularizationFunction, initZero?: boolean) {
constructor (source: Node, dest: Node, initZero?: boolean) {
this.id = source.id + '-' + dest.id
this.source = source
this.dest = dest
this.regularization = regularization
if (initZero) {
this.weight = 0
}
Expand All @@ -197,16 +192,12 @@ export class Link {
* 3 nodes in second hidden layer and 1 output node.
* @param activation The activation function of every hidden node.
* @param outputActivation The activation function for the output nodes.
* @param regularization The regularization function that computes a penalty
* for a given weight (parameter) in the network. If null, there will be
* no regularization.
* @param inputIds List of ids for the input nodes.
* @param initZero
*/
export function buildNetwork (
networkShape: number[], activation: ActivationFunction,
outputActivation: ActivationFunction,
regularization: RegularizationFunction,
inputIds: string[], initZero?: boolean): Node[][] {
let numLayers = networkShape.length
let id = 1
Expand All @@ -232,7 +223,7 @@ export function buildNetwork (
// Add links from nodes in the previous layer to this node.
for (let j = 0; j < network[layerIdx - 1].length; j++) {
let prevNode = network[layerIdx - 1][j]
let link = new Link (prevNode, node, regularization, initZero)
let link = new Link (prevNode, node, initZero)
prevNode.outputs.push (link)
node.inputLinks.push (link)
}
Expand Down Expand Up @@ -330,12 +321,25 @@ export function backProp (network: Node[][], target: number,
}
}

type UpdateWeights = {
network: Node[][],
learningRate: number,
regularization: RegularizationFunction,
regularizationRate: number,
}

/**
* Updates the weights of the network using the previously accumulated error
* derivatives.
*/
export function updateWeights (network: Node[][], learningRate: number,
regularizationRate: number) {
export function updateWeights (
{
network,
learningRate,
regularization,
regularizationRate,
}: UpdateWeights,
) {
for (let layerIdx = 1; layerIdx < network.length; layerIdx++) {
let currentLayer = network[layerIdx]
for (let i = 0; i < currentLayer.length; i++) {
Expand All @@ -352,16 +356,16 @@ export function updateWeights (network: Node[][], learningRate: number,
if (link.isDead) {
continue
}
let regulDer = link.regularization ?
link.regularization.der (link.weight) : 0
let regulDer = regularization ?
regularization.der (link.weight) : 0
if (link.numAccumulatedDers > 0) {
// Update the weight based on dE/dw.
link.weight = link.weight -
(learningRate / link.numAccumulatedDers) * link.accErrorDer
// Further update the weight based on regularization.
let newLinkWeight = link.weight -
(learningRate * regularizationRate) * regulDer
if (link.regularization === RegularizationFunction.L1 &&
if (regularization === RegularizationFunction.L1 &&
link.weight * newLinkWeight < 0) {
// The weight crossed 0 due to the regularization term. Set it to 0.
link.weight = 0
Expand Down
10 changes: 7 additions & 3 deletions src/playground.ts
Original file line number Diff line number Diff line change
Expand Up @@ -950,7 +950,12 @@ function oneStep (): void {
nn.forwardProp (network, input)
nn.backProp (network, point.label, nn.Errors.SQUARE)
if ((i + 1) % state.batchSize === 0) {
nn.updateWeights (network, state.learningRate, state.regularizationRate)
nn.updateWeights ({
network,
learningRate: state.learningRate,
regularization: state.regularization,
regularizationRate: state.regularizationRate,
})
}
})
// Compute the loss.
Expand Down Expand Up @@ -992,8 +997,7 @@ function reset (onStartup = false) {
let shape = [numInputs].concat (state.networkShape).concat ([1])
let outputActivation = (state.problem === Problem.REGRESSION) ?
nn.Activations.LINEAR : nn.Activations.TANH
network = nn.buildNetwork (shape, state.activation, outputActivation,
state.regularization, constructInputIds (), state.initZero)
network = nn.buildNetwork (shape, state.activation, outputActivation, constructInputIds (), state.initZero)
lossTrain = getLoss (network, trainData)
lossTest = getLoss (network, testData)
drawNetwork (network)
Expand Down

0 comments on commit 8627dec

Please sign in to comment.