From 9e15f2aef5cb06a08a6b2f24887f078a7da06772 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Fri, 14 Nov 2025 12:38:03 +0100 Subject: [PATCH 1/4] [python][utils] MemRef Manager Adds a utility for manual memory management of memref buffers across Python and jitted MLIR modules. Explicit memory management becomes required when an MLIR function returns a newly allocated buffer e.g., results of a computation. This can become a complex task due to difference in memory models between Python and the MLIR runtime allocators. By default, returned MLIR buffers' lifetime cannot be automatically managed by the Python environment. The Python memref manager aims to address the following challenges: - use of the same runtime allocators as a jitted MLIR module for consistent memory management - lean abstraction using memref descriptors directly - buffers usable both by Python and jitted MLIR modules Current implementation assumes that memref allocation ops are lowered to standard C functions, like 'malloc' and 'free', which are preloaded together with the Python process. --- examples/mlir/memref_management.py | 119 +++++++++++++++++++++++++++++ lighthouse/utils/__init__.py | 3 + lighthouse/utils/memref_manager.py | 98 ++++++++++++++++++++++++ 3 files changed, 220 insertions(+) create mode 100644 examples/mlir/memref_management.py create mode 100644 lighthouse/utils/memref_manager.py diff --git a/examples/mlir/memref_management.py b/examples/mlir/memref_management.py new file mode 100644 index 0000000..9cf4b4e --- /dev/null +++ b/examples/mlir/memref_management.py @@ -0,0 +1,119 @@ +# RUN: %PYTHON %s + +import torch +import ctypes + +from mlir import ir +from mlir.dialects import func, memref +from mlir.runtime import np_to_memref +from mlir.execution_engine import ExecutionEngine +from mlir.passmanager import PassManager + +import lighthouse.utils as lh_utils + + +def create_mlir_module(ctx: ir.Context, shape: list[int]) -> ir.Module: + with ctx, ir.Location.unknown(): + module = ir.Module.create() + with ir.InsertionPoint(module.body): + mem_type = ir.MemRefType.get(shape, ir.F32Type.get()) + + # Return a new buffer initialized with input's data. + @func.func(mem_type) + def copy(input): + new_buf = memref.alloc(mem_type, [], []) + memref.copy(input, new_buf) + return new_buf + + # Free given buffer. + @func.func(mem_type) + def module_dealloc(input): + memref.dealloc(input) + + return module + + +def lower_to_llvm(operation: ir.Operation) -> None: + with operation.context: + pm = PassManager("builtin.module") + pm.add("func.func(llvm-request-c-wrappers)") + pm.add("convert-to-llvm") + pm.add("reconcile-unrealized-casts") + pm.add("cse") + pm.add("canonicalize") + pm.run(operation) + + +def main(): + # Validate basic functionality. + print("Testing memref allocator...") + mem = lh_utils.MemRefManager() + # Check allocation. + buf = mem.alloc(32, 8, 16, ctype=ctypes.c_float) + assert buf.allocated != 0, "Invalid allocation" + assert list(buf.shape) == [32, 8, 16], "Invalid shape" + assert list(buf.strides) == [128, 16, 1], "Invalid strides" + # Check deallocation. + mem.dealloc(buf) + assert buf.allocated == 0, "Failed deallocation" + # Double free must not crash. + mem.dealloc(buf) + + # Zero rank buffer. + buf = mem.alloc(ctype=ctypes.c_float) + mem.dealloc(buf) + # Small buffer. + buf = mem.alloc(8, ctype=ctypes.c_int8) + mem.dealloc(buf) + # Large buffer. + buf = mem.alloc(1024, 1024, ctype=ctypes.c_int32) + mem.dealloc(buf) + + # Validate functionality across Python-MLIR boundary. + print("Testing JIT module memory management...") + # Buffer shape for testing. + shape = [16, 32] + + # Create and compile test module. + ctx = ir.Context() + kernel = create_mlir_module(ctx, shape) + lower_to_llvm(kernel.operation) + eng = ExecutionEngine(kernel, opt_level=3) + eng.initialize() + + # Validate passing memrefs between Python and jitted module. + print("...copy test...") + fn_copy = eng.lookup("copy") + + # Alloc buffer in Python and initialize it. + in_mem = mem.alloc(*shape, ctype=ctypes.c_float) + in_np = np_to_memref.ranked_memref_to_numpy([in_mem]) + assert not in_np.flags.owndata, "Expected non-owning memref conversion" + in_tensor = torch.from_numpy(in_np) + torch.randn(in_tensor.shape, out=in_tensor) + + out_mem = np_to_memref.make_nd_memref_descriptor(in_tensor.dim(), ctypes.c_float)() + out_mem.allocated = 0 + + args = lh_utils.memrefs_to_packed_args([out_mem, in_mem]) + fn_copy(args) + assert out_mem.allocated != 0, "Invalid buffer returned" + + out_tensor = torch.from_numpy(np_to_memref.ranked_memref_to_numpy([out_mem])) + torch.testing.assert_close(out_tensor, in_tensor) + + mem.dealloc(out_mem) + assert out_mem.allocated == 0, "Failed to dealloc returned buffer" + mem.dealloc(in_mem) + + # Validate external allocation with deallocation from within jitted module. + print("...dealloc test...") + fn_mlir_dealloc = eng.lookup("module_dealloc") + buf_mem = mem.alloc(*shape, ctype=ctypes.c_float) + fn_mlir_dealloc(lh_utils.memrefs_to_packed_args([buf_mem])) + + print("SUCCESS") + + +if __name__ == "__main__": + main() diff --git a/lighthouse/utils/__init__.py b/lighthouse/utils/__init__.py index 474b748..a544744 100644 --- a/lighthouse/utils/__init__.py +++ b/lighthouse/utils/__init__.py @@ -1,5 +1,7 @@ """A collection of utility tools""" +from .memref_manager import MemRefManager + from .runtime_args import ( get_packed_arg, memref_to_ctype, @@ -10,6 +12,7 @@ ) __all__ = [ + "MemRefManager", "get_packed_arg", "memref_to_ctype", "memrefs_to_packed_args", diff --git a/lighthouse/utils/memref_manager.py b/lighthouse/utils/memref_manager.py new file mode 100644 index 0000000..e10243d --- /dev/null +++ b/lighthouse/utils/memref_manager.py @@ -0,0 +1,98 @@ +import ctypes + +from itertools import accumulate +from functools import reduce +import operator + +import mlir.runtime.np_to_memref as np_mem + + +class MemRefManager: + """ + A utility class for manual management of MLIR memrefs. + + When used together with memref operation from within a jitted MLIR module, + it is assumed that Memref dialect allocations and deallocation are performed + through standard runtime `malloc` and `free` functions. + + Custom allocators are currently not supported. For more details, see: + https://mlir.llvm.org/docs/TargetLLVMIR/#generic-alloction-and-deallocation-functions + """ + + def __init__(self) -> None: + # Library name is left unspecified to allow for symbol search + # in the global symbol table of the current process. + # For more details, see: + # https://github.com/python/cpython/issues/78773 + self.dll = ctypes.CDLL(name=None) + self.fn_malloc = self.dll.malloc + self.fn_malloc.argtypes = [ctypes.c_size_t] + self.fn_malloc.restype = ctypes.c_void_p + self.fn_free = self.dll.free + self.fn_free.argtypes = [ctypes.c_void_p] + self.fn_free.restype = None + + def alloc(self, *shape: int, ctype: ctypes._SimpleCData) -> ctypes.Structure: + """ + Allocate an empty memory buffer. + Returns an MLIR ranked memref descriptor. + + Args: + shape: A sequence of integers defining the buffer's shape. + ctype: A C type of buffer's elements. + """ + assert issubclass(ctype, ctypes._SimpleCData), "Expected a simple data ctype" + size_bytes = reduce(operator.mul, shape, ctypes.sizeof(ctype)) + buf = self.fn_malloc(size_bytes) + assert buf, "Failed to allocate memory" + + rank = len(shape) + if rank == 0: + desc = np_mem.make_zero_d_memref_descriptor(ctype)() + desc.allocated = buf + desc.aligned = ctypes.cast(buf, ctypes.POINTER(ctype)) + desc.offset = ctypes.c_longlong(0) + return desc + + desc = np_mem.make_nd_memref_descriptor(rank, ctype)() + desc.allocated = buf + desc.aligned = ctypes.cast(buf, ctypes.POINTER(ctype)) + desc.offset = ctypes.c_longlong(0) + shape_ctype_t = ctypes.c_longlong * rank + desc.shape = shape_ctype_t(*shape) + + strides = list(accumulate(reversed(shape[1:]), func=operator.mul)) + strides.reverse() + strides.append(1) + desc.strides = shape_ctype_t(*strides) + return desc + + def dealloc(self, memref_desc: ctypes.Structure) -> None: + """ + Free underlying memory buffer. + + Args: + memref_desc: An MLIR memref descriptor. + """ + # TODO: Expose upstream MemrefDescriptor classes for easier handling + assert memref_desc.__class__.__name__ == "MemRefDescriptor" or isinstance( + memref_desc, np_mem.UnrankedMemRefDescriptor + ), "Invalid memref descriptor" + + if isinstance(memref_desc, np_mem.UnrankedMemRefDescriptor): + # Unranked memref holds the underlying descriptor as an opaque pointer. + # Cast the descriptor to a zero ranked memref with an arbitrary type to + # access the base allocated memory pointer. + ranked_desc_type = np_mem.make_zero_d_memref_descriptor(ctypes.c_char) + ranked_desc = ctypes.cast( + memref_desc.descriptor, ctypes.POINTER(ranked_desc_type) + ) + memref_desc = ranked_desc[0] + + alloc_ptr = memref_desc.allocated + if alloc_ptr == 0: + return + + c_ptr = ctypes.cast(alloc_ptr, ctypes.c_void_p) + self.fn_free(c_ptr) + memref_desc.allocated = 0 From 7808e51c4eb0bfe20f71579d0a8ea9f1daf00e18 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Thu, 20 Nov 2025 10:40:27 +0100 Subject: [PATCH 2/4] Simplify ctx usage --- examples/mlir/memref_management.py | 52 ++++++++++++++---------------- 1 file changed, 25 insertions(+), 27 deletions(-) diff --git a/examples/mlir/memref_management.py b/examples/mlir/memref_management.py index 9cf4b4e..2e6e859 100644 --- a/examples/mlir/memref_management.py +++ b/examples/mlir/memref_management.py @@ -12,35 +12,33 @@ import lighthouse.utils as lh_utils -def create_mlir_module(ctx: ir.Context, shape: list[int]) -> ir.Module: - with ctx, ir.Location.unknown(): - module = ir.Module.create() - with ir.InsertionPoint(module.body): - mem_type = ir.MemRefType.get(shape, ir.F32Type.get()) - - # Return a new buffer initialized with input's data. - @func.func(mem_type) - def copy(input): - new_buf = memref.alloc(mem_type, [], []) - memref.copy(input, new_buf) - return new_buf - - # Free given buffer. - @func.func(mem_type) - def module_dealloc(input): - memref.dealloc(input) +def create_mlir_module(shape: list[int]) -> ir.Module: + module = ir.Module.create() + with ir.InsertionPoint(module.body): + mem_type = ir.MemRefType.get(shape, ir.F32Type.get()) + + # Return a new buffer initialized with input's data. + @func.func(mem_type) + def copy(input): + new_buf = memref.alloc(mem_type, [], []) + memref.copy(input, new_buf) + return new_buf + + # Free given buffer. + @func.func(mem_type) + def module_dealloc(input): + memref.dealloc(input) return module def lower_to_llvm(operation: ir.Operation) -> None: - with operation.context: - pm = PassManager("builtin.module") - pm.add("func.func(llvm-request-c-wrappers)") - pm.add("convert-to-llvm") - pm.add("reconcile-unrealized-casts") - pm.add("cse") - pm.add("canonicalize") + pm = PassManager("builtin.module") + pm.add("func.func(llvm-request-c-wrappers)") + pm.add("convert-to-llvm") + pm.add("reconcile-unrealized-casts") + pm.add("cse") + pm.add("canonicalize") pm.run(operation) @@ -75,8 +73,7 @@ def main(): shape = [16, 32] # Create and compile test module. - ctx = ir.Context() - kernel = create_mlir_module(ctx, shape) + kernel = create_mlir_module(shape) lower_to_llvm(kernel.operation) eng = ExecutionEngine(kernel, opt_level=3) eng.initialize() @@ -116,4 +113,5 @@ def main(): if __name__ == "__main__": - main() + with ir.Context(), ir.Location.unknown(): + main() From f8066f744c615c1b3c3000fe14f33ca98fdcc7a0 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Wed, 10 Dec 2025 12:37:33 +0100 Subject: [PATCH 3/4] Move to runtime submodule --- examples/mlir/memref_management.py | 5 +++-- lighthouse/runtime/__init__.py | 5 +++++ lighthouse/{utils => runtime}/memref_manager.py | 0 lighthouse/utils/__init__.py | 3 --- 4 files changed, 8 insertions(+), 5 deletions(-) create mode 100644 lighthouse/runtime/__init__.py rename lighthouse/{utils => runtime}/memref_manager.py (100%) diff --git a/examples/mlir/memref_management.py b/examples/mlir/memref_management.py index 2e6e859..1733e27 100644 --- a/examples/mlir/memref_management.py +++ b/examples/mlir/memref_management.py @@ -9,7 +9,8 @@ from mlir.execution_engine import ExecutionEngine from mlir.passmanager import PassManager -import lighthouse.utils as lh_utils +from lighthouse import runtime as lh_runtime +from lighthouse import utils as lh_utils def create_mlir_module(shape: list[int]) -> ir.Module: @@ -45,7 +46,7 @@ def lower_to_llvm(operation: ir.Operation) -> None: def main(): # Validate basic functionality. print("Testing memref allocator...") - mem = lh_utils.MemRefManager() + mem = lh_runtime.MemRefManager() # Check allocation. buf = mem.alloc(32, 8, 16, ctype=ctypes.c_float) assert buf.allocated != 0, "Invalid allocation" diff --git a/lighthouse/runtime/__init__.py b/lighthouse/runtime/__init__.py new file mode 100644 index 0000000..d590ee3 --- /dev/null +++ b/lighthouse/runtime/__init__.py @@ -0,0 +1,5 @@ +from .memref_manager import MemRefManager + +__all__ = [ + "MemRefManager", +] diff --git a/lighthouse/utils/memref_manager.py b/lighthouse/runtime/memref_manager.py similarity index 100% rename from lighthouse/utils/memref_manager.py rename to lighthouse/runtime/memref_manager.py diff --git a/lighthouse/utils/__init__.py b/lighthouse/utils/__init__.py index a544744..474b748 100644 --- a/lighthouse/utils/__init__.py +++ b/lighthouse/utils/__init__.py @@ -1,7 +1,5 @@ """A collection of utility tools""" -from .memref_manager import MemRefManager - from .runtime_args import ( get_packed_arg, memref_to_ctype, @@ -12,7 +10,6 @@ ) __all__ = [ - "MemRefManager", "get_packed_arg", "memref_to_ctype", "memrefs_to_packed_args", From 58aac91e61dec6cbbd7e9882887e33b52c2d1026 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Wed, 10 Dec 2025 12:46:32 +0100 Subject: [PATCH 4/4] Use workload util for MLIR engine --- examples/mlir/memref_management.py | 5 ++--- lighthouse/workload/__init__.py | 4 ++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/examples/mlir/memref_management.py b/examples/mlir/memref_management.py index 1733e27..68ce99b 100644 --- a/examples/mlir/memref_management.py +++ b/examples/mlir/memref_management.py @@ -6,11 +6,11 @@ from mlir import ir from mlir.dialects import func, memref from mlir.runtime import np_to_memref -from mlir.execution_engine import ExecutionEngine from mlir.passmanager import PassManager from lighthouse import runtime as lh_runtime from lighthouse import utils as lh_utils +from lighthouse.workload import get_engine def create_mlir_module(shape: list[int]) -> ir.Module: @@ -76,8 +76,7 @@ def main(): # Create and compile test module. kernel = create_mlir_module(shape) lower_to_llvm(kernel.operation) - eng = ExecutionEngine(kernel, opt_level=3) - eng.initialize() + eng = get_engine(kernel) # Validate passing memrefs between Python and jitted module. print("...copy test...") diff --git a/lighthouse/workload/__init__.py b/lighthouse/workload/__init__.py index 4738604..0103153 100644 --- a/lighthouse/workload/__init__.py +++ b/lighthouse/workload/__init__.py @@ -1,4 +1,4 @@ from .workload import Workload -from .runner import execute, benchmark +from .runner import get_engine, execute, benchmark -__all__ = ["Workload", "benchmark", "execute"] +__all__ = ["Workload", "benchmark", "execute", "get_engine"]