Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion python/tvm/relax/vm_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,13 +243,25 @@ def _vmlink(
if ext_libs is None:
ext_libs = []
lib = None
relax_ext_libs = []
tir_ext_libs = []
if tir_mod is not None and len(tir_mod.get_global_vars()) > 0:
lib = tvm.build(
tir_mod,
target=target,
runtime=_autodetect_system_lib_req(target, system_lib),
)
return Executable(_ffi_api.VMLink(builder, target, lib, ext_libs, params)) # type: ignore
for ext_mod in ext_libs:
if ext_mod.type_key == "cuda":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as a followup, add a function to check if the module is_device_module, this should include cuda, rocm, webgpu, vulkan, opencl

tir_ext_libs.append(ext_mod)
else:
relax_ext_libs.append(ext_mod)
if lib is not None:
for mod in tir_ext_libs:
lib.import_module(mod)
elif len(tir_ext_libs) > 0:
print("Warning: No TIR module is found, but external modules for TIR are provided.")
return Executable(_ffi_api.VMLink(builder, target, lib, relax_ext_libs, params)) # type: ignore


def build(
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/script/ir_builder/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
def_function,
ir_module,
module_attrs,
module_get_attr,
module_set_attr,
module_global_infos,
lookup_vdevice,
vdevice,
Expand Down
58 changes: 55 additions & 3 deletions python/tvm/script/ir_builder/ir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
"""Package tvm.script.ir_builder.ir.ir"""

from typing import Dict, List
from typing import Dict, List, Optional

from tvm.ir import BaseFunc, GlobalVar, GlobalInfo, VDevice, DummyGlobalInfo
from tvm.runtime import Object as tvm_Object
Expand Down Expand Up @@ -77,14 +77,66 @@ def def_function(func_name: str, func: BaseFunc) -> None:
return _ffi_api.DefFunction(func_name, func) # type: ignore[attr-defined] # pylint: disable=no-member


def module_attrs(attrs: Dict[str, tvm_Object]) -> None:
def module_attrs(attrs: Dict[str, tvm_Object], allow_overwrite=False) -> None:
"""Specify the attrs of the ir_module frame.
Parameters
----------
attrs: Dict[str, Object]
The module attrs.
allow_overwrite: bool
Whether allow overwrite the existing attrs.
"""
return _ffi_api.ModuleAttrs(attrs) # type: ignore[attr-defined] # pylint: disable=no-member
return _ffi_api.ModuleAttrs(attrs, allow_overwrite) # type: ignore[attr-defined] # pylint: disable=no-member


def current_ir_module() -> IRModuleFrame:
"""Get the current ir_module frame.
Returns
-------
frame: IRModuleFrame
The current frame.
"""
return _ffi_api.CurrentIRModule() # type: ignore[attr-defined] # pylint: disable=no-member


def module_get_attrs() -> Dict[str, tvm_Object]:
"""Get the attrs of the ir_module frame.
Returns
-------
attrs: Dict[str, Object]
The module attrs.
"""
return _ffi_api.ModuleGetAttrs() # type: ignore[attr-defined] # pylint: disable=no-member


def module_get_attr(attr_key: str) -> Optional[tvm_Object]:
"""Get the specified attr of the ir_module frame.
Parameters
----------
attr_key: str
The key of the attr to be retrieved.
Returns
-------
attr: Optional[Object]
The specified module attr or None if not found.
"""
return _ffi_api.ModuleGetAttr(attr_key) # type: ignore[attr-defined] # pylint: disable=no-member


def module_set_attr(
attr_key: str, attr_value: Optional[tvm_Object], allow_overwrite: bool = False
) -> None:
"""Set the specified attr of the ir_module frame.
Parameters
----------
attr_key: str
The key of the attr to be set.
attr_value: Optional[Object]
The value of the attr to be set.
allow_overwrite: bool
Whether allow overwrite the existing attr.
"""
return _ffi_api.ModuleSetAttr(attr_key, attr_value, allow_overwrite) # type: ignore[attr-defined] # pylint: disable=no-member


def module_global_infos(global_infos: Dict[str, List[GlobalInfo]]) -> None:
Expand Down
141 changes: 141 additions & 0 deletions python/tvm/script/ir_builder/tir/external_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""External kernel integration fro TIR"""
import json
import logging
import tempfile
from typing import Any, Dict, List, Tuple, Union

from tvm import __version__ as tvm_version
from tvm import tir
from tvm.runtime import Module, load_module


class BaseKernel:
"""Base class for external kernels."""

def compile_to_device_module(
self, launch_args, *args, **kwargs
) -> Tuple[str, Module, List[Any]]:
"""Compile the kernel to a device module."""
raise NotImplementedError()

def _format_tvm_module_metadata(self, kernel_name, arg_types, launch_param_tags):
"""Format the TVM module metadata."""
tvm_metadata = """{{
"tvm_version": "{version}",
"func_info": {{
"{kernel_name}": {{
"name": "",
"arg_types": {arg_types},
"launch_param_tags": {launch_param_tags}
}}
}}
}}""".format_map(
{
"version": tvm_version,
"kernel_name": kernel_name,
"arg_types": json.dumps(arg_types),
"launch_param_tags": json.dumps(launch_param_tags),
}
)
return tvm_metadata

def _create_cuda_module(self, ptx, kernel_arg_types, launch_param_tags, kernel_name):
"""
Create a CUDA module from PTX and metadata.

Parameters
----------
ptx : str
The PTX code of the kernel.

kernel_arg_types : List[str]
The types of the kernel arguments.

launch_param_tags : List[str]
The tags of the launch parameters.

kernel_name : str
The name of the kernel.

Returns
-------
kernel_module : Module
The CUDA module.
"""
tvm_metadata = self._format_tvm_module_metadata(
kernel_name, kernel_arg_types, launch_param_tags
)
with tempfile.TemporaryDirectory() as temp_dir:
ptx_path = f"{temp_dir}/{kernel_name}.ptx"
with open(ptx_path, "w") as f:
f.write(ptx)
with open(f"{temp_dir}/{kernel_name}.tvm_meta.json", "w") as f:
f.write(tvm_metadata)
kernel_module = load_module(ptx_path)
return kernel_module


def call_kernel(
kernel,
launch_args: List[Union[int, tir.PrimExpr, List[Union[int, tir.PrimExpr]]]],
*args: List[Any],
**kwargs: Dict[str, Any],
):
"""
Call an external kernel.

Parameters
----------
kernel : Any
The external kernel to call.

launch_args : List[Union[int, tir.PrimExpr, List[Union[int, tir.PrimExpr]]]]
The launch arguments. A list of integers for grid size, block size, and shared memory size.
The actual requirements depend on the kernel.

args : List[tir.PrimExpr]
The arguments to pass to the kernel.

kwargs : Dict[str, Any]
Additional keyword arguments to pass to the kernel or compilation.
"""
from ..ir import module_get_attr, module_set_attr # pylint: disable=import-outside-toplevel
from .ir import call_packed # pylint: disable=import-outside-toplevel

kernel_type = f"{type(kernel).__module__}.{type(kernel).__qualname__}"
if kernel_type == "triton.runtime.jit.JITFunction":
from .triton import TritonKernel # pylint: disable=import-outside-toplevel

kernel = TritonKernel(kernel)
else:
raise ValueError("Unsupported kernel type {}".format(kernel_type))

kernel_name, kernel_module, runtime_args = kernel.compile_to_device_module(
launch_args, *args, **kwargs
)

# Attach the kernel module to the current IRModule
external_mods: List[Module] = module_get_attr("external_mods") or []
kernel_exists = any([mod.implements_function(kernel_name) for mod in external_mods])
if kernel_exists:
logging.debug("Kernel %s already exists in the IRModule", kernel_name)
else:
external_mods.append(kernel_module)
module_set_attr("external_mods", external_mods, True)
return call_packed(kernel_name, *runtime_args)
3 changes: 2 additions & 1 deletion python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
from tvm.tir.generic import cast

from . import _ffi_api, frame
from .external_kernel import call_kernel

# pylint: enable=unused-import

Expand Down Expand Up @@ -1943,7 +1944,6 @@ def wrapped(*args, **kwargs):
tvm_call_packed_lowered = call_packed_lowered
tvm_call_cpacked_lowered = call_cpacked_lowered


# pylint: enable=invalid-name


Expand Down Expand Up @@ -2255,4 +2255,5 @@ def wrapped(*args, **kwargs):
"Range",
"vscale",
"get_active_lane_mask",
"call_kernel",
]
115 changes: 115 additions & 0 deletions python/tvm/script/ir_builder/tir/triton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Triton kernel integration with TIR"""

from typing import Tuple, List, Union, Any, Dict

import triton
from triton.runtime.jit import type_canonicalisation_dict
from tvm import tir
from tvm.topi.utils import get_const_int
from tvm.runtime import Module
from .external_kernel import BaseKernel


class TritonKernel(BaseKernel):
"""A kernel from Triton JIT function.

This class bridges the Triton kernel with TVM runtime. The compilation includes the following
steps:
- Deduce the kernel signature and generate the Triton kernel
- Embed the compiled kernel into the current IRModule as an external module
- Generate a call to the Triton kernel following its calling convention via call_packed.
"""

def __init__(self, func):
self.func = func

def compile_to_device_module(
self,
launch_args: List[Union[int, tir.PrimExpr]],
*args: List[Any],
**kwargs: Dict[str, Any],
) -> Tuple[str, Module, List[Any]]:
"""Compile the kernel to a device module.

Parameters
----------
launch_args : List[int]
The grid size of the kernel. A list of one to three expressions, representing the number
of
"blockIdx.x", "blockIdx.y", and "blockIdx.z" respectively.

args : List[Any]
Arguments to the kernel function.

kwargs : Dict[str, Any]
Additional options for the kernel compilation.
"""
triton_kernel, kernel_args = self._generate_triton_kernel(self.func, *args, **kwargs)
kernel_metadata = triton_kernel.metadata
ptx = triton_kernel.asm["ptx"]
assert kernel_metadata.num_ctas == 1, "Cluster is not supported"
num_warps = kernel_metadata.num_warps
grid = launch_args
launch_param_tags = ["threadIdx.x"] + ["blockIdx.x", "blockIdx.y", "blockIdx.z"][
: len(grid)
]
launch_args = [num_warps * 32] + list(grid)
kernel_arg_types = [arg.dtype for arg in kernel_args]
if triton_kernel.metadata.shared > 0:
# Add shared memory size to the launch arguments
launch_param_tags.append("tir.use_dyn_shared_memory")
launch_args.append(triton_kernel.metadata.shared)

kernel_module = self._create_cuda_module(
ptx, kernel_arg_types, launch_param_tags, triton_kernel.name
)

return triton_kernel.name, kernel_module, kernel_args + launch_args

def _generate_triton_kernel(
self, func, *args, **kwargs
) -> Tuple["triton.compiler.CompiledKernel", List[tir.PrimExpr]]:
"""Deduce the kernel signature and generate the Triton kernel"""

kernel_params = func.params
assert len(kernel_params) == len(
args
), f"Number of arguments does not match, expected {len(kernel_params)}, got {len(args)}"

signature = {}
constants = {}
kernel_args = [] # Arguments to invoke the kernel
for i, arg in enumerate(args):
if kernel_params[i].is_constexpr:
constants[kernel_params[i].name] = get_const_int(arg)
continue
if arg.dtype == "handle":
assert isinstance(arg, tir.Var)
elem_type = arg.type_annotation.element_type.dtype
pointer_type = "*" + type_canonicalisation_dict[elem_type]
signature[kernel_params[i].name] = pointer_type
else:
signature[kernel_params[i].name] = type_canonicalisation_dict[arg.dtype]
kernel_args.append(arg)

# TODO: Support default argument in the kernel
# TODO: Add specialization for aligned buffer pointers
source = triton.compiler.ASTSource(fn=func, constants=constants, signature=signature)
compiled = triton.compiler.compile(source, options=kwargs)
return compiled, kernel_args
Loading