-
Notifications
You must be signed in to change notification settings - Fork 33
Description
Describe the bug
As the core computation kernels of LLMs, the matrix multiplication in NKI is significantly slower than torch.matmul, as measured by Python's time.time()
Relevant discussions can be found here (1), (2), (3).
I performed matrix multiplication with shapes [128,128] and [128,512] using both torch.matmul and nisa.nc_matmul, and obtained the respective web UI outputs with Neuron Profile as follows.
Their differences mainly appear after the x-axis reaches 20,000 (the area highlighted in the red box). The reason for nisa.nc_matmul being slower may be that it is not as well-optimized as torch.matmul.
Expected Behavior
Under Python's time.time() measurement, nki_matmul should be faster than torch.matmul.
Current Behavior
The matrix multiplication in NKI is significantly slower than torch.matmul, as measured by Python's time.time()
Reproduction Steps
torch_matmul.py:
from neuronxcc import nki
import neuronxcc.nki.language as nl
import neuronxcc.nki.isa as ni
import numpy as np
import torch
import time
import neuronxcc.nki.isa as nisa
from torch_xla.core import xla_model as xm
import os
os.environ["NEURON_FRAMEWORK_DEBUG"] = "1"
os.environ["NEURON_CC_FLAGS"]= " --disable-dge "
if __name__ == "__main__":
K, M, N = 128, 128, 512
device = xm.xla_device()
cpu = torch.device('cpu')
A = torch.rand((M, K), dtype=torch.bfloat16, device=device)
B = torch.rand((K, N), dtype=torch.bfloat16, device=device)
for _ in range(100):
start_time = time.time()
output_torch = torch.matmul(A, B)
xm.mark_step()
xm.wait_device_ops()
end_time = time.time()
print(f"output_torch={output_torch}")
print("torch matmul time (ms): ", (end_time - start_time ) * 1000)
nki_matmul.py:
from neuronxcc import nki
import neuronxcc.nki.language as nl
import neuronxcc.nki.isa as ni
import numpy as np
import torch
import time
import neuronxcc.nki.isa as nisa
from torch_xla.core import xla_model as xm
import os
os.environ["NEURON_FRAMEWORK_DEBUG"] = "1"
os.environ["NEURON_CC_FLAGS"]= " --disable-dge "
@nki.jit
def nki_matmul_basic_(lhsT, rhs):
"""NKI kernel to compute a 128x128x512 matrix multiplication operation
Args:
lhsT: an input tensor of shape [128,128], a left hand side argument of the
matrix multiplication, delivered transposed for optimal performance
rhs: an input tensor of shape [128,512], a right hand side argument of the
matrix multiplication
Returns:
result: the resulting output tensor of shape [128,512]
"""
result = nl.ndarray((128, 512), dtype=lhsT.dtype, buffer=nl.shared_hbm)
i_lhsT_p, i_lhsT_f = nl.mgrid[0:128, 0:128]
i_rhs_p, i_rhs_f = nl.mgrid[0:128, 0:512]
i_out_p, i_out_f = nl.mgrid[0:128, 0:512]
lhs_tile = nl.load(lhsT[i_lhsT_p, i_lhsT_f])
rhs_tile = nl.load(rhs[i_rhs_p, i_rhs_f])
result_psum = nisa.nc_matmul(lhs_tile, rhs_tile)
result_sbuf = nl.copy(result_psum, dtype=result.dtype)
nl.store(result[i_out_p, i_out_f], value=result_sbuf)
return result
if __name__ == "__main__":
K, M, N = 128, 128, 512
device = xm.xla_device()
cpu = torch.device('cpu')
A = torch.rand((M, K), dtype=torch.bfloat16, device=device)
B = torch.rand((K, N), dtype=torch.bfloat16, device=device)
for _ in range(100):
start_time = time.time()
output_nki = nki_matmul_basic_(A.T, B)
end_time = time.time()
print(f"output_nki={output_nki}")
print("nki matmul time (ms): ", (end_time - start_time ) * 1000)
Regression Issue
- Select this option if this issue appears to be a regression.
Possible Solution
No response
Additional Information/Context
No response
neuronx-cc version used
NeuronX Compiler version 2.16.372.0+4a9b2326
Framework(s) and their versions used (JAX, PyTorch, etc..)
Python version 3.10.12
HWM version 2.16.0.372+4a9b2326
NumPy version 1.25.2
Running on trn1

