-
Notifications
You must be signed in to change notification settings - Fork 2
/
LSTM_SiteClass.m
103 lines (79 loc) · 2.11 KB
/
LSTM_SiteClass.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
%% First LSTM Tests
%
% Description : This script is made for first tests for applying LSTM Nets
% on our data and perform site classification
%
% Author :
% Stefan Herdy
% m01610562
%
% Date: 10.04.2020
% --------------------------------------------------
% (c) 2020, Stefan Herdy
% Chair of Automation, University of Leoben, Austria
% email: [email protected]
% --------------------------------------------------
%
%% Prepare Workspace
close all;
clear;
%
%% Load Data
[DataMatrix,Labels] = generateSiteInput()
%% Prepare data
X = {};
Y = {};
[s1, s2] = size(DataMatrix);
% Shuffle the data
rand = randperm(s1);
for i = 1:s1;
plc = rand(i);
X{plc,1} = DataMatrix{i,1};
Y{plc,1} = Labels{i,1};
end
% Split into train and test data
[s1 s2] = size(X);
split = 0.9*s1;
split = round(split);
X_Train = X(1:split,1);
X_Test = X(split:end,1);
Y_Train = Y(1:split,1);
Y_Test = Y(split:end,1);
Y_Train = categorical(Y_Train);
Y_Test = categorical(Y_Test);
%% Define LSTM Network Architecture
inputSize = 9;
numHiddenUnits = 50;
numClasses = 4;
layers = [ ...
sequenceInputLayer(inputSize)
bilstmLayer(numHiddenUnits,'OutputMode','last')
fullyConnectedLayer(numClasses)
softmaxLayer
classificationLayer]
%% Specify the training options.
maxEpochs = 10;
miniBatchSize = 32;
options = trainingOptions('adam', ...
'ExecutionEnvironment','gpu', ...
'GradientThreshold',1, ...
'MaxEpochs',maxEpochs, ...
'MiniBatchSize',miniBatchSize, ...
'SequenceLength','longest', ...
'Shuffle','never', ...
'Verbose',0, ...
'Plots','training-progress');
%% Train the network
net = trainNetwork(X_Train,Y_Train,layers,options);
%% Test LSTM Network
% Classify the test data.
Y_Pred = classify(net,X_Test, ...
'MiniBatchSize',miniBatchSize, ...
'SequenceLength','longest');
%% Calculate the classification accuracy of the predictions.
acc = sum(Y_Pred == Y_Test)./numel(Y_Test)
C = confusionmat(Y_Test,Y_Pred);
CC = confusionchart(C)
CC.Title = 'Site Classification using LSTM Net';
%CC.RowSummary = 'row-normalized';
%CC.ColumnSummary = 'column-normalized';