diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml index 991d947..f35fe55 100644 --- a/.github/workflows/examples.yml +++ b/.github/workflows/examples.yml @@ -34,3 +34,7 @@ jobs: - name: Run apply basic schedule to basic payload run: |- uv run python/examples/schedule/transform_a_payload_according_to_a_schedule.py + + - name: Run MemRef Management + run: |- + uv run python/examples/mlir/memref_management.py diff --git a/python/examples/mlir/memref_management.py b/python/examples/mlir/memref_management.py new file mode 100644 index 0000000..68eb874 --- /dev/null +++ b/python/examples/mlir/memref_management.py @@ -0,0 +1,115 @@ +import torch +import ctypes + +from mlir import ir +from mlir.dialects import func, memref +from mlir.runtime import np_to_memref +from mlir.execution_engine import ExecutionEngine +from mlir.passmanager import PassManager + +import lighthouse.utils as lh_utils + + +def create_mlir_module(shape: list[int]) -> ir.Module: + module = ir.Module.create() + with ir.InsertionPoint(module.body): + mem_type = ir.MemRefType.get(shape, ir.F32Type.get()) + + # Return a new buffer initialized with input's data. + @func.func(mem_type) + def copy(input): + new_buf = memref.alloc(mem_type, [], []) + memref.copy(input, new_buf) + return new_buf + + # Free given buffer. + @func.func(mem_type) + def module_dealloc(input): + memref.dealloc(input) + + return module + + +def lower_to_llvm(operation: ir.Operation) -> None: + pm = PassManager("builtin.module") + pm.add("func.func(llvm-request-c-wrappers)") + pm.add("convert-to-llvm") + pm.add("reconcile-unrealized-casts") + pm.add("cse") + pm.add("canonicalize") + pm.run(operation) + + +def main(): + # Validate basic functionality. + print("Testing memref allocator...") + mem = lh_utils.MemRefManager() + # Check allocation. + buf = mem.alloc(32, 8, 16, ctype=ctypes.c_float) + assert buf.allocated != 0, "Invalid allocation" + assert list(buf.shape) == [32, 8, 16], "Invalid shape" + assert list(buf.strides) == [128, 16, 1], "Invalid strides" + # Check deallocation. + mem.dealloc(buf) + assert buf.allocated == 0, "Failed deallocation" + # Double free must not crash. + mem.dealloc(buf) + + # Zero rank buffer. + buf = mem.alloc(ctype=ctypes.c_float) + mem.dealloc(buf) + # Small buffer. + buf = mem.alloc(8, ctype=ctypes.c_int8) + mem.dealloc(buf) + # Large buffer. + buf = mem.alloc(1024, 1024, ctype=ctypes.c_int32) + mem.dealloc(buf) + + # Validate functionality across Python-MLIR boundary. + print("Testing JIT module memory management...") + # Buffer shape for testing. + shape = [16, 32] + + # Create and compile test module. + kernel = create_mlir_module(shape) + lower_to_llvm(kernel.operation) + eng = ExecutionEngine(kernel, opt_level=3) + eng.initialize() + + # Validate passing memrefs between Python and jitted module. + print("...copy test...") + fn_copy = eng.lookup("copy") + + # Alloc buffer in Python and initialize it. + in_mem = mem.alloc(*shape, ctype=ctypes.c_float) + in_np = np_to_memref.ranked_memref_to_numpy([in_mem]) + assert not in_np.flags.owndata, "Expected non-owning memref conversion" + in_tensor = torch.from_numpy(in_np) + torch.randn(in_tensor.shape, out=in_tensor) + + out_mem = np_to_memref.make_nd_memref_descriptor(in_tensor.dim(), ctypes.c_float)() + out_mem.allocated = 0 + + args = lh_utils.memrefs_to_packed_args([out_mem, in_mem]) + fn_copy(args) + assert out_mem.allocated != 0, "Invalid buffer returned" + + out_tensor = torch.from_numpy(np_to_memref.ranked_memref_to_numpy([out_mem])) + torch.testing.assert_close(out_tensor, in_tensor) + + mem.dealloc(out_mem) + assert out_mem.allocated == 0, "Failed to dealloc returned buffer" + mem.dealloc(in_mem) + + # Validate external allocation with deallocation from within jitted module. + print("...dealloc test...") + fn_mlir_dealloc = eng.lookup("module_dealloc") + buf_mem = mem.alloc(*shape, ctype=ctypes.c_float) + fn_mlir_dealloc(lh_utils.memrefs_to_packed_args([buf_mem])) + + print("SUCCESS") + + +if __name__ == "__main__": + with ir.Context(), ir.Location.unknown(): + main() diff --git a/python/lighthouse/utils/__init__.py b/python/lighthouse/utils/__init__.py index 22799cc..9e39171 100644 --- a/python/lighthouse/utils/__init__.py +++ b/python/lighthouse/utils/__init__.py @@ -1,5 +1,7 @@ """A collection of utility tools""" +from .memref_manager import MemRefManager + from .runtime_args import ( get_packed_arg, memref_to_ctype, diff --git a/python/lighthouse/utils/memref_manager.py b/python/lighthouse/utils/memref_manager.py new file mode 100644 index 0000000..e10243d --- /dev/null +++ b/python/lighthouse/utils/memref_manager.py @@ -0,0 +1,98 @@ +import ctypes + +from itertools import accumulate +from functools import reduce +import operator + +import mlir.runtime.np_to_memref as np_mem + + +class MemRefManager: + """ + A utility class for manual management of MLIR memrefs. + + When used together with memref operation from within a jitted MLIR module, + it is assumed that Memref dialect allocations and deallocation are performed + through standard runtime `malloc` and `free` functions. + + Custom allocators are currently not supported. For more details, see: + https://mlir.llvm.org/docs/TargetLLVMIR/#generic-alloction-and-deallocation-functions + """ + + def __init__(self) -> None: + # Library name is left unspecified to allow for symbol search + # in the global symbol table of the current process. + # For more details, see: + # https://github.com/python/cpython/issues/78773 + self.dll = ctypes.CDLL(name=None) + self.fn_malloc = self.dll.malloc + self.fn_malloc.argtypes = [ctypes.c_size_t] + self.fn_malloc.restype = ctypes.c_void_p + self.fn_free = self.dll.free + self.fn_free.argtypes = [ctypes.c_void_p] + self.fn_free.restype = None + + def alloc(self, *shape: int, ctype: ctypes._SimpleCData) -> ctypes.Structure: + """ + Allocate an empty memory buffer. + Returns an MLIR ranked memref descriptor. + + Args: + shape: A sequence of integers defining the buffer's shape. + ctype: A C type of buffer's elements. + """ + assert issubclass(ctype, ctypes._SimpleCData), "Expected a simple data ctype" + size_bytes = reduce(operator.mul, shape, ctypes.sizeof(ctype)) + buf = self.fn_malloc(size_bytes) + assert buf, "Failed to allocate memory" + + rank = len(shape) + if rank == 0: + desc = np_mem.make_zero_d_memref_descriptor(ctype)() + desc.allocated = buf + desc.aligned = ctypes.cast(buf, ctypes.POINTER(ctype)) + desc.offset = ctypes.c_longlong(0) + return desc + + desc = np_mem.make_nd_memref_descriptor(rank, ctype)() + desc.allocated = buf + desc.aligned = ctypes.cast(buf, ctypes.POINTER(ctype)) + desc.offset = ctypes.c_longlong(0) + shape_ctype_t = ctypes.c_longlong * rank + desc.shape = shape_ctype_t(*shape) + + strides = list(accumulate(reversed(shape[1:]), func=operator.mul)) + strides.reverse() + strides.append(1) + desc.strides = shape_ctype_t(*strides) + return desc + + def dealloc(self, memref_desc: ctypes.Structure) -> None: + """ + Free underlying memory buffer. + + Args: + memref_desc: An MLIR memref descriptor. + """ + # TODO: Expose upstream MemrefDescriptor classes for easier handling + assert memref_desc.__class__.__name__ == "MemRefDescriptor" or isinstance( + memref_desc, np_mem.UnrankedMemRefDescriptor + ), "Invalid memref descriptor" + + if isinstance(memref_desc, np_mem.UnrankedMemRefDescriptor): + # Unranked memref holds the underlying descriptor as an opaque pointer. + # Cast the descriptor to a zero ranked memref with an arbitrary type to + # access the base allocated memory pointer. + ranked_desc_type = np_mem.make_zero_d_memref_descriptor(ctypes.c_char) + ranked_desc = ctypes.cast( + memref_desc.descriptor, ctypes.POINTER(ranked_desc_type) + ) + memref_desc = ranked_desc[0] + + alloc_ptr = memref_desc.allocated + if alloc_ptr == 0: + return + + c_ptr = ctypes.cast(alloc_ptr, ctypes.c_void_p) + self.fn_free(c_ptr) + memref_desc.allocated = 0