Skip to content
12 changes: 10 additions & 2 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ option(ENABLE_MULTI_DEVICE
"Enable building with multi device support (requires NCCL, MPI,...)" ON)
option(ENABLE_UCX "Enable building with UCX (Uniform Communication X) support"
ON)
option(NVRTC_DYNAMIC_LINKING "Link against the dynamic NVRTC libraries" OFF)
option(USING_OSS_CUTLASS_LOW_LATENCY_GEMM
"Using open sourced Cutlass low latency gemm kernel" ON)
option(USING_OSS_CUTLASS_FP4_GEMM "Using open sourced Cutlass fp4 gemm kernel"
Expand Down Expand Up @@ -147,10 +148,17 @@ set(CURAND_LIB CUDA::curand)
set(CUDA_DRV_LIB CUDA::cuda_driver)
set(CUDA_NVML_LIB CUDA::nvml)
set(CUDA_RT_LIB CUDA::cudart_static)
set(NVRTC_LIB CUDA::nvrtc_static)
set(NVRTC_BUILTINS_LIB CUDA::nvrtc_builtins_static)
set(NVPTX_LIB CUDA::nvptxcompiler_static)
set(CMAKE_CUDA_RUNTIME_LIBRARY Static)

if(NVRTC_DYNAMIC_LINKING)
set(NVRTC_LIB CUDA::nvrtc)
set(NVRTC_BUILTINS_LIB CUDA::nvrtc_builtins)
else()
set(NVRTC_LIB CUDA::nvrtc_static)
set(NVRTC_BUILTINS_LIB CUDA::nvrtc_builtins_static)
endif()

resolve_dirs(CUDAToolkit_INCLUDE_DIRS "${CUDAToolkit_INCLUDE_DIRS}")

message(STATUS "CUDA library status:")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ add_dependencies(nvrtc_wrapper_src xqa_sources_h)
target_include_directories(nvrtc_wrapper_src
PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/xqa_build)
target_link_libraries(
nvrtc_wrapper_src PUBLIC ${NVRTC_LIB} ${NVRTC_BUILTINS_LIB} ${CUDA_DRV_LIB}
${CUDA_RT_LIB})
nvrtc_wrapper_src PUBLIC ${NVPTX_LIB} ${NVRTC_LIB} ${NVRTC_BUILTINS_LIB}
${CUDA_DRV_LIB} ${CUDA_RT_LIB})
set_property(TARGET nvrtc_wrapper_src PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET nvrtc_wrapper_src PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
11 changes: 10 additions & 1 deletion scripts/build_wheel.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,8 @@ def main(*,
nvtx: bool = False,
skip_stubs: bool = False,
generate_fmha: bool = False,
no_venv: bool = False):
no_venv: bool = False,
nvrtc_dynamic_linking: bool = False):

if clean:
clean_wheel = True
Expand Down Expand Up @@ -404,6 +405,9 @@ def main(*,
if fast_build:
cmake_def_args.append(f"-DFAST_BUILD=ON")

if nvrtc_dynamic_linking:
cmake_def_args.append(f"-DNVRTC_DYNAMIC_LINKING=ON")

targets = ["tensorrt_llm", "nvinfer_plugin_tensorrt_llm"]

if cpp_only:
Expand Down Expand Up @@ -787,6 +791,11 @@ def add_arguments(parser: ArgumentParser):
help=
"Use the current Python interpreter without creating a virtual environment."
)
parser.add_argument(
"--nvrtc_dynamic_linking",
action="store_true",
help="Link against the dynamic NVRTC libraries and not the static ones."
)


if __name__ == "__main__":
Expand Down