Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions cpp/include/tensorrt_llm/runtime/modelConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,23 +167,36 @@ class ModelConfig
LayerType layerType, SizeType32 pipelineParallelism = 1, SizeType32 pipelineParallelismRank = 0) const
{
TLLM_CHECK_WITH_INFO(pipelineParallelism > 0, "Invalid pipelineParallelism: %d", pipelineParallelism);
auto const numLocalLayers = mNbLayers / pipelineParallelism; // WARNING: assume no remainder
auto const firstLocalLayerIt = mLayerTypes.cbegin() + (numLocalLayers * pipelineParallelismRank);
auto const firstLocalLayer = getFirstLocalLayer(pipelineParallelism, pipelineParallelismRank);
auto const numLocalLayers = getNbLayers(pipelineParallelism, pipelineParallelismRank);
auto const firstLocalLayerIt = mLayerTypes.cbegin() + firstLocalLayer;
return std::count(firstLocalLayerIt, firstLocalLayerIt + numLocalLayers, layerType);
}

[[nodiscard]] SizeType32 getFirstLocalLayer(
SizeType32 pipelineParallelism = 1, SizeType32 pipelineParallelismRank = 0) const
{
auto const numBaseLayers = mNbLayers / pipelineParallelism;
auto const numExtraLayers = mNbLayers % pipelineParallelism;
// If num_layers % pp_size = n != 0, first n ranks get one extra layer
return pipelineParallelismRank * numBaseLayers + std::min(pipelineParallelismRank, numExtraLayers);
}

[[nodiscard]] SizeType32 countLowerRankLayers(
LayerType layerType, SizeType32 pipelineParallelism = 1, SizeType32 pipelineParallelismRank = 0) const
{
auto const numLocalLayers = mNbLayers / pipelineParallelism; // WARNING: assume no remainder
auto const firstLocalLayer = numLocalLayers * pipelineParallelismRank;
auto const firstLocalLayer = getFirstLocalLayer(pipelineParallelism, pipelineParallelismRank);
// count number of previous non-local attention layers
return std::count(mLayerTypes.cbegin(), mLayerTypes.cbegin() + firstLocalLayer, layerType);
}

[[nodiscard]] SizeType32 getNbLayers(SizeType32 pipelineParallelism = 1) const
[[nodiscard]] SizeType32 getNbLayers(
SizeType32 pipelineParallelism = 1, SizeType32 pipelineParallelismRank = 0) const
{
return mNbLayers / pipelineParallelism; // WARNING: assume no remainder
auto const numBaseLayers = mNbLayers / pipelineParallelism;
auto const numExtraLayers = mNbLayers % pipelineParallelism;
// If num_layers % pp_size = n != 0, first n ranks get one extra layer
return numBaseLayers + (pipelineParallelismRank < numExtraLayers ? 1 : 0);
}

