diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py index 36efc881958e..9334b94b7cf9 100644 --- a/python/tvm/autotvm/measure/measure_methods.py +++ b/python/tvm/autotvm/measure/measure_methods.py @@ -582,7 +582,13 @@ def _check(): @register_func def tvm_callback_cuda_compile(code): """use nvcc to generate ptx code for better optimization""" - ptx = nvcc.compile_cuda(code, target="ptx", arch=AutotvmGlobalScope.current.cuda_target_arch) + curr_cuda_target_arch = AutotvmGlobalScope.current.cuda_target_arch + # e.g., target arch could be [ + # "-gencode", "arch=compute_52,code=sm_52", + # "-gencode", "arch=compute_70,code=sm_70" + # ] + target = "fatbin" if isinstance(curr_cuda_target_arch, list) else "ptx" + ptx = nvcc.compile_cuda(code, target=target, arch=AutotvmGlobalScope.current.cuda_target_arch) return ptx @@ -591,8 +597,10 @@ def set_cuda_target_arch(arch): Parameters ---------- - arch: str + arch: str or list The argument of nvcc -arch. (e.g. "sm_51", "sm_62") + it can also be a count of gencode arguments pass to nvcc command line, + e.g., ["-gencode", "arch=compute_52,code=sm_52", "-gencode", "arch=compute_70,code=sm_70"] """ AutotvmGlobalScope.current.cuda_target_arch = arch diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py index 05ee338a26cb..0e97ac14f8b1 100644 --- a/python/tvm/contrib/nvcc.py +++ b/python/tvm/contrib/nvcc.py @@ -74,7 +74,10 @@ def compile_cuda(code, file_target = path_target if path_target else temp_target cmd = ["nvcc"] cmd += ["--%s" % target, "-O3"] - cmd += ["-arch", arch] + if isinstance(arch, list): + cmd += arch + else: + cmd += ["-arch", arch] if options: if isinstance(options, str):