Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/examples.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,7 @@ jobs:
- name: Run apply basic schedule to basic payload
run: |-
uv run python/examples/schedule/transform_a_payload_according_to_a_schedule.py

- name: Run MemRef Management
run: |-
uv run python/examples/mlir/memref_management.py
115 changes: 115 additions & 0 deletions python/examples/mlir/memref_management.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import torch
import ctypes

from mlir import ir
from mlir.dialects import func, memref
from mlir.runtime import np_to_memref
from mlir.execution_engine import ExecutionEngine
from mlir.passmanager import PassManager

import lighthouse.utils as lh_utils


def create_mlir_module(shape: list[int]) -> ir.Module:
module = ir.Module.create()
with ir.InsertionPoint(module.body):
mem_type = ir.MemRefType.get(shape, ir.F32Type.get())

# Return a new buffer initialized with input's data.
@func.func(mem_type)
def copy(input):
new_buf = memref.alloc(mem_type, [], [])
memref.copy(input, new_buf)
return new_buf

# Free given buffer.
@func.func(mem_type)
def module_dealloc(input):
memref.dealloc(input)

return module


def lower_to_llvm(operation: ir.Operation) -> None:
pm = PassManager("builtin.module")
pm.add("func.func(llvm-request-c-wrappers)")
pm.add("convert-to-llvm")
pm.add("reconcile-unrealized-casts")
pm.add("cse")
pm.add("canonicalize")
pm.run(operation)


def main():
# Validate basic functionality.
print("Testing memref allocator...")
mem = lh_utils.MemRefManager()
# Check allocation.
buf = mem.alloc(32, 8, 16, ctype=ctypes.c_float)
assert buf.allocated != 0, "Invalid allocation"
assert list(buf.shape) == [32, 8, 16], "Invalid shape"
assert list(buf.strides) == [128, 16, 1], "Invalid strides"
# Check deallocation.
mem.dealloc(buf)
assert buf.allocated == 0, "Failed deallocation"
# Double free must not crash.
mem.dealloc(buf)

# Zero rank buffer.
buf = mem.alloc(ctype=ctypes.c_float)
mem.dealloc(buf)
# Small buffer.
buf = mem.alloc(8, ctype=ctypes.c_int8)
mem.dealloc(buf)
# Large buffer.
buf = mem.alloc(1024, 1024, ctype=ctypes.c_int32)
mem.dealloc(buf)

# Validate functionality across Python-MLIR boundary.
print("Testing JIT module memory management...")
# Buffer shape for testing.
shape = [16, 32]

# Create and compile test module.
kernel = create_mlir_module(shape)
lower_to_llvm(kernel.operation)
eng = ExecutionEngine(kernel, opt_level=3)
eng.initialize()

# Validate passing memrefs between Python and jitted module.
print("...copy test...")
fn_copy = eng.lookup("copy")

# Alloc buffer in Python and initialize it.
in_mem = mem.alloc(*shape, ctype=ctypes.c_float)
in_np = np_to_memref.ranked_memref_to_numpy([in_mem])
assert not in_np.flags.owndata, "Expected non-owning memref conversion"
in_tensor = torch.from_numpy(in_np)
torch.randn(in_tensor.shape, out=in_tensor)

out_mem = np_to_memref.make_nd_memref_descriptor(in_tensor.dim(), ctypes.c_float)()
out_mem.allocated = 0

args = lh_utils.memrefs_to_packed_args([out_mem, in_mem])
fn_copy(args)
assert out_mem.allocated != 0, "Invalid buffer returned"

out_tensor = torch.from_numpy(np_to_memref.ranked_memref_to_numpy([out_mem]))
torch.testing.assert_close(out_tensor, in_tensor)

mem.dealloc(out_mem)
assert out_mem.allocated == 0, "Failed to dealloc returned buffer"
mem.dealloc(in_mem)

# Validate external allocation with deallocation from within jitted module.
print("...dealloc test...")
fn_mlir_dealloc = eng.lookup("module_dealloc")
buf_mem = mem.alloc(*shape, ctype=ctypes.c_float)
fn_mlir_dealloc(lh_utils.memrefs_to_packed_args([buf_mem]))

