-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathstarter_kit_baseline.py
More file actions
41 lines (32 loc) · 1.06 KB
/
starter_kit_baseline.py
File metadata and controls
41 lines (32 loc) · 1.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import numpy as np
import warnings
from numba import njit
def safe_njit(f):
try:
return njit(f)
except Exception as e:
warnings.warn(f"Numba JIT compilation failed for function {f.__name__}: {e}")
return f
@safe_njit
def matmul(a: np.ndarray, b: np.ndarray, tile_size: int = 16) -> np.ndarray:
"""Tiled matrix multiplication equivalent to np.matmul(a, b)"""
m, n = a.shape
n, k = b.shape
c = np.zeros((m, k), dtype=a.dtype)
# Loop over tiles
for i in range(0, m, tile_size):
for j in range(0, k, tile_size):
for l in range(0, n, tile_size):
# Compute the tile
for ii in range(i, min(i + tile_size, m)):
for jj in range(j, min(j + tile_size, k)):
for ll in range(l, min(l + tile_size, n)):
c[ii, jj] += a[ii, ll] * b[ll, jj]
return c
@safe_njit
def reduce(a: np.ndarray) -> float:
"""Equivalent to np.sum(a)"""
s = 0.0
for i in range(a.shape[0]):
s += a[i]
return s