-
Notifications
You must be signed in to change notification settings - Fork 5
[examples][mlir] Basic MLIR compilation and execution example #10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
9647a7f
[examples][mlir] Basic MLIR compilation and execution example
adam-smnk 8c72b51
Fix return type
adam-smnk 26749bf
Split lowering and add pass pipeline example
adam-smnk 10d066b
Further abstract with torch utils
adam-smnk 6b477c4
Add eng init
adam-smnk de87169
Address feedback
adam-smnk 9048e49
Expand libs docs
adam-smnk File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,197 @@ | ||
| import torch | ||
| import os | ||
|
|
||
| from mlir import ir | ||
| from mlir.dialects import transform | ||
| from mlir.dialects.transform import structured | ||
| from mlir.dialects.transform import interpreter | ||
| from mlir.execution_engine import ExecutionEngine | ||
| from mlir.passmanager import PassManager | ||
|
|
||
| from lighthouse import utils as lh_utils | ||
|
|
||
|
|
||
| def create_kernel(ctx: ir.Context) -> ir.Module: | ||
| """ | ||
| Create an MLIR module containing a function to execute. | ||
|
|
||
| Args: | ||
| ctx: MLIR context. | ||
| """ | ||
| with ctx: | ||
| module = ir.Module.parse( | ||
| r""" | ||
| // Compute element-wise addition. | ||
| func.func @add(%a: memref<16x32xf32>, %b: memref<16x32xf32>, %out: memref<16x32xf32>) { | ||
| linalg.add ins(%a, %b : memref<16x32xf32>, memref<16x32xf32>) | ||
| outs(%out : memref<16x32xf32>) | ||
| return | ||
| } | ||
| """ | ||
| ) | ||
| return module | ||
|
|
||
|
|
||
| def create_schedule(ctx: ir.Context) -> ir.Module: | ||
| """ | ||
| Create an MLIR module containing transformation schedule. | ||
| The schedule provides partial lowering to scalar operations. | ||
|
|
||
| Args: | ||
| ctx: MLIR context. | ||
| """ | ||
| with ctx, ir.Location.unknown(context=ctx): | ||
| # Create transform module. | ||
| schedule = ir.Module.create() | ||
| schedule.operation.attributes["transform.with_named_sequence"] = ( | ||
| ir.UnitAttr.get() | ||
| ) | ||
|
|
||
| # For simplicity, use generic matchers without requiring specific types. | ||
| anytype = transform.any_op_t() | ||
|
|
||
| # Create entry point transformation sequence. | ||
| with ir.InsertionPoint(schedule.body): | ||
| named_seq = transform.NamedSequenceOp( | ||
| sym_name="__transform_main", | ||
| input_types=[anytype], | ||
| result_types=[], | ||
| arg_attrs=[{"transform.readonly": ir.UnitAttr.get()}], | ||
| ) | ||
|
|
||
| # Create the schedule. | ||
| with ir.InsertionPoint(named_seq.body): | ||
| # Find the kernel's function op. | ||
| func = structured.MatchOp.match_op_names( | ||
| named_seq.bodyTarget, ["func.func"] | ||
| ) | ||
| # Use C interface wrappers - required to make function executable after jitting. | ||
| func = transform.apply_registered_pass( | ||
| anytype, func, "llvm-request-c-wrappers" | ||
| ) | ||
|
|
||
| # Find the kernel's module op. | ||
| mod = transform.get_parent_op( | ||
| anytype, func, op_name="builtin.module", deduplicate=True | ||
| ) | ||
| # Naive lowering to loops. | ||
| mod = transform.apply_registered_pass( | ||
| anytype, mod, "convert-linalg-to-loops" | ||
| ) | ||
| # Cleanup. | ||
| transform.apply_cse(mod) | ||
| with ir.InsertionPoint(transform.ApplyPatternsOp(mod).patterns): | ||
| transform.apply_patterns_canonicalization() | ||
|
|
||
| # Terminate the schedule. | ||
| transform.yield_([]) | ||
| return schedule | ||
|
|
||
|
|
||
| def apply_schedule(kernel: ir.Module, schedule: ir.Module) -> None: | ||
| """ | ||
| Apply transformation schedule to a kernel module. | ||
| The kernel is modified in-place. | ||
|
|
||
| Args: | ||
| kernel: A module with payload function. | ||
| schedule: A module with transform schedule. | ||
| """ | ||
| interpreter.apply_named_sequence( | ||
| payload_root=kernel, | ||
| transform_root=schedule.body.operations[0], | ||
| transform_module=schedule, | ||
| ) | ||
|
|
||
|
|
||
| def create_pass_pipeline(ctx: ir.Context) -> PassManager: | ||
| """ | ||
| Create an MLIR pass pipeline. | ||
| The pipeline lowers operations further down to LLVM dialect. | ||
|
|
||
| Args: | ||
| ctx: MLIR context. | ||
| """ | ||
| with ctx: | ||
| # Create a pass manager that applies passes to the whole module. | ||
| pm = PassManager("builtin.module") | ||
| # Lower to LLVM. | ||
| pm.add("convert-scf-to-cf") | ||
| pm.add("convert-to-llvm") | ||
| pm.add("reconcile-unrealized-casts") | ||
| # Cleanup | ||
| pm.add("cse") | ||
| pm.add("canonicalize") | ||
| return pm | ||
|
|
||
|
|
||
| # The example's entry point. | ||
| def main(): | ||
| ### Baseline computation ### | ||
| # Create inputs. | ||
| a = torch.randn(16, 32, dtype=torch.float32) | ||
| b = torch.randn(16, 32, dtype=torch.float32) | ||
|
|
||
| # Compute baseline result to verify numerical correctness. | ||
| out_ref = torch.add(a, b) | ||
|
|
||
| ### MLIR payload preparation ### | ||
| # Create payload kernel. | ||
| ctx = ir.Context() | ||
| kernel = create_kernel(ctx) | ||
|
|
||
| # Create a transform schedule and apply initial lowering. | ||
| schedule = create_schedule(ctx) | ||
| apply_schedule(kernel, schedule) | ||
|
|
||
| # Create a pass pipeline and lower the kernel to LLVM dialect. | ||
| pm = create_pass_pipeline(ctx) | ||
| pm.run(kernel.operation) | ||
|
|
||
| ### Compilation ### | ||
| # External shared libraries, containing MLIR runner utilities, are generally | ||
| # required to execute the compiled module. | ||
| # In this case, MLIR runner utils libraries are expected: | ||
| # - libmlir_runner_utils.so | ||
| # - libmlir_c_runner_utils.so | ||
| # | ||
| # Get paths to MLIR runner shared libraries through an environment variable. | ||
| # The execution engine requires full paths to the libraries. | ||
| # For example, the env variable can be set as: | ||
| # LIGHTHOUSE_SHARED_LIBS=$PATH_TO_LLVM/build/lib/lib1.so:$PATH_TO_LLVM/build/lib/lib2.so | ||
| mlir_libs = os.environ.get("LIGHTHOUSE_SHARED_LIBS", default="").split(":") | ||
|
|
||
| # JIT the kernel. | ||
| eng = ExecutionEngine(kernel, opt_level=2, shared_libs=mlir_libs) | ||
adam-smnk marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # Initialize the JIT engine. | ||
| # | ||
| # The deferred initialization executes global constructors that might have been | ||
| # created by the module during engine creation (for example, when `gpu.module` | ||
| # is present) or registered afterwards. | ||
| # | ||
| # Initialization is not strictly necessary in this case. | ||
| # However, it is a good practice to perform it regardless. | ||
| eng.initialize() | ||
|
|
||
| # Get the kernel function. | ||
| add_func = eng.lookup("add") | ||
|
|
||
| ### Execution ### | ||
| # Create an empty buffer to hold results. | ||
| out = torch.empty_like(out_ref) | ||
|
|
||
| # Execute the kernel. | ||
| args = lh_utils.torch_to_packed_args([a, b, out]) | ||
| add_func(args) | ||
|
|
||
| ### Verification ### | ||
| # Check numerical correctness. | ||
| if not torch.allclose(out_ref, out, rtol=0.01, atol=0.01): | ||
| print("Error! Result mismatch!") | ||
| else: | ||
| print("Result matched!") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
| """A collection of utility tools""" | ||
|
|
||
| from .runtime_args import ( | ||
| get_packed_arg, | ||
| memref_to_ctype, | ||
| memrefs_to_packed_args, | ||
| torch_to_memref, | ||
| torch_to_packed_args, | ||
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,62 @@ | ||
| import ctypes | ||
| import torch | ||
|
|
||
| from mlir.runtime.np_to_memref import ( | ||
| get_ranked_memref_descriptor, | ||
| ) | ||
|
|
||
|
|
||
| def get_packed_arg(ctypes_args) -> list[ctypes.c_void_p]: | ||
| """ | ||
| Return a list of packed ctype arguments compatible with | ||
| jitted MLIR function's interface. | ||
|
|
||
| Args: | ||
| ctypes_args: A list of ctype pointer arguments. | ||
| """ | ||
| packed_args = (ctypes.c_void_p * len(ctypes_args))() | ||
| for argNum in range(len(ctypes_args)): | ||
| packed_args[argNum] = ctypes.cast(ctypes_args[argNum], ctypes.c_void_p) | ||
| return packed_args | ||
|
|
||
|
|
||
| def memref_to_ctype(memref_desc) -> ctypes._Pointer: | ||
| """ | ||
| Convert a memref descriptor into a ctype argument. | ||
|
|
||
| Args: | ||
| memref_desc: An MLIR memref descriptor. | ||
| """ | ||
| return ctypes.pointer(ctypes.pointer(memref_desc)) | ||
|
|
||
|
|
||
| def memrefs_to_packed_args(memref_descs) -> list[ctypes.c_void_p]: | ||
| """ | ||
| Convert a list of memref descriptors into packed ctype arguments. | ||
|
|
||
| Args: | ||
| memref_descs: A list of memref descriptors. | ||
| """ | ||
| ctype_args = [memref_to_ctype(memref) for memref in memref_descs] | ||
| return get_packed_arg(ctype_args) | ||
|
|
||
|
|
||
| def torch_to_memref(input: torch.Tensor) -> ctypes.Structure: | ||
| """ | ||
| Convert a PyTorch tensor into a memref descriptor. | ||
|
|
||
| Args: | ||
| input: PyTorch tensor. | ||
| """ | ||
| return get_ranked_memref_descriptor(input.numpy()) | ||
|
|
||
|
|
||
| def torch_to_packed_args(inputs: list[torch.Tensor]) -> list[ctypes.c_void_p]: | ||
| """ | ||
| Convert a list of PyTorch tensors into packed ctype arguments. | ||
|
|
||
| Args: | ||
| inputs: A list of PyTorch tensors. | ||
| """ | ||
| memrefs = [torch_to_memref(input) for input in inputs] | ||
| return memrefs_to_packed_args(memrefs) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe
transform.named_sequenceshould already work.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AFAIK, snake cases expose and call only the base class of the ops.
In many cases, it is enough. And I think it's good to prioritize their usage overall.
However, for more complex ops, it means we miss out on the QoL overloads and extra initialization steps.
For named sequence, creating
transform.NamedSequenceOpgives direct access to easier to use APIs and also initializes the region's blocks. So, I think in such cases it's better to use camel case ops instead.Similar thing goes for
ApplyPatternsOporTileUsingForOp.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is only the case if the person adding a derived class forgot to also add a new wrapper alongside their derived class (the old wrapper will indeed only wrap the base class). This is a bug and should be fixed upstream. Whenever we encounter this, lets track all instances in an issue and I will fix it.
I don't follow. In Python,
transform.named_sequenceis defined as follows:Could you explain how using
transform.NamedSequenceOpgives you more flexibility? (In my experience this hasn't been the case, that is, I have been usingtransform.named_sequence(...).bodyetc without issue in my code.)(I know some of the wrappers return
TransformOp(...).resultrather thanTransformOp(...). I believe we should treat that as a bug as well and track and fix the instances we encounter.)Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I tried using
transform.named_sequence, it directs to the above wrapper as you show. When I follow further calls, this constructor seems to call the base version:instead of the derived:
which initializer takes care of building
function_type(less verbose user code), initialized the regionself.regions[0].blocks.append(*input_types)and gives convenient access methods:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm still pretty green to MLIR's python world so might be I'm just doing sth really stupid 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another example, in case of tiling it's just verbosity:
vs
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right!
transform.named_sequenceis only wrapping the base class version! My mistake, should have checked (my code does accessbodybut does so as a Region (as that is what the base class exposes)) -- sorry.I will put up a PR upstream to fix this for the main transform ops, such as
named_sequence, ASAP. Here's an upstream issue to track this: llvm/llvm-project#166765 Let's update it whenever we come across these issues. Please add other ops you encounter such issues with!There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, will do 👍
I'm still all in for pushing for snake case.
Camel case usage in lighthouse will also serve as a small tracker of what needs to be tweaked.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
llvm/llvm-project#166871