From 9b190767b2857ed75010d495fdbe86b4a1fe03fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=B8ren=20S=C3=B8nderby?= Date: Tue, 17 Jun 2014 16:26:49 +0200 Subject: [PATCH] add learning rate function to allow decaying learning rate --- NN/nnapplygrads.m | 2 +- NN/nnsetup.m | 9 +++++++-- NN/nntrain.m | 12 ++++++++++-- tests/test_example_NN.m | 18 +++++++++++++++++- 4 files changed, 35 insertions(+), 6 deletions(-) diff --git a/NN/nnapplygrads.m b/NN/nnapplygrads.m index 781163b..8c67882 100644 --- a/NN/nnapplygrads.m +++ b/NN/nnapplygrads.m @@ -10,7 +10,7 @@ dW = nn.dW{i}; end - dW = nn.learningRate * dW; + dW = nn.currentLearningRate * dW; if(nn.momentum>0) nn.vW{i} = nn.momentum*nn.vW{i} + dW; diff --git a/NN/nnsetup.m b/NN/nnsetup.m index b8ec742..8971d96 100644 --- a/NN/nnsetup.m +++ b/NN/nnsetup.m @@ -6,10 +6,15 @@ nn.size = architecture; nn.n = numel(nn.size); + + nn.activation_function = 'tanh_opt'; % Activation functions of hidden layers: 'sigm' (sigmoid) or 'tanh_opt' (optimal tanh). - nn.learningRate = 2; % learning rate Note: typically needs to be lower when using 'sigm' activation function and non-normalized inputs. + + % learningRate is a function that takes epoch number as input and return + % the desired learning rate for that epoch. + nn.learningRate = @(epoch) 2; nn.momentum = 0.5; % Momentum - nn.scaling_learningRate = 1; % Scaling factor for the learning rate (each epoch) + %nn.scaling_learningRate = 1; % Scaling factor for the learning rate (each epoch) nn.weightPenaltyL2 = 0; % L2 regularization nn.nonSparsityPenalty = 0; % Non sparsity penalty nn.sparsityTarget = 0.05; % Sparsity target diff --git a/NN/nntrain.m b/NN/nntrain.m index af844a6..6f8fc31 100644 --- a/NN/nntrain.m +++ b/NN/nntrain.m @@ -35,6 +35,10 @@ L = zeros(numepochs*numbatches,1); n = 1; for i = 1 : numepochs + + %calculate current learning rate + nn.currentLearningRate = nn.learningRate(i); + tic; kk = randperm(m); @@ -70,8 +74,12 @@ nnupdatefigures(nn, fhandle, loss, opts, i); end - disp(['epoch ' num2str(i) '/' num2str(opts.numepochs) '. Took ' num2str(t) ' seconds' '. Mini-batch mean squared error on training set is ' num2str(mean(L((n-numbatches):(n-1)))) str_perf]); - nn.learningRate = nn.learningRate * nn.scaling_learningRate; + disp(['epoch ' num2str(i) '/' num2str(opts.numepochs)... + '. Took ' num2str(t) ' seconds'... + '. Mini-batch mean squared error on training set is '... + num2str(mean(L((n-numbatches):(n-1)))) str_perf ... + ' Learning rate: ' num2str(nn.currentLearningRate)]); + end end diff --git a/tests/test_example_NN.m b/tests/test_example_NN.m index c254ee7..22c819d 100644 --- a/tests/test_example_NN.m +++ b/tests/test_example_NN.m @@ -53,7 +53,7 @@ nn = nnsetup([784 100 10]); nn.activation_function = 'sigm'; % Sigmoid activation function -nn.learningRate = 1; % Sigm require a lower learning rate +nn.learningRate = @(epoch) 1; % Sigm require a lower learning rate opts.numepochs = 1; % Number of full sweeps through data opts.batchsize = 100; % Take a mean gradient step over this many samples @@ -92,3 +92,19 @@ [er, bad] = nntest(nn, test_x, test_y); assert(er < 0.1, 'Too big error'); + + +% ex7 vanilla net with exponentially decaying learning rate +rand('state',0) +nn = nnsetup([784 100 10]); +nn.learningRate = @(epoch) 2*exp(-0.5*epoch); +opts.numepochs = 5; % Number of full sweeps through data +opts.batchsize = 100; % Take a mean gradient step over this many samples +[nn, L] = nntrain(nn, train_x, train_y, opts); + +[er, bad] = nntest(nn, test_x, test_y); + +assert(er < 0.08, 'Too big error'); + + +