-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_kf.py
More file actions
38 lines (30 loc) · 1.9 KB
/
test_kf.py
File metadata and controls
38 lines (30 loc) · 1.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import argparse
import torch
import torch.nn.functional as F
from rich.progress import track
from dataset import load_data
from kalman_filter import kf_predict
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Test kalman filter'
)
parser.add_argument('-d', '--data', type=str, help='Path to the test data', default='simulations/test.pkl')
args = parser.parse_args()
test_data = load_data(args.data)
real_errors = torch.zeros((len(test_data.simulations), len(test_data.simulations[0].states), 3))
cov_errors = torch.zeros((len(test_data.simulations), len(test_data.simulations[0].states), 3))
real_ext_errors = torch.zeros((len(test_data.simulations), len(test_data.simulations[0].states), 3))
cov_ext_errors = torch.zeros((len(test_data.simulations), len(test_data.simulations[0].states), 3))
with torch.no_grad():
for idx, simulation in track(enumerate(test_data.simulations), description='Testing Kalman filter'):
ekf_states, ekf_P_history = kf_predict(False, simulation, test_data)
ekf_ext_states, ekf_ext_P_history = kf_predict(True, simulation, test_data)
states = torch.tensor([(state.real_position[0], state.real_position[1], state.real_attitude) for state in simulation.states])
real_errors[idx] = F.mse_loss(torch.from_numpy(ekf_states[1:]), states, reduction='none')
cov_errors[idx] = torch.from_numpy(ekf_P_history[1:]).diagonal(dim1=1, dim2=2)
real_ext_errors[idx] = F.mse_loss(torch.from_numpy(ekf_ext_states[1:]), states, reduction='none')
cov_ext_errors[idx] = torch.from_numpy(ekf_ext_P_history[1:]).diagonal(dim1=1, dim2=2)
torch.save(real_errors, 'errors/kf_real_errors.pt')
torch.save(cov_errors, 'errors/kf_cov_errors.pt')
torch.save(real_ext_errors, 'errors/kf_ext_real_errors.pt')
torch.save(cov_ext_errors, 'errors/kf_ext_cov_errors.pt')