1414
1515Matrix = list [list [int ]]
1616
17- def add (A : Matrix , B : Matrix ) -> Matrix :
18- n = len (A )
19- return [[A [i ][j ] + B [i ][j ] for j in range (n )] for i in range (n )]
2017
18+ def add (matrix_a : Matrix , matrix_b : Matrix ) -> Matrix :
19+ """
20+ Add two square matrices of the same size.
21+
22+ >>> add([[1,2],[3,4]], [[5,6],[7,8]])
23+ [[6, 8], [10, 12]]
24+ """
25+ n = len (matrix_a )
26+ return [[matrix_a [i ][j ] + matrix_b [i ][j ] for j in range (n )] for i in range (n )]
27+
28+
29+ def sub (matrix_a : Matrix , matrix_b : Matrix ) -> Matrix :
30+ """
31+ Subtract matrix_b from matrix_a.
32+
33+ >>> sub([[5,6],[7,8]], [[1,2],[3,4]])
34+ [[4, 4], [4, 4]]
35+ """
36+ n = len (matrix_a )
37+ return [[matrix_a [i ][j ] - matrix_b [i ][j ] for j in range (n )] for i in range (n )]
2138
22- def sub (A : Matrix , B : Matrix ) -> Matrix :
23- n = len (A )
24- return [[A [i ][j ] - B [i ][j ] for j in range (n )] for i in range (n )]
2539
40+ def naive_mul (matrix_a : Matrix , matrix_b : Matrix ) -> Matrix :
41+ """
42+ Multiply two square matrices using the naive O(n^3) method.
2643
27- def naive_mul (A : Matrix , B : Matrix ) -> Matrix :
28- n = len (A )
29- C = [[0 ] * n for _ in range (n )]
44+ >>> naive_mul([[1,2],[3,4]], [[5,6],[7,8]])
45+ [[19, 22], [43, 50]]
46+ """
47+ n = len (matrix_a )
48+ result = [[0 ] * n for _ in range (n )]
3049 for i in range (n ):
31- ai = A [i ]
32- ci = C [i ]
50+ row_a = matrix_a [i ]
51+ row_result = result [i ]
3352 for k in range (n ):
34- a_ik = ai [k ]
35- bk = B [k ]
53+ a_ik = row_a [k ]
54+ col_b = matrix_b [k ]
3655 for j in range (n ):
37- ci [j ] += a_ik * bk [j ]
38- return C
56+ row_result [j ] += a_ik * col_b [j ]
57+ return result
3958
4059
4160def next_power_of_two (n : int ) -> int :
42- p = 1
43- while p < n :
44- p <<= 1
45- return p
61+ """
62+ Return the next power of two greater than or equal to n.
63+
64+ >>> next_power_of_two(5)
65+ 8
66+ """
67+ power = 1
68+ while power < n :
69+ power <<= 1
70+ return power
71+
4672
73+ def pad_matrix (matrix : Matrix , size : int ) -> Matrix :
74+ """
75+ Pad a matrix with zeros to reach the given size.
4776
48- def pad_matrix (A : Matrix , size : int ) -> Matrix :
49- n = len (A )
77+ >>> pad_matrix([[1,2],[3,4]], 4)
78+ [[1, 2, 0, 0], [3, 4, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]
79+ """
80+ rows = len (matrix )
81+ cols = len (matrix [0 ])
5082 padded = [[0 ] * size for _ in range (size )]
51- for i in range (n ):
52- for j in range (len ( A [ 0 ]) ):
53- padded [i ][j ] = A [i ][j ]
83+ for i in range (rows ):
84+ for j in range (cols ):
85+ padded [i ][j ] = matrix [i ][j ]
5486 return padded
5587
5688
57- def unpad_matrix (A : Matrix , rows : int , cols : int ) -> Matrix :
58- return [row [:cols ] for row in A [:rows ]]
89+ def unpad_matrix (matrix : Matrix , rows : int , cols : int ) -> Matrix :
90+ """
91+ Remove padding from a matrix.
92+
93+ >>> unpad_matrix([[1,2,0],[3,4,0],[0,0,0]], 2, 2)
94+ [[1, 2], [3, 4]]
95+ """
96+ return [row [:cols ] for row in matrix [:rows ]]
97+
5998
99+ def split (matrix : Matrix ) -> tuple :
100+ """
101+ Split a matrix into four quadrants (top-left, top-right, bottom-left, bottom-right).
60102
61- def split (A : Matrix ) -> tuple :
62- n = len (A )
103+ >>> split([[1,2],[3,4]])
104+ ([[1]], [[2]], [[3]], [[4]])
105+ """
106+ n = len (matrix )
63107 mid = n // 2
64- A11 = [[A [i ][j ] for j in range (mid )] for i in range (mid )]
65- A12 = [[A [i ][j ] for j in range (mid , n )] for i in range (mid )]
66- A21 = [[A [i ][j ] for j in range (mid )] for i in range (mid , n )]
67- A22 = [[A [i ][j ] for j in range (mid , n )] for i in range (mid , n )]
68- return A11 , A12 , A21 , A22
108+ top_left = [[matrix [i ][j ] for j in range (mid )] for i in range (mid )]
109+ top_right = [[matrix [i ][j ] for j in range (mid , n )] for i in range (mid )]
110+ bottom_left = [[matrix [i ][j ] for j in range (mid )] for i in range (mid , n )]
111+ bottom_right = [[matrix [i ][j ] for j in range (mid , n )] for i in range (mid , n )]
112+ return top_left , top_right , bottom_left , bottom_right
69113
70114
71- def join (C11 : Matrix , C12 : Matrix , C21 : Matrix , C22 : Matrix ) -> Matrix :
72- n2 = len (C11 )
115+ def join (c11 : Matrix , c12 : Matrix , c21 : Matrix , c22 : Matrix ) -> Matrix :
116+ """
117+ Join four quadrants into a single matrix.
118+
119+ >>> join([[1]], [[2]], [[3]], [[4]])
120+ [[1, 2], [3, 4]]
121+ """
122+ n2 = len (c11 )
73123 n = n2 * 2
74- C = [[0 ] * n for _ in range (n )]
124+ result = [[0 ] * n for _ in range (n )]
75125 for i in range (n2 ):
76126 for j in range (n2 ):
77- C [i ][j ] = C11 [i ][j ]
78- C [i ][j + n2 ] = C12 [i ][j ]
79- C [i + n2 ][j ] = C21 [i ][j ]
80- C [i + n2 ][j + n2 ] = C22 [i ][j ]
81- return C
127+ result [i ][j ] = c11 [i ][j ]
128+ result [i ][j + n2 ] = c12 [i ][j ]
129+ result [i + n2 ][j ] = c21 [i ][j ]
130+ result [i + n2 ][j + n2 ] = c22 [i ][j ]
131+ return result
82132
83133
84- def strassen (A : Matrix , B : Matrix , threshold : int = 64 ) -> Matrix :
134+ def strassen (matrix_a : Matrix , matrix_b : Matrix , threshold : int = 64 ) -> Matrix :
85135 """
86- Multiply square matrices A and B using Strassen algorithm.
87- threshold: below this size, uses naive multiplication (tweakable).
136+ Multiply two square matrices using Strassen's algorithm.
137+ Uses naive multiplication for matrices smaller than threshold.
138+
139+ >>> strassen([[1,2],[3,4]], [[5,6],[7,8]])
140+ [[19, 22], [43, 50]]
88141 """
89- assert len (A ) == len (A [0 ]) == len (B ) == len (B [0 ]), (
90- "Only square matrices supported in this implementation "
142+ assert len (matrix_a ) == len (matrix_a [0 ]) == len (matrix_b ) == len (matrix_b [0 ]), (
143+ "Only square matrices supported"
91144 )
92145
93- n_orig = len (A )
146+ n_orig = len (matrix_a )
94147 if n_orig == 0 :
95148 return []
96149
97150 if (m := next_power_of_two (n_orig )) != n_orig :
98- A_pad = pad_matrix (A , m )
99- B_pad = pad_matrix (B , m )
151+ a_pad = pad_matrix (matrix_a , m )
152+ b_pad = pad_matrix (matrix_b , m )
100153 else :
101- A_pad , B_pad = A , B
102-
103- C_pad = _strassen_recursive (A_pad , B_pad , threshold )
154+ a_pad , b_pad = matrix_a , matrix_b
104155
105- C = unpad_matrix ( C_pad , n_orig , n_orig )
106- return C
156+ c_pad = _strassen_recursive ( a_pad , b_pad , threshold )
157+ return unpad_matrix ( c_pad , n_orig , n_orig )
107158
108159
109- def _strassen_recursive (A : Matrix , B : Matrix , threshold : int ) -> Matrix :
110- n = len (A )
160+ def _strassen_recursive (matrix_a : Matrix , matrix_b : Matrix , threshold : int ) -> Matrix :
161+ n = len (matrix_a )
111162 if n <= threshold :
112- return naive_mul (A , B )
163+ return naive_mul (matrix_a , matrix_b )
113164 if n == 1 :
114- return [[A [0 ][0 ] * B [0 ][0 ]]]
165+ return [[matrix_a [0 ][0 ] * matrix_b [0 ][0 ]]]
115166
116- A11 , A12 , A21 , A22 = split (A )
117- B11 , B12 , B21 , B22 = split (B )
167+ a11 , a12 , a21 , a22 = split (matrix_a )
168+ b11 , b12 , b21 , b22 = split (matrix_b )
118169
119- M1 = _strassen_recursive (add (A11 , A22 ), add (B11 , B22 ), threshold )
120- M2 = _strassen_recursive (add (A21 , A22 ), B11 , threshold )
121- M3 = _strassen_recursive (A11 , sub (B12 , B22 ), threshold )
122- M4 = _strassen_recursive (A22 , sub (B21 , B11 ), threshold )
123- M5 = _strassen_recursive (add (A11 , A12 ), B22 , threshold )
124- M6 = _strassen_recursive (sub (A21 , A11 ), add (B11 , B12 ), threshold )
125- M7 = _strassen_recursive (sub (A12 , A22 ), add (B21 , B22 ), threshold )
170+ m1 = _strassen_recursive (add (a11 , a22 ), add (b11 , b22 ), threshold )
171+ m2 = _strassen_recursive (add (a21 , a22 ), b11 , threshold )
172+ m3 = _strassen_recursive (a11 , sub (b12 , b22 ), threshold )
173+ m4 = _strassen_recursive (a22 , sub (b21 , b11 ), threshold )
174+ m5 = _strassen_recursive (add (a11 , a12 ), b22 , threshold )
175+ m6 = _strassen_recursive (sub (a21 , a11 ), add (b11 , b12 ), threshold )
176+ m7 = _strassen_recursive (sub (a12 , a22 ), add (b21 , b22 ), threshold )
126177
127- C11 = add (sub (add (M1 , M4 ), M5 ), M7 )
128- C12 = add (M3 , M5 )
129- C21 = add (M2 , M4 )
130- C22 = add (sub (add (M1 , M3 ), M2 ), M6 )
178+ c11 = add (sub (add (m1 , m4 ), m5 ), m7 )
179+ c12 = add (m3 , m5 )
180+ c21 = add (m2 , m4 )
181+ c22 = add (sub (add (m1 , m3 ), m2 ), m6 )
131182
132- return join (C11 , C12 , C21 , C22 )
183+ return join (c11 , c12 , c21 , c22 )
133184
134185
135186if __name__ == "__main__" :
@@ -141,7 +192,6 @@ def _strassen_recursive(A: Matrix, B: Matrix, threshold: int) -> Matrix:
141192 for row in C :
142193 print (row )
143194
144- # verify against naive
145195 expected = naive_mul (A , B )
146196 assert C == expected , "Strassen result differs from naive multiplication!"
147197 print ("Verified: result matches naive multiplication." )
0 commit comments