print("SUCCESS")


if __name__ == "__main__":
with ir.Context(), ir.Location.unknown():
main()
2 changes: 2 additions & 0 deletions python/lighthouse/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""A collection of utility tools"""

from .memref_manager import MemRefManager

from .runtime_args import (
get_packed_arg,
memref_to_ctype,
Expand Down
98 changes: 98 additions & 0 deletions python/lighthouse/utils/memref_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import ctypes

from itertools import accumulate
from functools import reduce
import operator

import mlir.runtime.np_to_memref as np_mem


class MemRefManager:
"""
A utility class for manual management of MLIR memrefs.

When used together with memref operation from within a jitted MLIR module,
it is assumed that Memref dialect allocations and deallocation are performed
through standard runtime `malloc` and `free` functions.

Custom allocators are currently not supported. For more details, see:
https://mlir.llvm.org/docs/TargetLLVMIR/#generic-alloction-and-deallocation-functions
"""

def __init__(self) -> None:
# Library name is left unspecified to allow for symbol search
# in the global symbol table of the current process.
# For more details, see:
# https://github.com/python/cpython/issues/78773
self.dll = ctypes.CDLL(name=None)
self.fn_malloc = self.dll.malloc
self.fn_malloc.argtypes = [ctypes.c_size_t]
self.fn_malloc.restype = ctypes.c_void_p
self.fn_free = self.dll.free
self.fn_free.argtypes = [ctypes.c_void_p]
self.fn_free.restype = None

def alloc(self, *shape: int, ctype: ctypes._SimpleCData) -> ctypes.Structure:
"""
Allocate an empty memory buffer.
Returns an MLIR ranked memref descriptor.

Args:
shape: A sequence of integers defining the buffer's shape.
ctype: A C type of buffer's elements.
"""
assert issubclass(ctype, ctypes._SimpleCData), "Expected a simple data ctype"
size_bytes = reduce(operator.mul, shape, ctypes.sizeof(ctype))
buf = self.fn_malloc(size_bytes)
assert buf, "Failed to allocate memory"

rank = len(shape)
if rank == 0:
desc = np_mem.make_zero_d_memref_descriptor(ctype)()
desc.allocated = buf
desc.aligned = ctypes.cast(buf, ctypes.POINTER(ctype))
desc.offset = ctypes.c_longlong(0)
return desc

desc = np_mem.make_nd_memref_descriptor(rank, ctype)()
desc.allocated = buf
desc.aligned = ctypes.cast(buf, ctypes.POINTER(ctype))
desc.offset = ctypes.c_longlong(0)
shape_ctype_t = ctypes.c_longlong * rank
desc.shape = shape_ctype_t(*shape)

strides = list(accumulate(reversed(shape[1:]), func=operator.mul))
strides.reverse()
strides.append(1)
desc.strides = shape_ctype_t(*strides)
return desc

def dealloc(self, memref_desc: ctypes.Structure) -> None:
"""
Free underlying memory buffer.

Args:
memref_desc: An MLIR memref descriptor.
"""
# TODO: Expose upstream MemrefDescriptor classes for easier handling
assert memref_desc.__class__.__name__ == "MemRefDescriptor" or isinstance(
memref_desc, np_mem.UnrankedMemRefDescriptor
), "Invalid memref descriptor"

if isinstance(memref_desc, np_mem.UnrankedMemRefDescriptor):
# Unranked memref holds the underlying descriptor as an opaque pointer.
# Cast the descriptor to a zero ranked memref with an arbitrary type to
# access the base allocated memory pointer.
ranked_desc_type = np_mem.make_zero_d_memref_descriptor(ctypes.c_char)
ranked_desc = ctypes.cast(
memref_desc.descriptor, ctypes.POINTER(ranked_desc_type)
)
memref_desc = ranked_desc[0]

alloc_ptr = memref_desc.allocated
if alloc_ptr == 0:
return

c_ptr = ctypes.cast(alloc_ptr, ctypes.c_void_p)
self.fn_free(c_ptr)
memref_desc.allocated = 0