diff --git a/CODING_GUIDELINES.md b/CODING_GUIDELINES.md index a311f1bbff9..dfa177c01d4 100644 --- a/CODING_GUIDELINES.md +++ b/CODING_GUIDELINES.md @@ -298,7 +298,7 @@ for (int i = 0; i < static_cast(mTensors.size()); ++i) 1. C headers should not be used directly. - Example: Use `` instead of `` 2. Do not use C library functions, whenever possible. - * Use brace initialization or `std::fill_n()` instead of `memset()`. This is especially important when dealing with non-[POD types](http://en.cppreference.com/w/cpp/concept/PODType). In the example below, using `memset()` will corrupt the vtable of `Foo:` + * Use brace initialization or `std::fill_n()` instead of `memset()`. This is especially important when dealing with non-[POD types](https://en.cppreference.com/w/cpp/named_req/PODType). In the example below, using `memset()` will corrupt the vtable of `Foo:` ```cpp struct Foo { virtual int getX() { return x; } diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz index 93d69e6bf79..1e3a7ee2da4 100644 --- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz +++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:626761456288897fb021c78145fa4e56140890be62055db16491353f319cd455 -size 63672012 +oid sha256:9550450b1caceb6d04b2e5db9ca987b69a0f31859be399b9d0fc89aec619d1a6 +size 63677028 diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/version.txt b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/version.txt index f4e8a3ed29b..edb023bf0e2 100644 --- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/version.txt +++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/version.txt @@ -1,2 +1,2 @@ -4a23fc1883a3d35ae8b47c6bac3d8e7e85adc0e5615a9afd46acf8b84c8d62a8 libtensorrt_llm_internal_cutlass_kernels_static.a -commit 0971a156bd13fe781c1c16659c94bc69e0ca77e1 +b34f70bfe1f97af0b38c1866858d519fbf1c383f97b3f375c2a59985336e1fa2 libtensorrt_llm_internal_cutlass_kernels_static.a +commit 59ea816b923df8ece603c6eb06cbbf80d82557c9 diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz index f257496ee5b..8cf8f3da05b 100644 --- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz +++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/tensorrt_llm_internal_cutlass_kernels_static.tar.xz @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:ae0b3678821f4c2415f8c6ae4922cf12f87d4b9d3ac3503fe5f83dc3eca84320 -size 63178880 +oid sha256:8c68ca94f31b429f163c8f9af1e1a3e307e31bd76b57c760572eef6f5a79c150 +size 63182304 diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/version.txt b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/version.txt index adee7af80a3..e7e3c6a0bc9 100644 --- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/version.txt +++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/version.txt @@ -1,2 +1,2 @@ -ea9382af97a51c33ee4edb438e9add4e52fbdfcea68c6e420204992213013727 libtensorrt_llm_internal_cutlass_kernels_static.a -commit 0971a156bd13fe781c1c16659c94bc69e0ca77e1 +eb5d38d57e05d9d6702d29b99e961ace2927bf95356cd44c79b61aa3d450c0ee libtensorrt_llm_internal_cutlass_kernels_static.a +commit 59ea816b923df8ece603c6eb06cbbf80d82557c9 diff --git a/cpp/tensorrt_llm/kernels/preQuantScaleKernel.cu b/cpp/tensorrt_llm/kernels/preQuantScaleKernel.cu index e94fa9d362a..fef285931a0 100644 --- a/cpp/tensorrt_llm/kernels/preQuantScaleKernel.cu +++ b/cpp/tensorrt_llm/kernels/preQuantScaleKernel.cu @@ -40,8 +40,8 @@ struct Vec2Type<__nv_bfloat16> }; // namespace template -__global__ void apply_per_channel_scale( - T_out* smoothed_act, T_in const* act, T_in const* per_channel_scale, int rows, int cols) +__global__ void apply_per_channel_scale(T_out* smoothed_act, T_in const* act, T_in const* per_channel_scale, int rows, + int cols, int64_t const* num_valid_tokens_ptr) { static constexpr int kElems = sizeof(AccessType) / sizeof(T_in); T_in scale[kElems], act_vec[kElems]; @@ -49,6 +49,8 @@ __global__ void apply_per_channel_scale( int row_offset = blockIdx.x; if (col_offset * kElems >= cols || row_offset * kProcessRows >= rows) return; + if (num_valid_tokens_ptr && (row_offset * kProcessRows >= *num_valid_tokens_ptr)) + return; act += row_offset * kProcessRows * cols; smoothed_act += row_offset * kProcessRows * cols; *reinterpret_cast(scale) = reinterpret_cast(per_channel_scale)[col_offset]; @@ -95,46 +97,46 @@ __global__ void apply_per_channel_scale( } template -void apply_per_channel_scale_kernel_launcher_( - T_out* smoothed_act, T_in const* act, T_in const* per_channel_scale, int rows, int cols, cudaStream_t stream = 0) +void apply_per_channel_scale_kernel_launcher_(T_out* smoothed_act, T_in const* act, T_in const* per_channel_scale, + int rows, int cols, int64_t const* num_valid_tokens_ptr = nullptr, cudaStream_t stream = 0) { static constexpr int kElems = sizeof(AccessType) / sizeof(T_in); dim3 block(128); dim3 grid((rows + kProcessRows - 1) / kProcessRows, (cols / kElems + block.x - 1) / block.x); apply_per_channel_scale - <<>>(smoothed_act, act, per_channel_scale, rows, cols); + <<>>(smoothed_act, act, per_channel_scale, rows, cols, num_valid_tokens_ptr); } template -void apply_per_channel_scale_kernel_launcher( - T_out* smoothed_act, T_in const* act, T_in const* per_channel_scale, int rows, int cols, cudaStream_t stream) +void apply_per_channel_scale_kernel_launcher(T_out* smoothed_act, T_in const* act, T_in const* per_channel_scale, + int rows, int cols, int64_t const* num_valid_tokens_ptr, cudaStream_t stream) { uint64_t elems = static_cast(rows) * static_cast(cols); if (elems < 2048 * 2048) { apply_per_channel_scale_kernel_launcher_( - smoothed_act, act, per_channel_scale, rows, cols, stream); + smoothed_act, act, per_channel_scale, rows, cols, num_valid_tokens_ptr, stream); } else if (elems < 4096 * 4096) { apply_per_channel_scale_kernel_launcher_( - smoothed_act, act, per_channel_scale, rows, cols, stream); + smoothed_act, act, per_channel_scale, rows, cols, num_valid_tokens_ptr, stream); } else if (elems < 8192 * 8192) { apply_per_channel_scale_kernel_launcher_( - smoothed_act, act, per_channel_scale, rows, cols, stream); + smoothed_act, act, per_channel_scale, rows, cols, num_valid_tokens_ptr, stream); } else { apply_per_channel_scale_kernel_launcher_( - smoothed_act, act, per_channel_scale, rows, cols, stream); + smoothed_act, act, per_channel_scale, rows, cols, num_valid_tokens_ptr, stream); } } #define INSTANTIATE_PREQUANT_SCALE(T_in, T_out) \ - template void apply_per_channel_scale_kernel_launcher( \ - T_out * smoothed_act, const T_in* act, const T_in* per_channel_scale, int rows, int cols, cudaStream_t stream) + template void apply_per_channel_scale_kernel_launcher(T_out * smoothed_act, const T_in* act, \ + const T_in* per_channel_scale, int rows, int cols, int64_t const* num_valid_tokens_ptr, cudaStream_t stream) INSTANTIATE_PREQUANT_SCALE(half, half); #if defined(ENABLE_FP8) @@ -148,128 +150,5 @@ INSTANTIATE_PREQUANT_SCALE(__nv_bfloat16, __nv_fp8_e4m3); #endif #endif -template -__global__ void apply_per_expert_scale(T_out* smoothed_act, T_in const* act, T_in const* per_expert_scale, - int const* permuted_token_selected_experts, int64_t const* num_valid_tokens_ptr, int rows, int cols) -{ - static constexpr int kElems = sizeof(AccessType) / sizeof(T_in); - T_in act_vec[kElems]; - int col_offset = blockIdx.x * blockDim.x + threadIdx.x; - int row_offset = blockIdx.y; - int expert_idx = permuted_token_selected_experts[row_offset]; - T_in scale = per_expert_scale[expert_idx]; - if (col_offset * kElems >= cols || row_offset * kProcessRows >= rows) - return; - if (num_valid_tokens_ptr && (row_offset * kProcessRows >= *num_valid_tokens_ptr)) - return; - act += row_offset * kProcessRows * cols; - smoothed_act += row_offset * kProcessRows * cols; -#pragma unroll - for (int i = 0; i < kProcessRows; ++i) - { - *reinterpret_cast(act_vec) = reinterpret_cast(act + i * cols)[col_offset]; - if constexpr ((std::is_same_v -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16)) - || std::is_same_v -#endif - ) &&(kElems % 2 == 0)) - { - using Vec2 = typename Vec2Type::type; -#pragma unroll - for (int j = 0; j < kElems; j += 2) - { - if constexpr (std::is_same_v) - { - *reinterpret_cast(act_vec + j) - = __hmul2(*reinterpret_cast(act_vec + j), __half2half2(scale)); - } - else - { - *reinterpret_cast(act_vec + j) - = __hmul2(*reinterpret_cast(act_vec + j), __bfloat162bfloat162(scale)); - } - } - } - else - { -#pragma unroll - for (int j = 0; j < kElems; ++j) - { - act_vec[j] = static_cast(static_cast(act_vec[j]) * static_cast(scale)); - } - } - if constexpr (std::is_same_v) - { - reinterpret_cast(smoothed_act + i * cols)[col_offset] - = *reinterpret_cast(act_vec); - } - else - { -#pragma unroll - for (int j = 0; j < kElems; ++j) - { - (smoothed_act + i * cols)[col_offset * kElems + j] = static_cast(act_vec[j]); - } - } - } -} - -template -void apply_per_expert_scale_kernel_launcher_(T_out* smoothed_act, T_in const* act, T_in const* per_expert_scale, - int const* permuted_token_selected_experts, int64_t const* num_valid_tokens_ptr, int rows, int cols, - cudaStream_t stream = 0) -{ - static constexpr int kElems = sizeof(AccessType) / sizeof(T_in); - dim3 block(128); - dim3 grid((cols / kElems + block.x - 1) / block.x, (rows + kProcessRows - 1) / kProcessRows); - apply_per_expert_scale<<>>( - smoothed_act, act, per_expert_scale, permuted_token_selected_experts, num_valid_tokens_ptr, rows, cols); -} - -template -void apply_per_expert_scale_kernel_launcher(T_out* smoothed_act, T_in const* act, T_in const* per_expert_scale, - int const* permuted_token_selected_experts, int64_t const* num_valid_tokens_ptr, int rows, int cols, - cudaStream_t stream) -{ - int elems = rows * cols; - if (elems < 2048 * 2048) - { - apply_per_expert_scale_kernel_launcher_(smoothed_act, act, per_expert_scale, - permuted_token_selected_experts, num_valid_tokens_ptr, rows, cols, stream); - } - else if (elems < 4096 * 4096) - { - apply_per_expert_scale_kernel_launcher_(smoothed_act, act, per_expert_scale, - permuted_token_selected_experts, num_valid_tokens_ptr, rows, cols, stream); - } - else if (elems < 8192 * 8192) - { - apply_per_expert_scale_kernel_launcher_(smoothed_act, act, per_expert_scale, - permuted_token_selected_experts, num_valid_tokens_ptr, rows, cols, stream); - } - else - { - apply_per_expert_scale_kernel_launcher_(smoothed_act, act, per_expert_scale, - permuted_token_selected_experts, num_valid_tokens_ptr, rows, cols, stream); - } -} - -#define INSTANTIATE_PEREXPERT_SCALE(T_in, T_out) \ - template void apply_per_expert_scale_kernel_launcher(T_out * smoothed_act, T_in const* act, \ - T_in const* per_expert_scale, int const* permuted_token_selected_experts, int64_t const* num_valid_tokens_ptr, \ - int rows, int cols, cudaStream_t stream) - -INSTANTIATE_PEREXPERT_SCALE(half, half); -#if defined(ENABLE_FP8) -INSTANTIATE_PEREXPERT_SCALE(half, __nv_fp8_e4m3); -#endif - -#if defined(ENABLE_BF16) -INSTANTIATE_PEREXPERT_SCALE(__nv_bfloat16, __nv_bfloat16); -#if defined(ENABLE_FP8) -INSTANTIATE_PEREXPERT_SCALE(__nv_bfloat16, __nv_fp8_e4m3); -#endif -#endif - } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/preQuantScaleKernel.h b/cpp/tensorrt_llm/kernels/preQuantScaleKernel.h index 63798373313..11cf0f193a8 100644 --- a/cpp/tensorrt_llm/kernels/preQuantScaleKernel.h +++ b/cpp/tensorrt_llm/kernels/preQuantScaleKernel.h @@ -36,13 +36,8 @@ namespace kernels { template -void apply_per_channel_scale_kernel_launcher( - T_out* smoothed_act, T_in const* act, T_in const* per_channel_scale, int rows, int cols, cudaStream_t stream = 0); - -template -void apply_per_expert_scale_kernel_launcher(T_out* smoothed_act, T_in const* act, T_in const* per_expert_scale, - int const* permuted_token_selected_experts, int64_t const* num_valid_tokens_ptr, int rows, int cols, - cudaStream_t stream = 0); +void apply_per_channel_scale_kernel_launcher(T_out* smoothed_act, T_in const* act, T_in const* per_channel_scale, + int rows, int cols, int64_t const* num_valid_tokens_ptr = nullptr, cudaStream_t stream = 0); } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/plugins/weightOnlyGroupwiseQuantMatmulPlugin/weightOnlyGroupwiseQuantMatmulPlugin.cpp b/cpp/tensorrt_llm/plugins/weightOnlyGroupwiseQuantMatmulPlugin/weightOnlyGroupwiseQuantMatmulPlugin.cpp index 5dc16460879..c9b779f4f33 100644 --- a/cpp/tensorrt_llm/plugins/weightOnlyGroupwiseQuantMatmulPlugin/weightOnlyGroupwiseQuantMatmulPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/weightOnlyGroupwiseQuantMatmulPlugin/weightOnlyGroupwiseQuantMatmulPlugin.cpp @@ -394,13 +394,13 @@ void pre_quant_scale_for_act(int const m, int const k, int const mQuantAlgo, int { tensorrt_llm::kernels::apply_per_channel_scale_kernel_launcher( reinterpret_cast<__nv_fp8_e4m3*>(workspace), reinterpret_cast(inputs[0]), - reinterpret_cast(inputs[mPreQuantScaleInputIdx]), m, k, stream); + reinterpret_cast(inputs[mPreQuantScaleInputIdx]), m, k, nullptr, stream); } else { tensorrt_llm::kernels::apply_per_channel_scale_kernel_launcher( reinterpret_cast(workspace), reinterpret_cast(inputs[0]), - reinterpret_cast(inputs[mPreQuantScaleInputIdx]), m, k, stream); + reinterpret_cast(inputs[mPreQuantScaleInputIdx]), m, k, nullptr, stream); } } diff --git a/cpp/tests/unit_tests/kernels/weightOnly/weightOnlyKernelTest.cpp b/cpp/tests/unit_tests/kernels/weightOnly/weightOnlyKernelTest.cpp index 81e0fce059e..3f22d594e2a 100644 --- a/cpp/tests/unit_tests/kernels/weightOnly/weightOnlyKernelTest.cpp +++ b/cpp/tests/unit_tests/kernels/weightOnly/weightOnlyKernelTest.cpp @@ -220,7 +220,7 @@ void exec_cutlass_kernel( { tensorrt_llm::kernels::apply_per_channel_scale_kernel_launcher( reinterpret_cast(scaled_act), reinterpret_cast(params.act), - reinterpret_cast(params.act_scale), params.m, params.k, stream); + reinterpret_cast(params.act_scale), params.m, params.k, nullptr, stream); act = scaled_act; } if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY) diff --git a/docs/source/architecture/core-concepts.md b/docs/source/architecture/core-concepts.md index 4534eccf3f2..3f7cfd558d8 100644 --- a/docs/source/architecture/core-concepts.md +++ b/docs/source/architecture/core-concepts.md @@ -4,24 +4,24 @@ TensorRT-LLM has a Model Definition API that can be used to define Large Language Models. This API is built on top of the powerful -[TensorRT Python API](https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/index.html#) +[TensorRT Python API](https://docs.nvidia.com/deeplearning/tensorrt/latest/_static/python-api/index.html) to create graph representations of deep neural networks in TensorRT. To become familiar with the core concepts of the TensorRT API, refer to the -[Core Concepts](https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/coreConcepts.html) +[Core Concepts](https://docs.nvidia.com/deeplearning/tensorrt/latest/_static/python-api/coreConcepts.html) section of the TensorRT documentation before proceeding further. In TensorRT-LLM, the [`tensorrt_llm.Builder`](source:tensorrt_llm/builder.py) class contains a -[`tensorrt.Builder`](https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/infer/Core/Builder.html#tensorrt.Builder) +[`tensorrt.Builder`](https://docs.nvidia.com/deeplearning/tensorrt/latest/_static/python-api/infer/Core/Builder.html#id1) object. That instance is used in the `tensorrt_llm.Builder.create_network` method to create an instance of the -[`tensorrt.INetworkDefinition`](https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/infer/Graph/Network.html#tensorrt.INetworkDefinition) +[`tensorrt.INetworkDefinition`](https://docs.nvidia.com/deeplearning/tensorrt/latest/_static/python-api/infer/Graph/Network.html#tensorrt.INetworkDefinition) class. The `INetworkDefinition` object can then be populated using the free functions defined in the [`tensorrt_llm.functional`](source:tensorrt_llm/functional.py). A simple example of such a free function is `tensorrt_llm.activation` that inserts a -[`tensorrt.IActivationLayer`](https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/infer/Graph/Layers.html#tensorrt.IActivationLayer) +[`tensorrt.IActivationLayer`](https://docs.nvidia.com/deeplearning/tensorrt/latest/_static/python-api/infer/Graph/Layers.html#tensorrt.IActivationLayer) node in the graph of the model: ```python @@ -56,23 +56,23 @@ def silu(input: Tensor) -> Tensor: When the TensorRT-LLM's Model Definition API is utilized, a graph of the network is assembled. The graph can later be traversed or transformed using the graph traversal API exposed by the -[`tensorrt.ILayer`](https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/infer/Graph/LayerBase.html#tensorrt.ILayer) +[`tensorrt.ILayer`](https://docs.nvidia.com/deeplearning/tensorrt/latest/_static/python-api/infer/Graph/LayerBase.html#tensorrt.ILayer) class. That graph will also be optimized by TensorRT during the compilation of the engine, as explained in the next section. # Compilation Once populated, the instance of the -[`tensorrt.INetworkDefinition`](https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/infer/Graph/Network.html#tensorrt.INetworkDefinition), +[`tensorrt.INetworkDefinition`](https://docs.nvidia.com/deeplearning/tensorrt/latest/_static/python-api/infer/Graph/Network.html#tensorrt.INetworkDefinition), can be compiled into an efficient engine by the -[`tensorrt.Builder`](https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/infer/Core/Builder.html#tensorrt.Builder) +[`tensorrt.Builder`](https://docs.nvidia.com/deeplearning/tensorrt/latest/_static/python-api/infer/Core/Builder.html#id1) In TensorRT-LLM, it is done through the `build_engine` member function of the `tensorrt_llm.Builder` class that calls the -[`build_serialized_network`](https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/infer/Core/Builder.html#tensorrt.Builder.build_serialized_network) +[`build_serialized_network`](https://docs.nvidia.com/deeplearning/tensorrt/latest/_static/python-api/infer/Core/Builder.html#tensorrt.Builder.build_serialized_network method of the -[`tensorrt.Builder`](https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/infer/Core/Builder.html#tensorrt.Builder) +[`tensorrt.Builder`](https://docs.nvidia.com/deeplearning/tensorrt/latest/_static/python-api/infer/Core/Builder.html#id1) object. That call, if everything works as expected, produces an instance of the -[`tensorrt.IHostMemory`](https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/infer/FoundationalTypes/HostMemory.html#tensorrt.IHostMemory) +[`tensorrt.IHostMemory`](https://docs.nvidia.com/deeplearning/tensorrt/latest/_static/python-api/infer/FoundationalTypes/HostMemory.html#tensorrt.IHostMemory) class. That object is an optimized TensorRT engine that can be stored as a binary file. diff --git a/docs/source/blogs/H100vsA100.md b/docs/source/blogs/H100vsA100.md index bdffe3fe74c..bd87dc718a3 100644 --- a/docs/source/blogs/H100vsA100.md +++ b/docs/source/blogs/H100vsA100.md @@ -4,7 +4,7 @@ # H100 has 4.6x A100 Performance in TensorRT-LLM, achieving 10,000 tok/s at 100ms to first token -TensorRT-LLM evaluated on both Hopper and Ampere shows **H100 FP8 is up to 4.6x max throughput and 4.4x faster 1st token latency than A100**. H100 FP8 is able to achieve over 10,000 output tok/s at [peak throughput](https://nvidia.github.io/TensorRT-LLM/performance.html#h100-gpus-fp8) for 64 concurrent requests, while maintaining a 1st token latency of 100ms. For [min-latency](https://nvidia.github.io/TensorRT-LLM/performance.html#id1) applications, TRT-LLM H100 can achieve less than 10ms to 1st token latency. +TensorRT-LLM evaluated on both Hopper and Ampere shows **H100 FP8 is up to 4.6x max throughput and 4.4x faster 1st token latency than A100**. H100 FP8 is able to achieve over 10,000 output tok/s at peak throughput for 64 concurrent requests, while maintaining a 1st token latency of 100ms. For min-latency applications, TRT-LLM H100 can achieve less than 10ms to 1st token latency. max throughput @@ -28,7 +28,7 @@ TensorRT-LLM evaluated on both Hopper and Ampere shows **H100 FP8 is up to 4.6x FP8 H100, FP16 A100, SXM 80GB GPUs, TP1, ISL/OSL's provided, TensorRT-LLM v0.5.0., TensorRT 9.1 -The full data behind these charts & tables and including larger models with higher TP values can be found in TensorRT-LLM's [Performance Documentation](https://nvidia.github.io/TensorRT-LLM/performance.html#performance-of-tensorrt-llm) +The full data behind these charts & tables and including larger models with higher TP values can be found in TensorRT-LLM's [Performance Documentation](https://nvidia.github.io/TensorRT-LLM/latest/performance/perf-overview.html) Stay tuned for a highlight on Llama coming soon! diff --git a/docs/source/blogs/H200launch.md b/docs/source/blogs/H200launch.md index 58f5c087819..baa4905613d 100644 --- a/docs/source/blogs/H200launch.md +++ b/docs/source/blogs/H200launch.md @@ -21,7 +21,7 @@ TensorRT-LLM evaluation of the [new H200 GPU](https://nvidianews.nvidia.com/news *(1) Largest batch supported on given TP configuration by power of 2.* *(2) TP = Tensor Parallelism* -Additional Performance data is available on the [NVIDIA Data Center Deep Learning Product Performance](https://developer.nvidia.com/deep-learning-performance-training-inference/ai-inference) page, & soon in [TensorRT-LLM's Performance Documentation](https://nvidia.github.io/TensorRT-LLM/performance.html). +Additional Performance data is available on the [NVIDIA Data Center Deep Learning Product Performance](https://developer.nvidia.com/deep-learning-performance-training-inference/ai-inference) page, & soon in [TensorRT-LLM's Performance Documentation](https://nvidia.github.io/TensorRT-LLM/latest/performance/perf-overview.html). ### H200 vs H100 diff --git a/examples/llm-api/README.md b/examples/llm-api/README.md index 665cec8a02a..10a338a59e6 100644 --- a/examples/llm-api/README.md +++ b/examples/llm-api/README.md @@ -1,3 +1,3 @@ # LLM API Examples -Please refer to the [official documentation](https://nvidia.github.io/TensorRT-LLM/llm-api/), [examples](https://nvidia.github.io/TensorRT-LLM/llm-api-examples/llm_api_examples.html) and [customization](https://nvidia.github.io/TensorRT-LLM/examples/customization.html) for detailed information and usage guidelines regarding the LLM API. +Please refer to the [official documentation](https://nvidia.github.io/TensorRT-LLM/llm-api/), [examples](https://nvidia.github.io/TensorRT-LLM/latest/examples/llm_api_examples.html) and [customization](https://nvidia.github.io/TensorRT-LLM/examples/customization.html) for detailed information and usage guidelines regarding the LLM API. diff --git a/examples/models/core/gpt/README.md b/examples/models/core/gpt/README.md index 764323532aa..dd74fad2bb9 100644 --- a/examples/models/core/gpt/README.md +++ b/examples/models/core/gpt/README.md @@ -694,7 +694,7 @@ python3 ../../../run.py --engine_dir gpt-next-2B/trt_engines/bf16/1-gpu \ ### Prompt-tuning For efficient fine-tuning, the NeMo framework allows you to learn virtual tokens to accomplish a downstream task. For more details, please read the -NeMo documentation [here](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/nlp/nemo_megatron/prompt_learning.html). +NeMo documentation [here](https://docs.nvidia.com/nemo-framework/user-guide/latest/overview.html). TensorRT-LLM supports inference with those virtual tokens. To enable it, pass the prompt embedding table's maximum size at build time with `--max_prompt_embedding_table_size N`. For example: diff --git a/examples/models/core/multimodal/README.md b/examples/models/core/multimodal/README.md index 94965e832bf..a45e3fc724b 100644 --- a/examples/models/core/multimodal/README.md +++ b/examples/models/core/multimodal/README.md @@ -831,7 +831,7 @@ Note that for instruct Vision model, please set the `max_encoder_input_len` as ` ## NeVA -[NeVA](https://docs.nvidia.com/nemo-framework/user-guide/latest/multimodalmodels/neva/index.html) is a groundbreaking addition to the NeMo Multimodal ecosystem. This model seamlessly integrates large language-centric models with a vision encoder, that can be deployed in TensorRT-LLM. +[NeVA](https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/multimodal/mllm/neva.html) is a groundbreaking addition to the NeMo Multimodal ecosystem. This model seamlessly integrates large language-centric models with a vision encoder, that can be deployed in TensorRT-LLM. 1. Generate TRT-LLM engine for NVGPT following example in `examples/models/core/gpt/README.md`. To adhere to the NVGPT conventions of the conversion script, some layer keys have to be remapped using `--nemo_rename_key`. diff --git a/examples/sample_weight_stripping/README.md b/examples/sample_weight_stripping/README.md index 622f86abc1f..bd28a60b840 100644 --- a/examples/sample_weight_stripping/README.md +++ b/examples/sample_weight_stripping/README.md @@ -241,7 +241,7 @@ python3 ../summarize.py --engine_dir engines/llama2-70b-hf-fp8-tp2.refit \ ## Experimental ### Checkpoint Pruner -The checkpoint pruner allows you to strip `Conv` and `Gemm` weights out of a TensorRT-LLM [checkpoint](https://nvidia.github.io/TensorRT-LLM/new_workflow.html). Since these make up the vast majority of weights, the pruner will decrease the size of your checkpoint up to 99%. +The checkpoint pruner allows you to strip `Conv` and `Gemm` weights out of a TensorRT-LLM [checkpoint](https://nvidia.github.io/TensorRT-LLM/latest/architecture/checkpoint.html). Since these make up the vast majority of weights, the pruner will decrease the size of your checkpoint up to 99%. When building an engine with a pruned checkpoint, TensorRT-LLM fills in the missing weights with random ones. These weights should later be [refit](#engine-refitter) with the original weights to preserve the intended behavior. diff --git a/jenkins/L0_MergeRequest.groovy b/jenkins/L0_MergeRequest.groovy index 4583772ab6f..462e4f23280 100644 --- a/jenkins/L0_MergeRequest.groovy +++ b/jenkins/L0_MergeRequest.groovy @@ -785,8 +785,8 @@ def collectTestResults(pipeline, testFilter) sh "cd cov && coverage combine" sh "cd cov && find . -type f" - sh "cd cov && coverage report" - sh "cd cov && coverage html -d test_coverage_html" + sh "cd cov && coverage report -i" // -i: ignore errors. Ignore the error that the source code file cannot be found. + sh "cd cov && coverage html -d test_coverage_html -i" trtllm_utils.uploadArtifacts("cov/test_coverage_html/*", "${UPLOAD_PATH}/test-results/coverage-report/") echo "Test coverage report: https://urm.nvidia.com/artifactory/${UPLOAD_PATH}/test-results/coverage-report/index.html" } // Test coverage diff --git a/tensorrt_llm/_torch/modules/fused_moe.py b/tensorrt_llm/_torch/modules/fused_moe.py index 90fc1040041..4b7a065bae4 100755 --- a/tensorrt_llm/_torch/modules/fused_moe.py +++ b/tensorrt_llm/_torch/modules/fused_moe.py @@ -1,7 +1,7 @@ import copy import math import os -import threading +from concurrent.futures import ThreadPoolExecutor from enum import Enum, IntEnum from typing import Dict, List, NamedTuple, Optional, Union @@ -895,13 +895,17 @@ def create_weights(self): self.hidden_size, self.intermediate_size_per_partition // 2) - fc31_act_scale = nn.Parameter(torch.empty( - self.expert_size_per_partition, 1, dtype=self.dtype), + fc31_act_scale = nn.Parameter(torch.empty(1, + self.hidden_size, + dtype=self.dtype), requires_grad=False) self.register_parameter("fc31_act_scale", fc31_act_scale) fc2_act_scale = nn.Parameter(torch.empty( - self.expert_size_per_partition, 1, dtype=self.dtype), + 1, + self.intermediate_size_per_partition, + 1, + dtype=self.dtype), requires_grad=False) self.register_parameter("fc2_act_scale", fc2_act_scale) @@ -1625,47 +1629,57 @@ def load_expert_w2_weight(w2_weight, # Even though CPython has global interpreter lock (GIL), # it's still faster to load weights in parallel because it can utilize # CPU memory bandwidth better. - threads = [] - - for local_slot_id, expert_id in enumerate( - self.initial_local_expert_ids): - # expert_idx is the local slot index of current rank - expert_idx = local_slot_id + max_workers = min( + (self.expert_end - self.expert_start) * 2, + os.cpu_count() * 2, + 16, + ) - if self.weight_loading_mode == MoEWeightLoadingMode.VANILLA: - w1_weight = weights[f"{expert_id}.w1.weight"] - w3_weight = weights[f"{expert_id}.w3.weight"] - w2_weight = weights[f"{expert_id}.w2.weight"] - elif self.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ: - w1_w3_weight = weights["gate_up_proj"][expert_id].transpose( - 0, 1) - w1_weight, w3_weight = w1_w3_weight.chunk(2, dim=0) - w2_weight = weights["down_proj"][expert_id].transpose( - 0, 1).contiguous() - else: - raise NotImplementedError( - f"Unknown weight loading mode in MoE: {self.weight_loading_mode}" - ) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [] + + for local_slot_id, expert_id in enumerate( + self.initial_local_expert_ids): + # expert_idx is the local slot index of current rank + expert_idx = local_slot_id + + if self.weight_loading_mode == MoEWeightLoadingMode.VANILLA: + w1_weight = weights[f"{expert_id}.w1.weight"] + w3_weight = weights[f"{expert_id}.w3.weight"] + w2_weight = weights[f"{expert_id}.w2.weight"] + elif self.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ: + w1_w3_weight = weights["gate_up_proj"][expert_id].transpose( + 0, 1) + w1_weight, w3_weight = w1_w3_weight.chunk(2, dim=0) + w2_weight = weights["down_proj"][expert_id].transpose( + 0, 1).contiguous() + else: + raise NotImplementedError( + f"Unknown weight loading mode in MoE: {self.weight_loading_mode}" + ) - is_trtllm_nvfp4 = self.is_trtllm( - ) and self.quant_config.quant_mode.has_nvfp4() + is_trtllm_nvfp4 = self.is_trtllm( + ) and self.quant_config.quant_mode.has_nvfp4() - thread = threading.Thread(target=load_expert_w3_w1_weight, - args=(w1_weight, w3_weight, - self.w3_w1_weight.data[expert_idx], - is_trtllm_nvfp4)) - thread.start() - threads.append(thread) + future_w3_w1 = executor.submit( + load_expert_w3_w1_weight, + w1_weight, + w3_weight, + self.w3_w1_weight.data[expert_idx], + is_trtllm_nvfp4, + ) + futures.append(future_w3_w1) - thread = threading.Thread(target=load_expert_w2_weight, - args=(w2_weight, - self.w2_weight.data[expert_idx], - is_trtllm_nvfp4)) - thread.start() - threads.append(thread) + future_w2 = executor.submit( + load_expert_w2_weight, + w2_weight, + self.w2_weight.data[expert_idx], + is_trtllm_nvfp4, + ) + futures.append(future_w2) - for thread in threads: - thread.join() + for future in futures: + future.result() if self.quant_config and self.quant_config.quant_mode.has_any_quant( exclude_kv_cache=True): @@ -2042,12 +2056,14 @@ def _load_int4_groupwise_scales(self, weights: Dict): load_weight_shard(weights[f"{expert_id}.w1.input_scale"]) for expert_id in self.initial_local_expert_ids ] - all_w3_w1_input_scales = torch.max(torch.stack(all_w3_input_scales), - torch.stack(all_w1_input_scales)) - all_w3_w1_input_scales = torch.ones_like( - all_w3_w1_input_scales) * all_w3_w1_input_scales.max() - self.fc31_act_scale.data.copy_(1 / all_w3_w1_input_scales) - self.fc31_alpha.data.copy_(all_w3_w1_input_scales.float()) + all_w3_w1_input_scales_max = torch.max( + torch.stack(all_w3_input_scales), + torch.stack(all_w1_input_scales)).max() + self.fc31_act_scale.data.copy_( + torch.ones_like(self.fc31_act_scale) * + (1 / all_w3_w1_input_scales_max)) + self.fc31_alpha.data.copy_((torch.ones_like(self.fc31_alpha) * + all_w3_w1_input_scales_max).float()) all_w3_scales = [ load_weight_shard(weights[f"{expert_id}.w3.weight_scale_inv"], @@ -2083,11 +2099,12 @@ def _load_int4_groupwise_scales(self, weights: Dict): load_weight_shard(weights[f"{expert_id}.w2.input_scale"]) for expert_id in self.initial_local_expert_ids ] - all_w2_input_scales = torch.stack(all_w2_input_scales).to(self.dtype) - all_w2_input_scales = torch.ones_like( - all_w2_input_scales) * all_w2_input_scales.max() - self.fc2_act_scale.data.copy_(1 / all_w2_input_scales) - self.fc2_alpha.data.copy_(all_w2_input_scales.float()) + all_w2_input_scales_max = torch.stack(all_w2_input_scales).to( + self.dtype).max() + self.fc2_act_scale.data.copy_( + torch.ones_like(self.fc2_act_scale) * (1 / all_w2_input_scales_max)) + self.fc2_alpha.data.copy_( + (torch.ones_like(self.fc2_alpha) * all_w2_input_scales_max).float()) all_w2_scales = [ load_weight_shard(weights[f"{expert_id}.w2.weight_scale_inv"], diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 2cc3bca8e5a..7245eba4b64 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -90,7 +90,8 @@ def cal_max_tokens(peak_memory, total_gpu_memory, fraction, model_config, logger.info( f"Peak memory during memory usage profiling (torch + non-torch): {peak_memory / (GB):.2f} GiB, " f"available KV cache memory when calculating max tokens: {available_kv_mem / (GB):.2f} GiB, " - f"fraction is set {fraction}, kv size is {kv_size_per_token}") + f"fraction is set {fraction}, kv size is {kv_size_per_token}, device total memory {total_gpu_memory / (GB):.2f} GiB, " + f", tmp kv_mem { (alloc_kv_tokens * kv_size_per_token) / (GB):.2f} GiB") max_tokens = int((available_kv_mem) // kv_size_per_token) max_tokens = max(max_tokens, 0) return max_tokens @@ -155,10 +156,15 @@ def estimate_max_kv_cache_tokens( torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() + end, total_gpu_memory = torch.cuda.mem_get_info() + total_used_bytes = total_gpu_memory - end model_bytes = torch.cuda.memory_stats()["allocated_bytes.all.current"] logger.info( f"Memory used after loading model weights (inside torch) in memory usage profiling: {model_bytes / (GB):.2f} GiB" ) + logger.info( + f"Memory used after loading model weights (outside torch) in memory usage profiling: {((total_used_bytes - model_bytes) if total_used_bytes > model_bytes else 0) / (GB):.2f} GiB" + ) py_executor.set_gather_responses(True) origin_iter_stats = py_executor.enable_iter_perf_stats diff --git a/tensorrt_llm/layers/moe.py b/tensorrt_llm/layers/moe.py index 6b55cb4643b..46e48f9d96a 100755 --- a/tensorrt_llm/layers/moe.py +++ b/tensorrt_llm/layers/moe.py @@ -500,8 +500,8 @@ def __init__(self, in_features: int, out_features: int, else: self.register_parameter('zero', None) if groupwise_quant_algo & GroupwiseQuantAlgo.PRE_QUANT_SCALE: - self.prequant_scaling_factor = Parameter( - shape=(experts_per_node, 1), dtype=dtype) + self.prequant_scaling_factor = Parameter(shape=(1, in_features), + dtype=dtype) else: self.register_parameter('prequant_scaling_factor', None) if groupwise_quant_algo & GroupwiseQuantAlgo.W4A8_ALPHA: diff --git a/tensorrt_llm/runtime/multimodal_model_runner.py b/tensorrt_llm/runtime/multimodal_model_runner.py index 495006e80a5..73518dde33d 100644 --- a/tensorrt_llm/runtime/multimodal_model_runner.py +++ b/tensorrt_llm/runtime/multimodal_model_runner.py @@ -909,9 +909,6 @@ def preprocess(self, pre_prompt, post_prompt, image, other_vision_inputs, elif self.model_type == 'pixtral': # Hold on to pixel_values and input_ids. dtype = str_dtype_to_torch(self.vision_precision) - pixel_values = image["pixel_values"].to(device="cuda", dtype=dtype) - input_ids = image["input_ids"].to(device="cuda") - # Shape of pixel values from the processor varies with the raw image. # So we create a new tensor with a fixed shape as expected by the vision # encoder and create a corresponding attention mask. @@ -919,19 +916,30 @@ def preprocess(self, pre_prompt, post_prompt, image, other_vision_inputs, patch_size = self.patch_size d_min = torch.finfo(dtype).min num_patches = (image_size // patch_size) - image = torch.full((1, 3, image_size, image_size), - fill_value=0, - dtype=dtype, - device="cuda") - attention_mask = torch.full((1, num_patches, num_patches), - fill_value=d_min, - dtype=dtype, - device="cuda") - h, w = pixel_values.shape[-2:] - image[..., :h, :w] = pixel_values - attention_mask[..., :h // patch_size, :w // patch_size] = 0 + padded_image = torch.full( + (self.args.batch_size, 3, image_size, image_size), + fill_value=0, + dtype=dtype, + device="cuda") + padded_attention_mask = torch.full( + (self.args.batch_size, num_patches, num_patches), + fill_value=d_min, + dtype=dtype, + device="cuda") + h, w, input_ids = [], [], [] + for img_idx in range(self.args.batch_size): + pixel_values = image["pixel_values"][img_idx] + img_h, img_w = pixel_values.shape[-2:] + padded_image[img_idx, :, :img_h, :img_w] = pixel_values + padded_attention_mask[img_idx, :img_h // patch_size, :img_w // + patch_size] = 0 + input_ids.append(image["input_ids"][img_idx]) + h.append(img_h) + w.append(img_w) + + image = padded_image other_vision_inputs = { - "attention_mask": attention_mask, + "attention_mask": padded_attention_mask, } elif self.model_type == 'llava_next': input = image @@ -1150,12 +1158,29 @@ def preprocess(self, pre_prompt, post_prompt, image, other_vision_inputs, elif self.model_type == 'pixtral': relevant_patch_size = self.patch_size * self.spatial_merge_size output_img_size = self.image_size // relevant_patch_size - visual_features = visual_features.reshape( - output_img_size, output_img_size, - -1)[:h // relevant_patch_size, :w // - relevant_patch_size].flatten(0, 1) + # Note: max_h * max_w shall serve as the `tokens_per_task` in ptuning prompt table. + max_h = max(h) // relevant_patch_size + max_w = max(w) // relevant_patch_size + visual_embed_dim = visual_features.shape[-1] + relevant_visual_features = torch.zeros(self.args.batch_size, + max_h * max_w, + visual_embed_dim) + for img_idx in range(self.args.batch_size): + complete_features = visual_features[img_idx] + complete_features = complete_features.reshape( + output_img_size, output_img_size, visual_embed_dim) + relevant_h = h[img_idx] // relevant_patch_size + relevant_w = w[img_idx] // relevant_patch_size + flattened_features = complete_features[:relevant_h, : + relevant_w, :].flatten( + 0, 1) + relevant_visual_features[img_idx, :relevant_h * + relevant_w, :] = flattened_features + visual_features = relevant_visual_features input_ids = self.ptuning_setup_pixtral(input_ids=input_ids) - length = input_ids.shape[1] + # Note: length is not used for pixtral model downstream. Setting it to a list + # of length of input_ids causes errors downstream. So, supplying a placeholder. + length = input_ids[0].shape[0] elif self.model_type == 'llava_next': visual_features = LlavaNextUtils.rearrange_image_features( @@ -2027,16 +2052,19 @@ def ptuning_setup_fuyu(self, input_ids, image_patches_indices): def ptuning_setup_pixtral(self, input_ids): # input_ids obtained from processor has token_ids for text as well as image tokens - # where each image token is represented the same image_token_index (10 for this model). + # where each image token is represented by the same image_token_index. image_token_index = self.image_token_index vocab_size = self.vocab_size # Replace all image tokens with a unique token_id > text_vacab_size. # This shall be used to lookup the prompt table. - replacer = vocab_size - for i in range(len(input_ids[0])): - if input_ids[0][i] == image_token_index: - input_ids[0][i] = replacer - replacer += 1 + for img_idx in range(self.args.batch_size): + # Note: We reset replacer to text_vocab_size for each sample. This is as opposed to doing `replacer = vocab_size + img_idx * tokens_per_task`. + # That part of the look-up manipulation is done by the `task_ids` input to PromptEmbedding forward. + replacer = vocab_size + for token_idx in range(len(input_ids[img_idx])): + if input_ids[img_idx][token_idx] == image_token_index: + input_ids[img_idx][token_idx] = replacer + replacer += 1 return input_ids def ptuning_setup_llava_next(self, visual_features, pre_prompt, @@ -2166,7 +2194,24 @@ def load_images(image_paths): if isinstance(image_path, str): image_path = image_path.split(self.args.path_sep) images = load_images(image_path) - + elif "pixtral" in self.model_type: + if image_path is None: + image_urls = [ + "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/assets/merlion.png", + "https://www.ilankelman.org/stopsigns/australia.jpg", + "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.png", + "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", + ] + while len(image_urls) < self.args.batch_size: + image_urls *= 2 + image_urls = image_urls[:self.args.batch_size] + self.args.image_path = ",".join(image_urls) + images = load_images(image_urls) + else: + if isinstance(image_path, str): + image_path = image_path.split(self.args.path_sep) + images = load_images(image_path) + images = [images] if not isinstance(images, list) else images elif "nougat" in self.model_type: filepath = hf_hub_download( repo_id="hf-internal-testing/fixtures_docvqa", @@ -2413,9 +2458,15 @@ def setup_inputs(self, input_text, raw_image, raw_audio=None): post_prompt = "[/INST]" prompt = pre_prompt + input_text + post_prompt dtype = str_dtype_to_torch(self.vision_precision) - image = self.processor(text=prompt, - images=[raw_image], - return_tensors="pt").to(dtype) + image = {'pixel_values': [], 'input_ids': []} + for img_idx in range(self.args.batch_size): + image_info = self.processor(text=prompt, + images=[raw_image[img_idx]], + return_tensors="pt").to(dtype) + image['pixel_values'].append(image_info['pixel_values'].to( + self.device)) + image['input_ids'].append(image_info['input_ids'][0].to( + self.device)) elif 'internvl' in self.model_type: pre_prompt = "<|system|>\n你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。<|end|><|user|>\n\n" @@ -2619,7 +2670,9 @@ def setup_inputs(self, input_text, raw_image, raw_audio=None): image = image.expand( min(self.args.batch_size, len(input_text)), -1, -1, -1).contiguous() - if image is not None: + # Note: For pixtral model, image is a dict with each value being a list of tensors. + # Moving to device is handled above. So, it's safe to skip this for pixtral. + if image is not None and 'pixtral' not in self.model_type: image = image.to(self.device) # Generate decoder_input_ids for enc-dec models # Custom prompts can be added as: diff --git a/tensorrt_llm/tools/multimodal_builder.py b/tensorrt_llm/tools/multimodal_builder.py index 95c4b066d2f..d4a8e8287b5 100644 --- a/tensorrt_llm/tools/multimodal_builder.py +++ b/tensorrt_llm/tools/multimodal_builder.py @@ -1627,8 +1627,12 @@ def attn_forward(self, cos, sin = position_embeddings q, k = apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=0) + # attention_mask is of shape [batch, patches]. + mask = attention_mask[:, None, None, :] + attn_output = torch.nn.functional.scaled_dot_product_attention( - q, k, v, attn_mask=attention_mask).transpose(1, 2).contiguous() + q, k, v, attn_mask=mask).transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch, patches, -1) attn_output = self.o_proj(attn_output) diff --git a/tests/integration/defs/accuracy/test_cli_flow.py b/tests/integration/defs/accuracy/test_cli_flow.py index 4e1a5292d2d..32d838a0b28 100644 --- a/tests/integration/defs/accuracy/test_cli_flow.py +++ b/tests/integration/defs/accuracy/test_cli_flow.py @@ -1002,7 +1002,8 @@ def test_tp2(self): @pytest.mark.parametrize( "moe_tp_size", [1, 4, 8], ids=['expert_parallel', 'mixed_parallel', 'tensor_parallel']) - def test_ootb_except_mha_tp8(self, moe_tp_size): + def test_ootb_except_mha_tp8(self, moe_tp_size, mocker): + mocker.patch.object(CnnDailymail, "MAX_BATCH_SIZE", 1) self.run(tp_size=8, extra_convert_args=[ f"--moe_tp_size={moe_tp_size}", diff --git a/tests/integration/defs/accuracy/test_llm_api.py b/tests/integration/defs/accuracy/test_llm_api.py index d97f518616e..8e64ffacf87 100644 --- a/tests/integration/defs/accuracy/test_llm_api.py +++ b/tests/integration/defs/accuracy/test_llm_api.py @@ -113,6 +113,7 @@ class TestMixtral8x7B(LlmapiAccuracyTestHarness): MODEL_NAME = "mistralai/Mixtral-8x7B-v0.1" MODEL_PATH = f"{llm_models_root()}/Mixtral-8x7B-v0.1" + @pytest.mark.skip_less_device_memory(80000) @pytest.mark.skip_less_device(2) def test_tp2(self): with LLM(self.MODEL_PATH, tensor_parallel_size=2) as llm: diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index d39fc823aee..c1bceb763b2 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -1003,6 +1003,7 @@ class TestLlama3_3NemotronSuper49Bv1(LlmapiAccuracyTestHarness): MODEL_PATH = f"{llm_models_root()}/nemotron-nas/Llama-3_3-Nemotron-Super-49B-v1" @pytest.mark.skip_less_device(2) + @pytest.mark.skip_less_device_memory(80000) def test_auto_dtype_tp2(self): with LLM(self.MODEL_PATH, tensor_parallel_size=2) as llm: task = MMLU(self.MODEL_NAME) diff --git a/tests/integration/defs/examples/test_eagle.py b/tests/integration/defs/examples/test_eagle.py index 5d9d6bb2a4a..a7f347f442c 100644 --- a/tests/integration/defs/examples/test_eagle.py +++ b/tests/integration/defs/examples/test_eagle.py @@ -305,6 +305,7 @@ def test_mistral_eagle_1gpu(llm_mistral_model_root, @skip_post_blackwell @skip_pre_ada +@pytest.mark.skip_less_device_memory(80000) @pytest.mark.parametrize("use_dynamic_tree", [False, True], ids=['eagle1', 'eagle2']) @pytest.mark.parametrize("mistral_nemo_model_root", ['Mistral-Nemo-12b-Base'], diff --git a/tests/integration/defs/examples/test_multimodal.py b/tests/integration/defs/examples/test_multimodal.py index 36bc5d758cf..bf6ac91d45b 100644 --- a/tests/integration/defs/examples/test_multimodal.py +++ b/tests/integration/defs/examples/test_multimodal.py @@ -81,6 +81,8 @@ def _test_llm_multimodal_general(llm_venv, if "neva-22b" in tllm_model_name and get_device_memory() < 80000: pytest.skip("GPU memory is insufficient.") + if "Mistral-Small" in tllm_model_name and get_device_memory() < 80000: + pytest.skip("GPU memory is insufficient.") print("Converting huggingface model into binary format...") # ckpt from llm_models/ --> cmodels// diff --git a/tests/integration/defs/test_e2e.py b/tests/integration/defs/test_e2e.py index b8a0f059008..bc8f128754f 100644 --- a/tests/integration/defs/test_e2e.py +++ b/tests/integration/defs/test_e2e.py @@ -51,8 +51,33 @@ def _get_mem_info_from_log(file, ranks_num): # only when TLLM_LOG_LEVEL=INFO pattern = re.compile(r"\[MemUsageChange] Allocated ([\d]+\.[\d]+) GiB ") fraction_pattern = re.compile(r"fraction is set ([\d]+\.[\d]+), ") + total_mem_pattern = re.compile(r"device total memory ([\d]+\.[\d]+) GiB") + peak_mem_pattern = re.compile( + r"Peak memory during memory usage profiling \(torch \+ non-torch\): ([\d]+\.[\d]+) GiB" + ) + extra_mem_pattern = re.compile( + r"Memory used outside torch \(e\.g\., NCCL and CUDA graphs\) in memory usage profiling: ([\d]+\.[\d]+) GiB" + ) + activation_pattern = re.compile( + r"Memory dynamically allocated during inference \(inside torch\) in memory usage profiling: ([\d]+\.[\d]+) GiB" + ) + model_pattern = re.compile( + r"Memory used after loading model weights \(inside torch\) in memory usage profiling: ([\d]+\.[\d]+) GiB" + ) + tmp_kv_patterm = re.compile(r"tmp kv_mem ([\d]+\.[\d]+) GiB") + start_time_mem_pattern = re.compile( + r"Memory used after loading model weights \(outside torch\) in memory usage profiling: ([\d]+\.[\d]+) GiB" + ) + fraction = 0.90 kv_mem_size = [] + total_memory = [] + peak_memory = [] + extra_memory = [] + activation_memory = [] + model_memory = [] + tmp_kv = [] + start_time_mem = [] file.seek(0) lines = file.readlines() for line in lines: @@ -62,38 +87,67 @@ def _get_mem_info_from_log(file, ranks_num): match = fraction_pattern.findall(line) if len(match) > 0: fraction = float(match[0]) + match = total_mem_pattern.findall(line) + if len(match) > 0: + total_memory.append(float(match[0])) + match = peak_mem_pattern.findall(line) + if len(match) > 0: + peak_memory.append(float(match[0])) + match = extra_mem_pattern.findall(line) + if len(match) > 0: + extra_memory.append(float(match[0])) + match = activation_pattern.findall(line) + if len(match) > 0: + activation_memory.append(float(match[0])) + match = model_pattern.findall(line) + if len(match) > 0: + model_memory.append(float(match[0])) + match = tmp_kv_patterm.findall(line) + if len(match) > 0: + tmp_kv.append(float(match[0])) + match = start_time_mem_pattern.findall(line) + if len(match) > 0: + start_time_mem.append(float(match[0])) + assert len( kv_mem_size) % 2 == 0, "no enough memory usage information in log" kv_mem_size = kv_mem_size[len(kv_mem_size) // 2:] - return 0, 0, sum(kv_mem_size) / ranks_num, 0, fraction + return peak_memory, model_memory, sum( + kv_mem_size + ) / ranks_num, extra_memory, fraction, total_memory, activation_memory, sum( + tmp_kv) / ranks_num, sum(start_time_mem) - ranks_num -def _get_kv_mem_size_candidate(used_Gib, fraction): - import torch - _, total = torch.cuda.mem_get_info() - return (total / (1 << 30) - used_Gib) * fraction +def _get_kv_mem_size_candidate(total_Gib, used_Gib, fraction): + return (total_Gib - used_Gib) * fraction def _check_mem_usage(file, mem_info, ranks_num=1): if file is None or not TEST_MEM_USAGE: return delta = 0.2 # 0.2 GB as buffer - peak, model_size, kv_mem_size, extra, fraction = _get_mem_info_from_log( + peak, model_size, kv_mem_size, extra, fraction, total_memory, activation_memory, tmp_kv, start_time_mem = _get_mem_info_from_log( file, ranks_num) + peak = max(peak) + min_total = min(total_memory) e_peak, e_model_size, e_kv_mem_size, e_extra = mem_info - e_kv_mem_size = _get_kv_mem_size_candidate(e_peak, fraction) + import torch + _, total = torch.cuda.mem_get_info() + e_kv_mem_size = _get_kv_mem_size_candidate(min_total, + (e_peak + start_time_mem), + fraction) print( - f"Expected memory usage: peak mem {e_peak}, model mem {e_model_size}, kv mem {e_kv_mem_size}, extra {e_extra}" + f"Expected memory usage: peak mem {e_peak + start_time_mem}, model mem {e_model_size}, kv mem {e_kv_mem_size:.2f}, extra {e_extra}, total {total / (1 << 30):.2f}" ) print( - f"Running memory information: peak mem {peak}, model mem {model_size}, kv mem {kv_mem_size}, extra {extra}" + f"Running memory information: peak mem {peak}, model mem {model_size}, kv mem {kv_mem_size}, extra {extra}, total {min_total}, activation {activation_memory}, tmp_kv {tmp_kv}, fraction {fraction}, none-torch memory at starttime {start_time_mem}" ) - assert peak <= e_peak + delta, f"peak memory {peak} is larger than expected {e_peak}" - assert model_size <= e_model_size + delta, f"model memory {model_size} is larger than expected {e_model_size}" + assert peak - tmp_kv <= e_peak + start_time_mem + delta, f"peak memory {peak} is larger than expected {e_peak}" assert kv_mem_size >= e_kv_mem_size - delta, f"kv memory size {kv_mem_size} is smaller than expected {e_kv_mem_size}" - assert extra <= e_extra + delta, f"extra memory size {extra} is larger than expected {e_extra}" + # assert model_size <= e_model_size + delta, f"model memory {model_size} is larger than expected {e_model_size}" + # assert max(extra) <= e_extra + delta, f"extra memory size {extra} is larger than expected {e_extra}" def test_gpt3_175b_1layers_build_only(llm_root, llm_venv, engine_dir): @@ -1463,7 +1517,6 @@ def test_ptp_quickstart(llm_root, llm_venv): ("Llama3.2-11B-BF16", "llama-3.2-models/Llama-3.2-11B-Vision"), ("Nemotron4_4B-BF16", "nemotron/Minitron-4B-Base"), ("Nemotron-H-8B", "Nemotron-H-8B-Base-8K"), - ("Qwen3-30B-A3B", "Qwen3/Qwen3-30B-A3B"), pytest.param('Llama3.1-8B-NVFP4', 'nvfp4-quantized/Meta-Llama-3.1-8B', marks=skip_pre_blackwell), @@ -1488,6 +1541,9 @@ def test_ptp_quickstart(llm_root, llm_venv): pytest.param('Mixtral-8x7B-FP8', 'Mixtral-8x7B-Instruct-v0.1-fp8', marks=skip_pre_blackwell), + pytest.param('Qwen3-30B-A3B', + 'Qwen3/Qwen3-30B-A3B', + marks=pytest.mark.skip_less_device_memory(80000)), ]) def test_ptp_quickstart_advanced(llm_root, llm_venv, model_name, model_path): print(f"Testing {model_name}.") @@ -1549,7 +1605,7 @@ def test_ptq_quickstart_advanced_mtp(llm_root, llm_venv, model_name, f"{llm_models_root()}/{model_path}", ], stdout=running_log) - _check_mem_usage(running_log, [54.50, 0, 0, 0]) + _check_mem_usage(running_log, [54.60, 0, 0, 0]) @pytest.mark.skip_less_device_memory(80000) diff --git a/tests/integration/test_lists/qa/examples_test_list.txt b/tests/integration/test_lists/qa/examples_test_list.txt index f7a1b916fa9..33303578ad5 100644 --- a/tests/integration/test_lists/qa/examples_test_list.txt +++ b/tests/integration/test_lists/qa/examples_test_list.txt @@ -185,7 +185,9 @@ examples/test_multimodal.py::test_llm_multimodal_general[llava-v1.6-mistral-7b-h examples/test_multimodal.py::test_llm_multimodal_general[llava-v1.6-mistral-7b-hf-vision-trtllm-pp:1-tp:2-float16-bs:1-cpp_e2e:False-nb:1] examples/test_multimodal.py::test_llm_multimodal_general[llava-onevision-qwen2-7b-ov-hf-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1] examples/test_multimodal.py::test_llm_multimodal_general[llava-onevision-qwen2-7b-ov-hf-video-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1] -examples/test_multimodal.py::test_llm_multimodal_general[Mistral-Small-3.1-24B-Instruct-2503-pp:1-tp:1-bfloat16-bs:1-cpp_e2e:False-nb:1] +examples/test_multimodal.py::test_llm_multimodal_general[Mistral-Small-3.1-24B-Instruct-2503-pp:1-tp:1-bfloat16-bs:8-cpp_e2e:False-nb:1] +examples/test_multimodal.py::test_llm_multimodal_general[neva-22b-pp:1-tp:1-bfloat16-bs:1-cpp_e2e:False-nb:1] +examples/test_multimodal.py::test_llm_multimodal_general[neva-22b-pp:1-tp:1-bfloat16-bs:8-cpp_e2e:False-nb:1] examples/test_multimodal.py::test_llm_multimodal_general[nougat-base-pp:1-tp:1-bfloat16-bs:1-cpp_e2e:False-nb:1] examples/test_multimodal.py::test_llm_multimodal_general[nougat-base-pp:1-tp:1-bfloat16-bs:8-cpp_e2e:False-nb:1] examples/test_multimodal.py::test_llm_multimodal_general[video-neva-pp:1-tp:1-bfloat16-bs:1-cpp_e2e:False-nb:1] diff --git a/tests/integration/test_lists/test-db/l0_h100.yml b/tests/integration/test_lists/test-db/l0_h100.yml index 56f5d769e1d..42e17092ebf 100644 --- a/tests/integration/test_lists/test-db/l0_h100.yml +++ b/tests/integration/test_lists/test-db/l0_h100.yml @@ -241,7 +241,7 @@ l0_h100: - examples/test_multimodal.py::test_llm_multimodal_general[Phi-3-vision-128k-instruct-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1] - examples/test_multimodal.py::test_llm_multimodal_general[Phi-3.5-vision-instruct-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1] - examples/test_multimodal.py::test_llm_multimodal_general[Phi-4-multimodal-instruct-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1] - - examples/test_multimodal.py::test_llm_multimodal_general[Mistral-Small-3.1-24B-Instruct-2503-pp:1-tp:1-bfloat16-bs:1-cpp_e2e:False-nb:1] + - examples/test_multimodal.py::test_llm_multimodal_general[Mistral-Small-3.1-24B-Instruct-2503-pp:1-tp:1-bfloat16-bs:8-cpp_e2e:False-nb:1] - examples/test_multimodal.py::test_llm_multimodal_general[VILA1.5-3b-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1] # 10 mins - examples/test_enc_dec.py::test_llm_enc_dec_mmlu[flan-t5-small-float32-tp:1-pp:1-nb:1-enable_fp8] # 7 mins - examples/test_enc_dec.py::test_llm_enc_dec_general[compare_hf-bart-large-cnn-float16-enable_gemm_plugin-enable_attention_plugin-enable_paged_kv_cache-tp:1-pp:1-nb:1-enable_fp8] # 13 mins diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index 1248dc3285b..e53b8d096ad 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -411,7 +411,6 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[ep4-mt accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5294983) accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5239087) accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5239087) -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5234002) accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[-] SKIP (https://nvbugs/5234002) examples/test_gemma.py::test_llm_hf_gemma_quantization_1gpu[gemma-2-27b-it-fp8-bfloat16-8] SKIP (https://nvbugs/5234164) full::GH200/examples/test_commandr.py::test_llm_commandr_v01_single_gpu_summary[disable_weight_only] SKIP (https://nvbugs/5250460) @@ -475,11 +474,6 @@ test_e2e.py::test_ptp_quickstart_advanced_8gpus[Llama3.1-70B-BF16-llama-3.1-mode test_e2e.py::test_ptp_quickstart_advanced_8gpus[Mixtral-8x7B-BF16-Mixtral-8x7B-v0.1] SKIP (https://nvbugs/5136994) test_e2e.py::test_ptp_quickstart_advanced_8gpus[Llama3.1-70B-FP8-llama-3.1-model/Llama-3.1-70B-Instruct-FP8] SKIP (https://nvbugs/5289909) test_e2e.py::test_ptp_scaffolding[DeepSeek-R1-Distill-Qwen-7B-DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B] SKIP (https://nvbugs/5289910) -disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_ucx[DeepSeek-V3-Lite-fp8] SKIP (https://nvbugs/5289912) -disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8[DeepSeek-V3-Lite-fp8] SKIP (https://nvbugs/5232406) -disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_attention_dp[DeepSeek-V3-Lite-fp8] SKIP (https://nvbugs/5232406) -disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_attention_dp_one[DeepSeek-V3-Lite-fp8] SKIP (https://nvbugs/5232406) -disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_attention_dp_one_mtp[DeepSeek-V3-Lite-fp8] SKIP (https://nvbugs/5232406) full:B200/examples/test_gemma.py::test_llm_gemma_1gpu_summary_vswa[gemma-3-1b-it-other-bfloat16-8] SKIP (https://nvbugs/5292737) perf/test_perf.py::test_perf[mixtral_8x22b_v0.1-bench-float16-input_output_len:512,512-quant:fp8-tp:4] SKIP (https://nvbugspro.nvidia.com/bug/5274894) perf/test_perf.py::test_perf[llama_v3.1_70b-bench-bfloat16-input_output_len:128,128-gpus:4] SKIP (https://nvbugspro.nvidia.com/bug/5274894) diff --git a/tests/unittest/trt/quantization/test_moe_weight_only_groupwise_quant_matmul.py b/tests/unittest/trt/quantization/test_moe_weight_only_groupwise_quant_matmul.py index 757f769d04e..c48326e205d 100644 --- a/tests/unittest/trt/quantization/test_moe_weight_only_groupwise_quant_matmul.py +++ b/tests/unittest/trt/quantization/test_moe_weight_only_groupwise_quant_matmul.py @@ -154,14 +154,14 @@ def _woq_moe_groupwise_matmul(self, 2**31, (num_experts, n, k // num_weights_in_32_bits), dtype=torch.int32, device="cuda") - pre_quant_scale_1 = torch.ones(num_experts, - 1, - dtype=activation_dtype, - device="cuda") - pre_quant_scale_2 = torch.ones(num_experts, - 1, - dtype=activation_dtype, - device="cuda") + pre_quant_scale_1 = torch.randn(1, + k, + dtype=activation_dtype, + device="cuda") + pre_quant_scale_2 = torch.randn(1, + n, + dtype=activation_dtype, + device="cuda") scale_1 = torch.randn(num_experts, k // group_size, n * 2, @@ -182,8 +182,10 @@ def _woq_moe_groupwise_matmul(self, k, dtype=activation_dtype, device="cuda") * 0.01 - alpha_1 = 1 / pre_quant_scale_1.float() - alpha_2 = 1 / pre_quant_scale_2.float() + alpha_1 = torch.randn( + num_experts, 1, dtype=torch.float32, device="cuda") * 0.1 + alpha_2 = torch.randn( + num_experts, 1, dtype=torch.float32, device="cuda") * 0.1 preprocessor = tensorrt_llm.quantization.functional.preprocess_weights_for_mixed_gemm unpacker = torch.ops.trtllm.unpack_int4_packed_tensor_to_int8 @@ -236,7 +238,7 @@ def _woq_moe_groupwise_matmul(self, input = inputs_merged[i, :] fc1_qd = ref_weight_1[expert].cuda().float() if has_pre_quant: - input = input * pre_quant_scale_1[expert] + input = input * pre_quant_scale_1.squeeze() if has_alpha: input[input > 448.0] = 448.0 input = input.to(torch.float8_e4m3fn).float() @@ -248,7 +250,7 @@ def _woq_moe_groupwise_matmul(self, fc1 = fc1 * torch.nn.functional.silu(gate) fc2_qd = ref_weight_2[expert].cuda().float() if has_pre_quant: - fc1 = fc1 * pre_quant_scale_2[expert] + fc1 = fc1 * pre_quant_scale_2.squeeze() if has_alpha: fc1[fc1 > 448.0] = 448.0 fc1 = fc1.to(torch.float8_e4m3fn).float()