Skip to content

Commit 46ef515

Browse files
committed
feat: Strassen's matrix multiplication algorithm added
1 parent 4c9e72a commit 46ef515

File tree

1 file changed

+123
-73
lines changed

1 file changed

+123
-73
lines changed

matrix/strassen_matrix_multiply.py

Lines changed: 123 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -14,122 +14,173 @@
1414

1515
Matrix = list[list[int]]
1616

17-
def add(A: Matrix, B: Matrix) -> Matrix:
18-
n = len(A)
19-
return [[A[i][j] + B[i][j] for j in range(n)] for i in range(n)]
2017

18+
def add(matrix_a: Matrix, matrix_b: Matrix) -> Matrix:
19+
"""
20+
Add two square matrices of the same size.
21+
22+
>>> add([[1,2],[3,4]], [[5,6],[7,8]])
23+
[[6, 8], [10, 12]]
24+
"""
25+
n = len(matrix_a)
26+
return [[matrix_a[i][j] + matrix_b[i][j] for j in range(n)] for i in range(n)]
27+
28+
29+
def sub(matrix_a: Matrix, matrix_b: Matrix) -> Matrix:
30+
"""
31+
Subtract matrix_b from matrix_a.
32+
33+
>>> sub([[5,6],[7,8]], [[1,2],[3,4]])
34+
[[4, 4], [4, 4]]
35+
"""
36+
n = len(matrix_a)
37+
return [[matrix_a[i][j] - matrix_b[i][j] for j in range(n)] for i in range(n)]
2138

22-
def sub(A: Matrix, B: Matrix) -> Matrix:
23-
n = len(A)
24-
return [[A[i][j] - B[i][j] for j in range(n)] for i in range(n)]
2539

40+
def naive_mul(matrix_a: Matrix, matrix_b: Matrix) -> Matrix:
41+
"""
42+
Multiply two square matrices using the naive O(n^3) method.
2643
27-
def naive_mul(A: Matrix, B: Matrix) -> Matrix:
28-
n = len(A)
29-
C = [[0] * n for _ in range(n)]
44+
>>> naive_mul([[1,2],[3,4]], [[5,6],[7,8]])
45+
[[19, 22], [43, 50]]
46+
"""
47+
n = len(matrix_a)
48+
result = [[0] * n for _ in range(n)]
3049
for i in range(n):
31-
ai = A[i]
32-
ci = C[i]
50+
row_a = matrix_a[i]
51+
row_result = result[i]
3352
for k in range(n):
34-
a_ik = ai[k]
35-
bk = B[k]
53+
a_ik = row_a[k]
54+
col_b = matrix_b[k]
3655
for j in range(n):
37-
ci[j] += a_ik * bk[j]
38-
return C
56+
row_result[j] += a_ik * col_b[j]
57+
return result
3958

4059

4160
def next_power_of_two(n: int) -> int:
42-
p = 1
43-
while p < n:
44-
p <<= 1
45-
return p
61+
"""
62+
Return the next power of two greater than or equal to n.
63+
64+
>>> next_power_of_two(5)
65+
8
66+
"""
67+
power = 1
68+
while power < n:
69+
power <<= 1
70+
return power
71+
4672

73+
def pad_matrix(matrix: Matrix, size: int) -> Matrix:
74+
"""
75+
Pad a matrix with zeros to reach the given size.
4776
48-
def pad_matrix(A: Matrix, size: int) -> Matrix:
49-
n = len(A)
77+
>>> pad_matrix([[1,2],[3,4]], 4)
78+
[[1, 2, 0, 0], [3, 4, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]
79+
"""
80+
rows = len(matrix)
81+
cols = len(matrix[0])
5082
padded = [[0] * size for _ in range(size)]
51-
for i in range(n):
52-
for j in range(len(A[0])):
53-
padded[i][j] = A[i][j]
83+
for i in range(rows):
84+
for j in range(cols):
85+
padded[i][j] = matrix[i][j]
5486
return padded
5587

