Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 175 additions & 0 deletions cpp/modmesh/linalg/kalman_filter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,41 @@ template <typename T>
struct select_real_t;
} /* end namespace detail */

/**
* @brief Return type for `KalmanFilter<T>::batch_filter(...)`.
*
* @details
* Due to the iteration of the predict and update steps in `KalmanFilter<T>::batch_filter(...)`, the
* intermediate results are stored for each time step, including prior and posterior
* states and their covariances.
* @see BFType<T> KalmanFilter<T>::batch_filter(array_type const & zs)
* @see BFType<T> KalmanFilter<T>::batch_filter(array_type const & zs, array_type const & us)
*/
template <typename T>
struct BFType
{
using tuple_type = std::tuple<SimpleArray<T>, SimpleArray<T>, SimpleArray<T>, SimpleArray<T>>;

BFType(size_t observation_size, size_t state_size)
{
small_vector<size_t> xs_shape{observation_size, state_size};
small_vector<size_t> ps_shape{observation_size, state_size, state_size};
prior_state = SimpleArray<T>(xs_shape);
prior_state_covariance = SimpleArray<T>(ps_shape);
posterior_state = SimpleArray<T>(xs_shape);
posterior_state_covariance = SimpleArray<T>(ps_shape);
}
tuple_type to_tuple() const
{
return std::make_tuple(prior_state, prior_state_covariance, posterior_state, posterior_state_covariance);
}

SimpleArray<T> prior_state;
SimpleArray<T> prior_state_covariance;
SimpleArray<T> posterior_state;
SimpleArray<T> posterior_state_covariance;
}; /* end struct BFType */
Comment on lines +54 to +76
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change to POD design, and name it as BFtype.


/**
* Reference: FilterPy KalmanFilter documentation
* https://filterpy.readthedocs.io/en/latest/kalman/KalmanFilter.html
Expand Down Expand Up @@ -205,6 +240,31 @@ class KalmanFilter
update_covariance(k);
}

/**
* @brief Predict and update in batch mode without a batch of control input us.
*
* @ref https://filterpy.readthedocs.io/en/latest/_modules/filterpy/kalman/kalman_filter.html#KalmanFilter.batch_filter
*
* @param zs A batch of measurement inputs.
*
* @see BFType<T> KalmanFilter<T>::batch_filter(array_type const & zs, array_type const & us)
* @see struct BFType<T>;
*/
Comment on lines +243 to +252
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The api design is refered from filter.py, and it is the origin of the api naming.

BFType<T> batch_filter(array_type const & zs);

/**
* @brief Predict and update in batch mode with a batch of control input us.
*
* @ref https://filterpy.readthedocs.io/en/latest/_modules/filterpy/kalman/kalman_filter.html#KalmanFilter.batch_filter
*
* @param zs A batch of measurement inputs.
* @param us A batch of control inputs.
*
* @see BFType<T> KalmanFilter<T>::batch_filter(array_type const & zs)
* @see struct BFType<T>;
*/
BFType<T> batch_filter(array_type const & zs, array_type const & us);

private:

void check_dimensions();
Expand All @@ -223,6 +283,10 @@ class KalmanFilter
void update_state(array_type const & k, array_type const & y);
void update_covariance(array_type const & k);

// Batch filter
void predict_and_update(array_type const & z, BFType<T> & bfs, size_t iter);
void predict_and_update(array_type const & z, array_type const & u, BFType<T> & bfs, size_t iter);

}; /* end class KalmanFilter */

template <typename T>
Expand Down Expand Up @@ -493,6 +557,117 @@ void KalmanFilter<T>::update_covariance(array_type const & k)
m_p = m_p.symmetrize();
}

template <typename T>
BFType<T> KalmanFilter<T>::batch_filter(array_type const & zs)
{
size_t z_m = zs.shape(0);
size_t z_n = zs.shape(1);
array_type z(small_vector<size_t>{z_n});
BFType<T> bfs(z_m, m_state_size);

for (size_t iter = 0; iter < z_m; ++iter)
{
for (size_t j = 0; j < z_n; ++j)
{
z(j) = zs(iter, j);
}
predict_and_update(z, bfs, iter);
}
return bfs;
}

