Skip to content

Commit 6da4948

Browse files
chuangz0Wong4j
authored andcommitted
[TRTLLM-7361][feat] KV cache transfer for uneven pp (NVIDIA#7117)
Signed-off-by: Chuang Zhu <[email protected]>
1 parent 845c608 commit 6da4948

File tree

25 files changed

+868
-420
lines changed

25 files changed

+868
-420
lines changed

cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,20 +72,20 @@ class CacheTransceiver : public BaseCacheTransceiver
7272
public:
7373
CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheManager,
7474
executor::kv_cache::CacheState::ModelConfig const& cacheStateModelCfg, runtime::WorldConfig const& worldConfig,
75-
nvinfer1::DataType dataType,
75+
std::vector<SizeType32> const& attentionLayerNumPerPP, nvinfer1::DataType dataType,
7676
executor::kv_cache::CacheState::AttentionType attentionType
7777
= executor::kv_cache::CacheState::AttentionType::kDEFAULT,
7878
std::optional<executor::CacheTransceiverConfig> cacheTransceiverConfig = std::nullopt);
7979

8080
CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheManager, std::vector<SizeType32> numKvHeadsPerLayer,
8181
SizeType32 sizePerHead, SizeType32 tokensPerBlock, runtime::WorldConfig const& worldConfig,
82-
nvinfer1::DataType dataType,
82+
std::vector<SizeType32> const& attentionLayerNumPerPP, nvinfer1::DataType dataType,
8383
executor::kv_cache::CacheState::AttentionType attentionType
8484
= executor::kv_cache::CacheState::AttentionType::kDEFAULT,
8585
std::optional<executor::CacheTransceiverConfig> cacheTransceiverConfig = std::nullopt)
8686
: CacheTransceiver(cacheManager,
8787
executor::kv_cache::CacheState::ModelConfig{numKvHeadsPerLayer, sizePerHead, tokensPerBlock}, worldConfig,
88-
dataType, attentionType, cacheTransceiverConfig)
88+
attentionLayerNumPerPP, dataType, attentionType, cacheTransceiverConfig)
8989
{
9090
}
9191

cpp/include/tensorrt_llm/executor/dataTransceiverState.h

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,34 +48,39 @@ class CacheState final
4848
kMLA = 1,
4949
};
5050

51-
CacheState(ModelConfig modelConfig, runtime::WorldConfig const& worldConfig, nvinfer1::DataType dataType,
51+
CacheState(ModelConfig modelConfig, runtime::WorldConfig const& worldConfig,
52+
std::vector<SizeType32> const& attentionLayerNumPerPP, nvinfer1::DataType dataType,
5253
AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2)
5354
: mModelConfig(std::move(modelConfig))
5455
, mParallelConfig{worldConfig.getTensorParallelism(), worldConfig.getPipelineParallelism(),
5556
worldConfig.getContextParallelism(), worldConfig.enableAttentionDP(), worldConfig.getTensorParallelRank(),
56-
worldConfig.getTensorParallelism()}
57+
worldConfig.getTensorParallelism(), attentionLayerNumPerPP}
5758
, mDataType{dataType}
5859
, mAttentionConfig(attentionType, kvFactor)
5960
{
6061
}
6162

6263
CacheState(std::vector<SizeType32> nbKvHeadPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
6364
SizeType32 tensorParallelism, SizeType32 pipelineParallelism, SizeType32 contextParallelism,
64-
nvinfer1::DataType dataType, AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2,
65-
bool enableAttentionDP = false, int DPrank = 0, int DPsize = 0)
65+
std::vector<SizeType32> const& attentionLayerNumPerPP, nvinfer1::DataType dataType,
66+
AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2, bool enableAttentionDP = false,
67+
int DPrank = 0, int DPsize = 0)
6668
: mModelConfig{std::move(nbKvHeadPerLayer), sizePerHead, tokensPerBlock}
67-
, mParallelConfig{tensorParallelism, pipelineParallelism, contextParallelism, enableAttentionDP, DPrank, DPsize}
69+
, mParallelConfig{tensorParallelism, pipelineParallelism, contextParallelism, enableAttentionDP, DPrank, DPsize,
70+
attentionLayerNumPerPP}
6871
, mDataType{dataType}
6972
, mAttentionConfig(attentionType, kvFactor)
7073
{
7174
}
7275

7376
CacheState(SizeType32 nbAttentionLayers, SizeType32 nbKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
7477
SizeType32 tensorParallelism, SizeType32 pipelineParallelism, SizeType32 contextParallelism,
75-
nvinfer1::DataType dataType, AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2,
76-
bool enableAttentionDP = false, int DPrank = 0, int DPsize = 0)
78+
std::vector<SizeType32> const& attentionLayerNumPerPP, nvinfer1::DataType dataType,
79+
AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2, bool enableAttentionDP = false,
80+
int DPrank = 0, int DPsize = 0)
7781
: mModelConfig{std::vector(nbAttentionLayers, nbKvHeads), sizePerHead, tokensPerBlock}
78-
, mParallelConfig{tensorParallelism, pipelineParallelism, contextParallelism, enableAttentionDP, DPrank, DPsize}
82+
, mParallelConfig{tensorParallelism, pipelineParallelism, contextParallelism, enableAttentionDP, DPrank, DPsize,
83+
attentionLayerNumPerPP}
7984
, mDataType{dataType}
8085
, mAttentionConfig(attentionType, kvFactor)
8186
{
@@ -108,12 +113,16 @@ class CacheState final
108113
bool mEnableAttentionDP;
109114
SizeType32 mDPrank;
110115
SizeType32 mDPsize;
116+
// number of attention layers per pipeline parallelism rank, the size of the vector is equal to the pipeline
117+
// parallelism size.
118+
std::vector<SizeType32> mAttentionLayerNumPerPP;
111119

112120
[[nodiscard]] bool operator==(ParallelConfig const& other) const noexcept
113121
{
114122
return mTensorParallelism == other.mTensorParallelism && mPipelineParallelism == other.mPipelineParallelism
115123
&& mContextParallelism == other.mContextParallelism && mEnableAttentionDP == other.mEnableAttentionDP
116-
&& mDPrank == other.mDPrank && mDPsize == other.mDPsize;
124+
&& mDPrank == other.mDPrank && mDPsize == other.mDPsize
125+
&& mAttentionLayerNumPerPP == other.mAttentionLayerNumPerPP;
117126
}
118127
};
119128

0 commit comments

Comments
 (0)