[[nodiscard]] SizeType32 getNbAttentionLayers(
Expand Down
3 changes: 2 additions & 1 deletion cpp/tensorrt_llm/batch_manager/loraBuffers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ namespace tensorrt_llm::batch_manager
LoraBuffers::LoraBuffers(SizeType32 maxBatchSize, SizeType32 maxBeamWidth, runtime::TllmRuntime const& tllmRuntime,
runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig)
{
auto const localNbLayers = modelConfig.getNbAttentionLayers(worldConfig.getPipelineParallelism());
auto const localNbLayers
= modelConfig.getNbAttentionLayers(worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank());
auto const firstLayerId = worldConfig.getPipelineParallelRank() * localNbLayers;

auto nbModelConfigs = static_cast<SizeType32>(modelConfig.getLoraModules().size());
Expand Down
3 changes: 2 additions & 1 deletion cpp/tensorrt_llm/batch_manager/peftCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ PeftCacheManager::getPageManagerConfig(PeftCacheManagerConfig const& config, run

auto const tpSize = worldConfig.getTensorParallelism();
auto const ppSize = worldConfig.getPipelineParallelism();
auto const numLocalLayers = modelConfig.getNbAttentionLayers(ppSize);
auto const ppRank = worldConfig.getPipelineParallelRank();
auto const numLocalLayers = modelConfig.getNbAttentionLayers(ppSize, ppRank);
uint64_t min1dModSize = std::numeric_limits<uint64_t>::max(); // used to setup the pageWidth
uint64_t total1dModSize = 0;
uint64_t total1lSlots = 0; // the slots we need for each layer, summing the slots of all modules
Expand Down
7 changes: 4 additions & 3 deletions cpp/tensorrt_llm/batch_manager/rnnStateManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ RnnStateManager::RnnStateManager(SizeType32 maxNumSequences, tensorrt_llm::runti
auto const rnnHiddenSize = rnnConfig->rnnHiddenSize;
auto const rnnHeadSize = rnnConfig->rnnHeadSize;
auto const rnnConvDimSize = rnnConfig->rnnConvDimSize;
auto const localNbLayers = modelConfig.getNbRnnLayers(worldConfig.getPipelineParallelism());
auto const localNbLayers
= modelConfig.getNbRnnLayers(worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank());
auto const dataType = modelConfig.getDataType();

auto const rnnStateShape = [&]()
Expand Down Expand Up @@ -84,8 +85,8 @@ RnnStateManager::RnnStateManager(SizeType32 maxNumSequences, tensorrt_llm::runti
void RnnStateManager::getPtrBuffers(
TensorMap& inputBuffers, runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig) const
{
auto const localNbLayers = modelConfig.getNbRnnLayers(worldConfig.getPipelineParallelism());
auto const firstLayerId = worldConfig.getPipelineParallelRank() * localNbLayers;
auto const firstLayerId
= modelConfig.getFirstLocalLayer(worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank());
auto const& layerTypes = modelConfig.getLayerTypes();

utils::insertTensorVector(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1717,7 +1717,8 @@ void TrtGptModelInflightBatching::executeStep(
// TODO: support layer-wise cross kv cache in encoder-decoder models
if (!layerWiseRequests.empty() && !mModelConfig.useCrossAttention())
{
int const numLayers = mModelConfig.getNbAttentionLayers(mWorldConfig.getPipelineParallelism());
int const numLayers = mModelConfig.getNbAttentionLayers(
mWorldConfig.getPipelineParallelism(), mWorldConfig.getPipelineParallelRank());
progress = std::make_shared<ContextProgress>(numLayers);
}
bufferCast<void*>(*mBuffers[bufferId]->transformerBuffers->contextProgressHost)[0] = progress.get();
Expand Down
3 changes: 2 additions & 1 deletion cpp/tensorrt_llm/pybind/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,8 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
py::arg("num_heads"), py::arg("hidden_size"), py::arg("data_type"))
.def_property_readonly("vocab_size", &tr::ModelConfig::getVocabSize)
.def("vocab_size_padded", &tr::ModelConfig::getVocabSizePadded, py::arg("world_size"))
.def("num_layers", &tr::ModelConfig::getNbLayers, py::arg("pipeline_parallelism") = 1)
.def("num_layers", &tr::ModelConfig::getNbLayers, py::arg("pipeline_parallelism") = 1,
py::arg("pipeline_parallelism_rank") = 0)
.def("num_attention_layers", &tr::ModelConfig::getNbAttentionLayers, py::arg("pipeline_parallelism") = 1,
py::arg("pipeline_parallelism_rank") = 0)
.def("num_rnn_layers", &tr::ModelConfig::getNbRnnLayers, py::arg("pipeline_parallelism") = 1,
Expand Down
5 changes: 3 additions & 2 deletions cpp/tensorrt_llm/runtime/loraCache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,8 @@ SizeType32 LoraCache::determineNumPages(TaskIdType taskId) const
SizeType32 LoraCache::determineNumPages(TensorPtr loraConfig) const
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
auto const localNumLayers = mModelConfig.getNbAttentionLayers(mWorldConfig.getPipelineParallelism());
auto const localNumLayers = mModelConfig.getNbAttentionLayers(
mWorldConfig.getPipelineParallelism(), mWorldConfig.getPipelineParallelRank());
auto const firstLayerId = mWorldConfig.getPipelineParallelRank() * localNumLayers;
auto const lastLayerId = firstLayerId + localNumLayers;

Expand Down Expand Up @@ -579,7 +580,7 @@ std::vector<LoraCache::TaskLayerModuleConfig> LoraCache::copyToPages(TensorPtr s
auto const ppSize = worldConfig.getPipelineParallelism();
auto const ppRank = worldConfig.getPipelineParallelRank();
// TODO(oargov): why *attention* layers?
auto const localNumLayers = modelConfig.getNbAttentionLayers(ppSize);
auto const localNumLayers = modelConfig.getNbAttentionLayers(ppSize, ppRank);
auto const firstLayerId = ppRank * localNumLayers;
auto const lastLayerId = firstLayerId + localNumLayers;

Expand Down
5 changes: 3 additions & 2 deletions cpp/tensorrt_llm/runtime/loraManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ void LoraManager::fillInputTensors(TensorPtr weightsPtrs, TensorPtr adapterSizes

auto const ppSize = worldConfig.getPipelineParallelism();
auto const ppRank = worldConfig.getPipelineParallelRank();
auto const localNumLayers = modelConfig.getNbAttentionLayers(ppSize);
auto const localNumLayers = modelConfig.getNbAttentionLayers(ppSize, ppRank);
auto const firstLayerId = ppRank * localNumLayers;

auto weightsPointersPtr = bufferCast<int64_t>(*weightsPtrs);
Expand Down Expand Up @@ -123,7 +123,8 @@ void LoraManager::insertInputTensors(TensorMap& inputTensors, TensorPtr weightsP
ModelConfig const& modelConfig, WorldConfig const& worldConfig) const
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto localNbLayers = modelConfig.getNbAttentionLayers(worldConfig.getPipelineParallelism());
auto localNbLayers
= modelConfig.getNbAttentionLayers(worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank());
auto firstLayerId = worldConfig.getPipelineParallelRank() * localNbLayers;

for (auto const& [modId, mod] : mModuleIdToModule)
Expand Down
4 changes: 2 additions & 2 deletions cpp/tests/unit_tests/runtime/loraManagerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ static void checkLoraTensors(LoraManager const& loraManager, std::vector<int64_t
auto weightsPtrsPtr = bufferCast<int64_t>(*weightsPtrs);
ASSERT_EQ(targetPtrs.size(), weightsPtrs->getSize());
ASSERT_EQ(targetAdapterSizes.size(), adapterSizes->getSize());
auto firstLayerId = modelConfig.getNbAttentionLayers(worldConfig.getPipelineParallelism())
* worldConfig.getPipelineParallelRank();
auto firstLayerId
= modelConfig.getFirstLocalLayer(worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank());
LoraManager::TensorMap expectedTensors;

for (SizeType32 m = 0; m < numModules; ++m)
Expand Down
11 changes: 5 additions & 6 deletions tensorrt_llm/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# limitations under the License.
from typing import List

import torch


class Mapping(object):
'''
Expand Down Expand Up @@ -420,12 +422,9 @@ def has_moe_ep(self):
return self.moe_ep_size > 1

def pp_layers(self, num_layers: int) -> List[int]:
base_layers = num_layers // self.pp_size
extra_layers = num_layers % self.pp_size
start_idx = self.pp_rank * base_layers + min(self.pp_rank, extra_layers)
layers_in_stage = base_layers + (1
if self.pp_rank < extra_layers else 0)
return list(range(start_idx, start_idx + layers_in_stage))
# If num_layers % pp_size = n != 0, first n ranks get one extra layer
return torch.tensor_split(torch.arange(num_layers),
self.pp_size)[self.pp_rank].tolist()

def ep_experts(self, num_experts: int) -> List[int]:
assert self.cp_size == 1
Expand Down
15 changes: 6 additions & 9 deletions tensorrt_llm/models/generation_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,7 @@ def prepare_attention_inputs(
local_attn_layers = [i for i in layers_range if i in attn_layer_idx]
# number of attention layers local to previous pp ranks
num_attn_layers_lower_ranks = attn_layer_idx.index(local_attn_layers[0])
num_attn_layers_prev_rank = num_attn_layers_lower_ranks // mapping.pp_rank if mapping.pp_rank != 0 else len(
local_attn_layers)
num_attn_layers = len(local_attn_layers)
num_layers_prev_rank = layers_range[
0] // mapping.pp_rank if mapping.pp_rank != 0 else len(layers_range)
past_key_value = []
Expand Down Expand Up @@ -376,11 +375,10 @@ def prepare_attention_inputs(
host_kv_cache_pool_mapping = Tensor(
name=f'host_kv_cache_pool_mapping',
dtype=trt.int32,
shape=[num_attn_layers_prev_rank,
shape=[num_attn_layers,
2], # 2: (Index of pool, Index of layer within pool)
dim_range=OrderedDict([
('pools_mapping',
[num_attn_layers_prev_rank] * num_profiles),
('pools_mapping', [num_attn_layers] * num_profiles),
('layer_cache_pool_locator', [2] * num_profiles)
]))

Expand Down Expand Up @@ -467,10 +465,9 @@ def prepare_attention_inputs(
host_max_attention_window_sizes = Tensor(
name=f'host_max_attention_window_sizes',
dtype=trt.int32,
shape=[num_attn_layers_prev_rank],
dim_range=OrderedDict([
('num_layers', [num_attn_layers_prev_rank] * num_profiles)
]))
shape=[num_attn_layers],
dim_range=OrderedDict([('num_layers',
[num_attn_layers] * num_profiles)]))

host_sink_token_length = Tensor(name='host_sink_token_length',
dtype=trt.int32,
Expand Down
4 changes: 3 additions & 1 deletion tensorrt_llm/runtime/model_runner_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,9 @@ def num_heads(self) -> int:
@property
def num_layers(self) -> int:
return self.model_config.num_layers(
self.world_config.pipeline_parallelism)
self.world_config.pipeline_parallelism,
self.world_config.pipeline_parallel_rank,
)

@property
def max_sequence_length(self) -> int:
Expand Down
1 change: 0 additions & 1 deletion tests/integration/test_lists/waives.txt
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,6 @@ triton_server/test_triton.py::test_gpt_speculative_decoding[gpt-speculative-deco
triton_server/test_triton.py::test_qwen2_vl[qwen2_vl] SKIP
triton_server/test_triton.py::test_gpt_ib_speculative_decoding_bls[gpt-ib-speculative-decoding-bls] SKIP
triton_server/test_triton_llm.py::test_mistral_v1_multi_models[False-1-False-True-False-0-128-enableDecoupleMode-inflight_fused_batching-disableTrtOverlap-max_utilization-4096-1-1-1-False-ensemble] SKIP
accuracy/test_cli_flow.py::TestTinyLlama1_1BChat::test_pp4 SKIP (https://nvbugs/5287097)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5286795)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=2-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5286795)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=2-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5286795)
Expand Down