5688

57-
def unpad_matrix(A: Matrix, rows: int, cols: int) -> Matrix:
58-
return [row[:cols] for row in A[:rows]]
89+
def unpad_matrix(matrix: Matrix, rows: int, cols: int) -> Matrix:
90+
"""
91+
Remove padding from a matrix.
92+
93+
>>> unpad_matrix([[1,2,0],[3,4,0],[0,0,0]], 2, 2)
94+
[[1, 2], [3, 4]]
95+
"""
96+
return [row[:cols] for row in matrix[:rows]]
97+
5998

99+
def split(matrix: Matrix) -> tuple:
100+
"""
101+
Split a matrix into four quadrants (top-left, top-right, bottom-left, bottom-right).
60102
61-
def split(A: Matrix) -> tuple:
62-
n = len(A)
103+
>>> split([[1,2],[3,4]])
104+
([[1]], [[2]], [[3]], [[4]])
105+
"""
106+
n = len(matrix)
63107
mid = n // 2
64-
A11 = [[A[i][j] for j in range(mid)] for i in range(mid)]
65-
A12 = [[A[i][j] for j in range(mid, n)] for i in range(mid)]
66-
A21 = [[A[i][j] for j in range(mid)] for i in range(mid, n)]
67-
A22 = [[A[i][j] for j in range(mid, n)] for i in range(mid, n)]
68-
return A11, A12, A21, A22
108+
top_left = [[matrix[i][j] for j in range(mid)] for i in range(mid)]
109+
top_right = [[matrix[i][j] for j in range(mid, n)] for i in range(mid)]
110+
bottom_left = [[matrix[i][j] for j in range(mid)] for i in range(mid, n)]
111+
bottom_right = [[matrix[i][j] for j in range(mid, n)] for i in range(mid, n)]
112+
return top_left, top_right, bottom_left, bottom_right
69113

70114

71-
def join(C11: Matrix, C12: Matrix, C21: Matrix, C22: Matrix) -> Matrix:
72-
n2 = len(C11)
115+
def join(c11: Matrix, c12: Matrix, c21: Matrix, c22: Matrix) -> Matrix:
116+
"""
117+
Join four quadrants into a single matrix.
118+
119+
>>> join([[1]], [[2]], [[3]], [[4]])
120+
[[1, 2], [3, 4]]
121+
"""
122+
n2 = len(c11)
73123
n = n2 * 2
74-
C = [[0] * n for _ in range(n)]
124+
result = [[0] * n for _ in range(n)]
75125
for i in range(n2):
76126
for j in range(n2):
77-
C[i][j] = C11[i][j]
78-
C[i][j + n2] = C12[i][j]
79-
C[i + n2][j] = C21[i][j]
80-
C[i + n2][j + n2] = C22[i][j]
81-
return C
127+
result[i][j] = c11[i][j]
128+
result[i][j + n2] = c12[i][j]
129+
result[i + n2][j] = c21[i][j]
130+
result[i + n2][j + n2] = c22[i][j]
131+
return result
82132

83133

