Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions nkigen/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Python
__pycache__/
*.py[cod]
*.so

# Distribution / packaging
build/
dist/
*.egg-info/
.eggs/
*.whl

# Built MLIR bindings (generated during build)
nkigen/_mlir/

# Top-level repo's .gitignore has `lib/` (for Python venv lib/ dirs); don't
# silently exclude the MLIR C++ pass sources in mlir/lib/.
!mlir/lib/

# Virtual environments
venv/
.env

# Testing
.pytest_cache/
.coverage
tests/**/outputs/
tests/**/artifacts/

# IDE
.vscode/
.idea/

# OS
.DS_Store
Thumbs.db

# Logs
*.log

# LLVM lit test outputs
.lit_test_times.txt
Output/
308 changes: 308 additions & 0 deletions nkigen/CLAUDE.md

Large diffs are not rendered by default.

94 changes: 94 additions & 0 deletions nkigen/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@

# NKIPy KernelGen

A Python-to-NISA MLIR compiler for Trainium. Trace NumPy functions, annotate
with knobs, and lower through a pipeline of MLIR passes to NKI/NISA dialect.

- https://quip-amazon.com/UYybAmfN2E2k/Design-Doc-NKIPy-KernelGen-in-MLIR

## Features

- **`@trace` decorator** — trace Python functions with NumPy operations into linalg MLIR
- **`knob()` API** — annotate tensors with tiling, memory placement, and partitioning hints (similar to OpenMP pragmas)
- **MLIR pass pipeline** — prepare-arithmetic, assign-op-ids, infer-layout, knob-driven-tiling, legalize-layout, linalg-to-nisa, and more
- **Compiler Explorer** — interactive web UI for inspecting IR at each compilation stage

## Setup

NKIPyKernelGen depends on LLVM/MLIR and the NKI compiler (NISA dialect). Build
them first by following the
[NKI build instructions](https://github.com/aws-neuron/private-nki-staging/blob/main/docs/BUILDING.md).

Then set up this project:

```bash
# Set up the NKI development environment (venv, LLVM/MLIR paths, PYTHONPATH)
source scripts/setup_nki.sh

# Install the Python package (editable)
pip install -e .
```

## Quick Start

```python
import numpy as np
from nkigen import trace, knob

@trace(input_specs=[((256, 256), "f32"), ((256, 256), "f32")])
def matmul_add(A, B):
C = np.matmul(A, B)
knob.knob(C, mem_space="Sbuf", tile_size=[128, 128, 128])
result = np.exp(C)
knob.knob(result, mem_space="Sbuf", tile_size=[128, 128])
return result
```

The `knob()` API injects `nkipy.annotate` ops into the IR to guide tiling and
buffer placement. The `infer-layout` pass propagates annotations to unannotated
intermediate ops.

## Project Structure

```
NKIPyKernelGen/
├── nkigen/ # Python package (tracer, knob API, transforms)
├── mlir/ # MLIR dialects and C++ passes
│ └── lib/Transforms/ # Pass implementations
├── tests/ # Test suite
│ ├── unit/ # Unit tests for ops and execution engine
│ ├── passes/ # MLIR pass tests (tiling, layout, etc.)
│ └── e2e/ # End-to-end compilation tests
└── scripts/ # Environment setup and build scripts
```

## MLIR Passes

The compilation pipeline runs these passes in order:

| Pass | Description |
|------|-------------|
| `prepare-arithmetic` | Normalize arithmetic (e.g. `divf by N` → `mulf by 1/N`) |
| `assign-linalg-op-ids` | Tag each linalg op with a unique `nkipy.op_id` |
| `infer-layout` | Propagate tile_size/mem_space to unannotated ops (elementwise chains, reduce chains) |
| `knob-driven-tiling` | Generate transform dialect sequences from knob annotations |
| `transform-interpreter` | Apply the generated tiling transforms |
| `legalize-layout` | Insert 4D physical layout transformations |
| `linalg-to-nisa` | Lower linalg ops to NISA dialect |
| `prepare-for-nki` | Final preparation for NKI backend |

## Testing

```bash
# Set up environment first
source scripts/setup_nki.sh

# Run all tests
cd tests && python -m pytest . -v

# Run a specific test category
python -m pytest passes/infer_layout/ -v
python -m pytest e2e/ -v
python -m pytest unit/ -v
```

48 changes: 48 additions & 0 deletions nkigen/mlir/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
cmake_minimum_required(VERSION 3.20.0)
cmake_policy(SET CMP0116 NEW)

project(nkipy-kg LANGUAGES CXX C)

set(CMAKE_BUILD_WITH_INSTALL_NAME_DIR ON)

set(CMAKE_CXX_STANDARD 17 CACHE STRING "C++ standard to conform to")
set(CMAKE_CXX_FLAGS "-Wfatal-errors -std=c++17")
add_compile_options ( -w )

set(CMAKE_BUILD_TYPE Debug)

# Define NDEBUG to match Release-built MLIR libraries (avoids undefined
# reference to debug-only functions like checkImplementsTransformOpInterface)
add_definitions(-DNDEBUG)

find_package(MLIR REQUIRED CONFIG)

message(STATUS "Using MLIRConfig.cmake in: ${MLIR_DIR}")
message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}")

