Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 197 additions & 0 deletions python/examples/mlir/compile_and_run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
import torch
import os

from mlir import ir
from mlir.dialects import transform
from mlir.dialects.transform import structured
from mlir.dialects.transform import interpreter
from mlir.execution_engine import ExecutionEngine
from mlir.passmanager import PassManager

from lighthouse import utils as lh_utils


def create_kernel(ctx: ir.Context) -> ir.Module:
"""
Create an MLIR module containing a function to execute.

Args:
ctx: MLIR context.
"""
with ctx:
module = ir.Module.parse(
r"""
// Compute element-wise addition.
func.func @add(%a: memref<16x32xf32>, %b: memref<16x32xf32>, %out: memref<16x32xf32>) {
linalg.add ins(%a, %b : memref<16x32xf32>, memref<16x32xf32>)
outs(%out : memref<16x32xf32>)
return
}
"""
)
return module


def create_schedule(ctx: ir.Context) -> ir.Module:
"""
Create an MLIR module containing transformation schedule.
The schedule provides partial lowering to scalar operations.

Args:
ctx: MLIR context.
"""
with ctx, ir.Location.unknown(context=ctx):
# Create transform module.
schedule = ir.Module.create()
schedule.operation.attributes["transform.with_named_sequence"] = (
ir.UnitAttr.get()
)

# For simplicity, use generic matchers without requiring specific types.
anytype = transform.any_op_t()

