Skip to content

Commit 6724990

Browse files
committed
feat: Strassen's matrix multiplication algorithm added
1 parent 5e951b6 commit 6724990

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

matrix/strassen_matrix_multiply.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def add(matrix_a: Matrix, matrix_b: Matrix) -> Matrix:
2626
return [[matrix_a[i][j] + matrix_b[i][j] for j in range(n)] for i in range(n)]
2727

2828

29-
def sub(matrix_a: Matrix, matrix_b: Matrix) -> Matrix:
29+
def subtract(matrix_a: Matrix, matrix_b: Matrix) -> Matrix:
3030
"""
3131
Subtract matrix_b from matrix_a.
3232
@@ -37,7 +37,7 @@ def sub(matrix_a: Matrix, matrix_b: Matrix) -> Matrix:
3737
return [[matrix_a[i][j] - matrix_b[i][j] for j in range(n)] for i in range(n)]
3838

3939

40-
def naive_mul(matrix_a: Matrix, matrix_b: Matrix) -> Matrix:
40+
def naive_multiplication(matrix_a: Matrix, matrix_b: Matrix) -> Matrix:
4141
"""
4242
Multiply two square matrices using the naive O(n^3) method.
4343
@@ -160,7 +160,7 @@ def strassen(matrix_a: Matrix, matrix_b: Matrix, threshold: int = 64) -> Matrix:
160160
def _strassen_recursive(matrix_a: Matrix, matrix_b: Matrix, threshold: int) -> Matrix:
161161
n = len(matrix_a)
162162
if n <= threshold:
163-
return naive_mul(matrix_a, matrix_b)
163+
return naive_multiplication(matrix_a, matrix_b)
164164
if n == 1:
165165
return [[matrix_a[0][0] * matrix_b[0][0]]]
166166

@@ -169,16 +169,16 @@ def _strassen_recursive(matrix_a: Matrix, matrix_b: Matrix, threshold: int) -> M
169169

170170
m1 = _strassen_recursive(add(a11, a22), add(b11, b22), threshold)
171171
m2 = _strassen_recursive(add(a21, a22), b11, threshold)
172-
m3 = _strassen_recursive(a11, sub(b12, b22), threshold)
173-
m4 = _strassen_recursive(a22, sub(b21, b11), threshold)
172+
m3 = _strassen_recursive(a11, subtract(b12, b22), threshold)
173+
m4 = _strassen_recursive(a22, subtract(b21, b11), threshold)
174174
m5 = _strassen_recursive(add(a11, a12), b22, threshold)
175-
m6 = _strassen_recursive(sub(a21, a11), add(b11, b12), threshold)
176-
m7 = _strassen_recursive(sub(a12, a22), add(b21, b22), threshold)
175+
m6 = _strassen_recursive(subtract(a21, a11), add(b11, b12), threshold)
176+
m7 = _strassen_recursive(subtract(a12, a22), add(b21, b22), threshold)
177177

178-
c11 = add(sub(add(m1, m4), m5), m7)
178+
c11 = add(subtract(add(m1, m4), m5), m7)
179179
c12 = add(m3, m5)
180180
c21 = add(m2, m4)
181-
c22 = add(sub(add(m1, m3), m2), m6)
181+
c22 = add(subtract(add(m1, m3), m2), m6)
182182

183183
return join(c11, c12, c21, c22)
184184

@@ -192,6 +192,6 @@ def _strassen_recursive(matrix_a: Matrix, matrix_b: Matrix, threshold: int) -> M
192192
for row in C:
193193
print(row)
194194

195-
expected = naive_mul(A, B)
195+
expected = naive_multiplication(A, B)
196196
assert expected == C, "Strassen result differs from naive multiplication!"
197197
print("Verified: result matches naive multiplication.")

0 commit comments

Comments
 (0)