set(LLVM_RUNTIME_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/bin)
set(LLVM_LIBRARY_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/lib)
set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR})

list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}")
list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}")
include(TableGen)
include(AddLLVM)
include(AddMLIR)
include(HandleLLVMOptions)

include_directories(${LLVM_INCLUDE_DIRS})
include_directories(${MLIR_INCLUDE_DIRS})
include_directories(${PROJECT_SOURCE_DIR}/include)
include_directories(${PROJECT_BINARY_DIR}/include)

link_directories(${LLVM_BUILD_LIBRARY_DIR})
add_definitions(${LLVM_DEFINITIONS})

message(STATUS "Using Python binding")
include(MLIRDetectPythonEnv)
mlir_configure_python_dev_packages()

add_subdirectory(include)
add_subdirectory(lib)
add_subdirectory(tools)
1 change: 1 addition & 0 deletions nkigen/mlir/include/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add_subdirectory(nkipy)
17 changes: 17 additions & 0 deletions nkigen/mlir/include/nkipy-c/Dialect/Dialects.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@

#ifndef NKIPY_C_DIALECT__H
#define NKIPY_C_DIALECT__H

#include "mlir-c/RegisterEverything.h"

#ifdef __cplusplus
extern "C" {
#endif

MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(NKIPY, nkipy);

#ifdef __cplusplus
}
#endif

#endif // NKIPY_C_DIALECT__H
24 changes: 24 additions & 0 deletions nkigen/mlir/include/nkipy-c/Dialect/NkipyAttributes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#ifndef NKIPY_MLIR_C_ATTRIBUTES__H
#define NKIPY_MLIR_C_ATTRIBUTES__H

#include "mlir-c/IR.h"
#include "mlir-c/IntegerSet.h"
#include "mlir-c/Support.h"
#include "mlir/CAPI/IntegerSet.h"

#ifdef __cplusplus
extern "C" {
#endif

// MLIR_CAPI_EXPORTED bool mlirAttributeIsAIntegerSet(MlirAttribute attr);
MLIR_CAPI_EXPORTED MlirAttribute mlirIntegerSetAttrGet(MlirIntegerSet set);

MLIR_CAPI_EXPORTED bool mlirAttributeIsAMemSpace(MlirAttribute attr);
MLIR_CAPI_EXPORTED MlirAttribute mlirMemSpaceGet(MlirContext ctx,
MlirAttribute space);

#ifdef __cplusplus
}
#endif

#endif // NKIPY_MLIR_C_ATTRIBUTES__H
25 changes: 25 additions & 0 deletions nkigen/mlir/include/nkipy-c/Dialect/Registration.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@

#ifndef NKIPY_MLIR_C_REGISTRATION_H
#define NKIPY_MLIR_C_REGISTRATION_H

#include "mlir/CAPI/IR.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"

#ifdef __cplusplus
extern "C" {
#endif

/** Registers all dialects with a context.
* This is needed before creating IR for these Dialects.
*/
MLIR_CAPI_EXPORTED void nkipyMlirRegisterAllDialects(MlirContext context);

/** Registers all passes for symbolic access with the global registry. */
MLIR_CAPI_EXPORTED void nkipyMlirRegisterAllPasses();

#ifdef __cplusplus
}
#endif

#endif // NKIPY_MLIR_C_REGISTRATION_H
96 changes: 96 additions & 0 deletions nkigen/mlir/include/nkipy/Bindings/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
include(AddMLIRPython)

