@@ -48,34 +48,39 @@ class CacheState final
4848 kMLA = 1 ,
4949 };
5050
51- CacheState (ModelConfig modelConfig, runtime::WorldConfig const & worldConfig, nvinfer1::DataType dataType,
51+ CacheState (ModelConfig modelConfig, runtime::WorldConfig const & worldConfig,
52+ std::vector<SizeType32> const & attentionLayerNumPerPP, nvinfer1::DataType dataType,
5253 AttentionType attentionType = AttentionType::kDEFAULT , int kvFactor = 2 )
5354 : mModelConfig (std::move(modelConfig))
5455 , mParallelConfig {worldConfig.getTensorParallelism (), worldConfig.getPipelineParallelism (),
5556 worldConfig.getContextParallelism (), worldConfig.enableAttentionDP (), worldConfig.getTensorParallelRank (),
56- worldConfig.getTensorParallelism ()}
57+ worldConfig.getTensorParallelism (), attentionLayerNumPerPP }
5758 , mDataType {dataType}
5859 , mAttentionConfig (attentionType, kvFactor)
5960 {
6061 }
6162
6263 CacheState (std::vector<SizeType32> nbKvHeadPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
6364 SizeType32 tensorParallelism, SizeType32 pipelineParallelism, SizeType32 contextParallelism,
64- nvinfer1::DataType dataType, AttentionType attentionType = AttentionType::kDEFAULT , int kvFactor = 2 ,
65- bool enableAttentionDP = false , int DPrank = 0 , int DPsize = 0 )
65+ std::vector<SizeType32> const & attentionLayerNumPerPP, nvinfer1::DataType dataType,
66+ AttentionType attentionType = AttentionType::kDEFAULT , int kvFactor = 2 , bool enableAttentionDP = false ,
67+ int DPrank = 0 , int DPsize = 0 )
6668 : mModelConfig {std::move (nbKvHeadPerLayer), sizePerHead, tokensPerBlock}
67- , mParallelConfig {tensorParallelism, pipelineParallelism, contextParallelism, enableAttentionDP, DPrank, DPsize}
69+ , mParallelConfig {tensorParallelism, pipelineParallelism, contextParallelism, enableAttentionDP, DPrank, DPsize,
70+ attentionLayerNumPerPP}
6871 , mDataType {dataType}
6972 , mAttentionConfig (attentionType, kvFactor)
7073 {
7174 }
7275
7376 CacheState (SizeType32 nbAttentionLayers, SizeType32 nbKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
7477 SizeType32 tensorParallelism, SizeType32 pipelineParallelism, SizeType32 contextParallelism,
75- nvinfer1::DataType dataType, AttentionType attentionType = AttentionType::kDEFAULT , int kvFactor = 2 ,
76- bool enableAttentionDP = false , int DPrank = 0 , int DPsize = 0 )
78+ std::vector<SizeType32> const & attentionLayerNumPerPP, nvinfer1::DataType dataType,
79+ AttentionType attentionType = AttentionType::kDEFAULT , int kvFactor = 2 , bool enableAttentionDP = false ,
80+ int DPrank = 0 , int DPsize = 0 )
7781 : mModelConfig {std::vector (nbAttentionLayers, nbKvHeads), sizePerHead, tokensPerBlock}
78- , mParallelConfig {tensorParallelism, pipelineParallelism, contextParallelism, enableAttentionDP, DPrank, DPsize}
82+ , mParallelConfig {tensorParallelism, pipelineParallelism, contextParallelism, enableAttentionDP, DPrank, DPsize,
83+ attentionLayerNumPerPP}
7984 , mDataType {dataType}
8085 , mAttentionConfig (attentionType, kvFactor)
8186 {
@@ -108,12 +113,16 @@ class CacheState final
108113 bool mEnableAttentionDP ;
109114 SizeType32 mDPrank ;
110115 SizeType32 mDPsize ;
116+ // number of attention layers per pipeline parallelism rank, the size of the vector is equal to the pipeline
117+ // parallelism size.
118+ std::vector<SizeType32> mAttentionLayerNumPerPP ;
111119
112120 [[nodiscard]] bool operator ==(ParallelConfig const & other) const noexcept
113121 {
114122 return mTensorParallelism == other.mTensorParallelism && mPipelineParallelism == other.mPipelineParallelism
115123 && mContextParallelism == other.mContextParallelism && mEnableAttentionDP == other.mEnableAttentionDP
116- && mDPrank == other.mDPrank && mDPsize == other.mDPsize ;
124+ && mDPrank == other.mDPrank && mDPsize == other.mDPsize
125+ && mAttentionLayerNumPerPP == other.mAttentionLayerNumPerPP ;
117126 }
118127 };
119128
0 commit comments