-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[TIR, TVMScript] Add TIR - Triton integration #17395
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
03fb1a4
[TIR, TVMScript] Add TIR - Triton integration
vinx13 529ccae
update test
vinx13 ebf0e78
refactor
vinx13 4baec55
dedup
vinx13 eb77789
lint
vinx13 8c75dbd
lint
vinx13 47b6f71
lint
vinx13 08b3fa8
lint
vinx13 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,141 @@ | ||
| # Licensed to the Apache Software Foundation (ASF) under one | ||
| # or more contributor license agreements. See the NOTICE file | ||
| # distributed with this work for additional information | ||
| # regarding copyright ownership. The ASF licenses this file | ||
| # to you under the Apache License, Version 2.0 (the | ||
| # "License"); you may not use this file except in compliance | ||
| # with the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, | ||
| # software distributed under the License is distributed on an | ||
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| # KIND, either express or implied. See the License for the | ||
| # specific language governing permissions and limitations | ||
| # under the License. | ||
| """External kernel integration fro TIR""" | ||
| import json | ||
| import logging | ||
| import tempfile | ||
| from typing import Any, Dict, List, Tuple, Union | ||
|
|
||
| from tvm import __version__ as tvm_version | ||
| from tvm import tir | ||
| from tvm.runtime import Module, load_module | ||
|
|
||
|
|
||
| class BaseKernel: | ||
| """Base class for external kernels.""" | ||
|
|
||
| def compile_to_device_module( | ||
| self, launch_args, *args, **kwargs | ||
| ) -> Tuple[str, Module, List[Any]]: | ||
| """Compile the kernel to a device module.""" | ||
| raise NotImplementedError() | ||
|
|
||
| def _format_tvm_module_metadata(self, kernel_name, arg_types, launch_param_tags): | ||
| """Format the TVM module metadata.""" | ||
| tvm_metadata = """{{ | ||
| "tvm_version": "{version}", | ||
| "func_info": {{ | ||
| "{kernel_name}": {{ | ||
| "name": "", | ||
| "arg_types": {arg_types}, | ||
| "launch_param_tags": {launch_param_tags} | ||
| }} | ||
| }} | ||
| }}""".format_map( | ||
| { | ||
| "version": tvm_version, | ||
| "kernel_name": kernel_name, | ||
| "arg_types": json.dumps(arg_types), | ||
| "launch_param_tags": json.dumps(launch_param_tags), | ||
| } | ||
| ) | ||
| return tvm_metadata | ||
|
|
||
| def _create_cuda_module(self, ptx, kernel_arg_types, launch_param_tags, kernel_name): | ||
| """ | ||
| Create a CUDA module from PTX and metadata. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| ptx : str | ||
| The PTX code of the kernel. | ||
|
|
||
| kernel_arg_types : List[str] | ||
| The types of the kernel arguments. | ||
|
|
||
| launch_param_tags : List[str] | ||
| The tags of the launch parameters. | ||
|
|
||
| kernel_name : str | ||
| The name of the kernel. | ||
|
|
||
| Returns | ||
| ------- | ||
| kernel_module : Module | ||
| The CUDA module. | ||
| """ | ||
| tvm_metadata = self._format_tvm_module_metadata( | ||
| kernel_name, kernel_arg_types, launch_param_tags | ||
| ) | ||
| with tempfile.TemporaryDirectory() as temp_dir: | ||
| ptx_path = f"{temp_dir}/{kernel_name}.ptx" | ||
| with open(ptx_path, "w") as f: | ||
| f.write(ptx) | ||
| with open(f"{temp_dir}/{kernel_name}.tvm_meta.json", "w") as f: | ||
| f.write(tvm_metadata) | ||
| kernel_module = load_module(ptx_path) | ||
| return kernel_module | ||
|
|
||
|
|
||
| def call_kernel( | ||
| kernel, | ||
| launch_args: List[Union[int, tir.PrimExpr, List[Union[int, tir.PrimExpr]]]], | ||
| *args: List[Any], | ||
| **kwargs: Dict[str, Any], | ||
| ): | ||
| """ | ||
| Call an external kernel. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| kernel : Any | ||
| The external kernel to call. | ||
|
|
||
| launch_args : List[Union[int, tir.PrimExpr, List[Union[int, tir.PrimExpr]]]] | ||
| The launch arguments. A list of integers for grid size, block size, and shared memory size. | ||
| The actual requirements depend on the kernel. | ||
|
|
||
| args : List[tir.PrimExpr] | ||
| The arguments to pass to the kernel. | ||
|
|
||
| kwargs : Dict[str, Any] | ||
| Additional keyword arguments to pass to the kernel or compilation. | ||
| """ | ||
| from ..ir import module_get_attr, module_set_attr # pylint: disable=import-outside-toplevel | ||
| from .ir import call_packed # pylint: disable=import-outside-toplevel | ||
|
|
||
| kernel_type = f"{type(kernel).__module__}.{type(kernel).__qualname__}" | ||
| if kernel_type == "triton.runtime.jit.JITFunction": | ||
| from .triton import TritonKernel # pylint: disable=import-outside-toplevel | ||
|
|
||
| kernel = TritonKernel(kernel) | ||
| else: | ||
| raise ValueError("Unsupported kernel type {}".format(kernel_type)) | ||
|
|
||
| kernel_name, kernel_module, runtime_args = kernel.compile_to_device_module( | ||
| launch_args, *args, **kwargs | ||
| ) | ||
|
|
||
| # Attach the kernel module to the current IRModule | ||
| external_mods: List[Module] = module_get_attr("external_mods") or [] | ||
| kernel_exists = any([mod.implements_function(kernel_name) for mod in external_mods]) | ||
| if kernel_exists: | ||
| logging.debug("Kernel %s already exists in the IRModule", kernel_name) | ||
| else: | ||
| external_mods.append(kernel_module) | ||
| module_set_attr("external_mods", external_mods, True) | ||
| return call_packed(kernel_name, *runtime_args) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,115 @@ | ||
| # Licensed to the Apache Software Foundation (ASF) under one | ||
| # or more contributor license agreements. See the NOTICE file | ||
| # distributed with this work for additional information | ||
| # regarding copyright ownership. The ASF licenses this file | ||
| # to you under the Apache License, Version 2.0 (the | ||
| # "License"); you may not use this file except in compliance | ||
| # with the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, | ||
| # software distributed under the License is distributed on an | ||
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| # KIND, either express or implied. See the License for the | ||
| # specific language governing permissions and limitations | ||
| # under the License. | ||
| """Triton kernel integration with TIR""" | ||
|
|
||
| from typing import Tuple, List, Union, Any, Dict | ||
|
|
||
| import triton | ||
| from triton.runtime.jit import type_canonicalisation_dict | ||
| from tvm import tir | ||
| from tvm.topi.utils import get_const_int | ||
| from tvm.runtime import Module | ||
| from .external_kernel import BaseKernel | ||
|
|
||
|
|
||
| class TritonKernel(BaseKernel): | ||
| """A kernel from Triton JIT function. | ||
|
|
||
| This class bridges the Triton kernel with TVM runtime. The compilation includes the following | ||
| steps: | ||
| - Deduce the kernel signature and generate the Triton kernel | ||
| - Embed the compiled kernel into the current IRModule as an external module | ||
| - Generate a call to the Triton kernel following its calling convention via call_packed. | ||
| """ | ||
|
|
||
| def __init__(self, func): | ||
| self.func = func | ||
|
|
||
| def compile_to_device_module( | ||
| self, | ||
| launch_args: List[Union[int, tir.PrimExpr]], | ||
| *args: List[Any], | ||
| **kwargs: Dict[str, Any], | ||
| ) -> Tuple[str, Module, List[Any]]: | ||
| """Compile the kernel to a device module. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| launch_args : List[int] | ||
| The grid size of the kernel. A list of one to three expressions, representing the number | ||
| of | ||
| "blockIdx.x", "blockIdx.y", and "blockIdx.z" respectively. | ||
|
|
||
| args : List[Any] | ||
| Arguments to the kernel function. | ||
|
|
||
| kwargs : Dict[str, Any] | ||
| Additional options for the kernel compilation. | ||
| """ | ||
| triton_kernel, kernel_args = self._generate_triton_kernel(self.func, *args, **kwargs) | ||
| kernel_metadata = triton_kernel.metadata | ||
| ptx = triton_kernel.asm["ptx"] | ||
| assert kernel_metadata.num_ctas == 1, "Cluster is not supported" | ||
| num_warps = kernel_metadata.num_warps | ||
| grid = launch_args | ||
| launch_param_tags = ["threadIdx.x"] + ["blockIdx.x", "blockIdx.y", "blockIdx.z"][ | ||
| : len(grid) | ||
| ] | ||
| launch_args = [num_warps * 32] + list(grid) | ||
| kernel_arg_types = [arg.dtype for arg in kernel_args] | ||
| if triton_kernel.metadata.shared > 0: | ||
| # Add shared memory size to the launch arguments | ||
| launch_param_tags.append("tir.use_dyn_shared_memory") | ||
| launch_args.append(triton_kernel.metadata.shared) | ||
|
|
||
| kernel_module = self._create_cuda_module( | ||
| ptx, kernel_arg_types, launch_param_tags, triton_kernel.name | ||
| ) | ||
|
|
||
| return triton_kernel.name, kernel_module, kernel_args + launch_args | ||
|
|
||
| def _generate_triton_kernel( | ||
| self, func, *args, **kwargs | ||
| ) -> Tuple["triton.compiler.CompiledKernel", List[tir.PrimExpr]]: | ||
| """Deduce the kernel signature and generate the Triton kernel""" | ||
|
|
||
| kernel_params = func.params | ||
| assert len(kernel_params) == len( | ||
| args | ||
| ), f"Number of arguments does not match, expected {len(kernel_params)}, got {len(args)}" | ||
|
|
||
| signature = {} | ||
| constants = {} | ||
| kernel_args = [] # Arguments to invoke the kernel | ||
| for i, arg in enumerate(args): | ||
| if kernel_params[i].is_constexpr: | ||
| constants[kernel_params[i].name] = get_const_int(arg) | ||
| continue | ||
| if arg.dtype == "handle": | ||
| assert isinstance(arg, tir.Var) | ||
| elem_type = arg.type_annotation.element_type.dtype | ||
| pointer_type = "*" + type_canonicalisation_dict[elem_type] | ||
| signature[kernel_params[i].name] = pointer_type | ||
| else: | ||
| signature[kernel_params[i].name] = type_canonicalisation_dict[arg.dtype] | ||
| kernel_args.append(arg) | ||
|
|
||
| # TODO: Support default argument in the kernel | ||
| # TODO: Add specialization for aligned buffer pointers | ||
| source = triton.compiler.ASTSource(fn=func, constants=constants, signature=signature) | ||
| compiled = triton.compiler.compile(source, options=kwargs) | ||
| return compiled, kernel_args |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as a followup, add a function to check if the module is_device_module, this should include cuda, rocm, webgpu, vulkan, opencl