# The directory at which the Python import tree begins.
# See documentation for `declare_mlir_python_sources`'s ROOT_DIR
# argument.
set(NKIPY_MLIR_PYTHON_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/nkipy")
set(NKIPY_MLIR_PYTHON_PACKAGES_DIR "${PROJECT_BINARY_DIR}/tools/nkipy")
set(MLIR_PYTHON_SOURCE_DIR "${MLIR_MAIN_SRC_DIR}/lib/Bindings")
set(NKIPY_PYTHON_SOURCE_DIR "${PROJECT_SOURCE_DIR}/lib/Bindings")

include_directories(${MLIR_PYTHON_SOURCE_DIR})

# Use the system MLIR package prefix to ensure capsule compatibility
add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=mlir.")

################################################################################
# Sources
################################################################################

declare_mlir_python_sources(NkipyMLIRPythonSources)
declare_mlir_python_sources(NkipyMLIRPythonExtensions)

declare_mlir_python_sources(NkipyMLIRPythonSources.Dialects
ROOT_DIR "${NKIPY_MLIR_PYTHON_ROOT_DIR}"
ADD_TO_PARENT NkipyMLIRPythonSources
)

declare_mlir_dialect_python_bindings(
ADD_TO_PARENT NkipyMLIRPythonSources.Dialects
ROOT_DIR "${NKIPY_MLIR_PYTHON_ROOT_DIR}"
TD_FILE dialects/NkipyBinding.td
SOURCES
dialects/nkipy.py
dialects/_ods_common.py
exceptions.py
__init__.py
DIALECT_NAME nkipy
)

################################################################################
# Extensions
################################################################################

declare_mlir_python_extension(NkipyMLIRPythonExtensions.Main
MODULE_NAME _nkipy
ADD_TO_PARENT NkipyMLIRPythonExtensions
ROOT_DIR "/"
PYTHON_BINDINGS_LIBRARY nanobind
SOURCES
${NKIPY_PYTHON_SOURCE_DIR}/NkipyModule.cpp
${NKIPY_PYTHON_SOURCE_DIR}/NkipyAttributes.cpp
EMBED_CAPI_LINK_LIBS
MLIRCAPIIR
MLIRCAPIDebug
MLIRNkipyCAPI
PRIVATE_LINK_LIBS
MLIRPass
MLIRFuncDialect
MLIRMemRefDialect
MLIRAffineDialect
LLVMSupport
)

################################################################################
# Generate packages and shared library
# Downstreams typically will not use these, but they are useful for local
# testing.
################################################################################

# Only build NKIPy custom dialect extensions, not full MLIR bindings
# This avoids nanobind type conflicts with system mlir/nki.compiler._internal
set(_source_components
NkipyMLIRPythonSources
NkipyMLIRPythonExtensions
# MLIRPythonSources - REMOVED: Use system mlir package instead
# MLIRPythonExtension.RegisterEverything - REMOVED: Use system mlir package instead
)

add_mlir_python_common_capi_library(NkipyMLIRAggregateCAPI
INSTALL_COMPONENT NkipyMLIRPythonModules
INSTALL_DESTINATION _mlir
OUTPUT_DIRECTORY "${NKIPY_MLIR_PYTHON_PACKAGES_DIR}/_mlir"
RELATIVE_INSTALL_ROOT "../.."
DECLARED_HEADERS
MLIRPythonCAPI.HeaderSources
DECLARED_SOURCES
${_source_components}
)

add_mlir_python_modules(NkipyMLIRPythonModules
ROOT_PREFIX "${NKIPY_MLIR_PYTHON_PACKAGES_DIR}/_mlir"
INSTALL_PREFIX "_mlir"
DECLARED_SOURCES ${_source_components}
COMMON_CAPI_LINK_LIBS
NkipyMLIRAggregateCAPI
)
17 changes: 17 additions & 0 deletions nkigen/mlir/include/nkipy/Bindings/NkipyModule.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#ifndef NKIPY_BINDINGS_PYTHON_IRMODULES_H
#define NKIPY_BINDINGS_PYTHON_IRMODULES_H

// #include "NanobindUtils.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"

namespace mlir {
namespace python {

// void populateNkipyIRTypes(nanobind::module_ &m);
void populateNkipyAttributes(nanobind::module_ &m);

} // namespace python
} // namespace mlir

#endif // NKIPY_BINDINGS_PYTHON_IRMODULES_H
10 changes: 10 additions & 0 deletions nkigen/mlir/include/nkipy/Bindings/nkipy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Import from system MLIR package (not bundled)
# This avoids nanobind type conflicts with nki.compiler._internal
from mlir import ir
from mlir import dialects

# Import NKIPy-specific extensions
from ._mlir_libs._nkipy import nkipy

# Re-export for convenience
__all__ = ['ir', 'dialects', 'nkipy']
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#ifndef BINDINGS_PYTHON_NKIPY_OPS_TD
#define BINDINGS_PYTHON_NKIPY_OPS_TD

include "nkipy/Dialect/NkipyOps.td"

#endif // BINDINGS_PYTHON_NKIPY_OPS_TD
Loading