diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 6732db6eaa7..d89967516a8 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -208,6 +208,7 @@ if(BINDING_TYPE STREQUAL "nanobind") ${CMAKE_CURRENT_BINARY_DIR}/nanobind) endif() + # include as system to suppress warnings include_directories( SYSTEM @@ -249,6 +250,15 @@ if(${CUDAToolkit_VERSION} VERSION_GREATER_EQUAL "12.8") ) endif() +if(${CUDAToolkit_VERSION} VERSION_GREATER_EQUAL "13.0") + message( + STATUS + "CUDAToolkit_VERSION ${CUDAToolkit_VERSION_MAJOR}.${CUDAToolkit_VERSION_MINOR} is greater or equal than 13.0, setting CMAKE_CUDA_RUNTIME_LIBRARY to Shared" + ) + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --device-entity-has-hidden-visibility=false -cudart=shared") + set(CMAKE_CUDA_RUNTIME_LIBRARY Shared) +endif() + if(ENABLE_MULTI_DEVICE) # MPI MPI isn't used until tensorrt_llm/CMakeLists.txt is invoked. However, if # it's not called before "CMAKE_CXX_FLAGS" is set, it breaks on Windows for @@ -365,6 +375,7 @@ if(NVCC_TIMING) set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --time ${CMAKE_CURRENT_BINARY_DIR}/nvcc-timing.csv") endif() + message("CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS}") set(COMMON_HEADER_DIRS ${PROJECT_SOURCE_DIR} ${CUDAToolkit_INCLUDE_DIR}) diff --git a/cpp/include/tensorrt_llm/deep_gemm/tma_utils.cuh b/cpp/include/tensorrt_llm/deep_gemm/tma_utils.cuh index 411d7447600..33ddfd31ec3 100644 --- a/cpp/include/tensorrt_llm/deep_gemm/tma_utils.cuh +++ b/cpp/include/tensorrt_llm/deep_gemm/tma_utils.cuh @@ -95,7 +95,7 @@ constexpr CUtensorMapDataType get_CUtensorMapDataType() } } -PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() +PFN_cuTensorMapEncodeTiled_v12000 get_cuTensorMapEncodeTiled() { // Get pointer to `cuTensorMapEncodeTiled` cudaDriverEntryPointQueryResult driver_status; @@ -110,12 +110,12 @@ PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() if (driver_status != cudaDriverEntryPointSuccess) throw std::runtime_error("driver_status != cudaDriverEntryPointSuccess"); - return reinterpret_cast(cuTensorMapEncodeTiled_ptr); + return reinterpret_cast(cuTensorMapEncodeTiled_ptr); } template CUtensorMap make_2d_tma_copy_desc(T* global_address, uint64_t gmem_dim[2], uint64_t stride_in_bytes, - uint32_t smem_dim[2], CUtensorMapSwizzle swizzle_type, PFN_cuTensorMapEncodeTiled encode_func = nullptr) + uint32_t smem_dim[2], CUtensorMapSwizzle swizzle_type, PFN_cuTensorMapEncodeTiled_v12000 encode_func = nullptr) { CUtensorMap tensor_map{}; constexpr uint32_t rank = 2; diff --git a/cpp/tensorrt_llm/deep_ep/CMakeLists.txt b/cpp/tensorrt_llm/deep_ep/CMakeLists.txt index 088391aef4f..ec90d7056f4 100644 --- a/cpp/tensorrt_llm/deep_ep/CMakeLists.txt +++ b/cpp/tensorrt_llm/deep_ep/CMakeLists.txt @@ -36,8 +36,100 @@ if(NOT DEEP_EP_CUDA_ARCHITECTURES) return() endif() +# TODO: restore patched nvshmem for CUDA12 +if(${CUDAToolkit_VERSION} VERSION_GREATER_EQUAL "13.0") + set(NVSHMEM_INSTALL_PREFIX "${TORCH_INSTALL_PREFIX}/../nvidia/nvshmem") + find_path(NVSHMEM_INCLUDE_DIR nvshmem.h HINTS ${NVSHMEM_INSTALL_PREFIX}/include) + find_library(NVSHMEM_DEVICE_LIBRARY nvshmem_device HINTS ${NVSHMEM_INSTALL_PREFIX}/lib) + find_library(NVSHMEM_HOST_LIBRARY libnvshmem_host.so.3 HINTS ${NVSHMEM_INSTALL_PREFIX}/lib) +else() + set(NVSHMEM_INSTALL_PREFIX "$ORIGIN/libs/nvshmem") + # Delete stale nvshmem on patch update + set(NVSHMEM_STAMP_FILE ${CMAKE_CURRENT_BINARY_DIR}/nvshmem_stamp.txt) + file(SHA256 ${DEEP_EP_SOURCE_DIR}/third-party/nvshmem.patch NVSHMEM_PATCH_HASH) + file(SHA256 ${CMAKE_CURRENT_SOURCE_DIR}/nvshmem_fast_build.patch + NVSHMEM_PATCH_2_HASH) + set(NVSHMEM_STAMP_CONTENT "${NVSHMEM_URL_HASH}") + string(APPEND NVSHMEM_STAMP_CONTENT " PATCH_COMMAND v1") + string(APPEND NVSHMEM_STAMP_CONTENT " ${NVSHMEM_PATCH_HASH}") + string(APPEND NVSHMEM_STAMP_CONTENT " 103") + string(APPEND NVSHMEM_STAMP_CONTENT " ${NVSHMEM_PATCH_2_HASH}") + set(OLD_NVSHMEM_STAMP_CONTENT "") + if(EXISTS ${NVSHMEM_STAMP_FILE}) + file(READ ${NVSHMEM_STAMP_FILE} OLD_NVSHMEM_STAMP_CONTENT) + endif() + if(NOT OLD_NVSHMEM_STAMP_CONTENT STREQUAL NVSHMEM_STAMP_CONTENT) + file(REMOVE_RECURSE ${CMAKE_CURRENT_BINARY_DIR}/nvshmem_project-prefix) + file(WRITE ${NVSHMEM_STAMP_FILE} "${NVSHMEM_STAMP_CONTENT}") + endif() + set_property( + DIRECTORY APPEND + PROPERTY CMAKE_CONFIGURE_DEPENDS + ${DEEP_EP_SOURCE_DIR}/third-party/nvshmem.patch + ${CMAKE_CURRENT_SOURCE_DIR}/nvshmem_fast_build.patch) + + # Add NVSHMEM + # =========== + + # NVSHMEM only works with GCC. Building NVSHMEM with Clang results in + # compilation errors. Using NVSHMEM with Clang results in slow builds and device + # link issues. + if(NOT CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + set(CMAKE_C_COMPILER gcc) + set(CMAKE_CXX_COMPILER g++) + set(CMAKE_CUDA_HOST_COMPILER g++) + endif() + + # Add nvshmem external project + include(ExternalProject) + ExternalProject_Add( + nvshmem_project + URL file://${CMAKE_CURRENT_SOURCE_DIR}/nvshmem_src_3.2.5-1.txz + URL_HASH ${NVSHMEM_URL_HASH} + PATCH_COMMAND patch -p1 --forward --batch -i + ${DEEP_EP_SOURCE_DIR}/third-party/nvshmem.patch + COMMAND sed "s/TRANSPORT_VERSION_MAJOR 3/TRANSPORT_VERSION_MAJOR 103/" -i + src/CMakeLists.txt + COMMAND patch -p1 --forward --batch -i + ${CMAKE_CURRENT_SOURCE_DIR}/nvshmem_fast_build.patch + CMAKE_CACHE_ARGS + -DCMAKE_C_COMPILER:STRING=${CMAKE_C_COMPILER} + -DCMAKE_C_COMPILER_LAUNCHER:STRING=${CMAKE_C_COMPILER_LAUNCHER} + -DCMAKE_CXX_COMPILER:STRING=${CMAKE_CXX_COMPILER} + -DCMAKE_CXX_COMPILER_LAUNCHER:STRING=${CMAKE_CXX_COMPILER_LAUNCHER} + -DCMAKE_CUDA_ARCHITECTURES:STRING=${DEEP_EP_CUDA_ARCHITECTURES} + -DCMAKE_CUDA_HOST_COMPILER:STRING=${CMAKE_CUDA_HOST_COMPILER} + -DCMAKE_CUDA_COMPILER_LAUNCHER:STRING=${CMAKE_CUDA_COMPILER_LAUNCHER} + -DNVSHMEM_BUILD_EXAMPLES:BOOL=0 + -DNVSHMEM_BUILD_PACKAGES:BOOL=0 + -DNVSHMEM_BUILD_TESTS:BOOL=0 + -DNVSHMEM_IBGDA_SUPPORT:BOOL=1 + -DNVSHMEM_IBRC_SUPPORT:BOOL=0 + -DNVSHMEM_MPI_SUPPORT:BOOL=0 + -DNVSHMEM_PMIX_SUPPORT:BOOL=0 + -DNVSHMEM_SHMEM_SUPPORT:BOOL=0 + -DNVSHMEM_TIMEOUT_DEVICE_POLLING:BOOL=0 + -DNVSHMEM_UCX_SUPPORT:BOOL=0 + -DNVSHMEM_USE_GDRCOPY:BOOL=0 + -DNVSHMEM_USE_NCCL:BOOL=0 + INSTALL_COMMAND "" + BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/nvshmem-build + BUILD_BYPRODUCTS + ${CMAKE_CURRENT_BINARY_DIR}/nvshmem-build/src/lib/libnvshmem.a) + add_library(nvshmem_project::nvshmem STATIC IMPORTED) + add_dependencies(nvshmem_project::nvshmem nvshmem_project) + file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/nvshmem-build/src/include) + set_target_properties( + nvshmem_project::nvshmem + PROPERTIES IMPORTED_LOCATION + ${CMAKE_CURRENT_BINARY_DIR}/nvshmem-build/src/lib/libnvshmem.a + INTERFACE_INCLUDE_DIRECTORIES + ${CMAKE_CURRENT_BINARY_DIR}/nvshmem-build/src/include) +endif() + # Ensure that dependent libraries are installed -find_library(MLX5_lib NAMES mlx5 REQUIRED) +find_library(MLX5_lib NAMES mlx5 libmlx5.so.1 REQUIRED) + # Prepare files # ============= @@ -81,87 +173,6 @@ foreach(_f IN LISTS _files) PROPERTY CMAKE_CONFIGURE_DEPENDS ${_src}) endforeach() -# Delete stale nvshmem on patch update -set(NVSHMEM_STAMP_FILE ${CMAKE_CURRENT_BINARY_DIR}/nvshmem_stamp.txt) -file(SHA256 ${DEEP_EP_SOURCE_DIR}/third-party/nvshmem.patch NVSHMEM_PATCH_HASH) -file(SHA256 ${CMAKE_CURRENT_SOURCE_DIR}/nvshmem_fast_build.patch - NVSHMEM_PATCH_2_HASH) -set(NVSHMEM_STAMP_CONTENT "${NVSHMEM_URL_HASH}") -string(APPEND NVSHMEM_STAMP_CONTENT " PATCH_COMMAND v1") -string(APPEND NVSHMEM_STAMP_CONTENT " ${NVSHMEM_PATCH_HASH}") -string(APPEND NVSHMEM_STAMP_CONTENT " 103") -string(APPEND NVSHMEM_STAMP_CONTENT " ${NVSHMEM_PATCH_2_HASH}") -set(OLD_NVSHMEM_STAMP_CONTENT "") -if(EXISTS ${NVSHMEM_STAMP_FILE}) - file(READ ${NVSHMEM_STAMP_FILE} OLD_NVSHMEM_STAMP_CONTENT) -endif() -if(NOT OLD_NVSHMEM_STAMP_CONTENT STREQUAL NVSHMEM_STAMP_CONTENT) - file(REMOVE_RECURSE ${CMAKE_CURRENT_BINARY_DIR}/nvshmem_project-prefix) - file(WRITE ${NVSHMEM_STAMP_FILE} "${NVSHMEM_STAMP_CONTENT}") -endif() -set_property( - DIRECTORY APPEND - PROPERTY CMAKE_CONFIGURE_DEPENDS - ${DEEP_EP_SOURCE_DIR}/third-party/nvshmem.patch - ${CMAKE_CURRENT_SOURCE_DIR}/nvshmem_fast_build.patch) - -# Add NVSHMEM -# =========== - -# NVSHMEM only works with GCC. Building NVSHMEM with Clang results in -# compilation errors. Using NVSHMEM with Clang results in slow builds and device -# link issues. -if(NOT CMAKE_CXX_COMPILER_ID STREQUAL "GNU") - set(CMAKE_C_COMPILER gcc) - set(CMAKE_CXX_COMPILER g++) - set(CMAKE_CUDA_HOST_COMPILER g++) -endif() - -# Add nvshmem external project -include(ExternalProject) -ExternalProject_Add( - nvshmem_project - URL file://${CMAKE_CURRENT_SOURCE_DIR}/nvshmem_src_3.2.5-1.txz - URL_HASH ${NVSHMEM_URL_HASH} - PATCH_COMMAND patch -p1 --forward --batch -i - ${DEEP_EP_SOURCE_DIR}/third-party/nvshmem.patch - COMMAND sed "s/TRANSPORT_VERSION_MAJOR 3/TRANSPORT_VERSION_MAJOR 103/" -i - src/CMakeLists.txt - COMMAND patch -p1 --forward --batch -i - ${CMAKE_CURRENT_SOURCE_DIR}/nvshmem_fast_build.patch - CMAKE_CACHE_ARGS - -DCMAKE_C_COMPILER:STRING=${CMAKE_C_COMPILER} - -DCMAKE_C_COMPILER_LAUNCHER:STRING=${CMAKE_C_COMPILER_LAUNCHER} - -DCMAKE_CXX_COMPILER:STRING=${CMAKE_CXX_COMPILER} - -DCMAKE_CXX_COMPILER_LAUNCHER:STRING=${CMAKE_CXX_COMPILER_LAUNCHER} - -DCMAKE_CUDA_ARCHITECTURES:STRING=${DEEP_EP_CUDA_ARCHITECTURES} - -DCMAKE_CUDA_HOST_COMPILER:STRING=${CMAKE_CUDA_HOST_COMPILER} - -DCMAKE_CUDA_COMPILER_LAUNCHER:STRING=${CMAKE_CUDA_COMPILER_LAUNCHER} - -DNVSHMEM_BUILD_EXAMPLES:BOOL=0 - -DNVSHMEM_BUILD_PACKAGES:BOOL=0 - -DNVSHMEM_BUILD_TESTS:BOOL=0 - -DNVSHMEM_IBGDA_SUPPORT:BOOL=1 - -DNVSHMEM_IBRC_SUPPORT:BOOL=0 - -DNVSHMEM_MPI_SUPPORT:BOOL=0 - -DNVSHMEM_PMIX_SUPPORT:BOOL=0 - -DNVSHMEM_SHMEM_SUPPORT:BOOL=0 - -DNVSHMEM_TIMEOUT_DEVICE_POLLING:BOOL=0 - -DNVSHMEM_UCX_SUPPORT:BOOL=0 - -DNVSHMEM_USE_GDRCOPY:BOOL=0 - -DNVSHMEM_USE_NCCL:BOOL=0 - INSTALL_COMMAND "" - BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/nvshmem-build - BUILD_BYPRODUCTS - ${CMAKE_CURRENT_BINARY_DIR}/nvshmem-build/src/lib/libnvshmem.a) -add_library(nvshmem_project::nvshmem STATIC IMPORTED) -add_dependencies(nvshmem_project::nvshmem nvshmem_project) -file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/nvshmem-build/src/include) -set_target_properties( - nvshmem_project::nvshmem - PROPERTIES IMPORTED_LOCATION - ${CMAKE_CURRENT_BINARY_DIR}/nvshmem-build/src/lib/libnvshmem.a - INTERFACE_INCLUDE_DIRECTORIES - ${CMAKE_CURRENT_BINARY_DIR}/nvshmem-build/src/include) # Add DeepEP cpp # ============== @@ -188,7 +199,7 @@ set_target_properties( CUDA_SEPARABLE_COMPILATION ON CUDA_ARCHITECTURES "${DEEP_EP_CUDA_ARCHITECTURES}" LINK_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/deep_ep_cpp_tllm.version - INSTALL_RPATH "$ORIGIN/libs/nvshmem;${TORCH_INSTALL_PREFIX}/lib" + INSTALL_RPATH "${TORCH_INSTALL_PREFIX}/lib;${NVSHMEM_INSTALL_PREFIX}/lib" BUILD_WITH_INSTALL_RPATH TRUE) target_compile_options( deep_ep_cpp_tllm @@ -197,8 +208,9 @@ target_compile_options( target_compile_definitions( deep_ep_cpp_tllm PRIVATE DISABLE_AGGRESSIVE_PTX_INSTRS TORCH_EXTENSION_NAME=deep_ep_cpp_tllm) +target_include_directories(deep_ep_cpp_tllm PRIVATE ${NVSHMEM_INCLUDE_DIR}) target_link_libraries( - deep_ep_cpp_tllm PRIVATE nvshmem_project::nvshmem ${TORCH_LIBRARIES} + deep_ep_cpp_tllm PRIVATE ${NVSHMEM_DEVICE_LIBRARY} ${NVSHMEM_HOST_LIBRARY} ${TORCH_LIBRARIES} ${TORCH_PYTHON_LIB}) target_link_options( deep_ep_cpp_tllm PRIVATE @@ -207,4 +219,4 @@ target_link_options( # Set targets # =========== -add_dependencies(deep_ep deep_ep_cpp_tllm nvshmem_project) +add_dependencies(deep_ep deep_ep_cpp_tllm) diff --git a/cpp/tensorrt_llm/kernels/beamSearchKernels.cu b/cpp/tensorrt_llm/kernels/beamSearchKernels.cu index 97c35478bca..d606dfea164 100644 --- a/cpp/tensorrt_llm/kernels/beamSearchKernels.cu +++ b/cpp/tensorrt_llm/kernels/beamSearchKernels.cu @@ -134,32 +134,6 @@ void invokeUpdateCacheIndirection(int* tgtCI, int const* srcCI, BeamHypotheses& sync_check_cuda_error(stream); } -template -__global__ void addCumLogProbs(T* __restrict pStage1LogProbs, float const* __restrict cumLogProbs, - FinishedState const* finished, int const* endIds, float const* diversityRates, - runtime::SizeType32 const* batchSlots, size_t const nBS, size_t const nBMIn, size_t const nBMOut, size_t const nBM) -{ - int const bid = blockIdx.x; // Index of request in batch - runtime::SizeType32 const slot = batchSlots[bid]; - float const diversityRate{diversityRates[slot]}; - T* pLocalLogProbs = pStage1LogProbs + bid * nBMIn * nBMOut * 2; - - for (int i = threadIdx.x; i < nBMIn * nBMOut * 2; i += blockDim.x) - { - int const iBMIn = i / (nBMOut * 2); - if (finished[slot * nBMIn + iBMIn].isFinished()) - { - pLocalLogProbs[i] += (i == endIds[slot]) ? 1.0f : 0.0f; - } - else - { - // nBM is used in VBWS since `cumLogProbs` is initialized with kMaxBeamWidth earlier than BeamSearchLayer - pLocalLogProbs[i] += cumLogProbs[slot * nBM + iBMIn] + diversityRate * iBMIn; - } - } - return; -} - template __global__ void addCumLogProbs(float* __restrict pStage1LogProbs, float const* __restrict cumLogProbs, FinishedState const* finished, int const* endIds, float const* diversityRates, runtime::SizeType32 const* batchSlots, size_t const nBS, size_t const nBMIn, size_t const nBMOut, size_t const nBM); diff --git a/cpp/tensorrt_llm/kernels/beamSearchKernels.h b/cpp/tensorrt_llm/kernels/beamSearchKernels.h index 10a285af900..c0ad49eb097 100644 --- a/cpp/tensorrt_llm/kernels/beamSearchKernels.h +++ b/cpp/tensorrt_llm/kernels/beamSearchKernels.h @@ -130,10 +130,34 @@ void invokeTopkBeamSearch(T const* logProbs, T const* bias, void* workspace, Bea void invokeUpdateCacheIndirection(int* tgtCI, int const* srcCI, BeamHypotheses& bh, runtime::SizeType32 const maxAttentionWindow, runtime::SizeType32 sinkTokenLength, cudaStream_t stream); +#ifdef __CUDACC__ template -__global__ void addCumLogProbs(T* __restrict pStage1Probs, float const* __restrict cumLogProbs, +__global__ __attribute__((visibility("default"))) void addCumLogProbs(T* __restrict pStage1LogProbs, float const* __restrict cumLogProbs, FinishedState const* finished, int const* endIds, float const* diversityRates, - runtime::SizeType32 const* batchSlots, size_t const nBS, size_t const nBMIn, size_t const nBMOut, size_t const nBM); + runtime::SizeType32 const* batchSlots, size_t const nBS, size_t const nBMIn, size_t const nBMOut, size_t const nBM) +{ + int const bid = blockIdx.x; // Index of request in batch + runtime::SizeType32 const slot = batchSlots[bid]; + float const diversityRate{diversityRates[slot]}; + T* pLocalLogProbs = pStage1LogProbs + bid * nBMIn * nBMOut * 2; + + for (int i = threadIdx.x; i < nBMIn * nBMOut * 2; i += blockDim.x) + { + int const iBMIn = i / (nBMOut * 2); + if (finished[slot * nBMIn + iBMIn].isFinished()) + { + pLocalLogProbs[i] += (i == endIds[slot]) ? 1.0f : 0.0f; + } + else + { + // nBM is used in VBWS since `cumLogProbs` is initialized with kMaxBeamWidth earlier than BeamSearchLayer + pLocalLogProbs[i] += cumLogProbs[slot * nBM + iBMIn] + diversityRate * iBMIn; + } + } + return; +} +#endif + __global__ void gatherId(int const* __restrict pStage1Id, int* __restrict pStage2Id, size_t const nBS, size_t const nBMIn, size_t const nBMOut, size_t const nV); diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/CMakeLists.txt b/cpp/tensorrt_llm/kernels/cutlass_kernels/CMakeLists.txt index 7a02cdee73f..4a0e7d21c5a 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/CMakeLists.txt +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/CMakeLists.txt @@ -190,7 +190,7 @@ set_cuda_architectures(fb_gemm_src 89 90 120f) # ${INSTANTIATION_GENERATION_DIR}/fp8_rowwise_gemm) add_library(fp8_blockscale_gemm_src STATIC ${FP8_BLOCKSCALE_GEMM_SRC_CU}) -set_cuda_architectures(fp8_blockscale_gemm_src 89 90 100f) +set_cuda_architectures(fp8_blockscale_gemm_src 90) set(GEMM_SWIGLU_SM90_SRC_CU ${CMAKE_CURRENT_SOURCE_DIR}/fused_gated_gemm/gemm_swiglu_e4m3.cu) diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_tma_utils.cuh b/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_tma_utils.cuh index b105368af03..18911feb7c4 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_tma_utils.cuh +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_tma_utils.cuh @@ -84,7 +84,7 @@ inline CUtensorMapDataType get_CUtensorMapDataType() } } -PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() +PFN_cuTensorMapEncodeTiled_v12000 get_cuTensorMapEncodeTiled() { // Get pointer to cuTensorMapEncodeTiled cudaDriverEntryPointQueryResult driver_status; @@ -101,12 +101,12 @@ PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() throw std::runtime_error("driver_status != cudaDriverEntryPointSuccess"); } - return reinterpret_cast(cuTensorMapEncodeTiled_ptr); + return reinterpret_cast(cuTensorMapEncodeTiled_ptr); } template CUtensorMap make_2d_tma_copy_desc(data_type* global_address, uint64_t gmem_dim[2], uint64_t stride_in_bytes, - uint32_t smem_dim[2], CUtensorMapSwizzle swizzle_type, PFN_cuTensorMapEncodeTiled encode_func = nullptr) + uint32_t smem_dim[2], CUtensorMapSwizzle swizzle_type, PFN_cuTensorMapEncodeTiled_v12000 encode_func = nullptr) { CUtensorMap tensor_map{}; constexpr uint32_t rank = 2; diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h index ccda8ce2042..43409dc4b4a 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h @@ -2597,7 +2597,7 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske __shared__ typename BlockReduce::TempStorage temp_storage; // Obtain a segment of consecutive items that are blocked across threads (final_max from above) // Compute the block-wide max for thread0 - final_max = BlockReduce(temp_storage).Reduce(thread_partial_max, cub::Max(), gridDim.z); + final_max = BlockReduce(temp_storage).Reduce(thread_partial_max, cuda::maximum(), gridDim.z); __shared__ float final_max_smem; if (tidx == 0) diff --git a/cpp/tensorrt_llm/kernels/sageAttentionKernels.cu b/cpp/tensorrt_llm/kernels/sageAttentionKernels.cu index 80a12b41ce5..e45a7bb97f9 100644 --- a/cpp/tensorrt_llm/kernels/sageAttentionKernels.cu +++ b/cpp/tensorrt_llm/kernels/sageAttentionKernels.cu @@ -250,7 +250,7 @@ __global__ void sage_quant_kernel(void const* q, void const* k, void const* v, i // Compute the block-wide max for thread0 // cuda::maximum<>{} - float aggregate = BlockReduce(temp_storage).Reduce(local_amax, cub::Max{}); + float aggregate = BlockReduce(temp_storage).Reduce(local_amax, cuda::maximum{}); if (row_id == 0 && col_id == 0) s_block_amax = static_cast(aggregate); @@ -429,7 +429,7 @@ __global__ void sage_quant_kernel(void const* q, void const* k, void const* v, i // Compute the block-wide max for thread0 // cuda::maximum<>{} - float aggregate = BlockReduce(temp_storage).Reduce(local_amax, cub::Max{}); + float aggregate = BlockReduce(temp_storage).Reduce(local_amax, cuda::maximum{}); if (row_id == 0 && col_id == 0) s_block_amax = static_cast(aggregate); diff --git a/cpp/tensorrt_llm/kernels/speculativeDecoding/eagleDecodingKernels.cu b/cpp/tensorrt_llm/kernels/speculativeDecoding/eagleDecodingKernels.cu index b3a90bea5f8..e963033855b 100644 --- a/cpp/tensorrt_llm/kernels/speculativeDecoding/eagleDecodingKernels.cu +++ b/cpp/tensorrt_llm/kernels/speculativeDecoding/eagleDecodingKernels.cu @@ -504,7 +504,7 @@ __global__ void prepareGenEagleNetInputsKernel(SizeType32* nextSequenceLengths, BlockScan(tempStorage.scan).ExclusiveSum(numNextLogits, outputLastIndicesBase); // Sync because tempStorage is reused. __syncthreads(); - auto const maxGenLength = BlockReduce(tempStorage.reduce).Reduce(nextDraftLen, cub::Max()); + auto const maxGenLength = BlockReduce(tempStorage.reduce).Reduce(nextDraftLen, cuda::maximum()); // Thread 0 has the result. if (bid == 0) diff --git a/cpp/tensorrt_llm/kernels/topkLastDim.cu b/cpp/tensorrt_llm/kernels/topkLastDim.cu index 3371ab4a0f2..3006f0f6466 100644 --- a/cpp/tensorrt_llm/kernels/topkLastDim.cu +++ b/cpp/tensorrt_llm/kernels/topkLastDim.cu @@ -25,6 +25,8 @@ #include "topkLastDim.h" #include #include +#include +#include namespace tensorrt_llm { @@ -1221,9 +1223,9 @@ void standalone_stable_radix_topk_(void* buf, size_t& buf_size, T const* in, Idx IdxT* sort_in_idx = nullptr; air_topk_stable::ComputeOffset computeoffset(k); - cub::CountingInputIterator counting_iter(0); - cub::TransformInputIterator, cub::CountingInputIterator> - transform_iter(counting_iter, computeoffset); + auto counting_iter = thrust::make_counting_iterator(0); + auto transform_iter = thrust::make_transform_iterator(counting_iter, computeoffset); + cub::DeviceSegmentedSort::SortPairs(NULL, temp_storage_bytes, out_idx, out_idx, out, out, k * batch_size, batch_size, transform_iter, transform_iter + 1, stream); if (sorted) @@ -1348,9 +1350,8 @@ void standalone_stable_radix_topk_one_block_(void* buf, size_t& buf_size, T cons const IdxT buf_len = air_topk_stable::calc_buf_len(len); air_topk_stable::ComputeOffset computeoffset(k); - cub::CountingInputIterator counting_iter(0); - cub::TransformInputIterator, cub::CountingInputIterator> - transform_iter(counting_iter, computeoffset); + auto counting_iter = thrust::make_counting_iterator(0); + auto transform_iter = thrust::make_transform_iterator(counting_iter, computeoffset); cub::DeviceSegmentedSort::SortPairs(NULL, temp_storage_bytes, out_idx, out_idx, out, out, k * batch_size, batch_size, transform_iter, transform_iter + 1, stream); diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/DevKernel.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/DevKernel.cu index ad5cd15fdda..ba850c45a2f 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/DevKernel.cu +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/DevKernel.cu @@ -154,7 +154,7 @@ __global__ void activationDeepSeekKernel(KernelParams params) float constexpr E4m3MaxVal{448.f}; // Compute the absolute max - float aMax = BlockReduce(temp_storage).Reduce(fabsf(out), cub::Max()); + float aMax = BlockReduce(temp_storage).Reduce(fabsf(out), cuda::maximum()); if (threadIdx.x == 0) { s_scaleOut = aMax / E4m3MaxVal; @@ -657,7 +657,7 @@ __global__ void finalizeDeepSeekKernel(KernelParams params) float constexpr E4m3MaxVal{448.f}; // Compute the absolute max - float aMax = BlockReduce(temp_storage).Reduce(fabsf(acc), cub::Max()); + float aMax = BlockReduce(temp_storage).Reduce(fabsf(acc), cuda::maximum()); if (threadIdx.x == 0) { diff --git a/cpp/tensorrt_llm/runtime/moeLoadBalancer/hostAccessibleDeviceAllocator.cpp b/cpp/tensorrt_llm/runtime/moeLoadBalancer/hostAccessibleDeviceAllocator.cpp index d41aa157c50..a384f845d6f 100644 --- a/cpp/tensorrt_llm/runtime/moeLoadBalancer/hostAccessibleDeviceAllocator.cpp +++ b/cpp/tensorrt_llm/runtime/moeLoadBalancer/hostAccessibleDeviceAllocator.cpp @@ -364,7 +364,8 @@ void* HostAccessibleDeviceAllocator::allocate(size_t memorySize) TLLM_CHECK_WITH_INFO( mAllowManagedFallback, "HostAccessibleDeviceAllocator is not supported on the current system."); TLLM_CUDA_CHECK(cudaMallocManaged(&devPtr, memorySize)); - TLLM_CUDA_CHECK(cudaMemAdvise(devPtr, memorySize, cudaMemAdviseSetPreferredLocation, currentDevId)); + cudaMemLocation location {cudaMemLocationTypeDevice, currentDevId}; + TLLM_CUDA_CHECK(cudaMemAdvise(devPtr, memorySize, cudaMemAdviseSetPreferredLocation, location)); hostPtr = devPtr; } recordAllocation(devPtr, memorySize, hostPtr, memDesc); diff --git a/cpp/tensorrt_llm/runtime/utils/debugUtils.cu b/cpp/tensorrt_llm/runtime/utils/debugUtils.cu index 7f1c8d8dfc6..661dacd9a7a 100644 --- a/cpp/tensorrt_llm/runtime/utils/debugUtils.cu +++ b/cpp/tensorrt_llm/runtime/utils/debugUtils.cu @@ -54,7 +54,7 @@ __global__ void checkTensorInvalidKernel(T const* data, std::size_t size, int* f __shared__ typename BlockReduceT::TempStorage tempStorage; // Compute block-wide maximum - int blockFound = BlockReduceT(tempStorage).Reduce(found, cub::Max()); + int blockFound = BlockReduceT(tempStorage).Reduce(found, cuda::maximum()); // Have thread 0 write out block's result if (threadIdx.x == 0) diff --git a/requirements.txt b/requirements.txt index 16c1e4b5f8c..c5db93a00f3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,9 @@ ---extra-index-url https://download.pytorch.org/whl/cu128 +--extra-index-url https://download.pytorch.org/whl/cu130 -c constraints.txt accelerate>=0.25.0 build colored -cuda-python # Do not override the custom version of cuda-python installed in the NGC PyTorch image. +cuda-python~=13.0.0 # Do not override the custom version of cuda-python installed in the NGC PyTorch image. diffusers>=0.27.0 lark mpi4py @@ -13,27 +13,30 @@ onnx_graphsurgeon>=0.5.2 openai polygraphy psutil -nvidia-ml-py>=12,<13 +nvidia-ml-py +# >=12,<13 # Just a wrapper since nvidia-modelopt requires pynvml -pynvml==12.0.0 +pynvml pulp pandas -h5py==3.12.1 +h5py>=3.12.1 StrEnum sentencepiece>=0.1.99 -tensorrt~=10.11.0 +tensorrt~=10.13.0 # https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-25-05.html#rel-25-05 uses 2.8.0a0. -torch>=2.7.1,<=2.8.0a0 +torch>=2.7.1,<=2.9.0a0 torchvision nvidia-modelopt[torch]~=0.33.0 -nvidia-nccl-cu12 -nvidia-cuda-nvrtc-cu12 -transformers==4.53.1 +nvidia-nccl-cu13 +nvidia-nvshmem-cu13 +nvidia-cuda-nvrtc-cu13 +transformers~=4.55.0 pydantic>=2.9.1 pydantic-settings[yaml] omegaconf -pillow==10.3.0 -wheel<=0.45.1 +pillow>=10.3.0 +wheel +#<=0.45.1 optimum # evaluate needs datasets>=2.0.0 which triggers datasets>3.1.0 which is not stable: https://github.com/huggingface/datasets/issues/7467 datasets==3.1.0 @@ -43,15 +46,15 @@ click click_option_group aenum pyzmq -fastapi==0.115.4 +fastapi>=0.115.4 uvicorn setuptools<80 ordered-set peft einops -flashinfer-python==0.2.5 +### flashinfer-python>=0.2.5 ### installs triton opencv-python-headless -xgrammar==0.1.19 +### xgrammar>=0.1.19 ### installs triton backoff nvtx matplotlib # FIXME: this is added to make nvtx happy @@ -59,5 +62,5 @@ meson ninja etcd3 blake3 -llguidance==0.7.29 +llguidance>=0.7.29 soundfile diff --git a/scripts/build_wheel.py b/scripts/build_wheel.py index 3fdaa93febb..7c0fe3d1b01 100755 --- a/scripts/build_wheel.py +++ b/scripts/build_wheel.py @@ -669,18 +669,21 @@ def get_binding_lib(subdirectory, name): "deep_ep", deep_ep_dir, dirs_exist_ok=True) + (lib_dir / "nvshmem").mkdir(exist_ok=True) - install_file( - build_dir / "tensorrt_llm/deep_ep/nvshmem-build/License.txt", - lib_dir / "nvshmem") - install_file( - build_dir / - "tensorrt_llm/deep_ep/nvshmem-build/src/lib/nvshmem_bootstrap_uid.so.3", - lib_dir / "nvshmem") - install_file( - build_dir / - "tensorrt_llm/deep_ep/nvshmem-build/src/lib/nvshmem_transport_ibgda.so.103", - lib_dir / "nvshmem") + nvshmem_license = build_dir / "tensorrt_llm/deep_ep/nvshmem-build/License.txt" + if nvshmem_license.exists(): + install_file( + build_dir / "tensorrt_llm/deep_ep/nvshmem-build/License.txt", + lib_dir / "nvshmem") + install_file( + build_dir / + "tensorrt_llm/deep_ep/nvshmem-build/src/lib/nvshmem_bootstrap_uid.so.3", + lib_dir / "nvshmem") + install_file( + build_dir / + "tensorrt_llm/deep_ep/nvshmem-build/src/lib/nvshmem_transport_ibgda.so.103", + lib_dir / "nvshmem") if not skip_stubs: with working_directory(project_dir): if binding_type == "nanobind":