Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
if layout == "NCHW":
assert kernel_layout == "OIHW"
if (
(target.kind.name in ["cuda", "vulkan"])
(target.kind.name in ["cuda", "vulkan", "rocm"])
and data.dtype in ("int8", "uint8")
and kernel.dtype in ("int8", "uint8")
):
Expand Down Expand Up @@ -297,7 +297,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
Need to satisfy tensor core schedule."
)
elif (
(target.kind.name in ["cuda", "vulkan"])
(target.kind.name in ["cuda", "vulkan", "rocm"])
and layout == "NCHW4c"
and data.dtype in ["int8", "uint8"]
):
Expand Down Expand Up @@ -376,7 +376,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
ic_chunk = in_channels // 4

if (
(target.kind.name in ["cuda", "vulkan"])
(target.kind.name in ["cuda", "vulkan", "rocm"])
and data.dtype in ["int8", "uint8"]
and kernel.dtype in ["int8", "uint8"]
and channels % groups == 0
Expand Down Expand Up @@ -836,7 +836,7 @@ def dense_strategy_cuda(attrs, inputs, out_type, target):
b, i = get_const_tuple(data.shape)
o, _ = get_const_tuple(weights.shape)
if (
target.kind.name in ["cuda", "vulkan"]
target.kind.name in ["cuda", "vulkan", "rocm"]
and data.dtype == "int8"
and weights.dtype == "int8"
and out_type.dtype == "int32"
Expand Down
172 changes: 20 additions & 152 deletions python/tvm/relay/op/strategy/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,175 +17,48 @@
"""Definition of ROCm operator strategy."""
# pylint: disable=invalid-name,unused-argument,unused-wildcard-import,wildcard-import
from tvm import topi
from tvm.auto_scheduler import is_auto_scheduler_enabled
from tvm.te import SpecializedCondition
from tvm.contrib.thrust import can_use_rocthrust
from tvm.contrib import miopen

from .generic import *
from .. import op as _op
from .cuda import judge_winograd, naive_schedule
from .cuda import batch_matmul_strategy_cuda, conv2d_strategy_cuda, dense_strategy_cuda


@conv2d_strategy.register("rocm")
def conv2d_strategy_rocm(attrs, inputs, out_type, target):
"""conv2d rocm strategy"""
strategy = _op.OpStrategy()
data, kernel = inputs
dilation_h, dilation_w = attrs.get_int_tuple("dilation")
groups = attrs.groups
layout = attrs.data_layout
stride_h, stride_w = attrs.get_int_tuple("strides")
kernel_layout = attrs.kernel_layout
padding = attrs.get_int_tuple("padding")
if dilation_h < 1 or dilation_w < 1:
raise ValueError("dilation should be positive value")

if groups == 1:
if layout == "NCHW":
# TODO(@vinx13, @icemelon9): Use conv2d_NCHWc_int8 when dtype is int8/uint8.
assert kernel_layout == "OIHW"
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_nchw),
wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw),
name="conv2d_nchw.cuda",
)
_, _, kh, kw = get_const_tuple(kernel.shape)
if (
2 < kh < 8
and 2 < kw < 8
and kh == kw
and stride_h == 1
and stride_w == 1
and dilation_h == 1
and dilation_w == 1
):
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_nchw_winograd),
wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw_winograd),
name="conv2d_nchw_winograd.cuda",
plevel=5,
)
elif layout == "NHWC":
assert kernel_layout == "HWIO"
strategy.add_implementation(
wrap_compute_conv2d(topi.gpu.conv2d_nhwc),
wrap_topi_schedule(topi.gpu.schedule_conv2d_nhwc),
name="conv2d_nhwc.gpu",
)
N, H, W, _ = get_const_tuple(data.shape)
KH, KW, CI, CO = get_const_tuple(kernel.shape)

(_, judge_winograd_autotvm, judge_winograd_auto_scheduler,) = judge_winograd(
N,
H,
W,
KH,
KW,
CI,
CO,
padding,
stride_h,
stride_w,
dilation_h,
dilation_w,
data.dtype,
kernel.dtype,
pre_flag=False,
)
strategy = conv2d_strategy_cuda(attrs, inputs, out_type, target)

