Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add learning rate function to allow decaying learning rate #99

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion NN/nnapplygrads.m
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
dW = nn.dW{i};
end

dW = nn.learningRate * dW;
dW = nn.currentLearningRate * dW;

if(nn.momentum>0)
nn.vW{i} = nn.momentum*nn.vW{i} + dW;
Expand Down
9 changes: 7 additions & 2 deletions NN/nnsetup.m
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,15 @@
nn.size = architecture;
nn.n = numel(nn.size);



nn.activation_function = 'tanh_opt'; % Activation functions of hidden layers: 'sigm' (sigmoid) or 'tanh_opt' (optimal tanh).
nn.learningRate = 2; % learning rate Note: typically needs to be lower when using 'sigm' activation function and non-normalized inputs.

% learningRate is a function that takes epoch number as input and return
% the desired learning rate for that epoch.
nn.learningRate = @(epoch) 2;
nn.momentum = 0.5; % Momentum
nn.scaling_learningRate = 1; % Scaling factor for the learning rate (each epoch)
%nn.scaling_learningRate = 1; % Scaling factor for the learning rate (each epoch)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please properly remove all references to nn.scaning_learningRate.
Your new stuff depreceates it.

nn.weightPenaltyL2 = 0; % L2 regularization
nn.nonSparsityPenalty = 0; % Non sparsity penalty
nn.sparsityTarget = 0.05; % Sparsity target
Expand Down
12 changes: 10 additions & 2 deletions NN/nntrain.m
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
L = zeros(numepochs*numbatches,1);
n = 1;
for i = 1 : numepochs

%calculate current learning rate
nn.currentLearningRate = nn.learningRate(i);

tic;

kk = randperm(m);
Expand Down Expand Up @@ -70,8 +74,12 @@
nnupdatefigures(nn, fhandle, loss, opts, i);
end

disp(['epoch ' num2str(i) '/' num2str(opts.numepochs) '. Took ' num2str(t) ' seconds' '. Mini-batch mean squared error on training set is ' num2str(mean(L((n-numbatches):(n-1)))) str_perf]);
nn.learningRate = nn.learningRate * nn.scaling_learningRate;
disp(['epoch ' num2str(i) '/' num2str(opts.numepochs)...
'. Took ' num2str(t) ' seconds'...
'. Mini-batch mean squared error on training set is '...
num2str(mean(L((n-numbatches):(n-1)))) str_perf ...
' Learning rate: ' num2str(nn.currentLearningRate)]);

end
end

18 changes: 17 additions & 1 deletion tests/test_example_NN.m
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
nn = nnsetup([784 100 10]);

nn.activation_function = 'sigm'; % Sigmoid activation function
nn.learningRate = 1; % Sigm require a lower learning rate
nn.learningRate = @(epoch) 1; % Sigm require a lower learning rate
opts.numepochs = 1; % Number of full sweeps through data
opts.batchsize = 100; % Take a mean gradient step over this many samples

Expand Down Expand Up @@ -92,3 +92,19 @@

[er, bad] = nntest(nn, test_x, test_y);
assert(er < 0.1, 'Too big error');


% ex7 vanilla net with exponentially decaying learning rate
rand('state',0)
nn = nnsetup([784 100 10]);
nn.learningRate = @(epoch) 2*exp(-0.5*epoch);
opts.numepochs = 5; % Number of full sweeps through data
opts.batchsize = 100; % Take a mean gradient step over this many samples
[nn, L] = nntrain(nn, train_x, train_y, opts);

[er, bad] = nntest(nn, test_x, test_y);

assert(er < 0.08, 'Too big error');