-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
326 lines (265 loc) · 9.2 KB
/
utils.py
File metadata and controls
326 lines (265 loc) · 9.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
#!/usr/bin/env python
# coding: utf-8
"""
Miscellaneous functions for fNIRS data processing (mostly in context of using MNE).
"""
# File Path Manipulation
import pathlib
# Additional Iteration Utilities
from itertools import compress
from collections import Counter
# Additional Function Utilities
from functools import partial
# Regex
# TIP - Use regexr.com to create/test/learn Regex
import re
# Array
import numpy as np
# Additional Configuration/Metadata
import constants
# MNE Constants
from mne.defaults import HEAD_SIZE_DEFAULT
def dec_to_hex(ch_names):
"""Converts channel IDs from decimal to hexadecimal.
Parameters
----------
ch_names : array-like
List of channel names with decimal IDs.
Returns
-------
list
List of channel names with hexadecimal IDs.
"""
def convert(ch_name):
s, d, r = re.compile(r'S(\d+)_D(\d+)(.*)').match(ch_name).groups()
return f'S{int(s):X}_D{int(d):X}{r}'
return list(map(convert, ch_names))
def hex_to_dec(ch_names):
"""Converts channel IDs from hexadecimal to decimal.
Parameters
----------
ch_names : array-like
List of channel names with hexadecimal IDs.
Returns
-------
list
List of channel names with decimal IDs.
"""
def convert(ch_name):
s, d, r = re.compile(r'S([0-9A-F]+)_D([0-9A-F]+)(.*)').match(ch_name).groups()
return f'S{int(s, base=16)}_D{int(d, base=16)}{r}'
return list(map(convert, ch_names))
def get_s_d(ch_names):
"""Gets the unique source detector names without wavelength/chromophore labels.
Parameters
----------
ch_names : array-like
List of channel names with wavelength/chromophore labels.
Returns
-------
list
Ordered list of unique channel names without wavelength/chromophore labels.
"""
return list(dict.fromkeys(map(lambda ch_name: ch_name.split()[0], ch_names)))
def is_short_channel(ch_name):
"""Check if channel is short based on its name.
Parameters
----------
channel : str
Channel name.
Returns
-------
bool
True if channel is short.
"""
return re.compile(r'S(\d+)_D\1').match(ch_name)
is_long_channel = lambda ch_name: not is_short_channel(ch_name)
def is_channel_type(ch_type, ch_name):
"""Check if channel is of given type based on its name.
Parameters
----------
ch_type: str
Channel type.
channel : str
Channel name.
Returns
-------
bool
True if channel is of given type.
"""
return bool(re.compile(fr'S\d+_D\d+ {ch_type}').match(ch_name))
def find_short_channels(ch_names):
"""Find short channels from names.
Parameters
----------
ch_names : array-like
List of channel names.
Returns
-------
list
List of short channels (names).
list
List of indices of short channels.
"""
return list(filter(is_short_channel, ch_names)), list(compress(range(len(ch_names)), map(is_short_channel, ch_names)))
def find_long_channels(ch_names):
"""Find long channels from names.
Parameters
----------
ch_names : array-like
List of channel names.
Returns
-------
list
List of long channels (names).
list
List of indices of long channels.
"""
return list(filter(is_long_channel, ch_names)), list(compress(range(len(ch_names)), map(is_long_channel, ch_names)))
def find_channels_type(ch_type, ch_names):
"""Find channels of given type from names.
Parameters
----------
ch_type: str
Channel type.
ch_names : array-like
List of channel names.
Returns
-------
list
List of channels (names) of given type.
list
List of indices of these channels.
"""
return list(filter(partial(is_channel_type, ch_type), ch_names)), list(compress(range(len(ch_names)), map(partial(is_channel_type, ch_type), ch_names)))
def find_channels(ch_type, separation, ch_names):
"""Find channels of given type and separation from names.
Parameters
----------
ch_type: str
Channel type.
separation: {'long', 'short'}
Channel separation.
ch_names : array-like
List of channel names.
Returns
-------
list
List of channels (names) of given type and separation.
list
List of indices of these channels.
"""
channels, chs = find_channels_type(ch_type, ch_names)
channels, chs_ = {'long': find_long_channels, 'short': find_short_channels}[separation](channels)
return channels, [chs[i] for i in chs_]
def filter_channels(regex, function, ch_names):
"""Match channels based on the given regex and filter using the given function.
Parameters
----------
regex: str
Regex to match by.
function: function
Function that takes captured groups as arguments and returns a boolean.
ch_names : array-like
List of channel names.
Returns
-------
list
List of channels (names) filtered using the regex and passed funciton.
list
List of indices of these channels.
"""
picks = list(filter(lambda ch_name, regex=re.compile(regex): function(*regex.match(ch_name).groups()), ch_names))
return picks, [ch_names.index(pick) for pick in picks]
def select_best_wavelengths(wavelengths, *args):
# TODO: Automate wavelength selection based on absorption spectra of hemoglobin.
pass
def find_ch_pairs(ch_names, channels):
"""Find the other channels with the same source-detector pair as the queried channels.
Parameters
----------
ch_names : array-like
List of channel names.
channels : array-like, str or int
List of channel names or ids to be paired.
Returns
-------
list
Names/ids of remaining channels with same source-detector pair.
"""
match next(iter(channels)):
case int():
id = True
channels = [ch_names[ch] for ch in channels]
case str():
id = False
return [ch if id else ch_name for ch, ch_name in enumerate(ch_names) if ch_name not in channels and get_s_d([ch_name])[0] in get_s_d(channels)]
def find_ch_paired(channels, ch_names=None):
"""Find the channels that are paired.
Parameters
----------
channels : array-like, str or int
List of channel names or ids to be filtered.
ch_names : array-like, optional
List of channel names.
Required if channels is a list of ids.
Returns
-------
list
Names/ids of channels with same source-detector pair.
"""
match next(iter(channels)):
case int():
id = True
channels = [ch_names[ch] for ch in channels]
case str():
id = False
counts = Counter(map(lambda ch_name: get_s_d([ch_name])[0], channels))
return [ch_names.index(ch_name) if id else ch_name for ch_name in channels if counts[get_s_d([ch_name])[0]] == 2]
def has_location(source, pos):
match source:
case pathlib.PurePath() | str(): # source is a filename
with open(source) as file:
for words in file:
if pos == re.split('\t| +|\n', words)[0]:
return True
return False
case _: # source is a montage
if pos in ('nasion', 'lpa', 'rpa'):
return source.get_positions()[pos] is not None
else:
return pos in source.get_positions()['ch_pos']
def get_location(source, pos):
match source:
case pathlib.PurePath() | str(): # source is a filename
def _get_line_number(word, file):
file.seek(0)
for i, words in enumerate(file, 1):
if word == re.split('\t| +|\n', words)[0]:
return i
def _get_word(num, file):
file.seek(0)
for i, words in enumerate(file, 1):
if num == i:
return words
with open(source) as file:
line = _get_line_number('Positions', file) - _get_line_number('Labels', file) + _get_line_number(pos, file)
return list(map(float, re.split('\t| +', _get_word(line, file).rsplit('\n')[0])))
case _: # source is a montage
if pos in ('nasion', 'lpa', 'rpa'):
return source.get_positions()[pos]
else:
return source.get_positions()['ch_pos'][pos]
def get_transformation(montage, reference=constants.DEFAULT_REFERENCE_LOCATIONS, scale=1/HEAD_SIZE_DEFAULT):
"""Transform montage based on expected location of reference points."""
available_pos = list(compress(reference, map(partial(has_location, montage), reference)))
if len(available_pos) < 4:
raise ValueError(f'At least 4 points are requred to get complete transformation.')
target = np.array([loc for pos, loc in reference.items() if pos in available_pos])
target = np.c_[target, np.ones((len(available_pos), 1))]
base = np.array(list(map(partial(get_location, montage), available_pos))) * scale
base = np.c_[base, np.ones((len(available_pos), 1))]
trans = np.linalg.pinv(base) @ target
return trans.T
if __name__ == '__main__':
pass