-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsample.py
More file actions
75 lines (56 loc) · 2.64 KB
/
Copy pathsample.py
File metadata and controls
75 lines (56 loc) · 2.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from pathlib import Path
import numpy as np
from loguru import logger
from scipy.io import loadmat
class SampleHeader:
def __init__(self, header_data: list[str]):
self.header_data = header_data
def __str__(self) -> str:
return '\n'.join(self.header_data)
def __len__(self) -> int:
return int(self.header_data[0].split(' ')[3])
@property
def age(self) -> int:
return int(self.header_data[13][6:])
@property
def gender(self) -> str:
return self.header_data[14][6:]
@property
def codes(self) -> list[int]:
codes_str = self.header_data[15][5:]
return [int(code) for code in codes_str.split(',')]
def filtered_codes(self, relevant_codes: list[int]) -> list[int]:
return [code for code in self.codes if code in relevant_codes]
def load_mat(path: Path) -> np.ndarray:
"""Load a signal .mat file into a numpy array with shape [12, N]."""
mat_path = path if path.suffix == '.mat' else path.with_suffix('.mat')
return np.asarray(loadmat(mat_path)['val'], dtype=np.float64)
def load_hea(path: Path) -> SampleHeader:
"""Load a header .hea file into a custom SampleHeader object."""
hea_path = path if path.suffix == '.hea' else path.with_suffix('.hea')
with open(hea_path, 'r') as f:
lines = [l.strip() for l in f.readlines()]
return SampleHeader(lines)
def load_sample(path: Path) -> tuple[np.ndarray, SampleHeader]:
"""Load a sample (signal and header)."""
data = load_mat(path)
header_data = load_hea(path)
return data, header_data
def get_samples_paths(data_dir: Path, limit_sample_length: tuple | int | None = None) -> list[Path]:
"""Get a list of all sample paths in the specified data directory."""
samples_paths = []
for ds_path in data_dir.iterdir():
ds_samples_paths = sorted(set([i.with_suffix('') for i in ds_path.iterdir()]))
samples_paths += list(ds_samples_paths)
logger.info(f'Got total of {len(samples_paths):,} samples')
if limit_sample_length:
lengths = [len(load_hea(path)) for path in samples_paths]
if isinstance(limit_sample_length, int):
min_length, max_length = limit_sample_length, limit_sample_length
else:
min_length, max_length = limit_sample_length
samples_paths = [path for path, length in zip(samples_paths, lengths)
if min_length <= length <= max_length]
logger.info(f'Filtering samples that not satisfies: {min_length} <= length <= {max_length}')
logger.info(f'Got total of {len(samples_paths):,} samples after filtering')
return samples_paths