84-
def strassen(A: Matrix, B: Matrix, threshold: int = 64) -> Matrix:
134+
def strassen(matrix_a: Matrix, matrix_b: Matrix, threshold: int = 64) -> Matrix:
85135
"""
86-
Multiply square matrices A and B using Strassen algorithm.
87-
threshold: below this size, uses naive multiplication (tweakable).
136+
Multiply two square matrices using Strassen's algorithm.
137+
Uses naive multiplication for matrices smaller than threshold.
138+
139+
>>> strassen([[1,2],[3,4]], [[5,6],[7,8]])
140+
[[19, 22], [43, 50]]
88141
"""
89-
assert len(A) == len(A[0]) == len(B) == len(B[0]), (
90-
"Only square matrices supported in this implementation"
142+
assert len(matrix_a) == len(matrix_a[0]) == len(matrix_b) == len(matrix_b[0]), (
143+
"Only square matrices supported"
91144
)
92145

93-
n_orig = len(A)
146+
n_orig = len(matrix_a)
94147
if n_orig == 0:
95148
return []
96149

97150
if (m := next_power_of_two(n_orig)) != n_orig:
98-
A_pad = pad_matrix(A, m)
99-
B_pad = pad_matrix(B, m)
151+
a_pad = pad_matrix(matrix_a, m)
152+
b_pad = pad_matrix(matrix_b, m)
100153
else:
101-
A_pad, B_pad = A, B
102-
103-
C_pad = _strassen_recursive(A_pad, B_pad, threshold)
154+
a_pad, b_pad = matrix_a, matrix_b
104155

105-
C = unpad_matrix(C_pad, n_orig, n_orig)
106-
return C
156+
c_pad = _strassen_recursive(a_pad, b_pad, threshold)
157+
return unpad_matrix(c_pad, n_orig, n_orig)
107158

108159

109-
def _strassen_recursive(A: Matrix, B: Matrix, threshold: int) -> Matrix:
110-
n = len(A)
160+
def _strassen_recursive(matrix_a: Matrix, matrix_b: Matrix, threshold: int) -> Matrix:
161+
n = len(matrix_a)
111162
if n <= threshold:
112-
return naive_mul(A, B)
163+
return naive_mul(matrix_a, matrix_b)
113164
if n == 1:
114-
return [[A[0][0] * B[0][0]]]
165+
return [[matrix_a[0][0] * matrix_b[0][0]]]
115166

116-
A11, A12, A21, A22 = split(A)
117-
B11, B12, B21, B22 = split(B)
167+
a11, a12, a21, a22 = split(matrix_a)
168+
b11, b12, b21, b22 = split(matrix_b)
118169

119-
M1 = _strassen_recursive(add(A11, A22), add(B11, B22), threshold)
120-
M2 = _strassen_recursive(add(A21, A22), B11, threshold)
121-
M3 = _strassen_recursive(A11, sub(B12, B22), threshold)
122-
M4 = _strassen_recursive(A22, sub(B21, B11), threshold)
123-
M5 = _strassen_recursive(add(A11, A12), B22, threshold)
124-
M6 = _strassen_recursive(sub(A21, A11), add(B11, B12), threshold)
125-
M7 = _strassen_recursive(sub(A12, A22), add(B21, B22), threshold)
170+
m1 = _strassen_recursive(add(a11, a22), add(b11, b22), threshold)
171+
m2 = _strassen_recursive(add(a21, a22), b11, threshold)
172+
m3 = _strassen_recursive(a11, sub(b12, b22), threshold)
173+
m4 = _strassen_recursive(a22, sub(b21, b11), threshold)
174+
m5 = _strassen_recursive(add(a11, a12), b22, threshold)
175+
m6 = _strassen_recursive(sub(a21, a11), add(b11, b12), threshold)
176+
m7 = _strassen_recursive(sub(a12, a22), add(b21, b22), threshold)
126177

127-
C11 = add(sub(add(M1, M4), M5), M7)
128-
C12 = add(M3, M5)
129-
C21 = add(M2, M4)
130-
C22 = add(sub(add(M1, M3), M2), M6)
178+
c11 = add(sub(add(m1, m4), m5), m7)
179+
c12 = add(m3, m5)
180+
c21 = add(m2, m4)
181+
c22 = add(sub(add(m1, m3), m2), m6)
131182

132-
return join(C11, C12, C21, C22)
183+
return join(c11, c12, c21, c22)
133184

134185

135186
if __name__ == "__main__":
@@ -141,7 +192,6 @@ def _strassen_recursive(A: Matrix, B: Matrix, threshold: int) -> Matrix:
141192
for row in C:
142193
print(row)
143194

144-
# verify against naive
145195
expected = naive_mul(A, B)
146196
assert C == expected, "Strassen result differs from naive multiplication!"
147197
print("Verified: result matches naive multiplication.")

0 commit comments

Comments
 (0)