@@ -70,7 +70,8 @@ def test_something():
7070import tvm .tir
7171import tvm .te
7272import tvm ._ffi
73- from tvm .contrib import nvcc
73+
74+ from tvm .contrib import nvcc , cudnn
7475from tvm .error import TVMError
7576
7677
@@ -375,11 +376,12 @@ def _check_forward(constraints1, constraints2, varmap, backvarmap):
375376def _get_targets (target_str = None ):
376377 if target_str is None :
377378 target_str = os .environ .get ("TVM_TEST_TARGETS" , "" )
379+ # Use dict instead of set for de-duplication so that the
380+ # targets stay in the order specified.
381+ target_names = list ({t .strip (): None for t in target_str .split (";" ) if t .strip ()})
378382
379- if len (target_str ) == 0 :
380- target_str = DEFAULT_TEST_TARGETS
381-
382- target_names = set (t .strip () for t in target_str .split (";" ) if t .strip ())
383+ if len (target_names ) == 0 :
384+ target_names = DEFAULT_TEST_TARGETS
383385
384386 targets = []
385387 for target in target_names :
@@ -413,10 +415,19 @@ def _get_targets(target_str=None):
413415 return targets
414416
415417
416- DEFAULT_TEST_TARGETS = (
417- "llvm;cuda;opencl;metal;rocm;vulkan -from_device=0;nvptx;"
418- "llvm -device=arm_cpu;opencl -device=mali,aocl_sw_emu"
419- )
418+ DEFAULT_TEST_TARGETS = [
419+ "llvm" ,
420+ "llvm -device=arm_cpu" ,
421+ "cuda" ,
422+ "cuda -model=unknown -libs=cudnn" ,
423+ "nvptx" ,
424+ "vulkan -from_device=0" ,
425+ "opencl" ,
426+ "opencl -device=mali,aocl_sw_emu" ,
427+ "opencl -device=intel_graphics" ,
428+ "metal" ,
429+ "rocm" ,
430+ ]
420431
421432
422433def device_enabled (target ):
@@ -548,6 +559,24 @@ def requires_cuda(*args):
548559 return _compose (args , _requires_cuda )
549560
550561
562+ def requires_cudnn (* args ):
563+ """Mark a test as requiring the cuDNN library.
564+
565+ This also marks the test as requiring a cuda gpu.
566+
567+ Parameters
568+ ----------
569+ f : function
570+ Function to mark
571+ """
572+
573+ requirements = [
574+ pytest .mark .skipif (not cudnn .exists (), reason = "cuDNN library not enabled, or not installed" )
575+ * requires_cuda (),
576+ ]
577+ return _compose (args , requirements )
578+
579+
551580def requires_nvptx (* args ):
552581 """Mark a test as requiring the NVPTX compilation on the CUDA runtime
553582
@@ -730,20 +759,27 @@ def requires_rpc(*args):
730759
731760
732761def _target_to_requirement (target ):
762+ if isinstance (target , str ):
763+ target = tvm .target .Target (target )
764+
733765 # mapping from target to decorator
734- if target .startswith ("cuda" ):
735- return requires_cuda ()
736- if target .startswith ("rocm" ):
766+ if target .kind .name == "cuda" :
767+ if "cudnn" in target .attrs .get ("libs" , []):
768+ return requires_cudnn ()
769+ else :
770+ return requires_cuda ()
771+
772+ if target .kind .name == "rocm" :
737773 return requires_rocm ()
738- if target .startswith ( "vulkan" ) :
774+ if target .kind . name == "vulkan" :
739775 return requires_vulkan ()
740- if target .startswith ( "nvptx" ) :
776+ if target .kind . name == "nvptx" :
741777 return requires_nvptx ()
742- if target .startswith ( "metal" ) :
778+ if target .kind . name == "metal" :
743779 return requires_metal ()
744- if target .startswith ( "opencl" ) :
780+ if target .kind . name == "opencl" :
745781 return requires_opencl ()
746- if target .startswith ( "llvm" ) :
782+ if target .kind . name == "llvm" :
747783 return requires_llvm ()
748784 return []
749785
0 commit comments