Skip to content

Commit 65b8bcf

Browse files
committed
[WIP] adding DP4A support to rocm
1 parent 4f8f308 commit 65b8bcf

File tree

6 files changed

+18
-12
lines changed

6 files changed

+18
-12
lines changed

python/tvm/relay/op/strategy/cuda.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
145145
if layout == "NCHW":
146146
assert kernel_layout == "OIHW"
147147
if (
148-
(target.kind.name in ["cuda", "vulkan"])
148+
(target.kind.name in ["cuda", "vulkan", "rocm"])
149149
and data.dtype in ("int8", "uint8")
150150
and kernel.dtype in ("int8", "uint8")
151151
):
@@ -297,7 +297,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
297297
Need to satisfy tensor core schedule."
298298
)
299299
elif (
300-
(target.kind.name in ["cuda", "vulkan"])
300+
(target.kind.name in ["cuda", "vulkan", "rocm"])
301301
and layout == "NCHW4c"
302302
and data.dtype in ["int8", "uint8"]
303303
):
@@ -376,7 +376,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
376376
ic_chunk = in_channels // 4
377377

378378
if (
379-
(target.kind.name in ["cuda", "vulkan"])
379+
(target.kind.name in ["cuda", "vulkan", "rocm"])
380380
and data.dtype in ["int8", "uint8"]
381381
and kernel.dtype in ["int8", "uint8"]
382382
and channels % groups == 0
@@ -836,7 +836,7 @@ def dense_strategy_cuda(attrs, inputs, out_type, target):
836836
b, i = get_const_tuple(data.shape)
837837
o, _ = get_const_tuple(weights.shape)
838838
if (
839-
target.kind.name in ["cuda", "vulkan"]
839+
target.kind.name in ["cuda", "vulkan", "rocm"]
840840
and data.dtype == "int8"
841841
and weights.dtype == "int8"
842842
and out_type.dtype == "int32"

python/tvm/relay/qnn/op/legalizations.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,12 @@ def is_aarch64_arm():
387387
return "aarch64" in target.attrs.get("mtriple", "")
388388

389389

390+
def is_rocm():
391+
"""Checks whether we are compiling for a rocm/spirv target."""
392+
target = tvm.target.Target.current(allow_none=False)
393+
return "rocm" in target.keys
394+
395+
390396
def is_vulkan():
391397
"""Checks whether we are compiling for a vulkan/spirv target."""
392398
target = tvm.target.Target.current(allow_none=False)
@@ -456,7 +462,7 @@ def _qnn_dense_legalize_intel_cpu(attrs, inputs, types):
456462

457463
@qnn_conv2d_legalize.register(["cuda", "gpu"])
458464
def _qnn_conv2d_legalize_cuda(attrs, inputs, types):
459-
if is_vulkan():
465+
if is_vulkan() or is_rocm():
460466
# prefers the dtypes to be same. Mixed type is not yet supported.
461467
return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d)
462468
if is_cuda():
@@ -467,7 +473,7 @@ def _qnn_conv2d_legalize_cuda(attrs, inputs, types):
467473

468474
@qnn_dense_legalize.register(["cuda", "gpu"])
469475
def _qnn_dense_legalize_cuda(attrs, inputs, types):
470-
if is_vulkan():
476+
if is_vulkan() or is_rocm():
471477
# prefers the dtypes to be same. Mixed type is not yet supported.
472478
return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.dense)
473479
if is_cuda():

python/tvm/topi/cuda/batch_matmul.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ def _schedule_batch_matmul_int8(cfg, s, output):
372372
target = tvm.target.Target.current(allow_none=False)
373373
do_tensorize = True
374374

375-
if "vulkan" in target.keys:
375+
if "vulkan" in target.keys or "rocm" in target.keys:
376376
do_tensorize = "+dotprod" in target.mattr or target.supports_integer_dot_product
377377

378378
if do_tensorize:

python/tvm/topi/cuda/conv2d_alter_op.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
@nn.conv2d_alter_layout.register(["cuda", "gpu"])
3535
def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
3636
target = tvm.target.Target.current(allow_none=False)
37-
doit = "vulkan" in target.keys or "cuda" in target.keys
37+
doit = "vulkan" in target.keys or "cuda" in target.keys or "rocm" in target.keys
3838
if not doit:
3939
return None
4040
dispatch_ctx = autotvm.task.DispatchContext.current
@@ -87,7 +87,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
8787
if cfg.is_fallback: # if is fallback, clear query cache and return None
8888
autotvm.task.clear_fallback_cache(target, workload)
8989
do_new_layout = False
90-
if "vulkan" in target.keys:
90+
if "vulkan" in target.keys or "rocm" in target.keys:
9191
do_new_layout = "+dotprod" in target.mattr or target.supports_integer_dot_product
9292
if not do_new_layout:
9393
return None
@@ -351,7 +351,7 @@ def _conv2d_legalize(attrs, inputs, arg_types):
351351
"""
352352

353353
target = tvm.target.Target.current(allow_none=False)
354-
doit = "vulkan" in target.keys or "cuda" in target.keys
354+
doit = "vulkan" in target.keys or "cuda" in target.keys or "rocm" in target.keys
355355
if not doit:
356356
return None
357357
# Dilation not supported yet. Return None if dilation is not (1, 1)

python/tvm/topi/cuda/conv2d_int8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ def _schedule_conv2d_NCHWc_int8(cfg, s, output):
312312
_, rc_block = s[conv].split(rc_block, factor=4)
313313
target = tvm.target.Target.current(allow_none=False)
314314
do_tensorize = True
315-
if "vulkan" in target.keys:
315+
if "vulkan" in target.keys or "rocm" in target.keys:
316316
do_tensorize = "+dotprod" in target.mattr or target.supports_integer_dot_product
317317
if do_tensorize:
318318
dtypes = (pad_data.dtype, packed_kernel.dtype)

python/tvm/topi/cuda/dense.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def _schedule_dense_int8(cfg, s, output):
173173
ko, kt = cfg["tile_k"].apply(s, CC, ko)
174174
target = tvm.target.Target.current(allow_none=False)
175175
do_tensorize = True
176-
if "vulkan" in target.keys:
176+
if "vulkan" in target.keys or "rocm" in target.keys:
177177
do_tensorize = "+dotprod" in target.mattr or target.supports_integer_dot_product
178178
if do_tensorize:
179179
dtypes = (data.dtype, weight.dtype)

0 commit comments

Comments
 (0)