-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathsigInspectClassify.m
More file actions
332 lines (288 loc) · 12.3 KB
/
sigInspectClassify.m
File metadata and controls
332 lines (288 loc) · 12.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
function annot = sigInspectClassify(signal, signalId, fs, method, varargin)
% function annot = sigInspectClassify(signal,varargin)
% classify artifacts in each second of provided micro-EEG signal
%
% IN:
% signal - micro-EEG signal vector (or signal matrix with channels in rows)
% fs - sampling frequency in Hz
% method - classification method:
% 'psd' - normalized PSD spectrum thresholding, based on [1]
% (default) 91% train /88% test set accuracy on EMBC data
% threshold based on [2]
% 'tree'- pre-trained decision tree, based on multiple features,
% trained on a multi-centric database from [2]
% 'cov' - not yet implemented
% 'svm' - 3 svm classifiers - POW, BASE, FREQ artifacts
% params - optional parameters for some of the classification methods:
% for 'psd': detection trheshold (default:0.01)
% for 'cov' three parameters:
% - threshold (default 1.2 based on [2])
% - winLength: length of signal segment (default 0.25s)
% - aggregPerc: what proportion of one second window has
% to be marked as artifact to mark the whole sec. as
% artifact (deafault: winLength)
% (default values based on [2])
% for 'svm': struct containing artifact names and optionally threshold values
% OUT:
% annot - logical vector of annotation for each second of input signal.
% true = artifact, false = clean signal
%
% E. Bakstein 2015-06-29
%
% [1] Bakstein, E. et. al.: Supervised Segmentation of Microelectrode Recording Artifacts Using Power Spectral Density, in Proceedings of IEEE EMBS, 2015
% [2] Bakstein, E.: Deep Brain Recordings in Parkinson's Disease: Processing, Analysis and Fusion with Anatomical Models, Doctoral Thesis, October 2016
% input checks
if(nargin<2)
error('sampling frequency must be specified')
end
if(nargin<3 || isempty(method))
method='psd';
end
if(isempty(signal))
annot=[];
return;
end
if(length(signal)<fs)
error('signal must be at least 1s long');
end
[Nch,N] = size(signal); % number of channels + samples
Ns = ceil(N/fs); % number of seconds
% ---- CLASSIFIER PARAMETERS ----
% features to be computed
switch(method)
case {'psd','psdPrg'}
% Default threshold values
if strcmpi(method,'psd')
psdThr = 0.01; % trained on multi-center data
else
psdThr = 0.0085; % trained on Prague data
end
% If user passed a custom numeric threshold, override
if nargin > 3 && ~isempty(varargin) && isnumeric(varargin{1}) && isscalar(varargin{1})
psdThr = varargin{1};
elseif nargin > 3 && ~isempty(varargin) && ~(isnumeric(varargin{1}) && isempty(varargin{1}))
warning('Invalid PSD threshold argument. Using default value: %.3f', psdThr);
end
% Safety check (avoid non-scalar logic error)
if isnumeric(psdThr) && isscalar(psdThr)
if psdThr < 0 || psdThr > 0.03
warning('Recommended PSD threshold range is 0.005–0.02. Value %.3f may lead to unexpected results.', psdThr);
end
else
warning('PSD threshold was not a valid scalar. Using default value: %.3f', 0.01);
psdThr = 0.01;
end
% features
featNames={'maxNormPSD'}; % definition of dataset columns
featComp = [1]; % features actually computet (for compatibility with dec. tree)
method = 'psd';
case {'tree', 'treePrg'}
% load classifier params
try
classif = load('sigInspectClassifiers.mat'); % load pre-trained classifiers
catch err
error('could not load precalculated classifiers: search sigInspect root for the file sigInspectClassifers.mat (necessary for the tree-based classifiers)')
end
if(strcmp(method,'tree'))
classif.tree = classif.treeAll;
else
classif.tree = classif.treePrg;
end
% features to compute
featNames = classif.featNames; % all 19 features have to be in the set
featComp = setdiff(unique(classif.tree.var),0); % only some are needed (non-nan)
method = 'tree';
case 'cov'
featNames = {};
covThr = 1.2;
winLength = .25;
aggregPerc = winLength;
% Unwrap any cells that contain numeric values
if numel(varargin) == 1 && iscell(varargin{1}) && numel(varargin{1}) > 1
args = varargin{1};
else
args = varargin;
end
% Unwrap each cell value safely
args = cellfun(@(x) unwrapCell(x), args, 'UniformOutput', false);
% Threshold (arg 1)
if numel(args) >= 1 && isnumeric(args{1}) && isscalar(args{1}) && args{1} >= 1
covThr = args{1};
elseif numel(args) >= 1 && ~isempty(args{1})
warning('Invalid COV threshold argument. Using default value: %.2f', covThr);
end
% Window length (arg 2)
if numel(args) >= 2 && isnumeric(args{2}) && isscalar(args{2}) && args{2} > 0 && args{2} < 1
winLength = args{2};
elseif numel(args) >= 2 && ~isempty(args{2})
warning('Invalid COV window length argument. Using default value: %.2f', winLength);
end
% Aggregation percentage (arg 3)
if numel(args) >= 3 && isnumeric(args{3}) && isscalar(args{3}) && args{3} > 0 && args{3} <= 1
aggregPerc = args{3};
elseif numel(args) >= 3 && ~isempty(args{3})
warning('Invalid COV aggregation argument. Using default value: %.2f', aggregPerc);
end
case 'svm'
try
classif = load('sigInspectSVMClassifiers.mat'); % load pre-trained classifiers
classif = classif.classifiers;
catch err
error('could not load precalculated classifiers: search sigInspect root for the file sigInspectSVMClassifiers.mat (necessary for the svm classifiers)')
end
if nargin >= 5 && ~isempty(varargin)
param = varargin{1};
else
param = struct( ...
'POW', struct(), ...
'BASE', struct(), ...
'FREQ', struct());
end
% Extract artifact types from user's input struct (fields of param)
artifactTypes = intersect(fieldnames(param), fieldnames(classif)); % make sure it's valid
if isempty(artifactTypes)
artifactTypes = fieldnames(classif); % default to all if none provided
end
% Fill in default threshold where not provided
thresholds = struct();
for i = 1:numel(artifactTypes)
art = artifactTypes{i};
if isfield(param.(art), 'threshold')
thresholds.(art) = param.(art).threshold;
else
thresholds.(art) = 0.5;
end
end
% Collect unique relevant feature names
featNames = {};
for i = 1:numel(artifactTypes)
featList = classif.(artifactTypes{i}).featNames;
featNames = [featNames, featList]; % concatenate all
end
featNames = unique(featNames); % remove duplicates
featComp = 1:length(featNames);
% Init annotation matrix
Nartif = length(artifactTypes);
method = 'svm';
otherwise
error('Unknown method: %s',method)
end
% ---- COMPUTE FEATURES ----
cacheFile = 'featuresCache.mat'; % or full path if needed
if ~strcmpi(method, 'cov')
featVals = getOrComputeFeatures(signal, signalId, method, featNames, fs, cacheFile);
else
featVals = [];
end
% ---- CLASSIFY ----
switch(method)
case 'psd'
% compare to preset threshold
annot = featVals>psdThr;
% change dims to be Nch*Ns
annot = reshape(annot,Nch,Ns);
case 'tree'
% classify using decision tree
annot = eval(classif.tree,featVals);
annot=strcmp(annot,'1');
% change dims to be Nch*Ns
annot = reshape(annot,Nch,Ns);
case 'cov'
annot = false(Nch,Ns);
for chi=1:Nch
annot(chi,:) = sigInspectClassifyCov(signal(chi,:),fs,'cov', covThr, winLength, aggregPerc,false);
end
case 'svm'
artifactTypes = fieldnames(param); % Update artifact types from param keys
Nartif = length(artifactTypes);
annot = false(Nch, Ns, Nartif);
scores = zeros(Nch*Ns, Nartif); % soft outputs
for i = 1:Nartif
art = artifactTypes{i};
featureSet = classif.(art).featNames; % features for this artifact
[~, featureInds] = ismember(featureSet, featNames); % indices in featVals, in correct order
if any(featureInds == 0)
error('Some features required by the SVM are missing in the computed features!');
end
X = featVals(:, featureInds);
% Get classifier
model = classif.(art).svmProbModel;
% Check for NaNs and replace with zeros
X(isnan(X)) = 0;
% Predict using trained SVM – get scores (posterior probabilities)
[~, score] = predict(model, X); % score is N x 2, we take the 2nd column (class == 1)
scores(:, i) = score(:, 2); % probability of being artifact
% Apply threshold to convert to binary
annot(:,:,i) = reshape(scores(:, i) > thresholds.(art), Nch, Ns);
end
end
function featVals = getOrComputeFeatures(signal, signalId, method, featNames, fs, cacheFile)
% signal: matrix (channels x samples)
% signalId: string, e.g. '#1'
% method: string, e.g. 'svm'
% featNames: cell array of feature names
% fs: sampling frequency
% cacheFile: string, path to .mat file
% Make valid field names
signalIdField = matlab.lang.makeValidName(signalId);
methodField = matlab.lang.makeValidName(method);
% Try to load cache
featuresCache = struct();
if exist(cacheFile, 'file')
S = load(cacheFile);
% If the file contains a featuresCache struct, unwrap it to top-level fields
if isfield(S, 'featuresCache')
% Unwrap: assign each field of featuresCache to the workspace struct
fcFields = fieldnames(S.featuresCache);
for k = 1:numel(fcFields)
featuresCache.(fcFields{k}) = S.featuresCache.(fcFields{k});
end
else
% Already top-level signalId fields
featuresCache = S;
end
end
% Compute expected dimensions
Nch = size(signal, 1);
N = size(signal, 2);
Ns = ceil(N / fs);
Nfeat = numel(featNames);
expectedSize = [Nch * Ns, Nfeat];
% Check if features already cached
if isfield(featuresCache, signalIdField) && isfield(featuresCache.(signalIdField), methodField)
featVals = featuresCache.(signalIdField).(methodField);
% Verify cached size
if isequal(size(featVals), expectedSize)
disp(['Using cached features for ', signalIdField, ' (', method, ')']);
return;
else
warning(['Cached features for ', signalIdField, ' (', method, ...
') have mismatched size. Expected [', num2str(expectedSize), ...
'], got [', num2str(size(featVals)), ']. Recomputing...']);
end
end
% Compute features
featVals = nan(expectedSize);
for si = 1:Ns
inds = (si-1)*fs+1 : min(si*fs, N);
fv = sigInspectComputeFeatures(signal(:,inds), featNames, fs);
nChHere = size(fv,1);
featVals((si-1)*Nch + (1:nChHere), :) = fv;
end
% Save to cache (append or create) as top-level signalId fields
featuresCache.(signalIdField).(methodField) = featVals;
if exist(cacheFile, 'file')
save(cacheFile, '-struct', 'featuresCache', '-append');
else
save(cacheFile, '-struct', 'featuresCache');
end
end
function val = unwrapCell(x)
% unwrap { [value] } → value
if iscell(x)
val = x{1};
else
val = x;
end
end
end