Skip to content

Commit 401a820

Browse files
committed
skip cuda compile on old gpu
1 parent 90e01fd commit 401a820

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ def maybe_swap(i, j):
7777

7878
def is_ampere_or_newer():
7979
arch = tvm.contrib.nvcc.get_target_compute_version()
80-
major, minor = tvm.contrib.nvcc.parse_compute_version(arch)
81-
return major * 10 + minor >= 80
80+
major, _ = tvm.contrib.nvcc.parse_compute_version(arch)
81+
return major >= 8
8282

8383

8484
def run_test(
@@ -187,11 +187,11 @@ def tile_wmma_fragment(block_read, height, width):
187187
sch.tensorize(sch.get_loops(block_init_c)[-2], mma_fill_intrin)
188188
sch.tensorize(sch.get_loops(C_warp)[-2], mma_store_intrin)
189189

190-
f = tvm.build(sch.mod["main"], target="cuda", name="dense")
191-
192190
if not is_ampere_or_newer():
193191
return None
194192

193+
f = tvm.build(sch.mod["main"], target="cuda", name="dense")
194+
195195
dev = tvm.device("cuda", 0)
196196

197197
if in_dtype == "float16":

0 commit comments

Comments
 (0)