-
Notifications
You must be signed in to change notification settings - Fork 55
Feat: Implement batch filter #660
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -40,6 +40,41 @@ template <typename T> | |
| struct select_real_t; | ||
| } /* end namespace detail */ | ||
|
|
||
| /** | ||
| * @brief Return type for `KalmanFilter<T>::batch_filter(...)`. | ||
| * | ||
| * @details | ||
| * Due to the iteration of the predict and update steps in `KalmanFilter<T>::batch_filter(...)`, the | ||
| * intermediate results are stored for each time step, including prior and posterior | ||
| * states and their covariances. | ||
| * @see BFType<T> KalmanFilter<T>::batch_filter(array_type const & zs) | ||
| * @see BFType<T> KalmanFilter<T>::batch_filter(array_type const & zs, array_type const & us) | ||
| */ | ||
| template <typename T> | ||
| struct BFType | ||
| { | ||
| using tuple_type = std::tuple<SimpleArray<T>, SimpleArray<T>, SimpleArray<T>, SimpleArray<T>>; | ||
|
|
||
| BFType(size_t observation_size, size_t state_size) | ||
| { | ||
| small_vector<size_t> xs_shape{observation_size, state_size}; | ||
| small_vector<size_t> ps_shape{observation_size, state_size, state_size}; | ||
| prior_state = SimpleArray<T>(xs_shape); | ||
| prior_state_covariance = SimpleArray<T>(ps_shape); | ||
| posterior_state = SimpleArray<T>(xs_shape); | ||
| posterior_state_covariance = SimpleArray<T>(ps_shape); | ||
| } | ||
| tuple_type to_tuple() const | ||
| { | ||
| return std::make_tuple(prior_state, prior_state_covariance, posterior_state, posterior_state_covariance); | ||
| } | ||
|
|
||
| SimpleArray<T> prior_state; | ||
| SimpleArray<T> prior_state_covariance; | ||
| SimpleArray<T> posterior_state; | ||
| SimpleArray<T> posterior_state_covariance; | ||
| }; /* end struct BFType */ | ||
|
|
||
| /** | ||
| * Reference: FilterPy KalmanFilter documentation | ||
| * https://filterpy.readthedocs.io/en/latest/kalman/KalmanFilter.html | ||
|
|
@@ -205,6 +240,31 @@ class KalmanFilter | |
| update_covariance(k); | ||
| } | ||
|
|
||
| /** | ||
| * @brief Predict and update in batch mode without a batch of control input us. | ||
| * | ||
| * @ref https://filterpy.readthedocs.io/en/latest/_modules/filterpy/kalman/kalman_filter.html#KalmanFilter.batch_filter | ||
| * | ||
| * @param zs A batch of measurement inputs. | ||
| * | ||
| * @see BFType<T> KalmanFilter<T>::batch_filter(array_type const & zs, array_type const & us) | ||
| * @see struct BFType<T>; | ||
| */ | ||
|
Comment on lines
+243
to
+252
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The api design is refered from filter.py, and it is the origin of the api naming. |
||
| BFType<T> batch_filter(array_type const & zs); | ||
|
|
||
| /** | ||
| * @brief Predict and update in batch mode with a batch of control input us. | ||
| * | ||
| * @ref https://filterpy.readthedocs.io/en/latest/_modules/filterpy/kalman/kalman_filter.html#KalmanFilter.batch_filter | ||
| * | ||
| * @param zs A batch of measurement inputs. | ||
| * @param us A batch of control inputs. | ||
| * | ||
| * @see BFType<T> KalmanFilter<T>::batch_filter(array_type const & zs) | ||
| * @see struct BFType<T>; | ||
| */ | ||
| BFType<T> batch_filter(array_type const & zs, array_type const & us); | ||
|
|
||
| private: | ||
|
|
||
| void check_dimensions(); | ||
|
|
@@ -223,6 +283,10 @@ class KalmanFilter | |
| void update_state(array_type const & k, array_type const & y); | ||
| void update_covariance(array_type const & k); | ||
|
|
||
| // Batch filter | ||
| void predict_and_update(array_type const & z, BFType<T> & bfs, size_t iter); | ||
| void predict_and_update(array_type const & z, array_type const & u, BFType<T> & bfs, size_t iter); | ||
|
|
||
| }; /* end class KalmanFilter */ | ||
|
|
||
| template <typename T> | ||
|
|
@@ -493,6 +557,117 @@ void KalmanFilter<T>::update_covariance(array_type const & k) | |
| m_p = m_p.symmetrize(); | ||
| } | ||
|
|
||
| template <typename T> | ||
| BFType<T> KalmanFilter<T>::batch_filter(array_type const & zs) | ||
| { | ||
| size_t z_m = zs.shape(0); | ||
| size_t z_n = zs.shape(1); | ||
| array_type z(small_vector<size_t>{z_n}); | ||
| BFType<T> bfs(z_m, m_state_size); | ||
|
|
||
| for (size_t iter = 0; iter < z_m; ++iter) | ||
| { | ||
| for (size_t j = 0; j < z_n; ++j) | ||
| { | ||
| z(j) = zs(iter, j); | ||
| } | ||
| predict_and_update(z, bfs, iter); | ||
| } | ||
| return bfs; | ||
| } | ||
|
|
||
| template <typename T> | ||
| void KalmanFilter<T>::predict_and_update(array_type const & z, BFType<T> & bfs, size_t iter) | ||
| { | ||
| SimpleArray<T> & prior_state = bfs.prior_state; | ||
| SimpleArray<T> & prior_state_covariance = bfs.prior_state_covariance; | ||
| SimpleArray<T> & posterior_state = bfs.posterior_state; | ||
| SimpleArray<T> & posterior_state_covariance = bfs.posterior_state_covariance; | ||
|
|
||
| predict(); | ||
| for (size_t j = 0; j < m_state_size; ++j) | ||
| { | ||
| prior_state(iter, j) = m_x(j); | ||
| for (size_t k = 0; k < m_state_size; ++k) | ||
| { | ||
| prior_state_covariance(iter, j, k) = m_p(j, k); | ||
| } | ||
| } | ||
|
|
||
| update(z); | ||
| for (size_t j = 0; j < m_state_size; ++j) | ||
| { | ||
| posterior_state(iter, j) = m_x(j); | ||
| for (size_t k = 0; k < m_state_size; ++k) | ||
| { | ||
| posterior_state_covariance(iter, j, k) = m_p(j, k); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| template <typename T> | ||
| BFType<T> KalmanFilter<T>::batch_filter(array_type const & zs, array_type const & us) | ||
| { | ||
| size_t z_m = zs.shape(0); | ||
| size_t z_n = zs.shape(1); | ||
| array_type z(small_vector<size_t>{z_n}); | ||
| BFType<T> bfs(z_m, m_state_size); | ||
|
|
||
| array_type u; | ||
| size_t u_n = 0; | ||
|
|
||
| size_t u_m = us.shape(0); | ||
| if (u_m != z_m) | ||
| { | ||
| throw std::invalid_argument("KalmanFilter::batch_filter: The number of control inputs must match the number of measurements."); | ||
| } | ||
| u_n = us.shape(1); | ||
| u = array_type(small_vector<size_t>{u_n}); | ||
|
|
||
| for (size_t iter = 0; iter < z_m; ++iter) | ||
| { | ||
| for (size_t j = 0; j < z_n; ++j) | ||
| { | ||
| z(j) = zs(iter, j); | ||
| } | ||
| for (size_t j = 0; j < u_n; ++j) | ||
| { | ||
| u(j) = us(iter, j); | ||
| } | ||
| predict_and_update(z, u, bfs, iter); | ||
| } | ||
| return bfs; | ||
| } | ||
|
|
||
| template <typename T> | ||
| void KalmanFilter<T>::predict_and_update(array_type const & z, array_type const & u, BFType<T> & bfs, size_t iter) | ||
| { | ||
| SimpleArray<T> & prior_state = bfs.prior_state; | ||
| SimpleArray<T> & prior_state_covariance = bfs.prior_state_covariance; | ||
| SimpleArray<T> & posterior_state = bfs.posterior_state; | ||
| SimpleArray<T> & posterior_state_covariance = bfs.posterior_state_covariance; | ||
|
|
||
| predict(u); | ||
| for (size_t j = 0; j < m_state_size; ++j) | ||
| { | ||
| prior_state(iter, j) = m_x(j); | ||
| for (size_t k = 0; k < m_state_size; ++k) | ||
| { | ||
| prior_state_covariance(iter, j, k) = m_p(j, k); | ||
| } | ||
| } | ||
|
|
||
| update(z); | ||
| for (size_t j = 0; j < m_state_size; ++j) | ||
| { | ||
| posterior_state(iter, j) = m_x(j); | ||
| for (size_t k = 0; k < m_state_size; ++k) | ||
| { | ||
| posterior_state_covariance(iter, j, k) = m_p(j, k); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| } /* end namespace modmesh */ | ||
|
|
||
| // vim: set ff=unix fenc=utf8 et sw=4 ts=4 sts=4: | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -93,7 +93,22 @@ class MODMESH_PYTHON_WRAPPER_VISIBILITY WrapKalmanFilter | |
| .def( | ||
| "update", | ||
| &wrapped_type::update, | ||
| py::arg("z")); | ||
| py::arg("z")) | ||
| .def( | ||
| "batch_filter", | ||
| [](wrapped_type & self, array_type const & zs, array_type const & us) | ||
| { | ||
| return self.batch_filter(zs, us).to_tuple(); | ||
| }, | ||
| py::arg("zs"), | ||
| py::arg("us")) | ||
| .def( | ||
| "batch_filter", | ||
| [](wrapped_type & self, array_type const & zs) | ||
| { | ||
| return self.batch_filter(zs).to_tuple(); | ||
| }, | ||
| py::arg("zs")); | ||
|
Comment on lines
+97
to
+111
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To pass a few return variables, the helper class should be casted to std::tuple. |
||
| } | ||
|
|
||
| }; /* end class WrapKalmanFilter */ | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change to POD design, and name it as
BFtype.