@@ -207,10 +207,40 @@ def test_compare_with_numpy(self):
207207 [8.0 , 11.0 , 11.0 ]
208208 ], dtype = self .dtype )
209209
210+ # Test case 4: (4x6) x (6)
211+ a_data_4 = np .array ([
212+ [1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 ],
213+ [7.0 , 8.0 , 9.0 , 10.0 , 11.0 , 12.0 ],
214+ [13.0 , 14.0 , 15.0 , 16.0 , 17.0 , 18.0 ],
215+ [19.0 , 20.0 , 21.0 , 22.0 , 23.0 , 24.0 ]
216+ ], dtype = self .dtype )
217+ b_data_4 = np .array ([1. , 2. , 3. , 4. , 5. , 6 ], dtype = self .dtype )
218+ expected_4 = np .array ([91. , 217. , 343. , 469. ], dtype = self .dtype )
219+
220+ # Test case 5: (6) x (6 x 4)
221+ a_data_5 = np .array ([1. , 2. , 3. , 4. , 5. , 6 ], dtype = self .dtype )
222+ b_data_5 = np .array ([
223+ [1.0 , 2.0 , 3.0 , 4.0 ],
224+ [5.0 , 6.0 , 7.0 , 8.0 ],
225+ [9.0 , 10.0 , 11.0 , 12.0 ],
226+ [13.0 , 14.0 , 15.0 , 16.0 ],
227+ [17.0 , 18.0 , 19.0 , 20.0 ],
228+ [21.0 , 22.0 , 23.0 , 24.0 ]
229+ ], dtype = self .dtype )
230+ expected_5 = np .array ([301. , 322. , 343. , 364. ], dtype = self .dtype )
231+
232+ # Test case 6: (3) x (3)
233+ a_data_6 = np .array ([1. , 2. , 3. ], dtype = self .dtype )
234+ b_data_6 = np .array ([4. , 5. , 6. ], dtype = self .dtype )
235+ expected_6 = np .array ([32. ], dtype = self .dtype )
236+
210237 test_cases = [
211238 (a_data_1 , b_data_1 , expected_1 , "2x3 x 3x4" ),
212239 (a_data_2 , b_data_2 , expected_2 , "4x6 x 6x3" ),
213- (a_data_3 , b_data_3 , expected_3 , "3x3 x 3x3" )
240+ (a_data_3 , b_data_3 , expected_3 , "3x3 x 3x3" ),
241+ (a_data_4 , b_data_4 , expected_4 , "4x6 x 6" ),
242+ (a_data_5 , b_data_5 , expected_5 , "6 x 6x3" ),
243+ (a_data_6 , b_data_6 , expected_6 , "3 x 3x3" )
214244 ]
215245
216246 for a_data , b_data , expected , description in test_cases :
@@ -234,64 +264,66 @@ def test_compare_with_numpy(self):
234264 np .testing .assert_array_almost_equal (
235265 result .ndarray , expected , decimal = 10 )
236266
237- def test_unsupported_dimensions_error (self ):
238- """Test error handling for unsupported dimensions """
267+ def test_wrong_shape_error (self ):
268+ """Test error handling for wrong shapes """
239269
240- # Test 1D x 1D (not supported)
241- a_1d = self .SimpleArray (array = np .array ([1.0 , 2.0 , 3.0 ],
242- dtype = self .dtype ))
243- b_1d = self .SimpleArray (array = np .array ([4.0 , 5.0 , 6.0 ],
244- dtype = self .dtype ))
270+ a_3d_data = np .array ([[[1.0 , 2.0 ], [3.0 , 4.0 ]],
271+ [[5.0 , 6.0 ], [7.0 , 8.0 ]]], dtype = self .dtype )
272+ b_3d_data = np .array ([[[1.0 , 0.0 ], [0.0 , 1.0 ]],
273+ [[2.0 , 0.0 ], [0.0 , 2.0 ]]], dtype = self .dtype )
274+ a_3d = self .SimpleArray (array = a_3d_data )
275+ b_3d = self .SimpleArray (array = b_3d_data )
245276
246277 with self .assertRaisesRegex (
247278 IndexError ,
248- r"SimpleArray::matmul\(\): unsupported dimensions: this=\(3\) "
249- r"other=\(3 \)\. Only 2D x 2D matrix multiplication is supported "
279+ r"SimpleArray::matmul\(\): unsupported dimensions: "
280+ r"this=\(2,2,2\) other=\(2,2,2 \)\. SimpleArray must be 1D or 2D. "
250281 ):
251- a_1d .matmul (b_1d )
252-
253- # Test 1D x 2D (not supported)
254- a_1d = self .SimpleArray (array = np .array ([1.0 , 2.0 ], dtype = self .dtype ))
255- b_2d = self .SimpleArray (array = np .array ([[1.0 , 2.0 ], [3.0 , 4.0 ]],
256- dtype = self .dtype ))
282+ a_3d .matmul (b_3d )
257283
284+ a = np .zeros ((3 , 3 ), dtype = self .dtype )
285+ b = np .zeros ((2 , 3 ), dtype = self .dtype )
286+ a = self .SimpleArray (array = a )
287+ b = self .SimpleArray (array = b )
258288 with self .assertRaisesRegex (
259289 IndexError ,
260- r"SimpleArray::matmul\(\): unsupported dimensions: this=\(2\) "
261- r"other=\(2,2\)\. Only 2D x 2D matrix multiplication is supported "
290+ r"SimpleArray::matmul\(\): shape mismatch: "
291+ r"this=\(3,3\) other=\(2,3\) "
262292 ):
263- a_1d .matmul (b_2d )
264-
265- # Test 2D x 1D (not supported)
266- a_2d = self .SimpleArray (array = np .array ([[1.0 , 2.0 , 3.0 ],
267- [4.0 , 5.0 , 6.0 ]],
268- dtype = self .dtype ))
269- b_1d = self .SimpleArray (array = np .array ([7.0 , 8.0 , 9.0 ],
270- dtype = self .dtype ))
293+ a .matmul (b )
271294
295+ a = np .zeros ((3 , 3 ), dtype = self .dtype )
296+ b = np .zeros ((2 ), dtype = self .dtype )
297+ a = self .SimpleArray (array = a )
298+ b = self .SimpleArray (array = b )
272299 with self .assertRaisesRegex (
273300 IndexError ,
274- r"SimpleArray::matmul\(\): unsupported dimensions: this=\(2,3\) "
275- r"other =\(3\)\. Only 2D x 2D matrix multiplication is supported "
301+ r"SimpleArray::matmul\(\): shape mismatch: "
302+ r"this =\(3,3\) other=\(2\) "
276303 ):
277- a_2d .matmul (b_1d )
278-
279- # Test 3D x 3D (not supported - tensor operation)
280- a_3d_data = np .array ([[[1.0 , 2.0 ], [3.0 , 4.0 ]],
281- [[5.0 , 6.0 ], [7.0 , 8.0 ]]], dtype = self .dtype )
282- b_3d_data = np .array ([[[1.0 , 0.0 ], [0.0 , 1.0 ]],
283- [[2.0 , 0.0 ], [0.0 , 2.0 ]]], dtype = self .dtype )
304+ a .matmul (b )
284305
285- a_3d = self .SimpleArray (array = a_3d_data )
286- b_3d = self .SimpleArray (array = b_3d_data )
306+ a = np .zeros ((2 ), dtype = self .dtype )
307+ b = np .zeros ((3 , 3 ), dtype = self .dtype )
308+ a = self .SimpleArray (array = a )
309+ b = self .SimpleArray (array = b )
310+ with self .assertRaisesRegex (
311+ IndexError ,
312+ r"SimpleArray::matmul\(\): shape mismatch: "
313+ r"this=\(2\) other=\(3,3\)"
314+ ):
315+ a .matmul (b )
287316
317+ a = np .zeros ((2 ), dtype = self .dtype )
318+ b = np .zeros ((3 ), dtype = self .dtype )
319+ a = self .SimpleArray (array = a )
320+ b = self .SimpleArray (array = b )
288321 with self .assertRaisesRegex (
289322 IndexError ,
290- r"SimpleArray::matmul\(\): unsupported dimensions: "
291- r"this=\(2,2,2\) other=\(2,2,2\)\. Only 2D x 2D matrix "
292- r"multiplication is supported"
323+ r"SimpleArray::matmul\(\): shape mismatch: "
324+ r"this=\(2\) other=\(3\)"
293325 ):
294- a_3d .matmul (b_3d )
326+ a .matmul (b )
295327
296328 def test_matmul_operator (self ):
297329 """Test @ operator for matrix multiplication"""
0 commit comments