# Create entry point transformation sequence.
with ir.InsertionPoint(schedule.body):
named_seq = transform.NamedSequenceOp(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe transform.named_sequence should already work.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK, snake cases expose and call only the base class of the ops.
In many cases, it is enough. And I think it's good to prioritize their usage overall.
However, for more complex ops, it means we miss out on the QoL overloads and extra initialization steps.

For named sequence, creating transform.NamedSequenceOp gives direct access to easier to use APIs and also initializes the region's blocks. So, I think in such cases it's better to use camel case ops instead.
Similar thing goes for ApplyPatternsOp or TileUsingForOp.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK, snake cases expose and call only the base class of the ops.

This is only the case if the person adding a derived class forgot to also add a new wrapper alongside their derived class (the old wrapper will indeed only wrap the base class). This is a bug and should be fixed upstream. Whenever we encounter this, lets track all instances in an issue and I will fix it.

For named sequence, creating transform.NamedSequenceOp gives direct access to easier to use APIs and also initializes the region's blocks. So, I think in such cases it's better to use camel case ops instead.

I don't follow. In Python, transform.named_sequence is defined as follows:

def named_sequence(sym_name, function_type, *, sym_visibility=None, arg_attrs=None, res_attrs=None, loc=None, ip=None) -> NamedSequenceOp:
  return NamedSequenceOp(sym_name=sym_name, function_type=function_type, sym_visibility=sym_visibility, arg_attrs=arg_attrs, res_attrs=res_attrs, loc=loc, ip=ip)

Could you explain how using transform.NamedSequenceOp gives you more flexibility? (In my experience this hasn't been the case, that is, I have been using transform.named_sequence(...).body etc without issue in my code.)

(I know some of the wrappers return TransformOp(...).result rather than TransformOp(...). I believe we should treat that as a bug as well and track and fix the instances we encounter.)

Copy link
Contributor Author

@adam-smnk adam-smnk Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I tried using transform.named_sequence, it directs to the above wrapper as you show. When I follow further calls, this constructor seems to call the base version:

@_ods_cext.register_operation(_Dialect)
class NamedSequenceOp(_ods_ir.OpView):

instead of the derived:

@_ods_cext.register_operation(_Dialect, replace=True)
class NamedSequenceOp(NamedSequenceOp):

which initializer takes care of building function_type (less verbose user code), initialized the region self.regions[0].blocks.append(*input_types) and gives convenient access methods:

@property
    def body(self) -> Block:
        return self.regions[0].blocks[0]

    @property
    def bodyTarget(self) -> Value:
        return self.body.arguments[0]

    @property
    def bodyExtraArgs(self) -> BlockArgumentList:
        return self.body.arguments[1:]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still pretty green to MLIR's python world so might be I'm just doing sth really stupid 😅

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another example, in case of tiling it's just verbosity:

structured.TileUsingForOp(op, sizes=[1, 32])

vs

structured.structured_tile_using_for(
                anytype,
                [anytype, anytype],
                op,
                dynamic_sizes=[],
                static_sizes=[1, 32],
                scalable_sizes=[False, False],
            )

Copy link
Contributor

@rolfmorel rolfmorel Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right! transform.named_sequence is only wrapping the base class version! My mistake, should have checked (my code does access body but does so as a Region (as that is what the base class exposes)) -- sorry.

I will put up a PR upstream to fix this for the main transform ops, such as named_sequence, ASAP. Here's an upstream issue to track this: llvm/llvm-project#166765 Let's update it whenever we come across these issues. Please add other ops you encounter such issues with!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, will do 👍

I'm still all in for pushing for snake case.
Camel case usage in lighthouse will also serve as a small tracker of what needs to be tweaked.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sym_name="__transform_main",
input_types=[anytype],
result_types=[],
arg_attrs=[{"transform.readonly": ir.UnitAttr.get()}],
)

# Create the schedule.
with ir.InsertionPoint(named_seq.body):
# Find the kernel's function op.
func = structured.MatchOp.match_op_names(
named_seq.bodyTarget, ["func.func"]
)
# Use C interface wrappers - required to make function executable after jitting.
func = transform.apply_registered_pass(
anytype, func, "llvm-request-c-wrappers"
)

# Find the kernel's module op.
mod = transform.get_parent_op(
anytype, func, op_name="builtin.module", deduplicate=True
)
# Naive lowering to loops.
mod = transform.apply_registered_pass(
anytype, mod, "convert-linalg-to-loops"
)
# Cleanup.
transform.apply_cse(mod)
with ir.InsertionPoint(transform.ApplyPatternsOp(mod).patterns):
transform.apply_patterns_canonicalization()

# Terminate the schedule.
transform.yield_([])
return schedule


def apply_schedule(kernel: ir.Module, schedule: ir.Module) -> None:
"""
Apply transformation schedule to a kernel module.
The kernel is modified in-place.

Args:
kernel: A module with payload function.
schedule: A module with transform schedule.
"""
interpreter.apply_named_sequence(
payload_root=kernel,
transform_root=schedule.body.operations[0],
transform_module=schedule,
)


def create_pass_pipeline(ctx: ir.Context) -> PassManager:
"""
Create an MLIR pass pipeline.
The pipeline lowers operations further down to LLVM dialect.

Args:
ctx: MLIR context.
"""
with ctx:
# Create a pass manager that applies passes to the whole module.
pm = PassManager("builtin.module")
# Lower to LLVM.
pm.add("convert-scf-to-cf")
pm.add("convert-to-llvm")
pm.add("reconcile-unrealized-casts")
# Cleanup
pm.add("cse")
pm.add("canonicalize")
return pm


# The example's entry point.
def main():
### Baseline computation ###
# Create inputs.
a = torch.randn(16, 32, dtype=torch.float32)
b = torch.randn(16, 32, dtype=torch.float32)

# Compute baseline result to verify numerical correctness.
out_ref = torch.add(a, b)

### MLIR payload preparation ###
# Create payload kernel.
ctx = ir.Context()
kernel = create_kernel(ctx)

# Create a transform schedule and apply initial lowering.
schedule = create_schedule(ctx)
apply_schedule(kernel, schedule)

# Create a pass pipeline and lower the kernel to LLVM dialect.
pm = create_pass_pipeline(ctx)
pm.run(kernel.operation)

### Compilation ###
# External shared libraries, containing MLIR runner utilities, are generally
# required to execute the compiled module.
# In this case, MLIR runner utils libraries are expected:
# - libmlir_runner_utils.so
# - libmlir_c_runner_utils.so
#
# Get paths to MLIR runner shared libraries through an environment variable.
# The execution engine requires full paths to the libraries.
# For example, the env variable can be set as:
# LIGHTHOUSE_SHARED_LIBS=$PATH_TO_LLVM/build/lib/lib1.so:$PATH_TO_LLVM/build/lib/lib2.so
mlir_libs = os.environ.get("LIGHTHOUSE_SHARED_LIBS", default="").split(":")

# JIT the kernel.
eng = ExecutionEngine(kernel, opt_level=2, shared_libs=mlir_libs)

# Initialize the JIT engine.
#
# The deferred initialization executes global constructors that might have been
# created by the module during engine creation (for example, when `gpu.module`
# is present) or registered afterwards.
#
# Initialization is not strictly necessary in this case.
# However, it is a good practice to perform it regardless.
eng.initialize()

# Get the kernel function.
add_func = eng.lookup("add")

### Execution ###
# Create an empty buffer to hold results.
out = torch.empty_like(out_ref)

# Execute the kernel.
args = lh_utils.torch_to_packed_args([a, b, out])
add_func(args)

### Verification ###
# Check numerical correctness.
if not torch.allclose(out_ref, out, rtol=0.01, atol=0.01):
print("Error! Result mismatch!")
else:
print("Result matched!")


if __name__ == "__main__":
main()
9 changes: 9 additions & 0 deletions python/lighthouse/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""A collection of utility tools"""

from .runtime_args import (
get_packed_arg,
memref_to_ctype,
memrefs_to_packed_args,
torch_to_memref,
torch_to_packed_args,
)
62 changes: 62 additions & 0 deletions python/lighthouse/utils/runtime_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import ctypes
import torch

from mlir.runtime.np_to_memref import (
get_ranked_memref_descriptor,
)


def get_packed_arg(ctypes_args) -> list[ctypes.c_void_p]:
"""
Return a list of packed ctype arguments compatible with
jitted MLIR function's interface.

Args:
ctypes_args: A list of ctype pointer arguments.
"""
packed_args = (ctypes.c_void_p * len(ctypes_args))()
for argNum in range(len(ctypes_args)):
packed_args[argNum] = ctypes.cast(ctypes_args[argNum], ctypes.c_void_p)
return packed_args


def memref_to_ctype(memref_desc) -> ctypes._Pointer:
"""
Convert a memref descriptor into a ctype argument.

Args:
memref_desc: An MLIR memref descriptor.
"""
return ctypes.pointer(ctypes.pointer(memref_desc))


def memrefs_to_packed_args(memref_descs) -> list[ctypes.c_void_p]:
"""
Convert a list of memref descriptors into packed ctype arguments.

Args:
memref_descs: A list of memref descriptors.
"""
ctype_args = [memref_to_ctype(memref) for memref in memref_descs]
return get_packed_arg(ctype_args)


def torch_to_memref(input: torch.Tensor) -> ctypes.Structure:
"""
Convert a PyTorch tensor into a memref descriptor.

Args:
input: PyTorch tensor.
"""
return get_ranked_memref_descriptor(input.numpy())


def torch_to_packed_args(inputs: list[torch.Tensor]) -> list[ctypes.c_void_p]:
"""
Convert a list of PyTorch tensors into packed ctype arguments.

Args:
inputs: A list of PyTorch tensors.
"""
memrefs = [torch_to_memref(input) for input in inputs]
return memrefs_to_packed_args(memrefs)