-
Notifications
You must be signed in to change notification settings - Fork 12
Expand file tree
/
Copy pathEKF_test.py
More file actions
83 lines (58 loc) · 2.57 KB
/
EKF_test.py
File metadata and controls
83 lines (58 loc) · 2.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch.nn as nn
import torch
import time
from EKF import ExtendedKalmanFilter
def EKFTest(SysModel, test_input, test_target, modelKnowledge = 'full', allStates=True):
N_T = test_target.size()[0]
# LOSS
loss_fn = nn.MSELoss(reduction='mean')
# MSE [Linear]
MSE_EKF_linear_arr = torch.empty(N_T)
start = time.time()
EKF = ExtendedKalmanFilter(SysModel, modelKnowledge)
EKF.InitSequence(SysModel.m1x_0, SysModel.m2x_0)
KG_array = torch.zeros_like(EKF.KG_array)
EKF_out = torch.empty([N_T, SysModel.m, SysModel.T_test])
for j in range(0, N_T):
EKF.GenerateSequence(test_input[j, :, :], EKF.T_test)
if(allStates):
MSE_EKF_linear_arr[j] = loss_fn(EKF.x, test_target[j, :, :]).item()
else:
loc = torch.tensor([True,False,True,False])
MSE_EKF_linear_arr[j] = loss_fn(EKF.x[loc,:], test_target[j, :, :]).item()
KG_array = torch.add(EKF.KG_array, KG_array)
EKF_out[j,:,:] = EKF.x
end = time.time()
t = end - start
# Average KG_array over Test Examples
KG_array /= N_T
MSE_EKF_linear_avg = torch.mean(MSE_EKF_linear_arr)
MSE_EKF_dB_avg = 10 * torch.log10(MSE_EKF_linear_avg)
print("Extended Kalman Filter - MSE LOSS:", MSE_EKF_dB_avg, "[dB]")
# Print Run Time
print("Inference Time:", t)
return [MSE_EKF_linear_arr, MSE_EKF_linear_avg, MSE_EKF_dB_avg, KG_array, EKF_out]
def EKFTest_evol(SysModel, test_input, test_target, modelKnowledge = 'full'):
N_T = test_target.size()[0]
# LOSS
loss_fn = nn.MSELoss(reduction='none')
# MSE [Linear]
MSE_EKF_linear_arr = torch.empty(N_T,SysModel.m, SysModel.T_test)
EKF = ExtendedKalmanFilter(SysModel, modelKnowledge)
EKF.InitSequence(SysModel.m1x_0, SysModel.m2x_0)
KG_array = torch.empty([N_T, SysModel.T_test, SysModel.m, SysModel.n])
KG_trace = torch.empty([SysModel.T_test])
EKF_out = torch.empty([N_T, SysModel.m, SysModel.T_test])
for j in range(0, N_T):
EKF.GenerateSequence(test_input[j, :, :], EKF.T_test)
MSE_EKF_linear_arr[j,:,:] = loss_fn(EKF.x, test_target[j, :, :])
KG_array[j,:,:,:] = EKF.KG_array
EKF_out[j,:,:] = EKF.x
# Average KG_array over Test Examples
KG_avg = torch.mean(KG_array,0)
for j in range(0, SysModel.T_test):
KG_trace[j] = torch.trace(KG_avg[j,:,:])
MSE_EKF_linear_avg = torch.mean(MSE_EKF_linear_arr, [0,1])
MSE_EKF_dB_avg = 10 * torch.log10(MSE_EKF_linear_avg)
trace_dB_avg = 10* torch.log10(KG_trace)
return [MSE_EKF_dB_avg, trace_dB_avg]