@@ -70,7 +70,7 @@ def solve(arr: NDArray, row: int, cols: set[int], cache: dict[str, int]) -> int:
7070 included in `cols`. `cache` is used for caching intermediate results.
7171
7272 >>> solve(arr=np.array([[1, 2], [3, 4]]), row=0, cols={0, 1}, cache={})
73- np.int64(5)
73+ 5
7474 """
7575
7676 cache_id = f"{ row } , { sorted (cols )} "
@@ -85,7 +85,7 @@ def solve(arr: NDArray, row: int, cols: set[int], cache: dict[str, int]) -> int:
8585 new_cols = cols - {col }
8686 max_sum = max (
8787 max_sum ,
88- arr [row , col ] + solve (arr = arr , row = row + 1 , cols = new_cols , cache = cache ),
88+ int ( arr [row , col ]) + solve (arr = arr , row = row + 1 , cols = new_cols , cache = cache ),
8989 )
9090 cache [cache_id ] = max_sum
9191 return max_sum
@@ -102,7 +102,7 @@ def solution(matrix_str: list[str] = MATRIX_2) -> int:
102102 """
103103
104104 n = len (matrix_str )
105- arr = np .empty (shape = (n , n ), dtype = np . int64 )
105+ arr = np .empty (shape = (n , n ), dtype = int )
106106 for row , matrix_row_str in enumerate (matrix_str ):
107107 matrix_row_list_str = matrix_row_str .split ()
108108 for col , elem_str in enumerate (matrix_row_list_str ):
0 commit comments