if judge_winograd_autotvm:
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_nhwc_winograd_direct),
wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc_winograd_direct),
name="conv2d_nhwc_winograd_direct.cuda",
plevel=5,
)
# add miopen implementation
if (
"miopen" in target.libs
and groups == 1
and layout == "NCHW"
and padding[0] == padding[2]
and padding[1] == padding[3]
):
strategy.add_implementation(
wrap_compute_conv2d(topi.rocm.conv2d_nchw_miopen, True),
wrap_topi_schedule(topi.rocm.schedule_conv2d_nchw_miopen),
name="conv2d_nchw_miopen.rocm",
plevel=50,
)

if is_auto_scheduler_enabled() and judge_winograd_auto_scheduler:
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc),
naive_schedule, # this implementation should never be picked by autotvm
name="conv2d_nhwc.winograd",
plevel=15,
)
elif layout == "HWCN":
assert kernel_layout == "HWIO"
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_hwcn),
wrap_topi_schedule(topi.cuda.schedule_conv2d_hwcn),
name="conv2d_hwcn.cuda",
)
elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]:
assert kernel_layout == "OIHW4o4i"
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_NCHWc_int8, True),
wrap_topi_schedule(topi.cuda.schedule_conv2d_NCHWc_int8),
name="conv2d_NCHWc_int8.cuda",
)
else:
raise RuntimeError("Unsupported conv2d layout {} for CUDA".format(layout))
# add miopen implementation
if (
"miopen" in target.libs
and layout == "NCHW"
and padding[0] == padding[2]
and padding[1] == padding[3]
):
strategy.add_implementation(
wrap_compute_conv2d(topi.rocm.conv2d_nchw_miopen, True),
wrap_topi_schedule(topi.rocm.schedule_conv2d_nchw_miopen),
name="conv2d_nchw_miopen.rocm",
plevel=15,
)
elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups):
if layout == "NCHW":
assert kernel_layout == "OIHW"
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.depthwise_conv2d_nchw),
wrap_topi_schedule(topi.cuda.schedule_depthwise_conv2d_nchw),
name="depthwise_conv2d_nchw.cuda",
)
elif layout == "NHWC":
assert kernel_layout == "HWOI"
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc),
wrap_topi_schedule(topi.cuda.schedule_depthwise_conv2d_nhwc),
name="depthwise_conv2d_nhwc.cuda",
)
else:
raise RuntimeError("Unsupported depthwise_conv2d layout {}".format(layout))
else: # group_conv2d
if layout == "NCHW":
# TODO(@vinx13, @icemelon9): Use group_conv2d_NCHWc_int8 when dtype is int8/uint8.
assert kernel_layout == "OIHW"
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.group_conv2d_nchw, has_groups=True),
wrap_topi_schedule(topi.cuda.schedule_group_conv2d_nchw),
name="group_conv2d_nchw.cuda",
)
elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]:
assert kernel_layout == "OIHW4o4i"
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.group_conv2d_NCHWc_int8, True),
wrap_topi_schedule(topi.cuda.schedule_group_conv2d_NCHWc_int8),
name="group_conv2d_NCHWc_int8.cuda",
)
else:
raise RuntimeError("Unsupported group_conv2d layout {}".format(layout))
return strategy


@dense_strategy.register("rocm")
def dense_strategy_rocm(attrs, inputs, out_type, target):
"""Dense strategy for ROCM"""
assert len(inputs[0].shape) == 2 and len(inputs[1].shape) == 2, "Only support 2-dim dense"
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_dense(topi.rocm.dense),
wrap_topi_schedule(topi.rocm.schedule_dense),
name="dense.rocm",
)
strategy = dense_strategy_cuda(attrs, inputs, out_type, target)

