diff --git a/python/examples/mlir/compile_and_run.py b/python/examples/mlir/compile_and_run.py new file mode 100644 index 0000000..345251a --- /dev/null +++ b/python/examples/mlir/compile_and_run.py @@ -0,0 +1,197 @@ +import torch +import os + +from mlir import ir +from mlir.dialects import transform +from mlir.dialects.transform import structured +from mlir.dialects.transform import interpreter +from mlir.execution_engine import ExecutionEngine +from mlir.passmanager import PassManager + +from lighthouse import utils as lh_utils + + +def create_kernel(ctx: ir.Context) -> ir.Module: + """ + Create an MLIR module containing a function to execute. + + Args: + ctx: MLIR context. + """ + with ctx: + module = ir.Module.parse( + r""" + // Compute element-wise addition. + func.func @add(%a: memref<16x32xf32>, %b: memref<16x32xf32>, %out: memref<16x32xf32>) { + linalg.add ins(%a, %b : memref<16x32xf32>, memref<16x32xf32>) + outs(%out : memref<16x32xf32>) + return + } +""" + ) + return module + + +def create_schedule(ctx: ir.Context) -> ir.Module: + """ + Create an MLIR module containing transformation schedule. + The schedule provides partial lowering to scalar operations. + + Args: + ctx: MLIR context. + """ + with ctx, ir.Location.unknown(context=ctx): + # Create transform module. + schedule = ir.Module.create() + schedule.operation.attributes["transform.with_named_sequence"] = ( + ir.UnitAttr.get() + ) + + # For simplicity, use generic matchers without requiring specific types. + anytype = transform.any_op_t() + + # Create entry point transformation sequence. + with ir.InsertionPoint(schedule.body): + named_seq = transform.NamedSequenceOp( + sym_name="__transform_main", + input_types=[anytype], + result_types=[], + arg_attrs=[{"transform.readonly": ir.UnitAttr.get()}], + ) + + # Create the schedule. + with ir.InsertionPoint(named_seq.body): + # Find the kernel's function op. + func = structured.MatchOp.match_op_names( + named_seq.bodyTarget, ["func.func"] + ) + # Use C interface wrappers - required to make function executable after jitting. + func = transform.apply_registered_pass( + anytype, func, "llvm-request-c-wrappers" + ) + + # Find the kernel's module op. + mod = transform.get_parent_op( + anytype, func, op_name="builtin.module", deduplicate=True + ) + # Naive lowering to loops. + mod = transform.apply_registered_pass( + anytype, mod, "convert-linalg-to-loops" + ) + # Cleanup. + transform.apply_cse(mod) + with ir.InsertionPoint(transform.ApplyPatternsOp(mod).patterns): + transform.apply_patterns_canonicalization() + + # Terminate the schedule. + transform.yield_([]) + return schedule + + +def apply_schedule(kernel: ir.Module, schedule: ir.Module) -> None: + """ + Apply transformation schedule to a kernel module. + The kernel is modified in-place. + + Args: + kernel: A module with payload function. + schedule: A module with transform schedule. + """ + interpreter.apply_named_sequence( + payload_root=kernel, + transform_root=schedule.body.operations[0], + transform_module=schedule, + ) + + +def create_pass_pipeline(ctx: ir.Context) -> PassManager: + """ + Create an MLIR pass pipeline. + The pipeline lowers operations further down to LLVM dialect. + + Args: + ctx: MLIR context. + """ + with ctx: + # Create a pass manager that applies passes to the whole module. + pm = PassManager("builtin.module") + # Lower to LLVM. + pm.add("convert-scf-to-cf") + pm.add("convert-to-llvm") + pm.add("reconcile-unrealized-casts") + # Cleanup + pm.add("cse") + pm.add("canonicalize") + return pm + + +# The example's entry point. +def main(): + ### Baseline computation ### + # Create inputs. + a = torch.randn(16, 32, dtype=torch.float32) + b = torch.randn(16, 32, dtype=torch.float32) + + # Compute baseline result to verify numerical correctness. + out_ref = torch.add(a, b) + + ### MLIR payload preparation ### + # Create payload kernel. + ctx = ir.Context() + kernel = create_kernel(ctx) + + # Create a transform schedule and apply initial lowering. + schedule = create_schedule(ctx) + apply_schedule(kernel, schedule) + + # Create a pass pipeline and lower the kernel to LLVM dialect. + pm = create_pass_pipeline(ctx) + pm.run(kernel.operation) + + ### Compilation ### + # External shared libraries, containing MLIR runner utilities, are generally + # required to execute the compiled module. + # In this case, MLIR runner utils libraries are expected: + # - libmlir_runner_utils.so + # - libmlir_c_runner_utils.so + # + # Get paths to MLIR runner shared libraries through an environment variable. + # The execution engine requires full paths to the libraries. + # For example, the env variable can be set as: + # LIGHTHOUSE_SHARED_LIBS=$PATH_TO_LLVM/build/lib/lib1.so:$PATH_TO_LLVM/build/lib/lib2.so + mlir_libs = os.environ.get("LIGHTHOUSE_SHARED_LIBS", default="").split(":") + + # JIT the kernel. + eng = ExecutionEngine(kernel, opt_level=2, shared_libs=mlir_libs) + + # Initialize the JIT engine. + # + # The deferred initialization executes global constructors that might have been + # created by the module during engine creation (for example, when `gpu.module` + # is present) or registered afterwards. + # + # Initialization is not strictly necessary in this case. + # However, it is a good practice to perform it regardless. + eng.initialize() + + # Get the kernel function. + add_func = eng.lookup("add") + + ### Execution ### + # Create an empty buffer to hold results. + out = torch.empty_like(out_ref) + + # Execute the kernel. + args = lh_utils.torch_to_packed_args([a, b, out]) + add_func(args) + + ### Verification ### + # Check numerical correctness. + if not torch.allclose(out_ref, out, rtol=0.01, atol=0.01): + print("Error! Result mismatch!") + else: + print("Result matched!") + + +if __name__ == "__main__": + main() diff --git a/python/lighthouse/utils/__init__.py b/python/lighthouse/utils/__init__.py new file mode 100644 index 0000000..22799cc --- /dev/null +++ b/python/lighthouse/utils/__init__.py @@ -0,0 +1,9 @@ +"""A collection of utility tools""" + +from .runtime_args import ( + get_packed_arg, + memref_to_ctype, + memrefs_to_packed_args, + torch_to_memref, + torch_to_packed_args, +) diff --git a/python/lighthouse/utils/runtime_args.py b/python/lighthouse/utils/runtime_args.py new file mode 100644 index 0000000..6719896 --- /dev/null +++ b/python/lighthouse/utils/runtime_args.py @@ -0,0 +1,62 @@ +import ctypes +import torch + +from mlir.runtime.np_to_memref import ( + get_ranked_memref_descriptor, +) + + +def get_packed_arg(ctypes_args) -> list[ctypes.c_void_p]: + """ + Return a list of packed ctype arguments compatible with + jitted MLIR function's interface. + + Args: + ctypes_args: A list of ctype pointer arguments. + """ + packed_args = (ctypes.c_void_p * len(ctypes_args))() + for argNum in range(len(ctypes_args)): + packed_args[argNum] = ctypes.cast(ctypes_args[argNum], ctypes.c_void_p) + return packed_args + + +def memref_to_ctype(memref_desc) -> ctypes._Pointer: + """ + Convert a memref descriptor into a ctype argument. + + Args: + memref_desc: An MLIR memref descriptor. + """ + return ctypes.pointer(ctypes.pointer(memref_desc)) + + +def memrefs_to_packed_args(memref_descs) -> list[ctypes.c_void_p]: + """ + Convert a list of memref descriptors into packed ctype arguments. + + Args: + memref_descs: A list of memref descriptors. + """ + ctype_args = [memref_to_ctype(memref) for memref in memref_descs] + return get_packed_arg(ctype_args) + + +def torch_to_memref(input: torch.Tensor) -> ctypes.Structure: + """ + Convert a PyTorch tensor into a memref descriptor. + + Args: + input: PyTorch tensor. + """ + return get_ranked_memref_descriptor(input.numpy()) + + +def torch_to_packed_args(inputs: list[torch.Tensor]) -> list[ctypes.c_void_p]: + """ + Convert a list of PyTorch tensors into packed ctype arguments. + + Args: + inputs: A list of PyTorch tensors. + """ + memrefs = [torch_to_memref(input) for input in inputs] + return memrefs_to_packed_args(memrefs)