Skip to content

Commit 3983fc8

Browse files
committed
Feat SimpleArray.hpp, kalman_filter.hpp and test_gemm.py: Modify matmul() to support the argument of 1 dimensional and 2 dimensional SimpleArray
1 parent 6ced276 commit 3983fc8

3 files changed

Lines changed: 166 additions & 76 deletions

File tree

cpp/modmesh/buffer/SimpleArray.hpp

Lines changed: 89 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -937,8 +937,8 @@ detail::SimpleArrayMixinCalculators<A, T>::median_freq(small_vector<value_type>
937937
}
938938

939939
/**
940-
* Perform matrix multiplication for 2D arrays.
941-
* This implementation supports only 2D x 2D matrix multiplication.
940+
* Perform matrix multiplication for SimpleArrays.
941+
* This implementation supports 1D x 1D, 1D x 2D, 2D x 1D, and 2D x 2D matrix multiplication.
942942
*/
943943
template <typename A, typename T>
944944
A SimpleArrayMixinCalculators<A, T>::matmul(A const & other) const
@@ -969,43 +969,112 @@ A SimpleArrayMixinCalculators<A, T>::matmul(A const & other) const
969969
}
970970
};
971971

972-
if (this_ndim != 2 || other_ndim != 2)
972+
auto check_product_shape = [&](A const * athis, A const * other, ssize_t athis_idx, ssize_t other_idx) -> void
973+
{
974+
if (athis->shape(athis_idx) != other->shape(other_idx))
975+
{
976+
throw std::out_of_range(
977+
std::format("SimpleArray::matmul(): shape mismatch: this={} other={}",
978+
format_shape(athis),
979+
format_shape(other)));
980+
}
981+
};
982+
983+
if ((this_ndim != 2 && this_ndim != 1) || (other_ndim != 2 && other_ndim != 1))
973984
{
974985
const std::string err = std::format("SimpleArray::matmul(): unsupported dimensions: "
975-
"this={} other={}. Only 2D x 2D matrix multiplication is supported",
986+
"this={} other={}. SimpleArray must be 1D or 2D.",
976987
format_shape(athis),
977988
format_shape(&other));
978989
throw std::out_of_range(err);
979990
}
980991

981-
const size_t m = athis->shape(0);
982-
const size_t k = athis->shape(1);
983-
const size_t n = other.shape(1);
992+
bool this_is_1d = (this_ndim == 1);
993+
bool other_is_1d = (other_ndim == 1);
984994

985-
if (k != other.shape(0))
995+
// 1D x 1D
996+
if (this_is_1d && other_is_1d)
986997
{
987-
throw std::out_of_range(
988-
std::format("SimpleArray::matmul(): shape mismatch: this={} other={}",
989-
format_shape(athis),
990-
format_shape(&other)));
998+
check_product_shape(athis, &other, 0, 0);
999+
A result(1);
1000+
value_type v = 0;
1001+
for (size_t i = 0; i < athis->shape(0); ++i)
1002+
{
1003+
v += (*athis)(i)*other.data(i);
1004+
}
1005+
result.data(0) = v;
1006+
return result;
9911007
}
992-
993-
typename detail::SimpleArrayInternalTypes<T>::shape_type result_shape{m, n};
994-
A result(result_shape);
995-
result.fill(static_cast<value_type>(0));
996-
997-
for (size_t i = 0; i < m; ++i)
1008+
// 1D x 2D
1009+
else if (this_is_1d)
9981010
{
1011+
const size_t k = athis->shape(0);
1012+
const size_t n = other.shape(1);
1013+
check_product_shape(athis, &other, 0, 0);
1014+
A result(n);
1015+
9991016
for (size_t j = 0; j < n; ++j)
10001017
{
1018+
value_type v = 0;
1019+
for (size_t l = 0; l < k; ++l)
1020+
{
1021+
v += (*athis)(l)*other(l, j);
1022+
}
1023+
result.data(j) = v;
1024+
}
1025+
return result;
1026+
}
1027+
// 2D x 1D
1028+
else if (other_is_1d)
1029+
{
1030+
const size_t m = athis->shape(0);
1031+
const size_t k = athis->shape(1);
1032+
1033+
check_product_shape(athis, &other, 1, 0);
1034+
A result(m);
1035+
1036+
for (size_t i = 0; i < m; ++i)
1037+
{
1038+
value_type v = 0;
10011039
for (size_t l = 0; l < k; ++l)
10021040
{
1003-
result(i, j) += athis->operator()(i, l) * other(l, j);
1041+
v += (*athis)(i, l) * other(l);
1042+
}
1043+
result.data(i) = v;
1044+
}
1045+
return result;
1046+
}
1047+
// 2D x 2D
1048+
else
1049+
{
1050+
const size_t m = athis->shape(0);
1051+
const size_t k = athis->shape(1);
1052+
const size_t n = other.shape(1);
1053+
check_product_shape(athis, &other, 1, 0);
1054+
1055+
shape_type result_shape{m, n};
1056+
A result(result_shape);
1057+
1058+
for (size_t i = 0; i < m; ++i)
1059+
{
1060+
for (size_t j = 0; j < n; ++j)
1061+
{
1062+
value_type v = 0;
1063+
for (size_t l = 0; l < k; ++l)
1064+
{
1065+
v += (*athis)(i, l) * other(l, j);
1066+
}
1067+
result(i, j) = v;
10041068
}
10051069
}
1070+
return result;
10061071
}
10071072