template <typename T>
void KalmanFilter<T>::predict_and_update(array_type const & z, BFType<T> & bfs, size_t iter)
{
SimpleArray<T> & prior_state = bfs.prior_state;
SimpleArray<T> & prior_state_covariance = bfs.prior_state_covariance;
SimpleArray<T> & posterior_state = bfs.posterior_state;
SimpleArray<T> & posterior_state_covariance = bfs.posterior_state_covariance;

predict();
for (size_t j = 0; j < m_state_size; ++j)
{
prior_state(iter, j) = m_x(j);
for (size_t k = 0; k < m_state_size; ++k)
{
prior_state_covariance(iter, j, k) = m_p(j, k);
}
}

update(z);
for (size_t j = 0; j < m_state_size; ++j)
{
posterior_state(iter, j) = m_x(j);
for (size_t k = 0; k < m_state_size; ++k)
{
posterior_state_covariance(iter, j, k) = m_p(j, k);
}
}
}

template <typename T>
BFType<T> KalmanFilter<T>::batch_filter(array_type const & zs, array_type const & us)
{
size_t z_m = zs.shape(0);
size_t z_n = zs.shape(1);
array_type z(small_vector<size_t>{z_n});
BFType<T> bfs(z_m, m_state_size);

array_type u;
size_t u_n = 0;

size_t u_m = us.shape(0);
if (u_m != z_m)
{
throw std::invalid_argument("KalmanFilter::batch_filter: The number of control inputs must match the number of measurements.");
}
u_n = us.shape(1);
u = array_type(small_vector<size_t>{u_n});

for (size_t iter = 0; iter < z_m; ++iter)
{
for (size_t j = 0; j < z_n; ++j)
{
z(j) = zs(iter, j);
}
for (size_t j = 0; j < u_n; ++j)
{
u(j) = us(iter, j);
}
predict_and_update(z, u, bfs, iter);
}
return bfs;
}

template <typename T>
void KalmanFilter<T>::predict_and_update(array_type const & z, array_type const & u, BFType<T> & bfs, size_t iter)
{
SimpleArray<T> & prior_state = bfs.prior_state;
SimpleArray<T> & prior_state_covariance = bfs.prior_state_covariance;
SimpleArray<T> & posterior_state = bfs.posterior_state;
SimpleArray<T> & posterior_state_covariance = bfs.posterior_state_covariance;

predict(u);
for (size_t j = 0; j < m_state_size; ++j)
{
prior_state(iter, j) = m_x(j);
for (size_t k = 0; k < m_state_size; ++k)
{
prior_state_covariance(iter, j, k) = m_p(j, k);
}
}

update(z);
for (size_t j = 0; j < m_state_size; ++j)
{
posterior_state(iter, j) = m_x(j);
for (size_t k = 0; k < m_state_size; ++k)
{
posterior_state_covariance(iter, j, k) = m_p(j, k);
}
}
}

} /* end namespace modmesh */

// vim: set ff=unix fenc=utf8 et sw=4 ts=4 sts=4:
17 changes: 16 additions & 1 deletion cpp/modmesh/linalg/pymod/wrap_kalman_filter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,22 @@ class MODMESH_PYTHON_WRAPPER_VISIBILITY WrapKalmanFilter
.def(
"update",
&wrapped_type::update,
py::arg("z"));
py::arg("z"))
.def(
"batch_filter",
[](wrapped_type & self, array_type const & zs, array_type const & us)
{
return self.batch_filter(zs, us).to_tuple();
},
py::arg("zs"),
py::arg("us"))
.def(
"batch_filter",
[](wrapped_type & self, array_type const & zs)
{
return self.batch_filter(zs).to_tuple();
},
py::arg("zs"));
Comment on lines +97 to +111
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To pass a few return variables, the helper class should be casted to std::tuple.

}

}; /* end class WrapKalmanFilter */
Expand Down
117 changes: 117 additions & 0 deletions tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,3 +671,120 @@ def test_wrong_shape_control_predict(self):
"must be 1D of length control_size \\(1\\), but got shape "
"\\(2\\)"):
kf.predict(u_wrong_sa)


class KalmanFilterBatchFilterTC(unittest.TestCase):

def kf_batchfilter_numpy(self, kf, zs, us=None):
m = zs.shape[0]
n = kf.state.shape[0]
xs_pred_np = np.zeros((m, n))
ps_pred_np = np.zeros((m, n, n))
xs_upd_np = np.zeros((m, n))
ps_upd_np = np.zeros((m, n, n))
for i in range(m):
if us is not None:
u = us[i]
u_sa = sa_from_np(u, type(kf.state))
kf.predict(u_sa)
else:
kf.predict()
x_pred = kf.state.ndarray
p_pred = kf.covariance.ndarray
xs_pred_np[i] = x_pred
ps_pred_np[i] = p_pred

