From 4946fb6b29d2040026f56fdb8654641002e2d93d Mon Sep 17 00:00:00 2001 From: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> Date: Mon, 19 May 2025 02:13:36 +0000 Subject: [PATCH 1/3] fix: Align PP layer distribution between pytorch and TRT flow. Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> --- .../tensorrt_llm/runtime/modelConfig.h | 29 +++++++++++++++---- .../batch_manager/loraBuffers.cpp | 3 +- .../batch_manager/peftCacheManager.cpp | 3 +- .../batch_manager/rnnStateManager.cpp | 7 +++-- .../trtGptModelInflightBatching.cpp | 3 +- cpp/tensorrt_llm/pybind/bindings.cpp | 3 +- cpp/tensorrt_llm/runtime/loraCache.cpp | 5 ++-- cpp/tensorrt_llm/runtime/loraManager.cpp | 5 ++-- .../unit_tests/runtime/loraManagerTest.cpp | 4 +-- tensorrt_llm/mapping.py | 11 ++++--- tensorrt_llm/models/generation_mixin.py | 15 ++++------ tensorrt_llm/runtime/model_runner_cpp.py | 4 ++- tests/integration/test_lists/waives.txt | 1 - 13 files changed, 57 insertions(+), 36 deletions(-) diff --git a/cpp/include/tensorrt_llm/runtime/modelConfig.h b/cpp/include/tensorrt_llm/runtime/modelConfig.h index daf8bd78175..d5b99cc0992 100644 --- a/cpp/include/tensorrt_llm/runtime/modelConfig.h +++ b/cpp/include/tensorrt_llm/runtime/modelConfig.h @@ -167,23 +167,40 @@ class ModelConfig LayerType layerType, SizeType32 pipelineParallelism = 1, SizeType32 pipelineParallelismRank = 0) const { TLLM_CHECK_WITH_INFO(pipelineParallelism > 0, "Invalid pipelineParallelism: %d", pipelineParallelism); - auto const numLocalLayers = mNbLayers / pipelineParallelism; // WARNING: assume no remainder - auto const firstLocalLayerIt = mLayerTypes.cbegin() + (numLocalLayers * pipelineParallelismRank); + auto const numBaseLayers = mNbLayers / pipelineParallelism; + auto const numExtraLayers = mNbLayers % pipelineParallelism; + auto const firstLocalLayer + = pipelineParallelismRank * numBaseLayers + std::min(pipelineParallelismRank, numExtraLayers); + // If num_layers % pp_size = n != 0, first n ranks get one extra layer + auto const numLocalLayers = numBaseLayers + (pipelineParallelismRank < numExtraLayers ? 1 : 0); + auto const firstLocalLayerIt = mLayerTypes.cbegin() + firstLocalLayer; return std::count(firstLocalLayerIt, firstLocalLayerIt + numLocalLayers, layerType); } + [[nodiscard]] SizeType32 getFirstLocalLayer( + SizeType32 pipelineParallelism = 1, SizeType32 pipelineParallelismRank = 0) const + { + auto const numBaseLayers = mNbLayers / pipelineParallelism; + auto const numExtraLayers = mNbLayers % pipelineParallelism; + // If num_layers % pp_size = n != 0, first n ranks get one extra layer + return pipelineParallelismRank * numBaseLayers + std::min(pipelineParallelismRank, numExtraLayers); + } + [[nodiscard]] SizeType32 countLowerRankLayers( LayerType layerType, SizeType32 pipelineParallelism = 1, SizeType32 pipelineParallelismRank = 0) const { - auto const numLocalLayers = mNbLayers / pipelineParallelism; // WARNING: assume no remainder - auto const firstLocalLayer = numLocalLayers * pipelineParallelismRank; + auto const firstLocalLayer = getFirstLocalLayer(pipelineParallelism, pipelineParallelismRank); // count number of previous non-local attention layers return std::count(mLayerTypes.cbegin(), mLayerTypes.cbegin() + firstLocalLayer, layerType); } - [[nodiscard]] SizeType32 getNbLayers(SizeType32 pipelineParallelism = 1) const + [[nodiscard]] SizeType32 getNbLayers( + SizeType32 pipelineParallelism = 1, SizeType32 pipelineParallelismRank = 0) const { - return mNbLayers / pipelineParallelism; // WARNING: assume no remainder + auto const numBaseLayers = mNbLayers / pipelineParallelism; + auto const numExtraLayers = mNbLayers % pipelineParallelism; + // If num_layers % pp_size = n != 0, first n ranks get one extra layer + return numBaseLayers + (pipelineParallelismRank < numExtraLayers ? 1 : 0); } [[nodiscard]] SizeType32 getNbAttentionLayers( diff --git a/cpp/tensorrt_llm/batch_manager/loraBuffers.cpp b/cpp/tensorrt_llm/batch_manager/loraBuffers.cpp index aa16c8505e0..b67b72f6c49 100644 --- a/cpp/tensorrt_llm/batch_manager/loraBuffers.cpp +++ b/cpp/tensorrt_llm/batch_manager/loraBuffers.cpp @@ -26,7 +26,8 @@ namespace tensorrt_llm::batch_manager LoraBuffers::LoraBuffers(SizeType32 maxBatchSize, SizeType32 maxBeamWidth, runtime::TllmRuntime const& tllmRuntime, runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig) { - auto const localNbLayers = modelConfig.getNbAttentionLayers(worldConfig.getPipelineParallelism()); + auto const localNbLayers + = modelConfig.getNbAttentionLayers(worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank()); auto const firstLayerId = worldConfig.getPipelineParallelRank() * localNbLayers; auto nbModelConfigs = static_cast(modelConfig.getLoraModules().size()); diff --git a/cpp/tensorrt_llm/batch_manager/peftCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/peftCacheManager.cpp index 75f53c33b43..8eeca23df35 100644 --- a/cpp/tensorrt_llm/batch_manager/peftCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/peftCacheManager.cpp @@ -102,7 +102,8 @@ PeftCacheManager::getPageManagerConfig(PeftCacheManagerConfig const& config, run auto const tpSize = worldConfig.getTensorParallelism(); auto const ppSize = worldConfig.getPipelineParallelism(); - auto const numLocalLayers = modelConfig.getNbAttentionLayers(ppSize); + auto const ppRank = worldConfig.getPipelineParallelRank(); + auto const numLocalLayers = modelConfig.getNbAttentionLayers(ppSize, ppRank); uint64_t min1dModSize = std::numeric_limits::max(); // used to setup the pageWidth uint64_t total1dModSize = 0; uint64_t total1lSlots = 0; // the slots we need for each layer, summing the slots of all modules diff --git a/cpp/tensorrt_llm/batch_manager/rnnStateManager.cpp b/cpp/tensorrt_llm/batch_manager/rnnStateManager.cpp index 598250d9a59..736458d98d1 100644 --- a/cpp/tensorrt_llm/batch_manager/rnnStateManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/rnnStateManager.cpp @@ -41,7 +41,8 @@ RnnStateManager::RnnStateManager(SizeType32 maxNumSequences, tensorrt_llm::runti auto const rnnHiddenSize = rnnConfig->rnnHiddenSize; auto const rnnHeadSize = rnnConfig->rnnHeadSize; auto const rnnConvDimSize = rnnConfig->rnnConvDimSize; - auto const localNbLayers = modelConfig.getNbRnnLayers(worldConfig.getPipelineParallelism()); + auto const localNbLayers + = modelConfig.getNbRnnLayers(worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank()); auto const dataType = modelConfig.getDataType(); auto const rnnStateShape = [&]() @@ -84,8 +85,8 @@ RnnStateManager::RnnStateManager(SizeType32 maxNumSequences, tensorrt_llm::runti void RnnStateManager::getPtrBuffers( TensorMap& inputBuffers, runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig) const { - auto const localNbLayers = modelConfig.getNbRnnLayers(worldConfig.getPipelineParallelism()); - auto const firstLayerId = worldConfig.getPipelineParallelRank() * localNbLayers; + auto const firstLayerId + = modelConfig.getFirstLocalLayer(worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank()); auto const& layerTypes = modelConfig.getLayerTypes(); utils::insertTensorVector( diff --git a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp index f5b90e279a5..a57bfb05364 100644 --- a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp +++ b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp @@ -1717,7 +1717,8 @@ void TrtGptModelInflightBatching::executeStep( // TODO: support layer-wise cross kv cache in encoder-decoder models if (!layerWiseRequests.empty() && !mModelConfig.useCrossAttention()) { - int const numLayers = mModelConfig.getNbAttentionLayers(mWorldConfig.getPipelineParallelism()); + int const numLayers = mModelConfig.getNbAttentionLayers( + mWorldConfig.getPipelineParallelism(), mWorldConfig.getPipelineParallelRank()); progress = std::make_shared(numLayers); } bufferCast(*mBuffers[bufferId]->transformerBuffers->contextProgressHost)[0] = progress.get(); diff --git a/cpp/tensorrt_llm/pybind/bindings.cpp b/cpp/tensorrt_llm/pybind/bindings.cpp index 8369a6af023..c6f040bfa62 100644 --- a/cpp/tensorrt_llm/pybind/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/bindings.cpp @@ -317,7 +317,8 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m) py::arg("num_heads"), py::arg("hidden_size"), py::arg("data_type")) .def_property_readonly("vocab_size", &tr::ModelConfig::getVocabSize) .def("vocab_size_padded", &tr::ModelConfig::getVocabSizePadded, py::arg("world_size")) - .def("num_layers", &tr::ModelConfig::getNbLayers, py::arg("pipeline_parallelism") = 1) + .def("num_layers", &tr::ModelConfig::getNbLayers, py::arg("pipeline_parallelism") = 1, + py::arg("pipeline_parallelism_rank") = 0) .def("num_attention_layers", &tr::ModelConfig::getNbAttentionLayers, py::arg("pipeline_parallelism") = 1, py::arg("pipeline_parallelism_rank") = 0) .def("num_rnn_layers", &tr::ModelConfig::getNbRnnLayers, py::arg("pipeline_parallelism") = 1, diff --git a/cpp/tensorrt_llm/runtime/loraCache.cpp b/cpp/tensorrt_llm/runtime/loraCache.cpp index 5b6ab1cc400..43a4d3d0642 100644 --- a/cpp/tensorrt_llm/runtime/loraCache.cpp +++ b/cpp/tensorrt_llm/runtime/loraCache.cpp @@ -454,7 +454,8 @@ SizeType32 LoraCache::determineNumPages(TaskIdType taskId) const SizeType32 LoraCache::determineNumPages(TensorPtr loraConfig) const { TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); - auto const localNumLayers = mModelConfig.getNbAttentionLayers(mWorldConfig.getPipelineParallelism()); + auto const localNumLayers = mModelConfig.getNbAttentionLayers( + mWorldConfig.getPipelineParallelism(), mWorldConfig.getPipelineParallelRank()); auto const firstLayerId = mWorldConfig.getPipelineParallelRank() * localNumLayers; auto const lastLayerId = firstLayerId + localNumLayers; @@ -579,7 +580,7 @@ std::vector LoraCache::copyToPages(TensorPtr s auto const ppSize = worldConfig.getPipelineParallelism(); auto const ppRank = worldConfig.getPipelineParallelRank(); // TODO(oargov): why *attention* layers? - auto const localNumLayers = modelConfig.getNbAttentionLayers(ppSize); + auto const localNumLayers = modelConfig.getNbAttentionLayers(ppSize, ppRank); auto const firstLayerId = ppRank * localNumLayers; auto const lastLayerId = firstLayerId + localNumLayers; diff --git a/cpp/tensorrt_llm/runtime/loraManager.cpp b/cpp/tensorrt_llm/runtime/loraManager.cpp index b38c7f389ea..7e81c848f29 100644 --- a/cpp/tensorrt_llm/runtime/loraManager.cpp +++ b/cpp/tensorrt_llm/runtime/loraManager.cpp @@ -72,7 +72,7 @@ void LoraManager::fillInputTensors(TensorPtr weightsPtrs, TensorPtr adapterSizes auto const ppSize = worldConfig.getPipelineParallelism(); auto const ppRank = worldConfig.getPipelineParallelRank(); - auto const localNumLayers = modelConfig.getNbAttentionLayers(ppSize); + auto const localNumLayers = modelConfig.getNbAttentionLayers(ppSize, ppRank); auto const firstLayerId = ppRank * localNumLayers; auto weightsPointersPtr = bufferCast(*weightsPtrs); @@ -123,7 +123,8 @@ void LoraManager::insertInputTensors(TensorMap& inputTensors, TensorPtr weightsP ModelConfig const& modelConfig, WorldConfig const& worldConfig) const { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - auto localNbLayers = modelConfig.getNbAttentionLayers(worldConfig.getPipelineParallelism()); + auto localNbLayers + = modelConfig.getNbAttentionLayers(worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank()); auto firstLayerId = worldConfig.getPipelineParallelRank() * localNbLayers; for (auto const& [modId, mod] : mModuleIdToModule) diff --git a/cpp/tests/unit_tests/runtime/loraManagerTest.cpp b/cpp/tests/unit_tests/runtime/loraManagerTest.cpp index 0f0efa30188..6910719da76 100644 --- a/cpp/tests/unit_tests/runtime/loraManagerTest.cpp +++ b/cpp/tests/unit_tests/runtime/loraManagerTest.cpp @@ -157,8 +157,8 @@ static void checkLoraTensors(LoraManager const& loraManager, std::vector(*weightsPtrs); ASSERT_EQ(targetPtrs.size(), weightsPtrs->getSize()); ASSERT_EQ(targetAdapterSizes.size(), adapterSizes->getSize()); - auto firstLayerId = modelConfig.getNbAttentionLayers(worldConfig.getPipelineParallelism()) - * worldConfig.getPipelineParallelRank(); + auto firstLayerId + = modelConfig.getFirstLocalLayer(worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank()); LoraManager::TensorMap expectedTensors; for (SizeType32 m = 0; m < numModules; ++m) diff --git a/tensorrt_llm/mapping.py b/tensorrt_llm/mapping.py index bef1f85b37a..ddf4cce877f 100644 --- a/tensorrt_llm/mapping.py +++ b/tensorrt_llm/mapping.py @@ -14,6 +14,8 @@ # limitations under the License. from typing import List +import torch + class Mapping(object): ''' @@ -420,12 +422,9 @@ def has_moe_ep(self): return self.moe_ep_size > 1 def pp_layers(self, num_layers: int) -> List[int]: - base_layers = num_layers // self.pp_size - extra_layers = num_layers % self.pp_size - start_idx = self.pp_rank * base_layers + min(self.pp_rank, extra_layers) - layers_in_stage = base_layers + (1 - if self.pp_rank < extra_layers else 0) - return list(range(start_idx, start_idx + layers_in_stage)) + # If num_layers % pp_size = n != 0, first n ranks get one extra layer + return torch.tensor_split(torch.arange(num_layers), + self.pp_size)[self.pp_rank].tolist() def ep_experts(self, num_experts: int) -> List[int]: assert self.cp_size == 1 diff --git a/tensorrt_llm/models/generation_mixin.py b/tensorrt_llm/models/generation_mixin.py index 049f177de65..f97b8d436b7 100644 --- a/tensorrt_llm/models/generation_mixin.py +++ b/tensorrt_llm/models/generation_mixin.py @@ -277,8 +277,7 @@ def prepare_attention_inputs( local_attn_layers = [i for i in layers_range if i in attn_layer_idx] # number of attention layers local to previous pp ranks num_attn_layers_lower_ranks = attn_layer_idx.index(local_attn_layers[0]) - num_attn_layers_prev_rank = num_attn_layers_lower_ranks // mapping.pp_rank if mapping.pp_rank != 0 else len( - local_attn_layers) + num_attn_layers = len(local_attn_layers) num_layers_prev_rank = layers_range[ 0] // mapping.pp_rank if mapping.pp_rank != 0 else len(layers_range) past_key_value = [] @@ -376,11 +375,10 @@ def prepare_attention_inputs( host_kv_cache_pool_mapping = Tensor( name=f'host_kv_cache_pool_mapping', dtype=trt.int32, - shape=[num_attn_layers_prev_rank, + shape=[num_attn_layers, 2], # 2: (Index of pool, Index of layer within pool) dim_range=OrderedDict([ - ('pools_mapping', - [num_attn_layers_prev_rank] * num_profiles), + ('pools_mapping', [num_attn_layers] * num_profiles), ('layer_cache_pool_locator', [2] * num_profiles) ])) @@ -467,10 +465,9 @@ def prepare_attention_inputs( host_max_attention_window_sizes = Tensor( name=f'host_max_attention_window_sizes', dtype=trt.int32, - shape=[num_attn_layers_prev_rank], - dim_range=OrderedDict([ - ('num_layers', [num_attn_layers_prev_rank] * num_profiles) - ])) + shape=[num_attn_layers], + dim_range=OrderedDict([('num_layers', + [num_attn_layers] * num_profiles)])) host_sink_token_length = Tensor(name='host_sink_token_length', dtype=trt.int32, diff --git a/tensorrt_llm/runtime/model_runner_cpp.py b/tensorrt_llm/runtime/model_runner_cpp.py index 8eef7a8b555..0ef28ffcc59 100644 --- a/tensorrt_llm/runtime/model_runner_cpp.py +++ b/tensorrt_llm/runtime/model_runner_cpp.py @@ -500,7 +500,9 @@ def num_heads(self) -> int: @property def num_layers(self) -> int: return self.model_config.num_layers( - self.world_config.pipeline_parallelism) + self.world_config.pipeline_parallelism, + self.world_config.pipeline_parallel_rank, + ) @property def max_sequence_length(self) -> int: diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index e5fc5f81be4..6ac7b7a7320 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -471,7 +471,6 @@ triton_server/test_triton.py::test_gpt_speculative_decoding[gpt-speculative-deco triton_server/test_triton.py::test_qwen2_vl[qwen2_vl] SKIP triton_server/test_triton.py::test_gpt_ib_speculative_decoding_bls[gpt-ib-speculative-decoding-bls] SKIP triton_server/test_triton_llm.py::test_mistral_v1_multi_models[False-1-False-True-False-0-128-enableDecoupleMode-inflight_fused_batching-disableTrtOverlap-max_utilization-4096-1-1-1-False-ensemble] SKIP -accuracy/test_cli_flow.py::TestTinyLlama1_1BChat::test_pp4 SKIP (https://nvbugs/5287097) accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False] SKIP (https://nvbugs/5286795) accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=2-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False] SKIP (https://nvbugs/5286795) accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=2-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False] SKIP (https://nvbugs/5286795) From 3574acc26dc8c2ff503657c5543224c754826ce2 Mon Sep 17 00:00:00 2001 From: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> Date: Mon, 19 May 2025 02:14:15 +0000 Subject: [PATCH 2/3] Address comments. Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> --- cpp/include/tensorrt_llm/runtime/modelConfig.h | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/cpp/include/tensorrt_llm/runtime/modelConfig.h b/cpp/include/tensorrt_llm/runtime/modelConfig.h index d5b99cc0992..5a0726cd7c8 100644 --- a/cpp/include/tensorrt_llm/runtime/modelConfig.h +++ b/cpp/include/tensorrt_llm/runtime/modelConfig.h @@ -169,10 +169,8 @@ class ModelConfig TLLM_CHECK_WITH_INFO(pipelineParallelism > 0, "Invalid pipelineParallelism: %d", pipelineParallelism); auto const numBaseLayers = mNbLayers / pipelineParallelism; auto const numExtraLayers = mNbLayers % pipelineParallelism; - auto const firstLocalLayer - = pipelineParallelismRank * numBaseLayers + std::min(pipelineParallelismRank, numExtraLayers); - // If num_layers % pp_size = n != 0, first n ranks get one extra layer - auto const numLocalLayers = numBaseLayers + (pipelineParallelismRank < numExtraLayers ? 1 : 0); + auto const firstLocalLayer = getFirstLocalLayer(pipelineParallelism, pipelineParallelismRank); + auto const numLocalLayers = getNbLayers(pipelineParallelism, pipelineParallelismRank); auto const firstLocalLayerIt = mLayerTypes.cbegin() + firstLocalLayer; return std::count(firstLocalLayerIt, firstLocalLayerIt + numLocalLayers, layerType); } From dccc457fcfab4c274affdd8b85908f1bb3eb80c1 Mon Sep 17 00:00:00 2001 From: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> Date: Mon, 19 May 2025 04:23:02 +0000 Subject: [PATCH 3/3] Fix CI error. Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> --- cpp/include/tensorrt_llm/runtime/modelConfig.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/cpp/include/tensorrt_llm/runtime/modelConfig.h b/cpp/include/tensorrt_llm/runtime/modelConfig.h index 5a0726cd7c8..bcd466fe2ad 100644 --- a/cpp/include/tensorrt_llm/runtime/modelConfig.h +++ b/cpp/include/tensorrt_llm/runtime/modelConfig.h @@ -167,8 +167,6 @@ class ModelConfig LayerType layerType, SizeType32 pipelineParallelism = 1, SizeType32 pipelineParallelismRank = 0) const { TLLM_CHECK_WITH_INFO(pipelineParallelism > 0, "Invalid pipelineParallelism: %d", pipelineParallelism); - auto const numBaseLayers = mNbLayers / pipelineParallelism; - auto const numExtraLayers = mNbLayers % pipelineParallelism; auto const firstLocalLayer = getFirstLocalLayer(pipelineParallelism, pipelineParallelismRank); auto const numLocalLayers = getNbLayers(pipelineParallelism, pipelineParallelismRank); auto const firstLocalLayerIt = mLayerTypes.cbegin() + firstLocalLayer;