1008-
return result;
1073+
throw std::out_of_range(
1074+
std::format("SimpleArray::matmul(): this={} other={}"
1075+
" cannot perform matrix multiplication.",
1076+
format_shape(athis),
1077+
format_shape(&other)));
10091078
}
10101079

10111080
/**

cpp/modmesh/linalg/kalman_filter.hpp

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -367,19 +367,14 @@ template <typename T>
367367
void KalmanFilter<T>::predict_state()
368368
{
369369
// x <- F x (m_x <- m_f @ m_x)
370-
array_type x_col = m_x.reshape(small_vector<size_t>{m_state_size, 1});
371-
x_col = m_f.matmul(x_col);
372-
m_x = x_col.reshape(small_vector<size_t>{m_state_size});
370+
m_x = m_f.matmul(m_x);
373371
}
374372

375373
template <typename T>
376374
void KalmanFilter<T>::predict_state(array_type const & u)
377375
{
378376
// x <- F x + B u (m_x <- m_f @ m_x + m_b @ u)
379-
array_type x_col = m_x.reshape(small_vector<size_t>{m_state_size, 1});
380-
array_type u_col = u.reshape(small_vector<size_t>{m_control_size, 1});
381-
x_col = m_f.matmul(x_col).add(m_b.matmul(u_col));
382-
m_x = x_col.reshape(small_vector<size_t>{m_state_size});
377+
m_x = m_f.matmul(m_x).add(m_b.matmul(u));
383378
}
384379

385380
template <typename T>
@@ -440,10 +435,7 @@ template <typename T>
440435
typename KalmanFilter<T>::array_type KalmanFilter<T>::innovation(array_type const & z)
441436
{
442437
// y <- z - H x (y <- z - m_h @ m_x)
443-
array_type x_col = m_x.reshape(small_vector<size_t>{m_state_size, 1});
444-
array_type hx = m_h.matmul(x_col);
445-
array_type hx_vec = hx.reshape(small_vector<size_t>{m_measurement_size});
446-
return z.sub(hx_vec);
438+
return z.sub(m_h.matmul(m_x));
447439
}
448440

449441
template <typename T>
@@ -471,10 +463,7 @@ template <typename T>
471463
void KalmanFilter<T>::update_state(array_type const & k, array_type const & y)
472464
{
473465
// x <- x + K y (m_x <- m_x + k @ y)
474-
array_type y_col = y.reshape(small_vector<size_t>{m_measurement_size, 1});
475-
array_type ky = k.matmul(y_col);
476-
array_type ky_vec = ky.reshape(small_vector<size_t>{m_state_size});
477-
m_x = m_x.add(ky_vec);
466+
m_x = m_x.add(k.matmul(y));
478467
}
479468

480469
template <typename T>

tests/test_gemm.py

Lines changed: 73 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -207,10 +207,40 @@ def test_compare_with_numpy(self):
207207
[8.0, 11.0, 11.0]
208208
], dtype=self.dtype)
209209

210+
# Test case 4: (4x6) x (6)
211+
a_data_4 = np.array([
212+
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
213+
[7.0, 8.0, 9.0, 10.0, 11.0, 12.0],
214+
[13.0, 14.0, 15.0, 16.0, 17.0, 18.0],
215+
[19.0, 20.0, 21.0, 22.0, 23.0, 24.0]
216+
], dtype=self.dtype)
217+
b_data_4 = np.array([1., 2., 3., 4., 5., 6], dtype=self.dtype)
218+
expected_4 = np.array([91., 217., 343., 469.], dtype=self.dtype)
219+
220+
# Test case 5: (6) x (6 x 4)
221+
a_data_5 = np.array([1., 2., 3., 4., 5., 6], dtype=self.dtype)
222+
b_data_5 = np.array([
223+
[1.0, 2.0, 3.0, 4.0],
224+
[5.0, 6.0, 7.0, 8.0],
225+
[9.0, 10.0, 11.0, 12.0],
226+
[13.0, 14.0, 15.0, 16.0],
227+
[17.0, 18.0, 19.0, 20.0],
228+
[21.0, 22.0, 23.0, 24.0]
229+
], dtype=self.dtype)
230+
expected_5 = np.array([301., 322., 343., 364.], dtype=self.dtype)
231+
232+
# Test case 6: (3) x (3)
233+
a_data_6 = np.array([1., 2., 3.], dtype=self.dtype)
234+
b_data_6 = np.array([4., 5., 6.], dtype=self.dtype)
235+
expected_6 = np.array([32.], dtype=self.dtype)
236+
210237
test_cases = [
211238
(a_data_1, b_data_1, expected_1, "2x3 x 3x4"),
212239
(a_data_2, b_data_2, expected_2, "4x6 x 6x3"),
213-
(a_data_3, b_data_3, expected_3, "3x3 x 3x3")
240+
(a_data_3, b_data_3, expected_3, "3x3 x 3x3"),
241+
(a_data_4, b_data_4, expected_4, "4x6 x 6"),
242+
(a_data_5, b_data_5, expected_5, "6 x 6x3"),
243+
(a_data_6, b_data_6, expected_6, "3 x 3x3")
214244
]
215245

216246
for a_data, b_data, expected, description in test_cases:
@@ -234,64 +264,66 @@ def test_compare_with_numpy(self):
234264
np.testing.assert_array_almost_equal(
235265
result.ndarray, expected, decimal=10)
236266

237-
def test_unsupported_dimensions_error(self):
238-
"""Test error handling for unsupported dimensions"""
267+
def test_wrong_shape_error(self):
268+
"""Test error handling for wrong shapes"""
239269

240-
# Test 1D x 1D (not supported)
241-
a_1d = self.SimpleArray(array=np.array([1.0, 2.0, 3.0],
242-
dtype=self.dtype))
243-
b_1d = self.SimpleArray(array=np.array([4.0, 5.0, 6.0],
244-
dtype=self.dtype))
270+
a_3d_data = np.array([[[1.0, 2.0], [3.0, 4.0]],
271+
[[5.0, 6.0], [7.0, 8.0]]], dtype=self.dtype)
272+
b_3d_data = np.array([[[1.0, 0.0], [0.0, 1.0]],
273+
[[2.0, 0.0], [0.0, 2.0]]], dtype=self.dtype)
274+
a_3d = self.SimpleArray(array=a_3d_data)
275+
b_3d = self.SimpleArray(array=b_3d_data)
245276

246277
with self.assertRaisesRegex(
247278
IndexError,
248-
r"SimpleArray::matmul\(\): unsupported dimensions: this=\(3\) "
249-
r"other=\(3\)\. Only 2D x 2D matrix multiplication is supported"
279+
r"SimpleArray::matmul\(\): unsupported dimensions: "
280+
r"this=\(2,2,2\) other=\(2,2,2\)\. SimpleArray must be 1D or 2D."
250281
):
251-
a_1d.matmul(b_1d)
252-
253-
# Test 1D x 2D (not supported)
254-
a_1d = self.SimpleArray(array=np.array([1.0, 2.0], dtype=self.dtype))
255-
b_2d = self.SimpleArray(array=np.array([[1.0, 2.0], [3.0, 4.0]],
256-
dtype=self.dtype))
282+
a_3d.matmul(b_3d)
257283

284+
a = np.zeros((3, 3), dtype=self.dtype)
285+
b = np.zeros((2, 3), dtype=self.dtype)
286+
a = self.SimpleArray(array=a)
287+
b = self.SimpleArray(array=b)
258288
with self.assertRaisesRegex(
259289
IndexError,
260-
r"SimpleArray::matmul\(\): unsupported dimensions: this=\(2\) "
261-
r"other=\(2,2\)\. Only 2D x 2D matrix multiplication is supported"
290+
r"SimpleArray::matmul\(\): shape mismatch: "
291+
r"this=\(3,3\) other=\(2,3\)"
262292
):
263-
a_1d.matmul(b_2d)
264-
265-
# Test 2D x 1D (not supported)
266-
a_2d = self.SimpleArray(array=np.array([[1.0, 2.0, 3.0],
267-
[4.0, 5.0, 6.0]],
268-
dtype=self.dtype))
269-
b_1d = self.SimpleArray(array=np.array([7.0, 8.0, 9.0],
270-
dtype=self.dtype))
293+
a.matmul(b)
271294

295+
a = np.zeros((3, 3), dtype=self.dtype)
296+
b = np.zeros((2), dtype=self.dtype)
297+
a = self.SimpleArray(array=a)
298+
b = self.SimpleArray(array=b)
272299
with self.assertRaisesRegex(
273300
IndexError,
274-
r"SimpleArray::matmul\(\): unsupported dimensions: this=\(2,3\) "
275-
r"other=\(3\)\. Only 2D x 2D matrix multiplication is supported"
301+
r"SimpleArray::matmul\(\): shape mismatch: "
302+
r"this=\(3,3\) other=\(2\)"
276303
):
277-
a_2d.matmul(b_1d)
278-
279-
# Test 3D x 3D (not supported - tensor operation)
280-
a_3d_data = np.array([[[1.0, 2.0], [3.0, 4.0]],
281-
[[5.0, 6.0], [7.0, 8.0]]], dtype=self.dtype)
282-
b_3d_data = np.array([[[1.0, 0.0], [0.0, 1.0]],
283-
[[2.0, 0.0], [0.0, 2.0]]], dtype=self.dtype)
304+
a.matmul(b)
284305

285-
a_3d = self.SimpleArray(array=a_3d_data)
286-
b_3d = self.SimpleArray(array=b_3d_data)
306+
a = np.zeros((2), dtype=self.dtype)
307+
b = np.zeros((3, 3), dtype=self.dtype)
308+
a = self.SimpleArray(array=a)
309+
b = self.SimpleArray(array=b)
310+
with self.assertRaisesRegex(
311+
IndexError,
312+
r"SimpleArray::matmul\(\): shape mismatch: "
313+
r"this=\(2\) other=\(3,3\)"
314+
):
315+
a.matmul(b)
287316

317+
a = np.zeros((2), dtype=self.dtype)
318+
b = np.zeros((3), dtype=self.dtype)
319+
a = self.SimpleArray(array=a)
320+
b = self.SimpleArray(array=b)
288321
with self.assertRaisesRegex(
289322
IndexError,
290-
r"SimpleArray::matmul\(\): unsupported dimensions: "
291-
r"this=\(2,2,2\) other=\(2,2,2\)\. Only 2D x 2D matrix "
292-
r"multiplication is supported"
323+
r"SimpleArray::matmul\(\): shape mismatch: "
324+
r"this=\(2\) other=\(3\)"
293325
):
294-
a_3d.matmul(b_3d)
326+
a.matmul(b)
295327

296328
def test_matmul_operator(self):
297329
"""Test @ operator for matrix multiplication"""

0 commit comments

Comments
 (0)