Skip to content

The nisa.nc_matmul in NKI is not as well-optimized as torch.matmul. #59

@dinghongsong

Description

@dinghongsong

Describe the bug

As the core computation kernels of LLMs, the matrix multiplication in NKI is significantly slower than torch.matmul, as measured by Python's time.time()

Relevant discussions can be found here (1), (2), (3).

I performed matrix multiplication with shapes [128,128] and [128,512] using both torch.matmul and nisa.nc_matmul, and obtained the respective web UI outputs with Neuron Profile as follows.

torch_matmul:
Image

nki_matmul:
Image

Their differences mainly appear after the x-axis reaches 20,000 (the area highlighted in the red box). The reason for nisa.nc_matmul being slower may be that it is not as well-optimized as torch.matmul.

Expected Behavior

Under Python's time.time() measurement, nki_matmul should be faster than torch.matmul.

Current Behavior

The matrix multiplication in NKI is significantly slower than torch.matmul, as measured by Python's time.time()

Reproduction Steps

torch_matmul.py:

from neuronxcc import nki
import neuronxcc.nki.language as nl
import neuronxcc.nki.isa as ni
import numpy as np
import torch 
import time
import neuronxcc.nki.isa as nisa
from torch_xla.core import xla_model as xm
import os
os.environ["NEURON_FRAMEWORK_DEBUG"] = "1"
os.environ["NEURON_CC_FLAGS"]= " --disable-dge "


if __name__ == "__main__":
  K, M, N = 128, 128, 512
  device = xm.xla_device()
  cpu = torch.device('cpu')

  A = torch.rand((M, K), dtype=torch.bfloat16, device=device)
  B = torch.rand((K, N), dtype=torch.bfloat16, device=device)
  
  for _ in range(100):
    start_time = time.time()
    output_torch = torch.matmul(A, B)
    xm.mark_step()
    xm.wait_device_ops()
    end_time = time.time()
    print(f"output_torch={output_torch}")
    print("torch matmul time (ms): ", (end_time - start_time ) * 1000)

nki_matmul.py:

from neuronxcc import nki
import neuronxcc.nki.language as nl
import neuronxcc.nki.isa as ni
import numpy as np
import torch 
import time
import neuronxcc.nki.isa as nisa
from torch_xla.core import xla_model as xm
import os
os.environ["NEURON_FRAMEWORK_DEBUG"] = "1"
os.environ["NEURON_CC_FLAGS"]= " --disable-dge "

@nki.jit
def nki_matmul_basic_(lhsT, rhs):
  """NKI kernel to compute a 128x128x512 matrix multiplication operation

  Args:
      lhsT: an input tensor of shape [128,128], a left hand side argument of the
        matrix multiplication, delivered transposed for optimal performance
      rhs: an input tensor of shape [128,512], a right hand side argument of the
        matrix multiplication
  Returns:
      result: the resulting output tensor of shape [128,512]
  """
  result = nl.ndarray((128, 512), dtype=lhsT.dtype, buffer=nl.shared_hbm)

  i_lhsT_p, i_lhsT_f = nl.mgrid[0:128, 0:128]
  i_rhs_p, i_rhs_f = nl.mgrid[0:128, 0:512]
  i_out_p, i_out_f = nl.mgrid[0:128, 0:512]

 
  lhs_tile = nl.load(lhsT[i_lhsT_p, i_lhsT_f])
  rhs_tile = nl.load(rhs[i_rhs_p, i_rhs_f])

  
  result_psum = nisa.nc_matmul(lhs_tile, rhs_tile)

 
  result_sbuf = nl.copy(result_psum, dtype=result.dtype)
  nl.store(result[i_out_p, i_out_f], value=result_sbuf)

  return result


if __name__ == "__main__":
  K, M, N = 128, 128, 512
  device = xm.xla_device()
  cpu = torch.device('cpu')

  A = torch.rand((M, K), dtype=torch.bfloat16, device=device)
  B = torch.rand((K, N), dtype=torch.bfloat16, device=device)
  
  for _ in range(100):
    start_time = time.time()
    output_nki = nki_matmul_basic_(A.T, B)
    end_time = time.time()
    print(f"output_nki={output_nki}")
    print("nki matmul time (ms): ", (end_time - start_time ) * 1000)

Regression Issue

  • Select this option if this issue appears to be a regression.

Possible Solution

No response

Additional Information/Context

No response

neuronx-cc version used

NeuronX Compiler version 2.16.372.0+4a9b2326

Framework(s) and their versions used (JAX, PyTorch, etc..)

Python version 3.10.12
HWM version 2.16.0.372+4a9b2326
NumPy version 1.25.2
Running on trn1

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions