Skip to content

Latest commit

 

History

History
35 lines (23 loc) · 1.07 KB

File metadata and controls

35 lines (23 loc) · 1.07 KB

GEMM Kernel

The gemm operator computes the matrix-matrix product of two matrices.

Mathematical Definition

Given input matrices A and B, along with an output matrix C and scalars α and β, the kernel evaluates

$$ C = \alpha A B + \beta C $$

The matrix-matrix product is computed by multiplying the matrix A with the matrix B, scaling the result by α, scaling the matrix C by β, and then adding the two scaled results together to produce the updated matrix C.

Kernel Implementations

All backends share the interface:

def gemm(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, alpha: float, beta: float) -> torch.Tensor:
    ...

Testing

See the test suite for the validation harness that exercises every backend.

pytest tests/test_gemm.py -s