diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 08da62e640e1..4253d93f6500 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -145,7 +145,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): if layout == "NCHW": assert kernel_layout == "OIHW" if ( - (target.kind.name in ["cuda", "vulkan"]) + (target.kind.name in ["cuda", "vulkan", "rocm"]) and data.dtype in ("int8", "uint8") and kernel.dtype in ("int8", "uint8") ): @@ -297,7 +297,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): Need to satisfy tensor core schedule." ) elif ( - (target.kind.name in ["cuda", "vulkan"]) + (target.kind.name in ["cuda", "vulkan", "rocm"]) and layout == "NCHW4c" and data.dtype in ["int8", "uint8"] ): @@ -376,7 +376,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): ic_chunk = in_channels // 4 if ( - (target.kind.name in ["cuda", "vulkan"]) + (target.kind.name in ["cuda", "vulkan", "rocm"]) and data.dtype in ["int8", "uint8"] and kernel.dtype in ["int8", "uint8"] and channels % groups == 0 @@ -836,7 +836,7 @@ def dense_strategy_cuda(attrs, inputs, out_type, target): b, i = get_const_tuple(data.shape) o, _ = get_const_tuple(weights.shape) if ( - target.kind.name in ["cuda", "vulkan"] + target.kind.name in ["cuda", "vulkan", "rocm"] and data.dtype == "int8" and weights.dtype == "int8" and out_type.dtype == "int32" diff --git a/python/tvm/relay/op/strategy/rocm.py b/python/tvm/relay/op/strategy/rocm.py index 1453128eeb67..6e91101826c9 100644 --- a/python/tvm/relay/op/strategy/rocm.py +++ b/python/tvm/relay/op/strategy/rocm.py @@ -17,162 +17,39 @@ """Definition of ROCm operator strategy.""" # pylint: disable=invalid-name,unused-argument,unused-wildcard-import,wildcard-import from tvm import topi -from tvm.auto_scheduler import is_auto_scheduler_enabled from tvm.te import SpecializedCondition from tvm.contrib.thrust import can_use_rocthrust from tvm.contrib import miopen from .generic import * from .. import op as _op -from .cuda import judge_winograd, naive_schedule +from .cuda import batch_matmul_strategy_cuda, conv2d_strategy_cuda, dense_strategy_cuda @conv2d_strategy.register("rocm") def conv2d_strategy_rocm(attrs, inputs, out_type, target): """conv2d rocm strategy""" - strategy = _op.OpStrategy() - data, kernel = inputs - dilation_h, dilation_w = attrs.get_int_tuple("dilation") groups = attrs.groups layout = attrs.data_layout - stride_h, stride_w = attrs.get_int_tuple("strides") - kernel_layout = attrs.kernel_layout padding = attrs.get_int_tuple("padding") - if dilation_h < 1 or dilation_w < 1: - raise ValueError("dilation should be positive value") - - if groups == 1: - if layout == "NCHW": - # TODO(@vinx13, @icemelon9): Use conv2d_NCHWc_int8 when dtype is int8/uint8. - assert kernel_layout == "OIHW" - strategy.add_implementation( - wrap_compute_conv2d(topi.cuda.conv2d_nchw), - wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw), - name="conv2d_nchw.cuda", - ) - _, _, kh, kw = get_const_tuple(kernel.shape) - if ( - 2 < kh < 8 - and 2 < kw < 8 - and kh == kw - and stride_h == 1 - and stride_w == 1 - and dilation_h == 1 - and dilation_w == 1 - ): - strategy.add_implementation( - wrap_compute_conv2d(topi.cuda.conv2d_nchw_winograd), - wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw_winograd), - name="conv2d_nchw_winograd.cuda", - plevel=5, - ) - elif layout == "NHWC": - assert kernel_layout == "HWIO" - strategy.add_implementation( - wrap_compute_conv2d(topi.gpu.conv2d_nhwc), - wrap_topi_schedule(topi.gpu.schedule_conv2d_nhwc), - name="conv2d_nhwc.gpu", - ) - N, H, W, _ = get_const_tuple(data.shape) - KH, KW, CI, CO = get_const_tuple(kernel.shape) - (_, judge_winograd_autotvm, judge_winograd_auto_scheduler,) = judge_winograd( - N, - H, - W, - KH, - KW, - CI, - CO, - padding, - stride_h, - stride_w, - dilation_h, - dilation_w, - data.dtype, - kernel.dtype, - pre_flag=False, - ) + strategy = conv2d_strategy_cuda(attrs, inputs, out_type, target) - if judge_winograd_autotvm: - strategy.add_implementation( - wrap_compute_conv2d(topi.cuda.conv2d_nhwc_winograd_direct), - wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc_winograd_direct), - name="conv2d_nhwc_winograd_direct.cuda", - plevel=5, - ) + # add miopen implementation + if ( + "miopen" in target.libs + and groups == 1 + and layout == "NCHW" + and padding[0] == padding[2] + and padding[1] == padding[3] + ): + strategy.add_implementation( + wrap_compute_conv2d(topi.rocm.conv2d_nchw_miopen, True), + wrap_topi_schedule(topi.rocm.schedule_conv2d_nchw_miopen), + name="conv2d_nchw_miopen.rocm", + plevel=50, + ) - if is_auto_scheduler_enabled() and judge_winograd_auto_scheduler: - strategy.add_implementation( - wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc), - naive_schedule, # this implementation should never be picked by autotvm - name="conv2d_nhwc.winograd", - plevel=15, - ) - elif layout == "HWCN": - assert kernel_layout == "HWIO" - strategy.add_implementation( - wrap_compute_conv2d(topi.cuda.conv2d_hwcn), - wrap_topi_schedule(topi.cuda.schedule_conv2d_hwcn), - name="conv2d_hwcn.cuda", - ) - elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]: - assert kernel_layout == "OIHW4o4i" - strategy.add_implementation( - wrap_compute_conv2d(topi.cuda.conv2d_NCHWc_int8, True), - wrap_topi_schedule(topi.cuda.schedule_conv2d_NCHWc_int8), - name="conv2d_NCHWc_int8.cuda", - ) - else: - raise RuntimeError("Unsupported conv2d layout {} for CUDA".format(layout)) - # add miopen implementation - if ( - "miopen" in target.libs - and layout == "NCHW" - and padding[0] == padding[2] - and padding[1] == padding[3] - ): - strategy.add_implementation( - wrap_compute_conv2d(topi.rocm.conv2d_nchw_miopen, True), - wrap_topi_schedule(topi.rocm.schedule_conv2d_nchw_miopen), - name="conv2d_nchw_miopen.rocm", - plevel=15, - ) - elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups): - if layout == "NCHW": - assert kernel_layout == "OIHW" - strategy.add_implementation( - wrap_compute_conv2d(topi.cuda.depthwise_conv2d_nchw), - wrap_topi_schedule(topi.cuda.schedule_depthwise_conv2d_nchw), - name="depthwise_conv2d_nchw.cuda", - ) - elif layout == "NHWC": - assert kernel_layout == "HWOI" - strategy.add_implementation( - wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc), - wrap_topi_schedule(topi.cuda.schedule_depthwise_conv2d_nhwc), - name="depthwise_conv2d_nhwc.cuda", - ) - else: - raise RuntimeError("Unsupported depthwise_conv2d layout {}".format(layout)) - else: # group_conv2d - if layout == "NCHW": - # TODO(@vinx13, @icemelon9): Use group_conv2d_NCHWc_int8 when dtype is int8/uint8. - assert kernel_layout == "OIHW" - strategy.add_implementation( - wrap_compute_conv2d(topi.cuda.group_conv2d_nchw, has_groups=True), - wrap_topi_schedule(topi.cuda.schedule_group_conv2d_nchw), - name="group_conv2d_nchw.cuda", - ) - elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]: - assert kernel_layout == "OIHW4o4i" - strategy.add_implementation( - wrap_compute_conv2d(topi.cuda.group_conv2d_NCHWc_int8, True), - wrap_topi_schedule(topi.cuda.schedule_group_conv2d_NCHWc_int8), - name="group_conv2d_NCHWc_int8.cuda", - ) - else: - raise RuntimeError("Unsupported group_conv2d layout {}".format(layout)) return strategy @@ -180,12 +57,8 @@ def conv2d_strategy_rocm(attrs, inputs, out_type, target): def dense_strategy_rocm(attrs, inputs, out_type, target): """Dense strategy for ROCM""" assert len(inputs[0].shape) == 2 and len(inputs[1].shape) == 2, "Only support 2-dim dense" - strategy = _op.OpStrategy() - strategy.add_implementation( - wrap_compute_dense(topi.rocm.dense), - wrap_topi_schedule(topi.rocm.schedule_dense), - name="dense.rocm", - ) + strategy = dense_strategy_cuda(attrs, inputs, out_type, target) + if target.kind.name == "rocm" and "rocblas" in target.libs: assert out_type.dtype == inputs[0].dtype, "Mixed precision not supported." strategy.add_implementation( @@ -200,13 +73,8 @@ def dense_strategy_rocm(attrs, inputs, out_type, target): @batch_matmul_strategy.register("rocm") def batch_matmul_strategy_rocm(attrs, inputs, out_type, target): """Batch matmul strategy for ROCM""" - strategy = _op.OpStrategy() - strategy.add_implementation( - wrap_compute_batch_matmul(topi.cuda.batch_matmul), - wrap_topi_schedule(topi.cuda.schedule_batch_matmul), - name="batch_matmul.cuda", - plevel=10, - ) + strategy = batch_matmul_strategy_cuda(attrs, inputs, out_type, target) + if target.kind.name == "rocm" and "rocblas" in target.libs: assert out_type.dtype == inputs[0].dtype, "Mixed precision not supported." strategy.add_implementation( diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py index 93b1ad7a44a8..e669e14032f9 100644 --- a/python/tvm/relay/qnn/op/legalizations.py +++ b/python/tvm/relay/qnn/op/legalizations.py @@ -24,6 +24,7 @@ from tvm.relay.qnn.op.canonicalizations import create_integer_lookup_op from ....topi.x86.utils import target_has_sse42 +from ....topi.utils import is_target from .. import op as reg ################################################# @@ -387,18 +388,6 @@ def is_aarch64_arm(): return "aarch64" in target.attrs.get("mtriple", "") -def is_vulkan(): - """Checks whether we are compiling for a vulkan/spirv target.""" - target = tvm.target.Target.current(allow_none=False) - return "vulkan" in target.keys - - -def is_cuda(): - """Checks whether we are compiling for a cuda target.""" - target = tvm.target.Target.current(allow_none=False) - return "cuda" in target.keys - - ######################## # ARM CPU legalizations. ######################## @@ -456,10 +445,10 @@ def _qnn_dense_legalize_intel_cpu(attrs, inputs, types): @qnn_conv2d_legalize.register(["cuda", "gpu"]) def _qnn_conv2d_legalize_cuda(attrs, inputs, types): - if is_vulkan(): + if is_target("vulkan"): # prefers the dtypes to be same. Mixed type is not yet supported. return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d) - if is_cuda(): + if is_target(["cuda", "rocm"]): # CUDA prefers both datatypes to be int8. return helper_change_dtypes_to_int8(attrs, inputs, types, relay.qnn.op.conv2d) return None @@ -467,11 +456,10 @@ def _qnn_conv2d_legalize_cuda(attrs, inputs, types): @qnn_dense_legalize.register(["cuda", "gpu"]) def _qnn_dense_legalize_cuda(attrs, inputs, types): - if is_vulkan(): + if is_target("vulkan"): # prefers the dtypes to be same. Mixed type is not yet supported. return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.dense) - if is_cuda(): + if is_target(["cuda", "rocm"]): # CUDA prefers both datatypes to be the int8. return helper_change_dtypes_to_int8(attrs, inputs, types, relay.qnn.op.dense) - return None diff --git a/python/tvm/tir/tensor_intrin/__init__.py b/python/tvm/tir/tensor_intrin/__init__.py index 62159851b3d4..4115c3b90070 100644 --- a/python/tvm/tir/tensor_intrin/__init__.py +++ b/python/tvm/tir/tensor_intrin/__init__.py @@ -18,3 +18,5 @@ """Intrinsics for tensorization.""" from .x86 import * from .arm_cpu import * +from .dot_product_common import * +from .rocm import * diff --git a/python/tvm/tir/tensor_intrin/dot_product_common.py b/python/tvm/tir/tensor_intrin/dot_product_common.py new file mode 100644 index 000000000000..c531b80380e3 --- /dev/null +++ b/python/tvm/tir/tensor_intrin/dot_product_common.py @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,missing-function-docstring +"""Dot product related intrinsics.""" +from tvm.script import tir as T +from .. import TensorIntrin + + +@T.prim_func +def dp4a_desc( + A: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"), + B: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"), + C: T.Buffer((1,), "int32", offset_factor=1, align=4, scope="local"), +) -> None: + with T.block("root"): + T.reads(C[0], A[0:4], B[0:4]) + T.writes(C[0]) + for i in range(0, 4): + with T.block("update"): + vi = T.axis.remap("R", [i]) + C[0] = C[0] + T.cast(A[vi], "int32") * T.cast(B[vi], "int32") + + +@T.prim_func +def dp4a_impl( + A: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"), + B: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"), + C: T.Buffer((1,), "int32", offset_factor=1, align=4, scope="local"), +) -> None: + with T.block("root"): + T.reads(C[0], A[0:4], B[0:4]) + T.writes(C[0]) + + C[0] += T.call_pure_extern( + "__dp4a", A.vload([0], "int8x4"), B.vload([0], "int8x4"), T.int32(0), dtype="int32" + ) + + +DP4A_INTRIN = "dp4a" + +TensorIntrin.register(DP4A_INTRIN, dp4a_desc, dp4a_impl) diff --git a/python/tvm/tir/tensor_intrin/rocm.py b/python/tvm/tir/tensor_intrin/rocm.py new file mode 100644 index 000000000000..7a989d0bccaa --- /dev/null +++ b/python/tvm/tir/tensor_intrin/rocm.py @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,missing-function-docstring +"""Intrinsics for AMDGPU tensorization.""" +from tvm.script import tir as T +from .. import TensorIntrin +from .dot_product_common import dp4a_desc + + +@T.prim_func +def sdot4( + A: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"), + B: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"), + C: T.Buffer((1,), "int32", offset_factor=1, align=4, scope="local"), +) -> None: + with T.block("root"): + T.reads(C[0], A[0:4], B[0:4]) + T.writes(C[0]) + + C[0] += T.call_llvm_pure_intrin( + T.llvm_lookup_intrinsic_id("llvm.amdgcn.sdot4"), + T.uint32(4), + T.reinterpret(A.vload([0], "int8x4"), dtype="int32"), + T.reinterpret(B.vload([0], "int8x4"), dtype="int32"), + T.int32(0), + T.bool(1), + dtype="int32", + ) + + +AMDGPU_SDOT4_INTRIN = "sdot4" + +TensorIntrin.register(AMDGPU_SDOT4_INTRIN, dp4a_desc, sdot4) diff --git a/python/tvm/topi/cuda/batch_matmul.py b/python/tvm/topi/cuda/batch_matmul.py index 5fce9d7a3f5d..ff625d6d714c 100644 --- a/python/tvm/topi/cuda/batch_matmul.py +++ b/python/tvm/topi/cuda/batch_matmul.py @@ -22,7 +22,7 @@ from tvm.contrib import cublas from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity from .. import nn, generic -from ..utils import traverse_inline, get_const_tuple, get_max_power2_factor +from ..utils import traverse_inline, get_const_tuple, get_max_power2_factor, is_target from .tensor_intrin import dp4a @@ -333,9 +333,6 @@ def _callback(op): return s -_dp4a = dp4a("shared", "shared", "local") - - def _schedule_batch_matmul_int8(cfg, s, output): input_x, input_y = s[output].op.input_tensors if len(input_y.op.input_tensors) == 1 and input_y.op.input_tensors[0] == input_x: @@ -372,7 +369,7 @@ def _schedule_batch_matmul_int8(cfg, s, output): target = tvm.target.Target.current(allow_none=False) do_tensorize = True - if "vulkan" in target.keys: + if is_target(["vulkan", "rocm"]): do_tensorize = "+dotprod" in target.mattr or target.supports_integer_dot_product if do_tensorize: diff --git a/python/tvm/topi/cuda/conv2d_alter_op.py b/python/tvm/topi/cuda/conv2d_alter_op.py index eaafe15e9600..35d50eb3673c 100644 --- a/python/tvm/topi/cuda/conv2d_alter_op.py +++ b/python/tvm/topi/cuda/conv2d_alter_op.py @@ -22,7 +22,7 @@ from tvm import te, relay, autotvm from .. import nn -from ..utils import get_const_tuple +from ..utils import get_const_tuple, is_target from .conv2d_winograd import _infer_tile_size from .tensorcore_alter_op import pad_to_tensorcore from ..nn import conv2d_legalize @@ -34,8 +34,7 @@ @nn.conv2d_alter_layout.register(["cuda", "gpu"]) def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): target = tvm.target.Target.current(allow_none=False) - doit = "vulkan" in target.keys or "cuda" in target.keys - if not doit: + if not is_target(["vulkan", "rocm", "cuda"]): return None dispatch_ctx = autotvm.task.DispatchContext.current @@ -87,7 +86,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): if cfg.is_fallback: # if is fallback, clear query cache and return None autotvm.task.clear_fallback_cache(target, workload) do_new_layout = False - if "vulkan" in target.keys: + if is_target(["vulkan", "rocm"]): do_new_layout = "+dotprod" in target.mattr or target.supports_integer_dot_product if not do_new_layout: return None @@ -349,10 +348,7 @@ def _conv2d_legalize(attrs, inputs, arg_types): result : tvm.relay.Expr The legalized expr """ - - target = tvm.target.Target.current(allow_none=False) - doit = "vulkan" in target.keys or "cuda" in target.keys - if not doit: + if not is_target(["vulkan", "rocm", "cuda"]): return None # Dilation not supported yet. Return None if dilation is not (1, 1) dilation = attrs.get_int_tuple("dilation") diff --git a/python/tvm/topi/cuda/conv2d_int8.py b/python/tvm/topi/cuda/conv2d_int8.py index 15120f6a2532..a8b21a1deca0 100644 --- a/python/tvm/topi/cuda/conv2d_int8.py +++ b/python/tvm/topi/cuda/conv2d_int8.py @@ -26,7 +26,7 @@ from ..nn.pad import pad from ..nn.conv2d import unpack_NCHWc_to_nchw from ..nn.utils import get_pad_tuple -from ..utils import get_const_tuple, traverse_inline +from ..utils import get_const_tuple, traverse_inline, is_target def conv2d_nchw_int8(data, kernel, strides, padding, dilation, out_dtype="int32"): @@ -312,7 +312,7 @@ def _schedule_conv2d_NCHWc_int8(cfg, s, output): _, rc_block = s[conv].split(rc_block, factor=4) target = tvm.target.Target.current(allow_none=False) do_tensorize = True - if "vulkan" in target.keys: + if is_target(["vulkan", "rocm"]): do_tensorize = "+dotprod" in target.mattr or target.supports_integer_dot_product if do_tensorize: dtypes = (pad_data.dtype, packed_kernel.dtype) diff --git a/python/tvm/topi/cuda/dense.py b/python/tvm/topi/cuda/dense.py index 862e7b5bc59d..859f6c1097c6 100644 --- a/python/tvm/topi/cuda/dense.py +++ b/python/tvm/topi/cuda/dense.py @@ -24,7 +24,7 @@ from .tensor_intrin import dp4a from .. import tag from .. import generic -from ..utils import traverse_inline, get_const_tuple +from ..utils import traverse_inline, get_const_tuple, is_target logger = logging.getLogger("topi") @@ -173,8 +173,9 @@ def _schedule_dense_int8(cfg, s, output): ko, kt = cfg["tile_k"].apply(s, CC, ko) target = tvm.target.Target.current(allow_none=False) do_tensorize = True - if "vulkan" in target.keys: + if is_target(["vulkan", "rocm"]): do_tensorize = "+dotprod" in target.mattr or target.supports_integer_dot_product + if do_tensorize: dtypes = (data.dtype, weight.dtype) s[CC].tensorize(ki, dp4a("shared", "shared", "local", dtypes)) diff --git a/python/tvm/topi/cuda/tensor_intrin.py b/python/tvm/topi/cuda/tensor_intrin.py index c0596fc43262..0a504906c053 100644 --- a/python/tvm/topi/cuda/tensor_intrin.py +++ b/python/tvm/topi/cuda/tensor_intrin.py @@ -18,6 +18,7 @@ """Tensor intrinsics on CUDA.""" import tvm from tvm import te +from ..utils import is_target def dp4a(x_scope="local", y_scope="local", z_scope="local", dtypes=("int8", "int8")): @@ -71,7 +72,27 @@ def _instr(index): vec_y = yy.vload(0, dtype=vec_y_dtype) prev_z = 0 if index == 0 else zz.vload(0) - new_z = tvm.tir.call_pure_extern(zz_dtype, "__dp4a", vec_x, vec_y, prev_z) + if is_target("rocm"): + # TODO(masahi): Here we are assuming that we are compiling for gfx10 or later + # We can refine the specification for dot product on rocm if needed later. + + # We can just use "llvm.amdgcn.udot4" for u8u8u32, but it is not tested. + assert ( + dtypes[0] == "int8" and dtypes[0] == "int8" + ), "u8u8u32 dot product for rocm not supported yet" + + new_z = tvm.tir.call_llvm_pure_intrin( + zz_dtype, + "llvm.amdgcn.sdot4", + tvm.tir.const(4, "uint32"), + tvm.tir.call_intrin("int32", "tir.reinterpret", vec_x), + tvm.tir.call_intrin("int32", "tir.reinterpret", vec_y), + prev_z, + True, + ) + else: + new_z = tvm.tir.call_pure_extern(zz_dtype, "__dp4a", vec_x, vec_y, prev_z) + ib.emit(zz.vstore(0, new_z)) return ib.get() diff --git a/python/tvm/topi/rocm/dense.py b/python/tvm/topi/rocm/dense.py index 2f3ce77cc7ba..983f235f0ec8 100644 --- a/python/tvm/topi/rocm/dense.py +++ b/python/tvm/topi/rocm/dense.py @@ -19,85 +19,8 @@ from tvm import te from tvm import autotvm from tvm.contrib import rocblas -from .. import generic, nn +from .. import generic from .. import tag -from ..utils import traverse_inline - - -@autotvm.register_topi_compute("dense.rocm") -def dense(cfg, data, weight, bias=None, out_dtype=None): - """Dense operator for rocm backend. - - Parameters - ---------- - data : tvm.te.Tensor - 2-D with shape [batch, in_dim] - - weight : tvm.te.Tensor - 2-D with shape [out_dim, in_dim] - - bias : tvm.te.Tensor, optional - 1-D with shape [out_dim] - - out_dtype : str - The output type. This is used for mixed precision. - - Returns - ------- - output : tvm.te.Tensor - 2-D with shape [batch, out_dim] - """ - assert len(data.shape) == 2 and len(weight.shape) == 2, "only support 2-dim dense" - if bias is not None: - assert len(bias.shape) == 1 - if out_dtype is None: - out_dtype = data.dtype - return nn.dense(data, weight, bias, out_dtype) - - -@autotvm.register_topi_schedule("dense.rocm") -def schedule_dense(cfg, outs): - """Schedule for dense operator. - - Parameters - ---------- - outs: Array of Tensor - The computation graph description of dense - in the format of an array of tensors. - - Returns - ------- - s: Schedule - The computation schedule for dense. - """ - outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs - s = te.create_schedule([x.op for x in outs]) - - def _callback(op): - if op.tag == "dense": - Dense = op.output(0) - num_thread = 64 - k = Dense.op.reduce_axis[0] - ko, kf = s[Dense].split(k, factor=num_thread) - DenseF = s.rfactor(Dense, kf) - - if Dense.op in s.outputs: - Out = Dense - else: - Out = outs[0].op.output(0) - s[Dense].compute_at(s[Out], s[Out].op.axis[1]) - s[Out].bind(s[Out].op.axis[0], te.thread_axis("blockIdx.y")) - s[Out].bind(s[Out].op.axis[1], te.thread_axis("blockIdx.x")) - - tx = s[Dense].op.reduce_axis[0] - thread_x = te.thread_axis("threadIdx.x") - s[Dense].bind(tx, thread_x) - s[DenseF].compute_at(s[Dense], tx) - s[Dense].set_store_predicate(thread_x.var.equal(0)) - s[Out].set_store_predicate(thread_x.var.equal(0)) - - traverse_inline(s, outs[0].op, _callback) - return s @autotvm.register_topi_compute("dense_rocblas.rocm") diff --git a/python/tvm/topi/utils.py b/python/tvm/topi/utils.py index af68ee905e56..f1c6fb5aa4f4 100644 --- a/python/tvm/topi/utils.py +++ b/python/tvm/topi/utils.py @@ -524,3 +524,10 @@ def ceil_div(a, b): def swap(arr, axis): """swap arr[axis] and arr[-1]""" return arr[:axis] + [arr[-1]] + arr[axis + 1 : -1] + [arr[axis]] + + +def is_target(names): + """Return True if the name of the current target is one of provided names""" + names = [names] if isinstance(names, str) else names + target = tvm.target.Target.current(allow_none=False) + return any(name in target.keys for name in names) diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index 6fef8b48c396..96c193d34aa1 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -306,6 +306,7 @@ TVM_REGISTER_TARGET_KIND("nvptx", kDLCUDA) TVM_REGISTER_TARGET_KIND("rocm", kDLROCM) .add_attr_option("mcpu") .add_attr_option("mtriple") + .add_attr_option>("mattr") .add_attr_option("system-lib") .add_attr_option("max_num_threads", Integer(256)) .add_attr_option("thread_warp_size", Integer(64)) diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index c7aceb685bcf..d4238f81e01b 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -676,5 +676,43 @@ def test_dense_vnni(): np.testing.assert_equal(out, ref) +@pytest.mark.skip("Requires GFX10 AMDGPU") +def test_dense_rocm_sdot4(): + data_shape = (32, 96) + weight_shape = (128, 96) + + data_dtype = "int8" + data = relay.var("data", shape=data_shape, dtype=data_dtype) + weight = relay.var("weight", shape=weight_shape, dtype="int8") + bias = relay.var("bias", shape=(weight_shape[0],), dtype="int32") + dense = relay.nn.dense(data, weight, out_dtype="int32") + out = relay.nn.bias_add(dense, bias) + mod = tvm.IRModule.from_expr(out) + + target = "rocm -mattr=+dotprod" + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(mod, target=target) + + asm = lib.lib.imported_modules[0].get_source("asm") + assert "v_dot4_i32_i8" in asm + + dev = tvm.device(target, 0) + runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) + + a = np.random.uniform(1, 10, size=data_shape).astype(data_dtype) + b = np.random.uniform(1, 10, size=weight_shape).astype("int8") + c = np.random.uniform(1, 10, size=(weight_shape[0],)).astype("int32") + + runtime.set_input("data", a) + runtime.set_input("weight", b) + runtime.set_input("bias", c) + runtime.run() + + out = runtime.get_output(0).numpy() + ref = np.dot(a.astype("int32"), b.transpose().astype("int32")) + c + + np.testing.assert_equal(out, ref) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index 85a3dd5636f1..8ee5adbb318d 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -447,6 +447,41 @@ def test_batch_matmul_vnni(): np.testing.assert_equal(out, ref) +@pytest.mark.skip("Requires GFX10 AMDGPU") +def test_batch_matmul_rocm_sdot4(): + x_shape = (16, 32, 96) + y_shape = (16, 128, 96) + + lhs_dtype = "int8" + x = relay.var("x", shape=x_shape, dtype=lhs_dtype) + y = relay.var("y", shape=y_shape, dtype="int8") + bmm = relay.nn.batch_matmul(x, y, out_dtype="int32") + + mod = tvm.IRModule.from_expr(bmm) + + target = "rocm -mattr=+dotprod" + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(mod, target=target) + + asm = lib.lib.imported_modules[0].get_source("asm") + assert "v_dot4_i32_i8" in asm + + dev = tvm.device(target, 0) + runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) + + x_np = np.random.uniform(1, 10, size=x_shape).astype(lhs_dtype) + y_np = np.random.uniform(1, 10, size=y_shape).astype("int8") + + runtime.set_input("x", x_np) + runtime.set_input("y", y_np) + runtime.run() + + out = runtime.get_output(0).numpy() + ref = tvm.topi.testing.batch_matmul(x_np, y_np, out_dtype="int32") + + np.testing.assert_equal(out, ref) + + @tvm.testing.uses_gpu def test_shape_of(): shape = (10, 5, 12) diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index bd9536742a8b..7b261b0eb7cd 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -1944,5 +1944,55 @@ def _test_correlation( ) +@pytest.mark.skip("Requires GFX10 AMDGPU") +def test_conv2d_rocm_sdot4(): + d_shape = (1, 64, 56, 56) + w_shape = (64, 64, 3, 3) + padding = (1, 1) + strides = (1, 1) + data_dtype = "int8" + weight_dtype = "int8" + out_dtype = "int32" + + data = relay.var("data", shape=d_shape, dtype=data_dtype) + weight = relay.var("weight", shape=w_shape, dtype=weight_dtype) + out_channel = w_shape[0] + conv2d = relay.nn.conv2d( + data=data, + weight=weight, + kernel_size=w_shape[2:], + channels=out_channel, + padding=padding, + strides=strides, + out_dtype=out_dtype, + ) + + mod = tvm.IRModule.from_expr(conv2d) + + data_np = np.random.uniform(1, 10, d_shape).astype("int8") + weight_np = np.random.uniform(1, 10, size=w_shape).astype("int8") + + target = "rocm -mattr=+dotprod" + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(mod, target=target, params={"weight": weight_np}) + + asm = lib.lib.imported_modules[0].get_source("asm") + assert "v_dot4_i32_i8" in asm + + dev = tvm.device(target, 0) + runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) + + runtime.set_input("data", data_np) + runtime.run() + + out = runtime.get_output(0).numpy() + + ref = tvm.topi.testing.conv2d_nchw_python( + data_np.astype("int32"), weight_np.astype("int32"), strides, padding + ) + + np.testing.assert_equal(out, ref) + + if __name__ == "__main__": sys.exit(pytest.main(sys.argv)) diff --git a/tests/python/topi/python/test_topi_conv2d_int8.py b/tests/python/topi/python/test_topi_conv2d_int8.py index 860118531e51..17c5573b2c70 100644 --- a/tests/python/topi/python/test_topi_conv2d_int8.py +++ b/tests/python/topi/python/test_topi_conv2d_int8.py @@ -376,15 +376,22 @@ def get_ref_data(): ) if in_dtype == "int8": - targets.append( + targets += [ ( "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon", topi.arm_cpu.conv2d_NCHWc_int8, topi.arm_cpu.schedule_conv2d_NCHWc_int8, 8, build_only_aarch64, - ) - ) + ), + ( + "rocm -mattr=+dotprod", + lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a, w, s, p, d, l, o), + topi.cuda.schedule_conv2d_NCHWc_int8, + 4, + False, + ), + ] for target, compute, schedule, oc_block_factor, build_only in targets: check_target(target, compute, schedule, oc_block_factor, build_only) diff --git a/tests/python/topi/python/test_topi_dense.py b/tests/python/topi/python/test_topi_dense.py index 8f58415da329..2826d70ba0ed 100644 --- a/tests/python/topi/python/test_topi_dense.py +++ b/tests/python/topi/python/test_topi_dense.py @@ -52,7 +52,6 @@ ], "mali": [(topi.mali.dense, topi.mali.schedule_dense)], "bifrost": [(topi.bifrost.dense, topi.bifrost.schedule_dense)], - "rocm": [(topi.rocm.dense, topi.rocm.schedule_dense)], "hls": [(topi.nn.dense, topi.hls.schedule_dense)], } diff --git a/tests/python/unittest/test_tir_schedule_tensorize.py b/tests/python/unittest/test_tir_schedule_tensorize.py index 482d6f3db574..65dfa06eb6c1 100644 --- a/tests/python/unittest/test_tir_schedule_tensorize.py +++ b/tests/python/unittest/test_tir_schedule_tensorize.py @@ -26,6 +26,8 @@ VNNI_DOT_16x4_INTRIN, ARM_DOT_4x4_i8_NEON_INTRIN, ARM_DOT_4x4_i8_SDOT_INTRIN, + AMDGPU_SDOT4_INTRIN, + DP4A_INTRIN, ) # fmt: off @@ -595,5 +597,53 @@ def test_tensorize_arm_dot(): verify_trace_roundtrip(sch=sch, mod=func) +def test_tensorize_dpa4(): + m, n, k = 128, 128, 128 + + X = te.placeholder((m, k), name="X", dtype="int8") + W = te.placeholder((n, k), name="W", dtype="int8") + ak = te.reduce_axis((0, k), name="k") + + matmul = te.compute( + (m, n), + lambda i, j: te.sum( + X[i, ak].astype("int32") + * W[j, ak].astype("int32"), + axis=ak, + ), + name="compute", + ) + + func = te.create_prim_func([X, W, matmul]) + + for intrin in [AMDGPU_SDOT4_INTRIN, DP4A_INTRIN]: + sch = tir.Schedule(func, debug_mask="all") + block = sch.get_block("compute") + i, j, k = sch.get_loops(block) + + by, ty, yi = sch.split(i, factors=sch.sample_perfect_tile(i, n=3)) + bx, tx, xi = sch.split(j, factors=sch.sample_perfect_tile(j, n=3)) + ko, ki = sch.split(k, [None, 4]) + ko, kt = sch.split(ko, factors=sch.sample_perfect_tile(ko, n=2)) + + sch.reorder(by, bx, ty, tx, yi, xi) + + CC = sch.cache_write(block, 0, "local") + sch.reverse_compute_at(CC, tx) + + def fetch_to_shared(block, idx): + block_read = sch.cache_read(block, idx, "shared") + sch.compute_at(block_read, ko, True) + return block_read + + fetch_to_shared(block, 0) + fetch_to_shared(block, 1) + + sch.decompose_reduction(block, ko) + sch.tensorize(ki, intrin) + + verify_trace_roundtrip(sch=sch, mod=func) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:]))