@@ -26,7 +26,7 @@ def add(matrix_a: Matrix, matrix_b: Matrix) -> Matrix:
2626 return [[matrix_a [i ][j ] + matrix_b [i ][j ] for j in range (n )] for i in range (n )]
2727
2828
29- def sub (matrix_a : Matrix , matrix_b : Matrix ) -> Matrix :
29+ def subtract (matrix_a : Matrix , matrix_b : Matrix ) -> Matrix :
3030 """
3131 Subtract matrix_b from matrix_a.
3232
@@ -37,7 +37,7 @@ def sub(matrix_a: Matrix, matrix_b: Matrix) -> Matrix:
3737 return [[matrix_a [i ][j ] - matrix_b [i ][j ] for j in range (n )] for i in range (n )]
3838
3939
40- def naive_mul (matrix_a : Matrix , matrix_b : Matrix ) -> Matrix :
40+ def naive_multiplication (matrix_a : Matrix , matrix_b : Matrix ) -> Matrix :
4141 """
4242 Multiply two square matrices using the naive O(n^3) method.
4343
@@ -160,7 +160,7 @@ def strassen(matrix_a: Matrix, matrix_b: Matrix, threshold: int = 64) -> Matrix:
160160def _strassen_recursive (matrix_a : Matrix , matrix_b : Matrix , threshold : int ) -> Matrix :
161161 n = len (matrix_a )
162162 if n <= threshold :
163- return naive_mul (matrix_a , matrix_b )
163+ return naive_multiplication (matrix_a , matrix_b )
164164 if n == 1 :
165165 return [[matrix_a [0 ][0 ] * matrix_b [0 ][0 ]]]
166166
@@ -169,16 +169,16 @@ def _strassen_recursive(matrix_a: Matrix, matrix_b: Matrix, threshold: int) -> M
169169
170170 m1 = _strassen_recursive (add (a11 , a22 ), add (b11 , b22 ), threshold )
171171 m2 = _strassen_recursive (add (a21 , a22 ), b11 , threshold )
172- m3 = _strassen_recursive (a11 , sub (b12 , b22 ), threshold )
173- m4 = _strassen_recursive (a22 , sub (b21 , b11 ), threshold )
172+ m3 = _strassen_recursive (a11 , subtract (b12 , b22 ), threshold )
173+ m4 = _strassen_recursive (a22 , subtract (b21 , b11 ), threshold )
174174 m5 = _strassen_recursive (add (a11 , a12 ), b22 , threshold )
175- m6 = _strassen_recursive (sub (a21 , a11 ), add (b11 , b12 ), threshold )
176- m7 = _strassen_recursive (sub (a12 , a22 ), add (b21 , b22 ), threshold )
175+ m6 = _strassen_recursive (subtract (a21 , a11 ), add (b11 , b12 ), threshold )
176+ m7 = _strassen_recursive (subtract (a12 , a22 ), add (b21 , b22 ), threshold )
177177
178- c11 = add (sub (add (m1 , m4 ), m5 ), m7 )
178+ c11 = add (subtract (add (m1 , m4 ), m5 ), m7 )
179179 c12 = add (m3 , m5 )
180180 c21 = add (m2 , m4 )
181- c22 = add (sub (add (m1 , m3 ), m2 ), m6 )
181+ c22 = add (subtract (add (m1 , m3 ), m2 ), m6 )
182182
183183 return join (c11 , c12 , c21 , c22 )
184184
@@ -192,6 +192,6 @@ def _strassen_recursive(matrix_a: Matrix, matrix_b: Matrix, threshold: int) -> M
192192 for row in C :
193193 print (row )
194194
195- expected = naive_mul (A , B )
195+ expected = naive_multiplication (A , B )
196196 assert expected == C , "Strassen result differs from naive multiplication!"
197197 print ("Verified: result matches naive multiplication." )
0 commit comments