Skip to content

Commit 04a36d4

Browse files
committed
[UnitTests] Added cuDNN target to default test targets
Some unit tests explicitly test cudnn in addition to tvm.testing.enabled_targets(). This moved the cudnn checks into the same framework as all other targets, and adds it to the default list of targets to be run. Also, added `@tvm.testing.requires_cudnn` for tests specific to cudnn.
1 parent faadb7d commit 04a36d4

File tree

1 file changed

+53
-17
lines changed

1 file changed

+53
-17
lines changed

python/tvm/testing.py

Lines changed: 53 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ def test_something():
7070
import tvm.tir
7171
import tvm.te
7272
import tvm._ffi
73-
from tvm.contrib import nvcc
73+
74+
from tvm.contrib import nvcc, cudnn
7475
from tvm.error import TVMError
7576

7677

@@ -375,11 +376,12 @@ def _check_forward(constraints1, constraints2, varmap, backvarmap):
375376
def _get_targets(target_str=None):
376377
if target_str is None:
377378
target_str = os.environ.get("TVM_TEST_TARGETS", "")
379+
# Use dict instead of set for de-duplication so that the
380+
# targets stay in the order specified.
381+
target_names = list({t.strip(): None for t in target_str.split(";") if t.strip()})
378382

379-
if len(target_str) == 0:
380-
target_str = DEFAULT_TEST_TARGETS
381-
382-
target_names = set(t.strip() for t in target_str.split(";") if t.strip())
383+
if len(target_names) == 0:
384+
target_names = DEFAULT_TEST_TARGETS
383385

384386
targets = []
385387
for target in target_names:
@@ -413,10 +415,19 @@ def _get_targets(target_str=None):
413415
return targets
414416

415417

416-
DEFAULT_TEST_TARGETS = (
417-
"llvm;cuda;opencl;metal;rocm;vulkan -from_device=0;nvptx;"
418-
"llvm -device=arm_cpu;opencl -device=mali,aocl_sw_emu"
419-
)
418+
DEFAULT_TEST_TARGETS = [
419+
"llvm",
420+
"llvm -device=arm_cpu",
421+
"cuda",
422+
"cuda -model=unknown -libs=cudnn",
423+
"nvptx",
424+
"vulkan -from_device=0",
425+
"opencl",
426+
"opencl -device=mali,aocl_sw_emu",
427+
"opencl -device=intel_graphics",
428+
"metal",
429+
"rocm",
430+
]
420431

421432

422433
def device_enabled(target):
@@ -548,6 +559,24 @@ def requires_cuda(*args):
548559
return _compose(args, _requires_cuda)
549560

550561

562+
def requires_cudnn(*args):
563+
"""Mark a test as requiring the cuDNN library.
564+
565+
This also marks the test as requiring a cuda gpu.
566+
567+
Parameters
568+
----------
569+
f : function
570+
Function to mark
571+
"""
572+
573+
requirements = [
574+
pytest.mark.skipif(not cudnn.exists(), reason="cuDNN library not enabled, or not installed")
575+
* requires_cuda(),
576+
]
577+
return _compose(args, requirements)
578+
579+
551580
def requires_nvptx(*args):
552581
"""Mark a test as requiring the NVPTX compilation on the CUDA runtime
553582
@@ -730,20 +759,27 @@ def requires_rpc(*args):
730759

731760

732761
def _target_to_requirement(target):
762+
if isinstance(target, str):
763+
target = tvm.target.Target(target)
764+
733765
# mapping from target to decorator
734-
if target.startswith("cuda"):
735-
return requires_cuda()
736-
if target.startswith("rocm"):
766+
if target.kind.name == "cuda":
767+
if "cudnn" in target.attrs.get("libs", []):
768+
return requires_cudnn()
769+
else:
770+
return requires_cuda()
771+
772+
if target.kind.name == "rocm":
737773
return requires_rocm()
738-
if target.startswith("vulkan"):
774+
if target.kind.name == "vulkan":
739775
return requires_vulkan()
740-
if target.startswith("nvptx"):
776+
if target.kind.name == "nvptx":
741777
return requires_nvptx()
742-
if target.startswith("metal"):
778+
if target.kind.name == "metal":
743779
return requires_metal()
744-
if target.startswith("opencl"):
780+
if target.kind.name == "opencl":
745781
return requires_opencl()
746-
if target.startswith("llvm"):
782+
if target.kind.name == "llvm":
747783
return requires_llvm()
748784
return []
749785

0 commit comments

Comments
 (0)