if target.kind.name == "rocm" and "rocblas" in target.libs:
assert out_type.dtype == inputs[0].dtype, "Mixed precision not supported."
strategy.add_implementation(
Expand All @@ -200,13 +73,8 @@ def dense_strategy_rocm(attrs, inputs, out_type, target):
@batch_matmul_strategy.register("rocm")
def batch_matmul_strategy_rocm(attrs, inputs, out_type, target):
"""Batch matmul strategy for ROCM"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_batch_matmul(topi.cuda.batch_matmul),
wrap_topi_schedule(topi.cuda.schedule_batch_matmul),
name="batch_matmul.cuda",
plevel=10,
)
strategy = batch_matmul_strategy_cuda(attrs, inputs, out_type, target)

if target.kind.name == "rocm" and "rocblas" in target.libs:
assert out_type.dtype == inputs[0].dtype, "Mixed precision not supported."
strategy.add_implementation(
Expand Down
22 changes: 5 additions & 17 deletions python/tvm/relay/qnn/op/legalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from tvm.relay.qnn.op.canonicalizations import create_integer_lookup_op

from ....topi.x86.utils import target_has_sse42
from ....topi.utils import is_target
from .. import op as reg

#################################################
Expand Down Expand Up @@ -387,18 +388,6 @@ def is_aarch64_arm():
return "aarch64" in target.attrs.get("mtriple", "")


def is_vulkan():
"""Checks whether we are compiling for a vulkan/spirv target."""
target = tvm.target.Target.current(allow_none=False)
return "vulkan" in target.keys


def is_cuda():
"""Checks whether we are compiling for a cuda target."""
target = tvm.target.Target.current(allow_none=False)
return "cuda" in target.keys


########################
# ARM CPU legalizations.
########################
Expand Down Expand Up @@ -456,22 +445,21 @@ def _qnn_dense_legalize_intel_cpu(attrs, inputs, types):

@qnn_conv2d_legalize.register(["cuda", "gpu"])
def _qnn_conv2d_legalize_cuda(attrs, inputs, types):
if is_vulkan():
if is_target("vulkan"):
# prefers the dtypes to be same. Mixed type is not yet supported.
return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d)
if is_cuda():
if is_target(["cuda", "rocm"]):
# CUDA prefers both datatypes to be int8.
return helper_change_dtypes_to_int8(attrs, inputs, types, relay.qnn.op.conv2d)
return None


@qnn_dense_legalize.register(["cuda", "gpu"])
def _qnn_dense_legalize_cuda(attrs, inputs, types):
if is_vulkan():
if is_target("vulkan"):
# prefers the dtypes to be same. Mixed type is not yet supported.
return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.dense)
if is_cuda():
if is_target(["cuda", "rocm"]):
# CUDA prefers both datatypes to be the int8.
return helper_change_dtypes_to_int8(attrs, inputs, types, relay.qnn.op.dense)

return None
2 changes: 2 additions & 0 deletions python/tvm/tir/tensor_intrin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,5 @@
"""Intrinsics for tensorization."""
from .x86 import *
from .arm_cpu import *
from .dot_product_common import *
from .rocm import *
55 changes: 55 additions & 0 deletions python/tvm/tir/tensor_intrin/dot_product_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,missing-function-docstring
"""Dot product related intrinsics."""
from tvm.script import tir as T
from .. import TensorIntrin


@T.prim_func
def dp4a_desc(
A: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
B: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
C: T.Buffer((1,), "int32", offset_factor=1, align=4, scope="local"),
) -> None:
with T.block("root"):
T.reads(C[0], A[0:4], B[0:4])
T.writes(C[0])
for i in range(0, 4):
with T.block("update"):
vi = T.axis.remap("R", [i])
C[0] = C[0] + T.cast(A[vi], "int32") * T.cast(B[vi], "int32")


@T.prim_func
def dp4a_impl(
A: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
B: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
C: T.Buffer((1,), "int32", offset_factor=1, align=4, scope="local"),
) -> None:
with T.block("root"):
T.reads(C[0], A[0:4], B[0:4])
T.writes(C[0])

C[0] += T.call_pure_extern(
"__dp4a", A.vload([0], "int8x4"), B.vload([0], "int8x4"), T.int32(0), dtype="int32"
)


DP4A_INTRIN = "dp4a"

TensorIntrin.register(DP4A_INTRIN, dp4a_desc, dp4a_impl)
Loading