z = zs[i]
z_sa = sa_from_np(z, type(kf.state))
kf.update(z_sa)
x_upd = kf.state.ndarray
p_upd = kf.covariance.ndarray
xs_upd_np[i] = x_upd
ps_upd_np[i] = p_upd
return xs_pred_np, ps_pred_np, xs_upd_np, ps_upd_np

def test_batchfilter(self):
m = 50
x0 = np.array([1.0, 2.0, 3.0])
f = np.array([[1.1, 0.2, 0.3],
[0.1, 0.9, 0.7],
[4.7, 5.2, 6.7]])
h = np.array([[1.0, 3.0, 2.0],
[4.0, 0.2, 0.1]])
sigma_w = 0.316
zs = np.zeros((m, 2))
for i in range(m):
zs[i] = np.array([i * i, i * 0.5 + 1.0])
x_sa = mm.SimpleArrayFloat64(array=x0)
f_sa = mm.SimpleArrayFloat64(array=f)
h_sa = mm.SimpleArrayFloat64(array=h)
zs_sa = mm.SimpleArrayFloat64(array=zs)

kf = mm.KalmanFilterFp64(
x=x_sa, f=f_sa, h=h_sa,
process_noise=sigma_w,
measurement_noise=1.0,
)
bps = kf.batch_filter(zs_sa)
xs_pred, ps_pred, xs_upd, ps_upd = bps

kf = mm.KalmanFilterFp64(
x=x_sa, f=f_sa, h=h_sa,
process_noise=sigma_w,
measurement_noise=1.0,
)
bps_np = self.kf_batchfilter_numpy(kf, zs)
xs_pred_np, ps_pred_np, xs_upd_np, ps_upd_np = bps_np

np.testing.assert_allclose(xs_pred, xs_pred_np, atol=1e-12, rtol=0.0)
np.testing.assert_allclose(ps_pred, ps_pred_np, atol=1e-12, rtol=0.0)
np.testing.assert_allclose(xs_upd, xs_upd_np, atol=1e-12, rtol=0.0)
np.testing.assert_allclose(ps_upd, ps_upd_np, atol=1e-12, rtol=0.0)

def test_batchfilter_with_control(self):
m = 50
x0 = np.array([1.0, 2.0, 3.0])
f = np.array([[1.1, 0.2, 0.3],
[0.1, 0.9, 0.7],
[4.7, 5.2, 6.7]])
h = np.array([[1.0, 3.0, 2.0],
[4.0, 0.2, 0.1]])
b = np.array([[0.7, 0.2, 5.3],
[3.1, 0.9, 1.7],
[4.7, 5.2, 6.7]])
sigma_w = 0.316
zs = np.zeros((m, 2))
for i in range(m):
zs[i] = np.array([i * i, i * 0.5 + 1.0])
us = np.zeros((m, 3))
for i in range(m):
us[i] = np.array([i, pow(i, 3.5), pow(i, 0.5)])
x_sa = mm.SimpleArrayFloat64(array=x0)
f_sa = mm.SimpleArrayFloat64(array=f)
b_sa = mm.SimpleArrayFloat64(array=b)
h_sa = mm.SimpleArrayFloat64(array=h)
zs_sa = mm.SimpleArrayFloat64(array=zs)
us_sa = mm.SimpleArrayFloat64(array=us)

kf = mm.KalmanFilterFp64(
x=x_sa, f=f_sa, b=b_sa, h=h_sa,
process_noise=sigma_w,
measurement_noise=1.0,
)
bps = kf.batch_filter(zs_sa, us_sa)
xs_pred, ps_pred, xs_upd, ps_upd = bps

kf = mm.KalmanFilterFp64(
x=x_sa, f=f_sa, b=b_sa, h=h_sa,
process_noise=sigma_w,
measurement_noise=1.0,
)
bps_np = self.kf_batchfilter_numpy(kf, zs, us)
xs_pred_np, ps_pred_np, xs_upd_np, ps_upd_np = bps_np

np.testing.assert_allclose(xs_pred, xs_pred_np, atol=1e-12, rtol=0.0)
np.testing.assert_allclose(ps_pred, ps_pred_np, atol=1e-12, rtol=0.0)
np.testing.assert_allclose(xs_upd, xs_upd_np, atol=1e-12, rtol=0.0)
np.testing.assert_allclose(ps_upd, ps_upd_np, atol=1e-12, rtol=0.0)

# vim: set ff=unix fenc=utf8 et sw=4 ts=4 sts=4:
Loading