From 9d141e7a870fcc01547d807fcbea53770bed6b3d Mon Sep 17 00:00:00 2001 From: ThreeMonth03 Date: Mon, 19 Jan 2026 02:22:45 +0800 Subject: [PATCH] Feat kalman_filter.hpp: Implement batch filter --- cpp/modmesh/linalg/kalman_filter.hpp | 175 ++++++++++++++++++ .../linalg/pymod/wrap_kalman_filter.cpp | 17 +- tests/test_linalg.py | 117 ++++++++++++ 3 files changed, 308 insertions(+), 1 deletion(-) diff --git a/cpp/modmesh/linalg/kalman_filter.hpp b/cpp/modmesh/linalg/kalman_filter.hpp index 5d3ed9fe..245bb1aa 100644 --- a/cpp/modmesh/linalg/kalman_filter.hpp +++ b/cpp/modmesh/linalg/kalman_filter.hpp @@ -40,6 +40,41 @@ template struct select_real_t; } /* end namespace detail */ +/** + * @brief Return type for `KalmanFilter::batch_filter(...)`. + * + * @details + * Due to the iteration of the predict and update steps in `KalmanFilter::batch_filter(...)`, the + * intermediate results are stored for each time step, including prior and posterior + * states and their covariances. + * @see BFType KalmanFilter::batch_filter(array_type const & zs) + * @see BFType KalmanFilter::batch_filter(array_type const & zs, array_type const & us) + */ +template +struct BFType +{ + using tuple_type = std::tuple, SimpleArray, SimpleArray, SimpleArray>; + + BFType(size_t observation_size, size_t state_size) + { + small_vector xs_shape{observation_size, state_size}; + small_vector ps_shape{observation_size, state_size, state_size}; + prior_state = SimpleArray(xs_shape); + prior_state_covariance = SimpleArray(ps_shape); + posterior_state = SimpleArray(xs_shape); + posterior_state_covariance = SimpleArray(ps_shape); + } + tuple_type to_tuple() const + { + return std::make_tuple(prior_state, prior_state_covariance, posterior_state, posterior_state_covariance); + } + + SimpleArray prior_state; + SimpleArray prior_state_covariance; + SimpleArray posterior_state; + SimpleArray posterior_state_covariance; +}; /* end struct BFType */ + /** * Reference: FilterPy KalmanFilter documentation * https://filterpy.readthedocs.io/en/latest/kalman/KalmanFilter.html @@ -205,6 +240,31 @@ class KalmanFilter update_covariance(k); } + /** + * @brief Predict and update in batch mode without a batch of control input us. + * + * @ref https://filterpy.readthedocs.io/en/latest/_modules/filterpy/kalman/kalman_filter.html#KalmanFilter.batch_filter + * + * @param zs A batch of measurement inputs. + * + * @see BFType KalmanFilter::batch_filter(array_type const & zs, array_type const & us) + * @see struct BFType; + */ + BFType batch_filter(array_type const & zs); + + /** + * @brief Predict and update in batch mode with a batch of control input us. + * + * @ref https://filterpy.readthedocs.io/en/latest/_modules/filterpy/kalman/kalman_filter.html#KalmanFilter.batch_filter + * + * @param zs A batch of measurement inputs. + * @param us A batch of control inputs. + * + * @see BFType KalmanFilter::batch_filter(array_type const & zs) + * @see struct BFType; + */ + BFType batch_filter(array_type const & zs, array_type const & us); + private: void check_dimensions(); @@ -223,6 +283,10 @@ class KalmanFilter void update_state(array_type const & k, array_type const & y); void update_covariance(array_type const & k); + // Batch filter + void predict_and_update(array_type const & z, BFType & bfs, size_t iter); + void predict_and_update(array_type const & z, array_type const & u, BFType & bfs, size_t iter); + }; /* end class KalmanFilter */ template @@ -493,6 +557,117 @@ void KalmanFilter::update_covariance(array_type const & k) m_p = m_p.symmetrize(); } +template +BFType KalmanFilter::batch_filter(array_type const & zs) +{ + size_t z_m = zs.shape(0); + size_t z_n = zs.shape(1); + array_type z(small_vector{z_n}); + BFType bfs(z_m, m_state_size); + + for (size_t iter = 0; iter < z_m; ++iter) + { + for (size_t j = 0; j < z_n; ++j) + { + z(j) = zs(iter, j); + } + predict_and_update(z, bfs, iter); + } + return bfs; +} + +template +void KalmanFilter::predict_and_update(array_type const & z, BFType & bfs, size_t iter) +{ + SimpleArray & prior_state = bfs.prior_state; + SimpleArray & prior_state_covariance = bfs.prior_state_covariance; + SimpleArray & posterior_state = bfs.posterior_state; + SimpleArray & posterior_state_covariance = bfs.posterior_state_covariance; + + predict(); + for (size_t j = 0; j < m_state_size; ++j) + { + prior_state(iter, j) = m_x(j); + for (size_t k = 0; k < m_state_size; ++k) + { + prior_state_covariance(iter, j, k) = m_p(j, k); + } + } + + update(z); + for (size_t j = 0; j < m_state_size; ++j) + { + posterior_state(iter, j) = m_x(j); + for (size_t k = 0; k < m_state_size; ++k) + { + posterior_state_covariance(iter, j, k) = m_p(j, k); + } + } +} + +template +BFType KalmanFilter::batch_filter(array_type const & zs, array_type const & us) +{ + size_t z_m = zs.shape(0); + size_t z_n = zs.shape(1); + array_type z(small_vector{z_n}); + BFType bfs(z_m, m_state_size); + + array_type u; + size_t u_n = 0; + + size_t u_m = us.shape(0); + if (u_m != z_m) + { + throw std::invalid_argument("KalmanFilter::batch_filter: The number of control inputs must match the number of measurements."); + } + u_n = us.shape(1); + u = array_type(small_vector{u_n}); + + for (size_t iter = 0; iter < z_m; ++iter) + { + for (size_t j = 0; j < z_n; ++j) + { + z(j) = zs(iter, j); + } + for (size_t j = 0; j < u_n; ++j) + { + u(j) = us(iter, j); + } + predict_and_update(z, u, bfs, iter); + } + return bfs; +} + +template +void KalmanFilter::predict_and_update(array_type const & z, array_type const & u, BFType & bfs, size_t iter) +{ + SimpleArray & prior_state = bfs.prior_state; + SimpleArray & prior_state_covariance = bfs.prior_state_covariance; + SimpleArray & posterior_state = bfs.posterior_state; + SimpleArray & posterior_state_covariance = bfs.posterior_state_covariance; + + predict(u); + for (size_t j = 0; j < m_state_size; ++j) + { + prior_state(iter, j) = m_x(j); + for (size_t k = 0; k < m_state_size; ++k) + { + prior_state_covariance(iter, j, k) = m_p(j, k); + } + } + + update(z); + for (size_t j = 0; j < m_state_size; ++j) + { + posterior_state(iter, j) = m_x(j); + for (size_t k = 0; k < m_state_size; ++k) + { + posterior_state_covariance(iter, j, k) = m_p(j, k); + } + } +} + } /* end namespace modmesh */ // vim: set ff=unix fenc=utf8 et sw=4 ts=4 sts=4: diff --git a/cpp/modmesh/linalg/pymod/wrap_kalman_filter.cpp b/cpp/modmesh/linalg/pymod/wrap_kalman_filter.cpp index 4e4c0f83..6eeba9c5 100644 --- a/cpp/modmesh/linalg/pymod/wrap_kalman_filter.cpp +++ b/cpp/modmesh/linalg/pymod/wrap_kalman_filter.cpp @@ -93,7 +93,22 @@ class MODMESH_PYTHON_WRAPPER_VISIBILITY WrapKalmanFilter .def( "update", &wrapped_type::update, - py::arg("z")); + py::arg("z")) + .def( + "batch_filter", + [](wrapped_type & self, array_type const & zs, array_type const & us) + { + return self.batch_filter(zs, us).to_tuple(); + }, + py::arg("zs"), + py::arg("us")) + .def( + "batch_filter", + [](wrapped_type & self, array_type const & zs) + { + return self.batch_filter(zs).to_tuple(); + }, + py::arg("zs")); } }; /* end class WrapKalmanFilter */ diff --git a/tests/test_linalg.py b/tests/test_linalg.py index 24278105..df3ef129 100644 --- a/tests/test_linalg.py +++ b/tests/test_linalg.py @@ -671,3 +671,120 @@ def test_wrong_shape_control_predict(self): "must be 1D of length control_size \\(1\\), but got shape " "\\(2\\)"): kf.predict(u_wrong_sa) + + +class KalmanFilterBatchFilterTC(unittest.TestCase): + + def kf_batchfilter_numpy(self, kf, zs, us=None): + m = zs.shape[0] + n = kf.state.shape[0] + xs_pred_np = np.zeros((m, n)) + ps_pred_np = np.zeros((m, n, n)) + xs_upd_np = np.zeros((m, n)) + ps_upd_np = np.zeros((m, n, n)) + for i in range(m): + if us is not None: + u = us[i] + u_sa = sa_from_np(u, type(kf.state)) + kf.predict(u_sa) + else: + kf.predict() + x_pred = kf.state.ndarray + p_pred = kf.covariance.ndarray + xs_pred_np[i] = x_pred + ps_pred_np[i] = p_pred + + z = zs[i] + z_sa = sa_from_np(z, type(kf.state)) + kf.update(z_sa) + x_upd = kf.state.ndarray + p_upd = kf.covariance.ndarray + xs_upd_np[i] = x_upd + ps_upd_np[i] = p_upd + return xs_pred_np, ps_pred_np, xs_upd_np, ps_upd_np + + def test_batchfilter(self): + m = 50 + x0 = np.array([1.0, 2.0, 3.0]) + f = np.array([[1.1, 0.2, 0.3], + [0.1, 0.9, 0.7], + [4.7, 5.2, 6.7]]) + h = np.array([[1.0, 3.0, 2.0], + [4.0, 0.2, 0.1]]) + sigma_w = 0.316 + zs = np.zeros((m, 2)) + for i in range(m): + zs[i] = np.array([i * i, i * 0.5 + 1.0]) + x_sa = mm.SimpleArrayFloat64(array=x0) + f_sa = mm.SimpleArrayFloat64(array=f) + h_sa = mm.SimpleArrayFloat64(array=h) + zs_sa = mm.SimpleArrayFloat64(array=zs) + + kf = mm.KalmanFilterFp64( + x=x_sa, f=f_sa, h=h_sa, + process_noise=sigma_w, + measurement_noise=1.0, + ) + bps = kf.batch_filter(zs_sa) + xs_pred, ps_pred, xs_upd, ps_upd = bps + + kf = mm.KalmanFilterFp64( + x=x_sa, f=f_sa, h=h_sa, + process_noise=sigma_w, + measurement_noise=1.0, + ) + bps_np = self.kf_batchfilter_numpy(kf, zs) + xs_pred_np, ps_pred_np, xs_upd_np, ps_upd_np = bps_np + + np.testing.assert_allclose(xs_pred, xs_pred_np, atol=1e-12, rtol=0.0) + np.testing.assert_allclose(ps_pred, ps_pred_np, atol=1e-12, rtol=0.0) + np.testing.assert_allclose(xs_upd, xs_upd_np, atol=1e-12, rtol=0.0) + np.testing.assert_allclose(ps_upd, ps_upd_np, atol=1e-12, rtol=0.0) + + def test_batchfilter_with_control(self): + m = 50 + x0 = np.array([1.0, 2.0, 3.0]) + f = np.array([[1.1, 0.2, 0.3], + [0.1, 0.9, 0.7], + [4.7, 5.2, 6.7]]) + h = np.array([[1.0, 3.0, 2.0], + [4.0, 0.2, 0.1]]) + b = np.array([[0.7, 0.2, 5.3], + [3.1, 0.9, 1.7], + [4.7, 5.2, 6.7]]) + sigma_w = 0.316 + zs = np.zeros((m, 2)) + for i in range(m): + zs[i] = np.array([i * i, i * 0.5 + 1.0]) + us = np.zeros((m, 3)) + for i in range(m): + us[i] = np.array([i, pow(i, 3.5), pow(i, 0.5)]) + x_sa = mm.SimpleArrayFloat64(array=x0) + f_sa = mm.SimpleArrayFloat64(array=f) + b_sa = mm.SimpleArrayFloat64(array=b) + h_sa = mm.SimpleArrayFloat64(array=h) + zs_sa = mm.SimpleArrayFloat64(array=zs) + us_sa = mm.SimpleArrayFloat64(array=us) + + kf = mm.KalmanFilterFp64( + x=x_sa, f=f_sa, b=b_sa, h=h_sa, + process_noise=sigma_w, + measurement_noise=1.0, + ) + bps = kf.batch_filter(zs_sa, us_sa) + xs_pred, ps_pred, xs_upd, ps_upd = bps + + kf = mm.KalmanFilterFp64( + x=x_sa, f=f_sa, b=b_sa, h=h_sa, + process_noise=sigma_w, + measurement_noise=1.0, + ) + bps_np = self.kf_batchfilter_numpy(kf, zs, us) + xs_pred_np, ps_pred_np, xs_upd_np, ps_upd_np = bps_np + + np.testing.assert_allclose(xs_pred, xs_pred_np, atol=1e-12, rtol=0.0) + np.testing.assert_allclose(ps_pred, ps_pred_np, atol=1e-12, rtol=0.0) + np.testing.assert_allclose(xs_upd, xs_upd_np, atol=1e-12, rtol=0.0) + np.testing.assert_allclose(ps_upd, ps_upd_np, atol=1e-12, rtol=0.0) + +# vim: set ff=unix fenc=utf8 et sw=4 ts=4 sts=4: