-
Notifications
You must be signed in to change notification settings - Fork 1
/
decodeData.m
118 lines (105 loc) · 4.58 KB
/
decodeData.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
function decoderResults = decodeData(X, Y, params)
% Decode data and perform cross-validation
% Based on Scott Brincat's code for BMI neural signal decoding
% ----------------------------------------
% X - signal / neural data
% Y - response / intended BMI targets
% params - parameters structure
% Fields of the structure "params":
% 'crossValType' - type of cross-validation: 'none', 'leave-one-out' or 'N-fold' (where N is a positive integer)
% 'decType' - type of decoder: 'lda', 'knn', 'mlr' (multinomial logistic regression) or 'map' (maximum a-posteriori Baysian inference)
% 'LDAdiscrimType' - (for 'lda' decoder only) discriminat type (see MATLAB help for "ClassificationDiscriminant" class)
% 'numNeighbors' - (for 'knn' decoder only) number of neighbors to use
% 'distance' - (for 'knn' decoder only) distance metric to use (see MATLAB help for "ClassificationKNN" class)
% 'spkPreds' - (for 'map' decoder only) logical vector flagging which predictors are based on spike count vs something else (ie, LFP power)
% % load features
% Ftr = load('C:\!analysis\CS20120505\CS20120505-features-lfp-[80-500Hz]-[0-750ms]-SEF.mat');
% features = Ftr.features(:, find(Ftr.params.chs==49));
%
% % load events
% loadParams.session = Ftr.params.session;
% loadParams.dataType = 'evt';
% Evt = loadData(loadParams);
%
% % select conditions
% conds = selectConditions(Ftr.params.session, Evt, 'target-direction');
% conds = conds(Ftr.params.trials);
% set parameters
if ~exist('params', 'var')
params = struct;
end
if ~isfield(params, 'crossValType')
params.crossValType = '10-fold';
end
if ~isfield(params, 'decType')
params.decType = 'lda';
end
if strcmp(params.decType, 'lda') && ~isfield(params, 'LDAdiscrimType')
params.LDAdiscrimType = 'pseudolinear';
end
if strcmp(params.decType, 'knn') && ~isfield(params, 'numNeighbors')
params.numNeighbors = 100;
end
if strcmp(params.decType, 'knn') && ~isfield(params, 'distance')
params.distance = 'cityblock';
end
if strcmp(params.decType, 'map') && ~isfield(params, 'spkPreds')
params.spkPreds = 0;
end
decoderResults.params = params;
% decode data without cross-validation
nTr = size(X, 1);
switch params.decType
case 'lda'
decoderResults.prediction = LDAdecoder(X, Y, X, params.LDAdiscrimType);
case 'knn'
decoderResults.prediction = KNNdecoder(X, Y, X, params.numNeighbors, params.distance);
case 'mlr'
decoderResults.prediction = MLRdecoder(X, Y, X);
case 'map'
decoderResults.prediction = MAPdecoder(X, Y, X, params.spkPreds);
end
testY = Y;
% perform cross-validation if requested
if ~strcmp(params.crossValType, 'none')
decoderResults.trainAccuracy = 100*sum(decoderResults.prediction==testY)/nTr;
% set the size of testing sets
if strcmp(params.crossValType, 'leave-one-out')
nTestSets = nTr;
testSetSize = 1;
elseif strcmp(params.crossValType(end-4:end), '-fold')
nTestSets = str2double(params.crossValType(1:end-5));
testSetSize = ceil(nTr/nTestSets);
else
error('CANNOT RECOGNIZE CROSS-VALIDATION TYPE!');
end
% shuffle data
rng('shuffle');
shuffledTr = randperm(nTr);
% decode data
for n = 1:nTestSets
testSet = (testSetSize*(n-1)+1) : min(testSetSize*n, nTr);
trainX = X;
trainX(shuffledTr(testSet), :) = [];
testX = X(shuffledTr(testSet), :);
trainY = Y;
trainY(shuffledTr(testSet)) = [];
switch params.decType
case 'lda'
decoderResults.prediction(testSet, 1) = LDAdecoder(trainX, trainY, testX, params.LDAdiscrimType);
case 'knn'
decoderResults.prediction(testSet, 1) = KNNdecoder(trainX, trainY, testX, params.numNeighbors, params.distance);
case 'mlr'
decoderResults.prediction(testSet, 1) = MLRdecoder(trainX, trainY, testX);
case 'map'
decoderResults.prediction(testSet, 1) = MAPdecoder(trainX, trainY, testX, params.spkPreds);
end
end
testY = Y(shuffledTr);
end
% compute overall cross-validation decoder accuracy (percent correct choices over all test trials)
decoderResults.accuracy = 100*sum(decoderResults.prediction==testY)/nTr;
% calculate confusion matrix(actual,predicted) showing number of trials where decoder predicts given response target for each actual target
decoderResults.confusionMtx = confusionmat(testY, decoderResults.prediction);
% calculate cross-validated decoder accuracy (pct correct) for each actual response target
decoderResults.tgtAccuracy = 100*diag(decoderResults.confusionMtx)./sum(decoderResults.confusionMtx,2);