-
Notifications
You must be signed in to change notification settings - Fork 55
Feat: Modify __matmul__ to support SimpleArray of 1 dimension and 2 dimension #661
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
ThreeMonth03
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yungyuc Please review this pull request. Thank you.
| bool other_is_1d = (other_ndim == 1); | ||
|
|
||
| if (k != other.shape(0)) | ||
| if (this_is_1d && other_is_1d) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition looks verbose. However, the input argument should be constant value, and if we try to reshape the input argument to another SimpleArry, the assignment operator would clone the buffer.
| array_type x_col = m_x.reshape(small_vector<size_t>{m_state_size, 1}); | ||
| array_type u_col = u.reshape(small_vector<size_t>{m_control_size, 1}); | ||
| x_col = m_f.matmul(x_col).add(m_b.matmul(u_col)); | ||
| m_x = x_col.reshape(small_vector<size_t>{m_state_size}); | ||
| m_x = m_f.matmul(m_x).add(m_b.matmul(u)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simplify the logic.
| result.fill(static_cast<value_type>(0)); | ||
|
|
||
| for (size_t i = 0; i < m; ++i) | ||
| else if (this_is_1d) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will add a comment for "1D x 2D"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
| } | ||
| return result; | ||
| } | ||
| else if (other_is_1d) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would leave a comment "2D x 1D".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
cpp/modmesh/buffer/SimpleArray.hpp
Outdated
| check_product_shape(athis, &other, 0, 0); | ||
| shape_type result_shape{n}; | ||
| A result(result_shape); | ||
| result.fill(static_cast<value_type>(0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need? Any element will not be modified later?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SimpleArray is constructed by std::malloc, which doesn't initialize the buffer. Furthermore, the elements are multiplied and added in the following for loop, so it is necessary to filled 0 in SimpleArray.
modmesh/cpp/modmesh/buffer/ConcreteBuffer.hpp
Lines 294 to 297 in 2a316f4
| else | |
| { | |
| ptr = std::malloc(nbytes); | |
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This zero-initialization loop should be removed by using a temporary scalar for the inner loop below.
for (size_t j = 0; j < n; ++j)
{
T v = 0.0;
for (size_t l = 0; l < k; ++l)
{
v += (*athis)(l)*other(l, j);
}
result.data(j);
}
cpp/modmesh/buffer/SimpleArray.hpp
Outdated
| } | ||
|
|
||
| return result; | ||
| return {}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will let it throw here as it should never reach here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
355d5c0 to
eb7d443
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good enhancement. Points to address:
- Update code comments for dimension order.
- Update error messages for dimension order.
- Remove the zero-initialization loop by using a temporary scalar for the inner loop below.
- Write informative summary in the commit log.
cpp/modmesh/buffer/SimpleArray.hpp
Outdated
| * Perform matrix multiplication for 2D arrays. | ||
| * This implementation supports only 2D x 2D matrix multiplication. | ||
| * Perform matrix multiplication for SimpleArrays. | ||
| * This implementation supports 2D x 2D, 2D x 1D, 1D x 2D, and 1D x 1D matrix multiplication. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
List the dimension from low to high: 1D x 1D, 1D x 2D, 2D x 1D, 2D x 2D.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
cpp/modmesh/buffer/SimpleArray.hpp
Outdated
| { | ||
| const std::string err = std::format("SimpleArray::matmul(): unsupported dimensions: " | ||
| "this={} other={}. Only 2D x 2D matrix multiplication is supported", | ||
| "this={} other={}. SimpleArray must be 2D or 1D.", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Order dimension in the error message from low to high: 1D or 2D.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
cpp/modmesh/buffer/SimpleArray.hpp
Outdated
| check_product_shape(athis, &other, 0, 0); | ||
| shape_type result_shape{n}; | ||
| A result(result_shape); | ||
| result.fill(static_cast<value_type>(0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This zero-initialization loop should be removed by using a temporary scalar for the inner loop below.
for (size_t j = 0; j < n; ++j)
{
T v = 0.0;
for (size_t l = 0; l < k; ++l)
{
v += (*athis)(l)*other(l, j);
}
result.data(j);
}2015d49 to
3cb8236
Compare
ThreeMonth03
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good enhancement. Points to address:
- Update code comments for dimension order.
- Update error messages for dimension order.
- Remove the zero-initialization loop by using a temporary scalar for the inner loop below.
- Write informative summary in the commit log.
@yungyuc Please review again. Thanks.
yungyuc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Clarify why
static_castis needed for value initialization instead of justvalue_type v = 0.
cpp/modmesh/buffer/SimpleArray.hpp
Outdated
|
|
||
| for (size_t j = 0; j < n; ++j) | ||
| { | ||
| value_type v = static_cast<value_type>(0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you need to write value_type v = static_cast<value_type>(0); instead of just:
value_type v = 0;There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we don't need to explicitly cast here, and I will fix it.
…ul() to support the argument of 1 dimensional and 2 dimensional SimpleArray
|
|
||
| for (size_t j = 0; j < n; ++j) | ||
| { | ||
| value_type v = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
Problem
There are some
__matmul__operations in Kalman Filter. However, the current design only supports the arguments of 2 dimensions SimpleArray. If we want to matrix multiply from 2 dimensions SimpleArray to 1 dimension SimpleArray, we should reshape the SimpleArray twice. The detail of the discussion could be found in #601 .Solution
In this pull request, I modify
__matmul__operation, so that it can support the following number of dimension of SimpleArray: