-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathWaveletUtils.py
More file actions
146 lines (131 loc) · 4.87 KB
/
WaveletUtils.py
File metadata and controls
146 lines (131 loc) · 4.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import numpy as np
import torch
import pywt
import pandas as pd
import math
def imputeVals(X, imputation='mean'):
"""
Imputes missing values in the input tensor X based on the specified imputation method.
Parameters:
X (numpy.ndarray): Input tensor with missing values.
imputation (str): Imputation method ('mean', 'forward', 'zero', 'backward').
Returns:
numpy.ndarray: Tensor with imputed values.
"""
Xs = []
for i in range(X.shape[-1]):
df = pd.DataFrame(X[:, :, i].transpose())
if imputation == 'mean':
df = df.fillna(df.mean())
elif imputation == 'forward':
df = df.ffill()
elif imputation == 'zero':
df = df.fillna(0)
elif imputation == 'backward':
df = df.bfill()
Xs.append(df.to_numpy().transpose())
return np.stack(Xs, -1)
def Regularize(X, times, imputation='mean'):
"""
Regularizes the input tensor X based on the provided times and imputation method.
Parameters:
X (list): List of input tensors.
times (list): List of time intervals.
imputation (str): Imputation method ('mean', 'forward', 'zero', 'backward').
Returns:
torch.Tensor: Regularized tensor.
torch.Tensor: All times tensor.
"""
size = X[0].shape[0]
series = []
for x, t in zip(X, times):
A = pd.DataFrame(x.T, index=t)
series.append(A)
df = pd.concat(series, axis=1)
AllTimes = torch.tensor(df.index.values).float()
X = df.to_numpy() # shape: Len, size*feats
X = X.reshape([X.shape[0], -1, size]) # shape: Len, feats, size
X = X.transpose([2, 0, 1])
X = imputeVals(X, imputation=imputation)
return torch.tensor(X), AllTimes
def getdeltaTimes(times):
"""
Computes the delta times from the provided times.
Parameters:
times (torch.Tensor): Input times tensor.
Returns:
torch.Tensor: Delta times tensor.
"""
Times = times.clone()
for i in reversed(range(1, len(times))):
Times[i] = times[i] - times[i-1]
return Times
def getRNNFreqGroups_mr(data, times, device=torch.device("cuda:0"), maxlevels=4, waveletType='haar', imputation='mean', fulldata=None, regularize=True, return_times=False):
"""
Computes the frequency groups for RNN using multi-resolution wavelet decomposition.
Parameters:
data (list): List of input tensors.
times (list): List of time intervals.
device (torch.device): Device to use for computation.
maxlevels (int): Maximum levels for wavelet decomposition.
waveletType (str): Type of wavelet to use.
imputation (str): Imputation method ('mean', 'forward', 'zero', 'backward').
fulldata (list, optional): Full data tensor.
regularize (bool): Whether to regularize the data.
return_times (bool): Whether to return times.
Returns:
list: List of frequency groups.
list (optional): List of times if return_times is True.
"""
WL = pywt.Wavelet(waveletType)
MLs = []
Outs = [[] for _ in range(maxlevels + 1)]
Ts = [[] for _ in range(maxlevels + 1)]
for d in data:
ML = pywt.dwt_max_level(d.shape[1], WL)
MLs.append(ML)
dL = max(MLs) - maxlevels
MaxT = max([max(t) for t in times])
for i, d in enumerate(data):
out = pywt.wavedec(d, WL, level=MLs[i] - dL, axis=1, mode='periodization')
for j, o in enumerate(out):
Outs[j].append(o)
TSubSamp = math.ceil(times[i].shape[0] / o.shape[1])
Ts[j].append(times[i][::TSubSamp])
if fulldata is None:
Outs.append([d.cpu().numpy() for d in data]) # Convert tensors to CPU tensors and then to NumPy arrays
Ts.append(times)
if regularize:
Times = []
Outs_ls = []
for x, t in zip(Outs, Ts):
o, time = Regularize(x, t, imputation)
time /= MaxT
time = getdeltaTimes(time)
Outs_ls.append(o)
Times.append(time)
Outs = Outs_ls
else:
Outs = [[torch.tensor(x) for x in x_arr] for x_arr in Outs]
if fulldata is not None:
Outs.append(fulldata)
if return_times:
return Outs, Times
return Outs
def getRNNFreqGroups(data, device=torch.device("cuda:0"), maxlevels=4, waveletType='haar'):
"""
Computes the frequency groups for RNN using wavelet decomposition.
Parameters:
data (torch.Tensor): Input tensor.
device (torch.device): Device to use for computation.
maxlevels (int): Maximum levels for wavelet decomposition.
waveletType (str): Type of wavelet to use.
Returns:
list: List of frequency groups.
"""
WL = pywt.Wavelet(waveletType)
ML = pywt.dwt_max_level(data.shape[1], WL)
out = pywt.wavedec(data, WL, level=min(maxlevels, ML), axis=1)
out.append(data)
out = [torch.tensor(o) for o in out] # Convert to tensor
return out