diff --git a/cpp/include/tensorrt_llm/executor/executor.h b/cpp/include/tensorrt_llm/executor/executor.h index 565b02c028d..740badd6370 100644 --- a/cpp/include/tensorrt_llm/executor/executor.h +++ b/cpp/include/tensorrt_llm/executor/executor.h @@ -71,6 +71,7 @@ class SamplingConfig std::optional const& repetitionPenalty = std::nullopt, std::optional const& presencePenalty = std::nullopt, std::optional const& frequencyPenalty = std::nullopt, + std::optional const& promptIgnoreLength = std::nullopt, std::optional const& lengthPenalty = std::nullopt, std::optional const& earlyStopping = std::nullopt, std::optional const& noRepeatNgramSize = std::nullopt, @@ -94,6 +95,7 @@ class SamplingConfig [[nodiscard]] std::optional getRepetitionPenalty() const; [[nodiscard]] std::optional getPresencePenalty() const; [[nodiscard]] std::optional getFrequencyPenalty() const; + [[nodiscard]] std::optional getPromptIgnoreLength() const; [[nodiscard]] std::optional getLengthPenalty() const; [[nodiscard]] std::optional getEarlyStopping() const; [[nodiscard]] std::optional getNoRepeatNgramSize() const; @@ -114,6 +116,7 @@ class SamplingConfig void setRepetitionPenalty(std::optional const& repetitionPenalty); void setPresencePenalty(std::optional const& presencePenalty); void setFrequencyPenalty(std::optional const& frequencyPenalty); + void setPromptIgnoreLength(std::optional const& promptIgnoreLength); void setLengthPenalty(std::optional const& lengthPenalty); void setEarlyStopping(std::optional const& earlyStopping); void setNoRepeatNgramSize(std::optional const& noRepeatNgramSize); @@ -133,6 +136,8 @@ class SamplingConfig static std::optional const& checkBeamSearchDiversityRate( std::optional const& beamSearchDiversityRate); static std::optional const& checkRepetitionPenalty(std::optional const& repetitionpenalty); + static std::optional const& checkPromptIgnoreLength( + std::optional const& promptIgnoreLength); static std::optional const& checkLengthPenalty(std::optional const& lengthPenalty); static std::optional const& checkEarlyStopping(std::optional const& earlyStopping); static std::optional const& checkNoRepeatNgramSize(std::optional const& noRepeatNgramSize); @@ -174,6 +179,9 @@ class SamplingConfig /// @brief Used to penalize tokens already present in the sequence (dependent on the number of appearances). It can /// have any values. Values < 0.f encourage repetition, values > 0.f discourage it. Default is 0.f std::optional mFrequencyPenalty; + /// @brief Controls how many tokens to ignore from the prompt for presence and frequency penalties. Values <= 0 have + /// no effect. Values > input (prompt) length will be clamped. Default is 0. + std::optional mPromptIgnoreLength; /// @brief Controls how to penalize longer sequences in beam search. Default is 0.f std::optional mLengthPenalty; /// @brief Controls whether the generation process finishes once beamWidth sentences are generated (ends with diff --git a/cpp/include/tensorrt_llm/layers/defaultDecodingParams.h b/cpp/include/tensorrt_llm/layers/defaultDecodingParams.h index d35080d5588..3794556e283 100644 --- a/cpp/include/tensorrt_llm/layers/defaultDecodingParams.h +++ b/cpp/include/tensorrt_llm/layers/defaultDecodingParams.h @@ -56,6 +56,11 @@ class DefaultDecodingParams return 1; } + [[nodiscard]] __host__ __device__ static constexpr runtime::SizeType32 getPromptIgnoreLength() + { + return 0; + } + [[nodiscard]] __host__ __device__ static constexpr uint64_t getSeed() { return 0; diff --git a/cpp/include/tensorrt_llm/runtime/samplingConfig.h b/cpp/include/tensorrt_llm/runtime/samplingConfig.h index 099dce17312..03355167653 100644 --- a/cpp/include/tensorrt_llm/runtime/samplingConfig.h +++ b/cpp/include/tensorrt_llm/runtime/samplingConfig.h @@ -133,6 +133,9 @@ class SamplingConfig frequencyPenalty = fuseValues( configs, [&configs](size_t ci) { return configs[ci].frequencyPenalty; }, layers::DefaultDecodingParams::getFrequencyPenalty()); + promptIgnoreLength = fuseValues( + configs, [&configs](size_t ci) { return configs[ci].promptIgnoreLength; }, + layers::DefaultDecodingParams::getPromptIgnoreLength()); noRepeatNgramSize = fuseValues( configs, [&configs](size_t ci) { return configs[ci].noRepeatNgramSize; }, layers::DefaultDecodingParams::getNoRepeatNgramSize()); @@ -224,6 +227,7 @@ class SamplingConfig SET_FROM_OPTIONAL(repetitionPenalty, RepetitionPenalty, FloatType) SET_FROM_OPTIONAL(presencePenalty, PresencePenalty, FloatType) SET_FROM_OPTIONAL(frequencyPenalty, FrequencyPenalty, FloatType) + SET_FROM_OPTIONAL(promptIgnoreLength, PromptIgnoreLength, SizeType32) SET_FROM_OPTIONAL(lengthPenalty, LengthPenalty, FloatType) SET_FROM_OPTIONAL(earlyStopping, EarlyStopping, SizeType32) SET_FROM_OPTIONAL(noRepeatNgramSize, NoRepeatNgramSize, SizeType32) @@ -342,6 +346,7 @@ class SamplingConfig OptVec repetitionPenalty; // [1] or [batchSize] OptVec presencePenalty; // [1] or [batchSize] OptVec frequencyPenalty; // [1] or [batchSize] + OptVec promptIgnoreLength; // [1] or [batchSize] OptVec noRepeatNgramSize; // [1] or [batchSize] // probs @@ -377,13 +382,14 @@ class SamplingConfig && temperature == other.temperature && originalTemperature == other.originalTemperature && minLength == other.minLength && repetitionPenalty == other.repetitionPenalty && presencePenalty == other.presencePenalty && frequencyPenalty == other.frequencyPenalty - && noRepeatNgramSize == other.noRepeatNgramSize && topK == other.topK && topP == other.topP - && randomSeed == other.randomSeed && topPDecay == other.topPDecay && topPMin == other.topPMin - && topPResetIds == other.topPResetIds && beamSearchDiversityRate == other.beamSearchDiversityRate - && lengthPenalty == other.lengthPenalty && earlyStopping == other.earlyStopping - && draftAcceptanceThreshold == other.draftAcceptanceThreshold && topKMedusaHeads == other.topKMedusaHeads - && normalizeLogProbs == other.normalizeLogProbs && outputLogProbs == other.outputLogProbs - && cumLogProbs == other.cumLogProbs && minP == other.minP && beamWidthArray == other.beamWidthArray; + && promptIgnoreLength == other.promptIgnoreLength && noRepeatNgramSize == other.noRepeatNgramSize + && topK == other.topK && topP == other.topP && randomSeed == other.randomSeed + && topPDecay == other.topPDecay && topPMin == other.topPMin && topPResetIds == other.topPResetIds + && beamSearchDiversityRate == other.beamSearchDiversityRate && lengthPenalty == other.lengthPenalty + && earlyStopping == other.earlyStopping && draftAcceptanceThreshold == other.draftAcceptanceThreshold + && topKMedusaHeads == other.topKMedusaHeads && normalizeLogProbs == other.normalizeLogProbs + && outputLogProbs == other.outputLogProbs && cumLogProbs == other.cumLogProbs && minP == other.minP + && beamWidthArray == other.beamWidthArray; } SizeType32 getNumReturnBeams() const diff --git a/cpp/tensorrt_llm/executor/samplingConfig.cpp b/cpp/tensorrt_llm/executor/samplingConfig.cpp index 176865340e2..3e0ef63f929 100644 --- a/cpp/tensorrt_llm/executor/samplingConfig.cpp +++ b/cpp/tensorrt_llm/executor/samplingConfig.cpp @@ -34,9 +34,9 @@ SamplingConfig::SamplingConfig(SizeType32 beamWidth, OptSize32 const& topK, OptF OptFloat const& topPMin, std::optional const& topPResetIds, OptFloat const& topPDecay, std::optional const& seed, OptFloat const& temperature, OptSize32 const& minTokens, OptFloat const& beamSearchDiversityRate, OptFloat const& repetitionPenalty, OptFloat const& presencePenalty, - OptFloat const& frequencyPenalty, OptFloat const& lengthPenalty, OptSize32 const& earlyStopping, - OptSize32 const& noRepeatNgramSize, OptSize32 const& numReturnSequences, OptFloat const& minP, - OptVec const& beamWidthArray) + OptFloat const& frequencyPenalty, OptSize32 const& promptIgnoreLength, OptFloat const& lengthPenalty, + OptSize32 const& earlyStopping, OptSize32 const& noRepeatNgramSize, OptSize32 const& numReturnSequences, + OptFloat const& minP, OptVec const& beamWidthArray) : mBeamWidth(checkBeamWidth(beamWidth)) , mTopK(checkTopK(topK)) , mTopP(checkTopP(topP)) @@ -50,6 +50,7 @@ SamplingConfig::SamplingConfig(SizeType32 beamWidth, OptSize32 const& topK, OptF , mRepetitionPenalty(checkRepetitionPenalty(repetitionPenalty)) , mPresencePenalty(presencePenalty) , mFrequencyPenalty(frequencyPenalty) + , mPromptIgnoreLength(checkPromptIgnoreLength(promptIgnoreLength)) , mLengthPenalty(checkLengthPenalty(lengthPenalty)) , mEarlyStopping(checkEarlyStopping(earlyStopping)) , mNoRepeatNgramSize(checkNoRepeatNgramSize(noRepeatNgramSize)) @@ -67,9 +68,10 @@ bool SamplingConfig::operator==(SamplingConfig const& other) const && mTemperature == other.mTemperature && mMinTokens == other.mMinTokens && mBeamSearchDiversityRate == other.mBeamSearchDiversityRate && mRepetitionPenalty == other.mRepetitionPenalty && mPresencePenalty == other.mPresencePenalty && mFrequencyPenalty == other.mFrequencyPenalty - && mLengthPenalty == other.mLengthPenalty && mEarlyStopping == other.mEarlyStopping - && mNoRepeatNgramSize == other.mNoRepeatNgramSize && mNumReturnSequences == other.mNumReturnSequences - && mMinP == other.mMinP && mBeamWidthArray == other.mBeamWidthArray; + && mPromptIgnoreLength == other.mPromptIgnoreLength && mLengthPenalty == other.mLengthPenalty + && mEarlyStopping == other.mEarlyStopping && mNoRepeatNgramSize == other.mNoRepeatNgramSize + && mNumReturnSequences == other.mNumReturnSequences && mMinP == other.mMinP + && mBeamWidthArray == other.mBeamWidthArray; } // Getters @@ -143,6 +145,11 @@ OptFloat SamplingConfig::getFrequencyPenalty() const return mFrequencyPenalty; } +OptSize32 SamplingConfig::getPromptIgnoreLength() const +{ + return mPromptIgnoreLength; +} + OptFloat SamplingConfig::getLengthPenalty() const { return mLengthPenalty; @@ -240,6 +247,11 @@ void SamplingConfig::setFrequencyPenalty(OptFloat const& frequencyPenalty) mFrequencyPenalty = frequencyPenalty; } +void SamplingConfig::setPromptIgnoreLength(OptSize32 const& promptIgnoreLength) +{ + mPromptIgnoreLength = checkPromptIgnoreLength(promptIgnoreLength); +} + void SamplingConfig::setLengthPenalty(OptFloat const& lengthPenalty) { mLengthPenalty = lengthPenalty; // TODO: re-enable `checkLengthPenalty` later @@ -362,6 +374,15 @@ OptFloat const& SamplingConfig::checkRepetitionPenalty(OptFloat const& repetitio return repetitionpenalty; } +OptSize32 const& SamplingConfig::checkPromptIgnoreLength(OptSize32 const& promptIgnoreLength) +{ + if (promptIgnoreLength.has_value()) + { + TLLM_CHECK(promptIgnoreLength.value() >= 0); + } + return promptIgnoreLength; +} + OptFloat const& SamplingConfig::checkLengthPenalty(OptFloat const& lengthPenalty) { if (lengthPenalty.has_value()) diff --git a/cpp/tensorrt_llm/executor/serialization.cpp b/cpp/tensorrt_llm/executor/serialization.cpp index 1786a43bdbe..b945db66b09 100644 --- a/cpp/tensorrt_llm/executor/serialization.cpp +++ b/cpp/tensorrt_llm/executor/serialization.cpp @@ -159,6 +159,7 @@ SamplingConfig Serialization::deserializeSamplingConfig(std::istream& is) auto repetitionPenalty = su::deserialize>(is); auto presencePenalty = su::deserialize>(is); auto frequencyPenalty = su::deserialize>(is); + auto promptIgnoreLength = su::deserialize>(is); auto lengthPenalty = su::deserialize>(is); auto earlyStopping = su::deserialize>(is); auto noRepeatNgramSize = su::deserialize>(is); @@ -167,8 +168,8 @@ SamplingConfig Serialization::deserializeSamplingConfig(std::istream& is) auto beamWidthArray = su::deserialize>>(is); return SamplingConfig{beamWidth, topK, topP, topPMin, topPResetIds, topPDecay, randomSeed, temperature, minLength, - beamSearchDiversityRate, repetitionPenalty, presencePenalty, frequencyPenalty, lengthPenalty, earlyStopping, - noRepeatNgramSize, numReturnSequences, minP, beamWidthArray}; + beamSearchDiversityRate, repetitionPenalty, presencePenalty, frequencyPenalty, promptIgnoreLength, + lengthPenalty, earlyStopping, noRepeatNgramSize, numReturnSequences, minP, beamWidthArray}; } void Serialization::serialize(SamplingConfig const& config, std::ostream& os) @@ -186,6 +187,7 @@ void Serialization::serialize(SamplingConfig const& config, std::ostream& os) su::serialize(config.mRepetitionPenalty, os); su::serialize(config.mPresencePenalty, os); su::serialize(config.mFrequencyPenalty, os); + su::serialize(config.mPromptIgnoreLength, os); su::serialize(config.mLengthPenalty, os); su::serialize(config.mEarlyStopping, os); su::serialize(config.mNoRepeatNgramSize, os); @@ -210,6 +212,7 @@ size_t Serialization::serializedSize(SamplingConfig const& config) totalSize += su::serializedSize(config.mRepetitionPenalty); totalSize += su::serializedSize(config.mPresencePenalty); totalSize += su::serializedSize(config.mFrequencyPenalty); + totalSize += su::serializedSize(config.mPromptIgnoreLength); totalSize += su::serializedSize(config.mLengthPenalty); totalSize += su::serializedSize(config.mEarlyStopping); totalSize += su::serializedSize(config.mNoRepeatNgramSize); diff --git a/cpp/tensorrt_llm/kernels/penaltyKernels.cu b/cpp/tensorrt_llm/kernels/penaltyKernels.cu index 257ce8a51f2..a85f1742089 100644 --- a/cpp/tensorrt_llm/kernels/penaltyKernels.cu +++ b/cpp/tensorrt_llm/kernels/penaltyKernels.cu @@ -39,10 +39,10 @@ template __global__ void batchApplyPenalty(T const* const* inputLogits, T* outputLogits, T const* biases, TokenIdType* penaltyWorkspace, TokenIdType const* penaltyWorkspacePrev, float const* temperatures, float const* repetitionPenalties, float const* presencePenalties, float const* frequencyPenalties, - SizeType32 maxSeqLen, SizeType32 vocabSize, SizeType32 vocabSizePadded, TokenIdType const** outputIdsPtr, - SizeType32 const** parentIdsPtr, SizeType32 const* inputLengths, SizeType32 const* sequenceLengths, - SizeType32 const* minLengths, TokenIdType const* endIds, SizeType32 const* batchSlots, - SizeType32 const* tokensPerStep, FinishedState const* finished) + SizeType32 const* promptIgnoreLengths, SizeType32 maxSeqLen, SizeType32 vocabSize, SizeType32 vocabSizePadded, + TokenIdType const** outputIdsPtr, SizeType32 const** parentIdsPtr, SizeType32 const* inputLengths, + SizeType32 const* sequenceLengths, SizeType32 const* minLengths, TokenIdType const* endIds, + SizeType32 const* batchSlots, SizeType32 const* tokensPerStep, FinishedState const* finished) { auto const beamWidth = static_cast(gridDim.y); auto const maxTokensPerStep = static_cast(gridDim.z); @@ -73,6 +73,7 @@ __global__ void batchApplyPenalty(T const* const* inputLogits, T* outputLogits, float presencePenalty{layers::DefaultDecodingParams::getPresencePenalty()}; float frequencyPenalty{layers::DefaultDecodingParams::getFrequencyPenalty()}; SizeType32 minLength{layers::DefaultDecodingParams::getMinLength()}; + SizeType32 promptIgnoreLength{layers::DefaultDecodingParams::getPromptIgnoreLength()}; bool accumulateVocab{false}; bool hasTemperature{false}; bool hasMinLength{false}; @@ -103,27 +104,42 @@ __global__ void batchApplyPenalty(T const* const* inputLogits, T* outputLogits, minLength = minLengths[batchSlot]; hasMinLength |= (minLength > 0); } + if (promptIgnoreLengths != nullptr) + { + promptIgnoreLength = min(promptIgnoreLengths[batchSlot], inputLen); + } // Initialize or update the number of occurrences of tokens if (accumulateVocab) { - penaltyWorkspace += batchBeamStepIdx * vocabSize; + penaltyWorkspace += batchBeamStepIdx * 2 * vocabSize; if (currentStep <= inputLen) { // Context phase - for (auto index = static_cast(threadIdx.x); index < vocabSize; + for (auto index = static_cast(threadIdx.x); index < 2 * vocabSize; index += static_cast(blockDim.x)) { penaltyWorkspace[index] = 0; } __syncthreads(); - for (auto step = static_cast(threadIdx.x); step < inputLen; + for (auto step = static_cast(threadIdx.x); step < promptIgnoreLength; + step += static_cast(blockDim.x)) + { + // All beams in the context phase are identical + auto penaltyIndex = outputIdsPtr[batchSlot][beamIdx * maxSeqLen + step]; + if (penaltyIndex < vocabSize) + { + penaltyWorkspace[penaltyIndex] = 1; + } + } + + for (auto step = promptIgnoreLength + static_cast(threadIdx.x); step < inputLen; step += static_cast(blockDim.x)) { // All beams in the context phase are identical auto penaltyIndex = outputIdsPtr[batchSlot][beamIdx * maxSeqLen + step]; if (penaltyIndex < vocabSize) { - atomicAdd(&penaltyWorkspace[penaltyIndex], 1); + atomicAdd(&penaltyWorkspace[penaltyIndex + vocabSize], 1); } } } @@ -132,8 +148,9 @@ __global__ void batchApplyPenalty(T const* const* inputLogits, T* outputLogits, if (beamWidth > 1) { auto parentBeam = parentIdsPtr[batchSlot][beamIdx * maxSeqLen + currentStep - 1]; - penaltyWorkspacePrev += ((batchIdx * beamWidth + parentBeam) * maxTokensPerStep + stepIdx) * vocabSize; - for (auto index = static_cast(threadIdx.x); index < vocabSize; + penaltyWorkspacePrev + += ((batchIdx * beamWidth + parentBeam) * maxTokensPerStep + stepIdx) * (2 * vocabSize); + for (auto index = static_cast(threadIdx.x); index < 2 * vocabSize; index += static_cast(blockDim.x)) { penaltyWorkspace[index] = penaltyWorkspacePrev[index]; @@ -145,7 +162,7 @@ __global__ void batchApplyPenalty(T const* const* inputLogits, T* outputLogits, auto penaltyIndex = outputIdsPtr[batchSlot][beamIdx * maxSeqLen + currentStep - 1]; if (penaltyIndex < vocabSize) { - penaltyWorkspace[penaltyIndex] += 1; + penaltyWorkspace[penaltyIndex + vocabSize] += 1; } } } @@ -174,14 +191,19 @@ __global__ void batchApplyPenalty(T const* const* inputLogits, T* outputLogits, } if (accumulateVocab) { - SizeType32 numOccurences = penaltyWorkspace[index]; - if (numOccurences > 0) + SizeType32 numOccurences = penaltyWorkspace[index + vocabSize]; + SizeType32 ifPresenceInFullSeq = numOccurences | penaltyWorkspace[index]; + if (ifPresenceInFullSeq > 0) { // Repetition if (repetitionPenalties != nullptr) { logit = logit < 0.0f ? logit * repetitionPenalty : logit / repetitionPenalty; } + } + + if (numOccurences > 0) + { // Presence if (presencePenalties != nullptr) { @@ -230,9 +252,10 @@ void invokeBatchApplyPenalty(InvokeBatchApplyPenaltyParams const& params) dim3 grid(params.batchSize, params.beamWidth, params.maxTokensPerStep); batchApplyPenalty<<>>(params.inputLogits, params.outputLogits, params.biases, params.penaltyWorkspace, params.penaltyWorkspacePrev, params.temperatures, params.repetitionPenalties, - params.presencePenalties, params.frequencyPenalties, params.maxSeqLen, params.vocabSize, params.vocabSizePadded, - params.outputIdsPtr, params.parentIdsPtr, params.inputLengths, params.sequenceLengths, params.minLengths, - params.endIds, params.batchSlots, params.tokensPerStep, params.finished); + params.presencePenalties, params.frequencyPenalties, params.promptIgnoreLengths, params.maxSeqLen, + params.vocabSize, params.vocabSizePadded, params.outputIdsPtr, params.parentIdsPtr, params.inputLengths, + params.sequenceLengths, params.minLengths, params.endIds, params.batchSlots, params.tokensPerStep, + params.finished); } template void invokeBatchApplyPenalty(InvokeBatchApplyPenaltyParams const& params); diff --git a/cpp/tensorrt_llm/kernels/penaltyKernels.h b/cpp/tensorrt_llm/kernels/penaltyKernels.h index 152e87a761a..c6ab87951d8 100644 --- a/cpp/tensorrt_llm/kernels/penaltyKernels.h +++ b/cpp/tensorrt_llm/kernels/penaltyKernels.h @@ -35,6 +35,7 @@ struct InvokeBatchApplyPenaltyParams float const* repetitionPenalties; float const* presencePenalties; float const* frequencyPenalties; + runtime::SizeType32 const* promptIgnoreLengths; runtime::SizeType32 batchSize; runtime::SizeType32 beamWidth; runtime::SizeType32 maxSeqLen; diff --git a/cpp/tensorrt_llm/kernels/penaltyTypes.h b/cpp/tensorrt_llm/kernels/penaltyTypes.h index 45532c29bea..79ab634967f 100644 --- a/cpp/tensorrt_llm/kernels/penaltyTypes.h +++ b/cpp/tensorrt_llm/kernels/penaltyTypes.h @@ -29,11 +29,12 @@ namespace kernels enum class DecodingPenaltyType { - Temperature, // the temperature penalty - Repetition, // the repetition penalty - Presence, // the presence penalty - Frequency, // the frequency penalty - MinLength, // the min length penalty + Temperature, // the temperature penalty + Repetition, // the repetition penalty + Presence, // the presence penalty + Frequency, // the frequency penalty + MinLength, // the min length penalty + PromptIgnoreLength, // the prompt ignore length for presence/frequency penalty }; inline std::pair getLimitsPenalty(DecodingPenaltyType penaltyType) @@ -49,6 +50,7 @@ inline std::pair getLimitsPenalty(DecodingPenaltyType penaltyType) case DecodingPenaltyType::Presence: return std::make_pair(fltMin, fltMax); case DecodingPenaltyType::Frequency: return std::make_pair(fltMin, fltMax); case DecodingPenaltyType::MinLength: return std::make_pair(-fltEpsilon, fltMax); + case DecodingPenaltyType::PromptIgnoreLength: return std::make_pair(-fltEpsilon, fltMax); } TLLM_CHECK_WITH_INFO(false, "Unknown penalty type %d", static_cast(penaltyType)); return std::make_pair(fltMin, fltMax); diff --git a/cpp/tensorrt_llm/layers/decodingParams.h b/cpp/tensorrt_llm/layers/decodingParams.h index 4298dbb8903..af254bed8f9 100644 --- a/cpp/tensorrt_llm/layers/decodingParams.h +++ b/cpp/tensorrt_llm/layers/decodingParams.h @@ -128,11 +128,12 @@ class BaseSetupParams class PenaltySetupParams : public BaseSetupParams { public: - OptVec temperature; // [1] or [setupBatchSize] - OptVec minLength; // [1] or [setupBatchSize] - OptVec repetitionPenalty; // [1] or [setupBatchSize] - OptVec presencePenalty; // [1] or [setupBatchSize] - OptVec frequencyPenalty; // [1] or [setupBatchSize] + OptVec temperature; // [1] or [setupBatchSize] + OptVec minLength; // [1] or [setupBatchSize] + OptVec repetitionPenalty; // [1] or [setupBatchSize] + OptVec presencePenalty; // [1] or [setupBatchSize] + OptVec frequencyPenalty; // [1] or [setupBatchSize] + OptVec promptIgnoreLength; // [1] or [setupBatchSize] }; // Ban words layer diff --git a/cpp/tensorrt_llm/layers/penaltyLayer.cpp b/cpp/tensorrt_llm/layers/penaltyLayer.cpp index 3fb0741b3bb..c6c57ca5034 100644 --- a/cpp/tensorrt_llm/layers/penaltyLayer.cpp +++ b/cpp/tensorrt_llm/layers/penaltyLayer.cpp @@ -83,7 +83,7 @@ void PenaltyLayer::allocateWorkspace() { auto const workspaceSize = mDecoderDomain.getBatchSize() * mDecoderDomain.getMaxDecodingTokens() - * mConfiguredBeamWidth * mDecoderDomain.getVocabSize(); + * mConfiguredBeamWidth * mDecoderDomain.getVocabSize() * 2; mPenaltyWorkspaceDevice = mBufferManager->gpu(workspaceSize, nvinfer1::DataType::kINT32); if (mDecodingMode.isBeamSearch()) @@ -107,6 +107,7 @@ void PenaltyLayer::allocateBuffer() mPresencePenalty = mBufferManager->pinnedPool(batchSizeShape, TRTDataType::value); mFrequencyPenalty = mBufferManager->pinnedPool(batchSizeShape, TRTDataType::value); mMinLength = mBufferManager->pinnedPool(batchSizeShape, TRTDataType::value); + mPromptIgnoreLength = mBufferManager->pinnedPool(batchSizeShape, TRTDataType::value); if (mDecodingMode.isUseTemperature()) { @@ -128,6 +129,10 @@ void PenaltyLayer::allocateBuffer() { mMinLengthDevice = mBufferManager->gpu(batchSizeShape, nvinfer1::DataType::kINT32); } + if (mDecodingMode.isUseOccurrencePenalty()) + { + mPromptIgnoreLengthDevice = mBufferManager->gpu(batchSizeShape, nvinfer1::DataType::kINT32); + } auto const logitsPtrDeviceDesc = std::make_pair(batchSizeShape, TRTDataType::value); mWorkspaceSize = DecodingLayerWorkspace::calculateRequiredWorkspaceSize(logitsPtrDeviceDesc); @@ -169,6 +174,8 @@ void PenaltyLayer::setup(SizeType32 batchSize, SizeType32 beamWidth, TensorCo bool const useFrequencyPenalty = mDecodingMode.isUseFrequencyPenalty() && penaltyParams->frequencyPenalty.has_value(); bool const useMinLength = mDecodingMode.isUseMinLength() && penaltyParams->minLength.has_value(); + bool const usePromptIgnoreLength + = mDecodingMode.isUseOccurrencePenalty() && penaltyParams->promptIgnoreLength.has_value(); // FIXME: once one of the requests has some penalty, we will always have to compute it. // To avoid that we need to scan through all active requests at each iteration. mUseTemperature |= useTemperature; @@ -176,6 +183,7 @@ void PenaltyLayer::setup(SizeType32 batchSize, SizeType32 beamWidth, TensorCo mUsePresencePenalty |= usePresencePenalty; mUseFrequencyPenalty |= useFrequencyPenalty; mUseMinLength |= useMinLength; + mUsePromptIgnoreLength |= usePromptIgnoreLength; if (mUseTemperature) { @@ -203,10 +211,16 @@ void PenaltyLayer::setup(SizeType32 batchSize, SizeType32 beamWidth, TensorCo fillBuffers(penaltyParams->minLength, DefaultDecodingParams::getMinLength(), mMinLength, mMinLengthDevice, batchSlots, getLimitsPenalty(DecodingPenaltyType::MinLength), "min length"); } + if (mUsePromptIgnoreLength) + { + fillBuffers(penaltyParams->promptIgnoreLength, DefaultDecodingParams::getPromptIgnoreLength(), + mPromptIgnoreLength, mPromptIgnoreLengthDevice, batchSlots, + getLimitsPenalty(DecodingPenaltyType::PromptIgnoreLength), "prompt ignore length"); + } // Reset penalty workspace auto const workspaceSizePerBatch - = mDecoderDomain.getMaxDecodingTokens() * mConfiguredBeamWidth * mDecoderDomain.getVocabSize(); + = mDecoderDomain.getMaxDecodingTokens() * mConfiguredBeamWidth * mDecoderDomain.getVocabSize() * 2; for (SizeType32 bi = 0; bi < batchSize; ++bi) { auto batchSlot = runtime::bufferCast(*batchSlots)[bi]; @@ -287,6 +301,7 @@ void PenaltyLayer::forwardAsync(std::shared_ptr const& b auto presencePenalties = GET_PENALTIES(PresencePenalty, float); auto frequencyPenalties = GET_PENALTIES(FrequencyPenalty, float); auto minLengths = GET_PENALTIES(MinLength, SizeType32); + auto promptIgnoreLengths = GET_PENALTIES(PromptIgnoreLength, SizeType32); #undef GET_PENALTIES @@ -316,6 +331,7 @@ void PenaltyLayer::forwardAsync(std::shared_ptr const& b penaltyParams.inputLengths = inputLengths; penaltyParams.sequenceLengths = bufferCast(*outputs->sequenceLength.value()); penaltyParams.minLengths = bufferCastOrNull(minLengths); + penaltyParams.promptIgnoreLengths = bufferCastOrNull(promptIgnoreLengths); penaltyParams.endIds = bufferCast(*params->endIds); penaltyParams.batchSlots = workspace->getDeviceBatchSlotsPtr(); penaltyParams.maxTokensPerStep = mDecoderDomain.getMaxDecodingTokens(); diff --git a/cpp/tensorrt_llm/layers/penaltyLayer.h b/cpp/tensorrt_llm/layers/penaltyLayer.h index 16294fdf9c8..7fa9e3a3895 100644 --- a/cpp/tensorrt_llm/layers/penaltyLayer.h +++ b/cpp/tensorrt_llm/layers/penaltyLayer.h @@ -67,18 +67,21 @@ class PenaltyLayer : public BaseLayer TensorPtr mPresencePenaltyDevice; TensorPtr mFrequencyPenaltyDevice; TensorPtr mMinLengthDevice; + TensorPtr mPromptIgnoreLengthDevice; TensorPtr mTemperature; TensorPtr mRepetitionPenalty; TensorPtr mPresencePenalty; TensorPtr mFrequencyPenalty; TensorPtr mMinLength; + TensorPtr mPromptIgnoreLength; bool mUseTemperature{false}; bool mUseRepetitionPenalty{false}; bool mUsePresencePenalty{false}; bool mUseFrequencyPenalty{false}; bool mUseMinLength{false}; + bool mUsePromptIgnoreLength{false}; runtime::SizeType32 mCyclicStep{0}; runtime::SizeType32 mRuntimeMaxSeqLen{0}; diff --git a/cpp/tensorrt_llm/nanobind/bindings.cpp b/cpp/tensorrt_llm/nanobind/bindings.cpp index b1e9e3094fd..8d0585bfc71 100644 --- a/cpp/tensorrt_llm/nanobind/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/bindings.cpp @@ -370,14 +370,14 @@ NB_MODULE(TRTLLM_NB_MODULE, m) auto SamplingConfigGetState = [](tr::SamplingConfig const& config) -> nb::tuple { return nb::make_tuple(config.beamWidth, config.temperature, config.minLength, config.repetitionPenalty, - config.presencePenalty, config.frequencyPenalty, config.topK, config.topP, config.randomSeed, - config.topPDecay, config.topPMin, config.topPResetIds, config.beamSearchDiversityRate, config.lengthPenalty, - config.earlyStopping, config.noRepeatNgramSize, config.numReturnSequences, config.minP, - config.beamWidthArray); + config.presencePenalty, config.frequencyPenalty, config.promptIgnoreLength, config.topK, config.topP, + config.randomSeed, config.topPDecay, config.topPMin, config.topPResetIds, config.beamSearchDiversityRate, + config.lengthPenalty, config.earlyStopping, config.noRepeatNgramSize, config.numReturnSequences, + config.minP, config.beamWidthArray); }; auto SamplingConfigSetState = [](tr::SamplingConfig& self, nb::tuple t) { - if (t.size() != 19) + if (t.size() != 20) { throw std::runtime_error("Invalid SamplingConfig state!"); } @@ -389,19 +389,20 @@ NB_MODULE(TRTLLM_NB_MODULE, m) config.repetitionPenalty = nb::cast>(t[3]); config.presencePenalty = nb::cast>(t[4]); config.frequencyPenalty = nb::cast>(t[5]); - config.topK = nb::cast>(t[6]); - config.topP = nb::cast>(t[7]); - config.randomSeed = nb::cast>(t[8]); - config.topPDecay = nb::cast>(t[9]); - config.topPMin = nb::cast>(t[10]); - config.topPResetIds = nb::cast>(t[11]); - config.beamSearchDiversityRate = nb::cast>(t[12]); - config.lengthPenalty = nb::cast>(t[13]); - config.earlyStopping = nb::cast>(t[14]); - config.noRepeatNgramSize = nb::cast>(t[15]); - config.numReturnSequences = nb::cast(t[16]); - config.minP = nb::cast>(t[17]); - config.beamWidthArray = nb::cast>>(t[18]); + config.promptIgnoreLength = nb::cast>(t[6]); + config.topK = nb::cast>(t[7]); + config.topP = nb::cast>(t[8]); + config.randomSeed = nb::cast>(t[9]); + config.topPDecay = nb::cast>(t[10]); + config.topPMin = nb::cast>(t[11]); + config.topPResetIds = nb::cast>(t[12]); + config.beamSearchDiversityRate = nb::cast>(t[13]); + config.lengthPenalty = nb::cast>(t[14]); + config.earlyStopping = nb::cast>(t[15]); + config.noRepeatNgramSize = nb::cast>(t[16]); + config.numReturnSequences = nb::cast(t[17]); + config.minP = nb::cast>(t[18]); + config.beamWidthArray = nb::cast>>(t[19]); new (&self) tr::SamplingConfig(config); }; @@ -416,6 +417,7 @@ NB_MODULE(TRTLLM_NB_MODULE, m) .def_rw("repetition_penalty", &tr::SamplingConfig::repetitionPenalty) .def_rw("presence_penalty", &tr::SamplingConfig::presencePenalty) .def_rw("frequency_penalty", &tr::SamplingConfig::frequencyPenalty) + .def_rw("prompt_ignore_length", &tr::SamplingConfig::promptIgnoreLength) .def_rw("top_k", &tr::SamplingConfig::topK) .def_rw("top_p", &tr::SamplingConfig::topP) .def_rw("random_seed", &tr::SamplingConfig::randomSeed) diff --git a/cpp/tensorrt_llm/nanobind/executor/request.cpp b/cpp/tensorrt_llm/nanobind/executor/request.cpp index de9aa8a8c07..db05409d866 100644 --- a/cpp/tensorrt_llm/nanobind/executor/request.cpp +++ b/cpp/tensorrt_llm/nanobind/executor/request.cpp @@ -76,12 +76,12 @@ void initRequestBindings(nb::module_& m) return nb::make_tuple(self.getBeamWidth(), self.getTopK(), self.getTopP(), self.getTopPMin(), self.getTopPResetIds(), self.getTopPDecay(), self.getSeed(), self.getTemperature(), self.getMinTokens(), self.getBeamSearchDiversityRate(), self.getRepetitionPenalty(), self.getPresencePenalty(), - self.getFrequencyPenalty(), self.getLengthPenalty(), self.getEarlyStopping(), self.getNoRepeatNgramSize(), - self.getNumReturnSequences(), self.getMinP(), self.getBeamWidthArray()); + self.getFrequencyPenalty(), self.getPromptIgnoreLength(), self.getLengthPenalty(), self.getEarlyStopping(), + self.getNoRepeatNgramSize(), self.getNumReturnSequences(), self.getMinP(), self.getBeamWidthArray()); }; auto samplingConfigSetstate = [](tle::SamplingConfig& samplingConfig, nb::tuple const& state) { - if (state.size() != 19) + if (state.size() != 20) { throw std::runtime_error("Invalid SamplingConfig state!"); } @@ -98,12 +98,13 @@ void initRequestBindings(nb::module_& m) nb::cast>(state[10]), // RepetitionPenalty nb::cast>(state[11]), // PresencePenalty nb::cast>(state[12]), // FrequencyPenalty - nb::cast>(state[13]), // LengthPenalty - nb::cast>(state[14]), // EarlyStopping - nb::cast>(state[15]), // NoRepeatNgramSize - nb::cast>(state[16]), // NumReturnSequences - nb::cast>(state[17]), // MinP - nb::cast>>(state[18]) // BeamWidthArray + nb::cast>(state[13]), // PromptIgnoreLength + nb::cast>(state[14]), // LengthPenalty + nb::cast>(state[15]), // EarlyStopping + nb::cast>(state[16]), // NoRepeatNgramSize + nb::cast>(state[17]), // NumReturnSequences + nb::cast>(state[18]), // MinP + nb::cast>>(state[19]) // BeamWidthArray ); }; nb::class_(m, "SamplingConfig") @@ -120,6 +121,7 @@ void initRequestBindings(nb::module_& m) std::optional const&, // repetitionPenalty std::optional const&, // presencePenalty std::optional const&, // frequencyPenalty + std::optional const&, // promptIgnoreLength std::optional const&, // lengthPenalty std::optional const&, // earlyStopping std::optional const&, // noRepeatNgramSize @@ -142,6 +144,7 @@ void initRequestBindings(nb::module_& m) nb::arg("repetition_penalty") = nb::none(), nb::arg("presence_penalty") = nb::none(), nb::arg("frequency_penalty") = nb::none(), + nb::arg("prompt_ignore_length") = nb::none(), nb::arg("length_penalty") = nb::none(), nb::arg("early_stopping") = nb::none(), nb::arg("no_repeat_ngram_size") = nb::none(), @@ -165,6 +168,8 @@ void initRequestBindings(nb::module_& m) [](tle::SamplingConfig& self, std::optional v) { self.setPresencePenalty(v); }) .def_prop_rw( "frequency_penalty", &tle::SamplingConfig::getFrequencyPenalty, &tle::SamplingConfig::setFrequencyPenalty) + .def_prop_rw("prompt_ignore_length", &tle::SamplingConfig::getPromptIgnoreLength, + &tle::SamplingConfig::setPromptIgnoreLength) .def_prop_rw("length_penalty", &tle::SamplingConfig::getLengthPenalty, &tle::SamplingConfig::setLengthPenalty) .def_prop_rw("early_stopping", &tle::SamplingConfig::getEarlyStopping, &tle::SamplingConfig::setEarlyStopping) .def_prop_rw("no_repeat_ngram_size", &tle::SamplingConfig::getNoRepeatNgramSize, diff --git a/cpp/tensorrt_llm/pybind/bindings.cpp b/cpp/tensorrt_llm/pybind/bindings.cpp index 3ed6b45054c..4b5415afd95 100644 --- a/cpp/tensorrt_llm/pybind/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/bindings.cpp @@ -361,14 +361,14 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m) auto SamplingConfigGetState = [](tr::SamplingConfig const& config) -> py::tuple { return py::make_tuple(config.beamWidth, config.temperature, config.minLength, config.repetitionPenalty, - config.presencePenalty, config.frequencyPenalty, config.topK, config.topP, config.randomSeed, - config.topPDecay, config.topPMin, config.topPResetIds, config.beamSearchDiversityRate, config.lengthPenalty, - config.earlyStopping, config.noRepeatNgramSize, config.numReturnSequences, config.minP, - config.beamWidthArray); + config.presencePenalty, config.frequencyPenalty, config.promptIgnoreLength, config.topK, config.topP, + config.randomSeed, config.topPDecay, config.topPMin, config.topPResetIds, config.beamSearchDiversityRate, + config.lengthPenalty, config.earlyStopping, config.noRepeatNgramSize, config.numReturnSequences, + config.minP, config.beamWidthArray); }; auto SamplingConfigSetState = [](py::tuple t) -> tr::SamplingConfig { - if (t.size() != 19) + if (t.size() != 20) { throw std::runtime_error("Invalid SamplingConfig state!"); } @@ -380,19 +380,20 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m) config.repetitionPenalty = t[3].cast>(); config.presencePenalty = t[4].cast>(); config.frequencyPenalty = t[5].cast>(); - config.topK = t[6].cast>(); - config.topP = t[7].cast>(); - config.randomSeed = t[8].cast>(); - config.topPDecay = t[9].cast>(); - config.topPMin = t[10].cast>(); - config.topPResetIds = t[11].cast>(); - config.beamSearchDiversityRate = t[12].cast>(); - config.lengthPenalty = t[13].cast>(); - config.earlyStopping = t[14].cast>(); - config.noRepeatNgramSize = t[15].cast>(); - config.numReturnSequences = t[16].cast(); - config.minP = t[17].cast>(); - config.beamWidthArray = t[18].cast>>(); + config.promptIgnoreLength = t[6].cast>(); + config.topK = t[7].cast>(); + config.topP = t[8].cast>(); + config.randomSeed = t[9].cast>(); + config.topPDecay = t[10].cast>(); + config.topPMin = t[11].cast>(); + config.topPResetIds = t[12].cast>(); + config.beamSearchDiversityRate = t[13].cast>(); + config.lengthPenalty = t[14].cast>(); + config.earlyStopping = t[15].cast>(); + config.noRepeatNgramSize = t[16].cast>(); + config.numReturnSequences = t[17].cast(); + config.minP = t[18].cast>(); + config.beamWidthArray = t[19].cast>>(); return config; }; @@ -407,6 +408,7 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m) .def_readwrite("repetition_penalty", &tr::SamplingConfig::repetitionPenalty) .def_readwrite("presence_penalty", &tr::SamplingConfig::presencePenalty) .def_readwrite("frequency_penalty", &tr::SamplingConfig::frequencyPenalty) + .def_readwrite("prompt_ignore_length", &tr::SamplingConfig::promptIgnoreLength) .def_readwrite("top_k", &tr::SamplingConfig::topK) .def_readwrite("top_p", &tr::SamplingConfig::topP) .def_readwrite("random_seed", &tr::SamplingConfig::randomSeed) diff --git a/cpp/tensorrt_llm/pybind/executor/request.cpp b/cpp/tensorrt_llm/pybind/executor/request.cpp index 097f598557b..2e9dae860e0 100644 --- a/cpp/tensorrt_llm/pybind/executor/request.cpp +++ b/cpp/tensorrt_llm/pybind/executor/request.cpp @@ -72,12 +72,12 @@ void initRequestBindings(pybind11::module_& m) return py::make_tuple(self.getBeamWidth(), self.getTopK(), self.getTopP(), self.getTopPMin(), self.getTopPResetIds(), self.getTopPDecay(), self.getSeed(), self.getTemperature(), self.getMinTokens(), self.getBeamSearchDiversityRate(), self.getRepetitionPenalty(), self.getPresencePenalty(), - self.getFrequencyPenalty(), self.getLengthPenalty(), self.getEarlyStopping(), self.getNoRepeatNgramSize(), - self.getNumReturnSequences(), self.getMinP(), self.getBeamWidthArray()); + self.getFrequencyPenalty(), self.getPromptIgnoreLength(), self.getLengthPenalty(), self.getEarlyStopping(), + self.getNoRepeatNgramSize(), self.getNumReturnSequences(), self.getMinP(), self.getBeamWidthArray()); }; auto samplingConfigSetstate = [](py::tuple const& state) { - if (state.size() != 19) + if (state.size() != 20) { throw std::runtime_error("Invalid SamplingConfig state!"); } @@ -94,12 +94,13 @@ void initRequestBindings(pybind11::module_& m) state[10].cast>(), // RepetitionPenalty state[11].cast>(), // PresencePenalty state[12].cast>(), // FrequencyPenalty - state[13].cast>(), // LengthPenalty - state[14].cast>(), // EarlyStopping - state[15].cast>(), // NoRepeatNgramSize - state[16].cast>(), // NumReturnSequences - state[17].cast>(), // MinP - state[18].cast>>() // BeamWidthArray + state[13].cast>(), // PromptIgnoreLength + state[14].cast>(), // LengthPenalty + state[15].cast>(), // EarlyStopping + state[16].cast>(), // NoRepeatNgramSize + state[17].cast>(), // NumReturnSequences + state[18].cast>(), // MinP + state[19].cast>>() // BeamWidthArray ); }; py::class_(m, "SamplingConfig") @@ -116,6 +117,7 @@ void initRequestBindings(pybind11::module_& m) std::optional const&, // repetitionPenalty std::optional const&, // presencePenalty std::optional const&, // frequencyPenalty + std::optional const&, // promptIgnoreLength std::optional const&, // lengthPenalty std::optional const&, // earlyStopping std::optional const&, // noRepeatNgramSize @@ -138,6 +140,7 @@ void initRequestBindings(pybind11::module_& m) py::arg("repetition_penalty") = py::none(), py::arg("presence_penalty") = py::none(), py::arg("frequency_penalty") = py::none(), + py::arg("prompt_ignore_length") = py::none(), py::arg("length_penalty") = py::none(), py::arg("early_stopping") = py::none(), py::arg("no_repeat_ngram_size") = py::none(), @@ -161,6 +164,8 @@ void initRequestBindings(pybind11::module_& m) [](tle::SamplingConfig& self, std::optional v) { self.setPresencePenalty(v); }) .def_property( "frequency_penalty", &tle::SamplingConfig::getFrequencyPenalty, &tle::SamplingConfig::setFrequencyPenalty) + .def_property("prompt_ignore_length", &tle::SamplingConfig::getPromptIgnoreLength, + &tle::SamplingConfig::setPromptIgnoreLength) .def_property("length_penalty", &tle::SamplingConfig::getLengthPenalty, &tle::SamplingConfig::setLengthPenalty) .def_property("early_stopping", &tle::SamplingConfig::getEarlyStopping, &tle::SamplingConfig::setEarlyStopping) .def_property("no_repeat_ngram_size", &tle::SamplingConfig::getNoRepeatNgramSize, diff --git a/cpp/tensorrt_llm/runtime/gptDecoder.cpp b/cpp/tensorrt_llm/runtime/gptDecoder.cpp index 610eae11385..93087720646 100644 --- a/cpp/tensorrt_llm/runtime/gptDecoder.cpp +++ b/cpp/tensorrt_llm/runtime/gptDecoder.cpp @@ -84,6 +84,7 @@ void GptDecoder::disableLookahead( penaltyParams->repetitionPenalty = mSamplingConfig.repetitionPenalty; penaltyParams->presencePenalty = mSamplingConfig.presencePenalty; penaltyParams->frequencyPenalty = mSamplingConfig.frequencyPenalty; + penaltyParams->promptIgnoreLength = mSamplingConfig.promptIgnoreLength; penaltyParams->temperature = mSamplingConfig.temperature; penaltyParams->minLength = mSamplingConfig.minLength; @@ -136,6 +137,7 @@ void GptDecoder::setup(SamplingConfig const& samplingConfig, size_t batchSize penaltyParams->repetitionPenalty = mSamplingConfig.repetitionPenalty; penaltyParams->presencePenalty = mSamplingConfig.presencePenalty; penaltyParams->frequencyPenalty = mSamplingConfig.frequencyPenalty; + penaltyParams->promptIgnoreLength = mSamplingConfig.promptIgnoreLength; penaltyParams->temperature = mSamplingConfig.temperature; penaltyParams->minLength = mSamplingConfig.minLength; diff --git a/cpp/tensorrt_llm/thop/dynamicDecodeOp.cpp b/cpp/tensorrt_llm/thop/dynamicDecodeOp.cpp index 26563fc9646..f9e0e76a46d 100644 --- a/cpp/tensorrt_llm/thop/dynamicDecodeOp.cpp +++ b/cpp/tensorrt_llm/thop/dynamicDecodeOp.cpp @@ -120,12 +120,12 @@ void FtDynamicDecode::setup(size_t const batch_size, size_t const beam_width, th::optional runtime_top_k_opt, th::optional runtime_top_p_opt, th::optional temperature_opt, th::optional repetition_penalty_opt, th::optional presence_penalty_opt, th::optional frequency_penalty_opt, - th::optional min_length_opt, th::optional length_penalty_opt, - th::optional early_stopping_opt, th::optional beam_search_diversity_rate_opt, - th::optional random_seed_opt, th::optional top_p_decay_opt, - th::optional top_p_min_opt, th::optional top_p_reset_ids_opt, - th::optional no_repeat_ngram_size_opt, th::optional min_p_opt, bool output_log_probs, - bool cum_log_probs) + th::optional prompt_ignore_length_opt, th::optional min_length_opt, + th::optional length_penalty_opt, th::optional early_stopping_opt, + th::optional beam_search_diversity_rate_opt, th::optional random_seed_opt, + th::optional top_p_decay_opt, th::optional top_p_min_opt, + th::optional top_p_reset_ids_opt, th::optional no_repeat_ngram_size_opt, + th::optional min_p_opt, bool output_log_probs, bool cum_log_probs) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); mBeamWidth = beam_width; @@ -137,6 +137,7 @@ void FtDynamicDecode::setup(size_t const batch_size, size_t const beam_width, safeInsert(repetition_penalty_opt, penaltyParams->repetitionPenalty); safeInsert(presence_penalty_opt, penaltyParams->presencePenalty); safeInsert(frequency_penalty_opt, penaltyParams->frequencyPenalty); + safeInsert(prompt_ignore_length_opt, penaltyParams->promptIgnoreLength); safeInsert(min_length_opt, penaltyParams->minLength); safeInsert(no_repeat_ngram_size_opt, banWordsParams->noRepeatNgramSize); if (beam_width == 1) @@ -328,10 +329,10 @@ void DynamicDecodeOp::createInstance() void DynamicDecodeOp::setup(int64_t const batchSize, int64_t const beamWidth, th::optional runtimeTopKOpt, th::optional runtimeTopPOpt, th::optional temperatureOpt, th::optional repetitionPenaltyOpt, th::optional presencePenaltyOpt, - th::optional frequencyPenaltyOpt, th::optional minLengthOpt, - th::optional lengthPenaltyOpt, th::optional earlyStoppingOpt, - th::optional beamSearchDiversityRateOpt, th::optional randomSeedOpt, - th::optional topPDecayOpt, th::optional topPMinOpt, + th::optional frequencyPenaltyOpt, th::optional promptIgnoreLengthOpt, + th::optional minLengthOpt, th::optional lengthPenaltyOpt, + th::optional earlyStoppingOpt, th::optional beamSearchDiversityRateOpt, + th::optional randomSeedOpt, th::optional topPDecayOpt, th::optional topPMinOpt, th::optional topPResetIdsOpt, th::optional noRepeatNgramSizeOpt, th::optional minPOpt, bool outputLogProbs, bool cumLogProbs) { @@ -343,6 +344,7 @@ void DynamicDecodeOp::setup(int64_t const batchSize, int64_t const beamWidth, th CHECK_OPTIONAL_CPU_INPUT(repetitionPenaltyOpt, torch::kFloat); CHECK_OPTIONAL_CPU_INPUT(presencePenaltyOpt, torch::kFloat); CHECK_OPTIONAL_CPU_INPUT(frequencyPenaltyOpt, torch::kFloat); + CHECK_OPTIONAL_CPU_INPUT(promptIgnoreLengthOpt, torch::kInt32); CHECK_OPTIONAL_CPU_INPUT(minLengthOpt, torch::kInt32); CHECK_OPTIONAL_CPU_INPUT(lengthPenaltyOpt, torch::kFloat); CHECK_OPTIONAL_CPU_INPUT(earlyStoppingOpt, torch::kInt32); @@ -356,8 +358,9 @@ void DynamicDecodeOp::setup(int64_t const batchSize, int64_t const beamWidth, th dynamicDecode_->setup(static_cast(batchSize), static_cast(beamWidth), runtimeTopKOpt, runtimeTopPOpt, temperatureOpt, repetitionPenaltyOpt, presencePenaltyOpt, frequencyPenaltyOpt, - minLengthOpt, lengthPenaltyOpt, earlyStoppingOpt, beamSearchDiversityRateOpt, randomSeedOpt, topPDecayOpt, - topPMinOpt, topPResetIdsOpt, noRepeatNgramSizeOpt, minPOpt, outputLogProbs, cumLogProbs); + promptIgnoreLengthOpt, minLengthOpt, lengthPenaltyOpt, earlyStoppingOpt, beamSearchDiversityRateOpt, + randomSeedOpt, topPDecayOpt, topPMinOpt, topPResetIdsOpt, noRepeatNgramSizeOpt, minPOpt, outputLogProbs, + cumLogProbs); } th::Tensor DynamicDecodeOp::forward( diff --git a/cpp/tensorrt_llm/thop/dynamicDecodeOp.h b/cpp/tensorrt_llm/thop/dynamicDecodeOp.h index ea2ea828aad..533066cc2a0 100644 --- a/cpp/tensorrt_llm/thop/dynamicDecodeOp.h +++ b/cpp/tensorrt_llm/thop/dynamicDecodeOp.h @@ -32,12 +32,13 @@ class IFtDynamicDecode virtual void setup(size_t const batch_size, size_t const beam_width, th::optional runtime_top_k_opt, th::optional runtime_top_p_opt, th::optional temperature_opt, th::optional repetition_penalty_opt, th::optional presence_penalty_opt, - th::optional frequency_penalty_opt, th::optional min_length_opt, - th::optional length_penalty_opt, th::optional early_stopping_opt, - th::optional beam_search_diversity_rate_opt, th::optional random_seed_opt, - th::optional top_p_decay_opt, th::optional top_p_min_opt, - th::optional top_p_reset_ids_opt, th::optional no_repeat_ngram_size_opt, - th::optional min_p_opt, bool output_log_probs, bool cum_log_probs) + th::optional frequency_penalty_opt, th::optional prompt_ignore_length_opt, + th::optional min_length_opt, th::optional length_penalty_opt, + th::optional early_stopping_opt, th::optional beam_search_diversity_rate_opt, + th::optional random_seed_opt, th::optional top_p_decay_opt, + th::optional top_p_min_opt, th::optional top_p_reset_ids_opt, + th::optional no_repeat_ngram_size_opt, th::optional min_p_opt, bool output_log_probs, + bool cum_log_probs) = 0; virtual void forward(th::Tensor const& logits, int const step, int const max_input_length, @@ -72,12 +73,13 @@ class FtDynamicDecode : public IFtDynamicDecode void setup(size_t const batch_size, size_t const beam_width, th::optional runtime_top_k_opt, th::optional runtime_top_p_opt, th::optional temperature_opt, th::optional repetition_penalty_opt, th::optional presence_penalty_opt, - th::optional frequency_penalty_opt, th::optional min_length_opt, - th::optional length_penalty_opt, th::optional early_stopping_opt, - th::optional beam_search_diversity_rate_opt, th::optional random_seed_opt, - th::optional top_p_decay_opt, th::optional top_p_min_opt, - th::optional top_p_reset_ids_opt, th::optional no_repeat_ngram_size_opt, - th::optional min_p_opt, bool output_log_probs, bool cum_log_probs) override; + th::optional frequency_penalty_opt, th::optional prompt_ignore_length_opt, + th::optional min_length_opt, th::optional length_penalty_opt, + th::optional early_stopping_opt, th::optional beam_search_diversity_rate_opt, + th::optional random_seed_opt, th::optional top_p_decay_opt, + th::optional top_p_min_opt, th::optional top_p_reset_ids_opt, + th::optional no_repeat_ngram_size_opt, th::optional min_p_opt, bool output_log_probs, + bool cum_log_probs) override; void forward(th::Tensor const& logits, int const step, int const max_input_length, int const max_attention_window, int const sink_token_length, uint64_t const ite, int const local_batch_size, th::Tensor end_id, @@ -115,12 +117,13 @@ class DynamicDecodeOp : public th::jit::CustomClassHolder void setup(int64_t const batch_size, int64_t const beam_width, th::optional runtime_top_k_opt, th::optional runtime_top_p_opt, th::optional temperature_opt, th::optional repetition_penalty_opt, th::optional presence_penalty_opt, - th::optional frequency_penalty_opt, th::optional min_length_opt, - th::optional length_penalty_opt, th::optional early_stopping_opt, - th::optional beam_search_diversity_rate_opt, th::optional random_seed_opt, - th::optional top_p_decay_opt, th::optional top_p_min_opt, - th::optional top_p_reset_ids_opt, th::optional no_repeat_ngram_size_opt, - th::optional min_p_opt, bool output_log_probs, bool cum_log_probs); + th::optional frequency_penalty_opt, th::optional prompt_ignore_length_opt, + th::optional min_length_opt, th::optional length_penalty_opt, + th::optional early_stopping_opt, th::optional beam_search_diversity_rate_opt, + th::optional random_seed_opt, th::optional top_p_decay_opt, + th::optional top_p_min_opt, th::optional top_p_reset_ids_opt, + th::optional no_repeat_ngram_size_opt, th::optional min_p_opt, bool output_log_probs, + bool cum_log_probs); th::Tensor forward(th::Tensor const& logits, int64_t const step, int64_t const max_input_length, int64_t const max_attention_window, int64_t const sink_token_length, int64_t const ite, diff --git a/cpp/tests/unit_tests/executor/samplingConfigTest.cpp b/cpp/tests/unit_tests/executor/samplingConfigTest.cpp index 20b1a71ef9d..ca9d4621d27 100644 --- a/cpp/tests/unit_tests/executor/samplingConfigTest.cpp +++ b/cpp/tests/unit_tests/executor/samplingConfigTest.cpp @@ -34,17 +34,18 @@ void test(bool const isTestValid, SizeType32 beamWidth = 1, std::optional randomSeed = no, std::optional temperature = no, std::optional minLength = no, std::optional beamSearchDiversityRate = no, std::optional repetitionPenalty = no, std::optional presencePenalty = no, - std::optional frequencyPenalty = no, std::optional lengthPenalty = no, - std::optional earlyStopping = no, std::optional noRepeatNgramSize = no, - std::optional numReturnSequences = no, std::optional minP = no, - std::optional> beamWidthArray = no) + std::optional frequencyPenalty = no, std::optional promptIgnoreLength = no, + std::optional lengthPenalty = no, std::optional earlyStopping = no, + std::optional noRepeatNgramSize = no, std::optional numReturnSequences = no, + std::optional minP = no, std::optional> beamWidthArray = no) { - // 19 parameters for SamplingConfig, from `beamWidth` to `beamWidthArray` + // 20 parameters for SamplingConfig, from `beamWidth` to `beamWidthArray` try { auto sc = SamplingConfig(beamWidth, topK, topP, topPMin, topPResetIds, topPDecay, randomSeed, temperature, - minLength, beamSearchDiversityRate, repetitionPenalty, presencePenalty, frequencyPenalty, lengthPenalty, - earlyStopping, noRepeatNgramSize, numReturnSequences, minP, beamWidthArray); + minLength, beamSearchDiversityRate, repetitionPenalty, presencePenalty, frequencyPenalty, + promptIgnoreLength, lengthPenalty, earlyStopping, noRepeatNgramSize, numReturnSequences, minP, + beamWidthArray); // Come here if `sc` is valid if (!isTestValid) @@ -102,18 +103,20 @@ TEST(SamplingConfigTest, validInputs) test(true, 1, no, no, no, no, no, no, no, no, no, no, 1.f); // Frequency penalty test(true, 1, no, no, no, no, no, no, no, no, no, no, no, 1.f); + // Prompt ignore length + test(true, 1, no, no, no, no, no, no, no, no, no, no, no, no, 1); // Length penalty - test(true, 1, no, no, no, no, no, no, no, no, no, no, no, no, 1.f); - // Early stopping test(true, 1, no, no, no, no, no, no, no, no, no, no, no, no, no, 1.f); + // Early stopping + test(true, 1, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 1.f); // No repeat ngram size - test(true, 1, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 2); + test(true, 1, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 2); // NumReturnSequences - test(true, 4, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 2); + test(true, 4, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 2); // MinP - test(true, 1, no, 0.9, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 0.5f); + test(true, 1, no, 0.9, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 0.5f); // BeamWidthArray - test(true, 5, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, + test(true, 5, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, std::vector{2, 3, 4, 5}); } @@ -156,32 +159,35 @@ TEST(SamplingConfigTest, invalidInputs) // Skip presence penalty, frequency penalty, no test - // Neg length penalty + // Neg prompt ignore length test(false, 1, no, no, no, no, no, no, no, no, no, no, no, no, -1); - // Neg early stopping + // Neg length penalty test(false, 1, no, no, no, no, no, no, no, no, no, no, no, no, no, -1); - // Neg no repeat ngram size + // Neg early stopping test(false, 1, no, no, no, no, no, no, no, no, no, no, no, no, no, no, -1); + // Neg no repeat ngram size + test(false, 1, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, -1); + // Neg or zero numReturnSequences - test(false, 1, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 0); + test(false, 1, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 0); // numReturnSequences > beamWidth - test(false, 2, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 4); + test(false, 2, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 4); // Neg minP - test(false, 1, no, 0.9, no, no, no, no, no, no, no, no, no, no, no, no, no, no, -1.f); + test(false, 1, no, 0.9, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, -1.f); // Neg / Large minP - test(false, 1, no, 0.9, no, no, no, no, no, no, no, no, no, no, no, no, no, no, -1.f); - test(false, 1, no, 0.9, no, no, no, no, no, no, no, no, no, no, no, no, no, no, +2.f); + test(false, 1, no, 0.9, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, -1.f); + test(false, 1, no, 0.9, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, +2.f); // BeamWidthArray with neg / large beamWidth - test(false, 4, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, + test(false, 4, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, std::vector{2, 3, 4, -1}); - test(false, 4, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, + test(false, 4, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, std::vector{2, 3, 4, 65536}); } @@ -265,6 +271,12 @@ TEST(SamplingConfigTest, getterSetter) sc.setFrequencyPenalty(0.5f); EXPECT_EQ(sc.getFrequencyPenalty(), 0.5f); } + // Prompt ignore length + { + auto sc = SamplingConfig(); + sc.setPromptIgnoreLength(1); + EXPECT_EQ(sc.getPromptIgnoreLength(), 1); + } // Length penalty { auto sc = SamplingConfig(); diff --git a/cpp/tests/unit_tests/kernels/sampling/samplingPenaltyTest.cpp b/cpp/tests/unit_tests/kernels/sampling/samplingPenaltyTest.cpp index 1ae034827b0..8896dd005cf 100644 --- a/cpp/tests/unit_tests/kernels/sampling/samplingPenaltyTest.cpp +++ b/cpp/tests/unit_tests/kernels/sampling/samplingPenaltyTest.cpp @@ -161,7 +161,7 @@ class TemperaturePenaltyTest : public SamplingKernelTest mLogitsPtrs = BufferManager::pinned(ITensor::makeShape({mBatchSize}), ptrType); mPenaltyWorkspaceDevice = mBufferManager->gpu( - ITensor::makeShape({mMaxBatchSize, mMaxTokensPerStep, mVocabSizePadded}), nvinfer1::DataType::kINT32); + ITensor::makeShape({mMaxBatchSize, mMaxTokensPerStep, mVocabSize * 2}), nvinfer1::DataType::kINT32); mTokensPerStep = BufferManager::pinned(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32); @@ -259,8 +259,8 @@ class TemperaturePenaltyTest : public SamplingKernelTest InvokeBatchApplyPenaltyParams penaltyParams{reinterpret_cast(bufferCast(*mLogitsPtrs)), bufferCast(*mOutLogitsDevice), bufferCast(*mBiasDevice), bufferCast(*mPenaltyWorkspaceDevice), nullptr, bufferCast(*mTemperaturesDevice), nullptr, - nullptr, nullptr, mBatchSize, 1, 1, mVocabSize, mVocabSizePadded, nullptr, nullptr, nullptr, nullptr, - nullptr, nullptr, bufferCast(*mBatchSlots), mMaxTokensPerStep, + nullptr, nullptr, nullptr, mBatchSize, 1, 1, mVocabSize, mVocabSizePadded, nullptr, nullptr, nullptr, + nullptr, nullptr, nullptr, bufferCast(*mBatchSlots), mMaxTokensPerStep, bufferCast(*mTokensPerStep), nullptr, mStream->get()}; tk::invokeBatchApplyPenalty(penaltyParams); auto logitsOutHost = mBufferManager->copyFrom(*mOutLogitsDevice, MemoryType::kCPU); @@ -382,9 +382,11 @@ struct RepetitionPenaltyTestCase TensorPtr repetitionPenalties; TensorPtr presencePenalties; TensorPtr frequencyPenalties; + TensorPtr promptIgnoreLengths; int32_t repetitionPenaltiesSize; int32_t presencePenaltiesSize; int32_t frequencyPenaltiesSize; + int32_t promptIgnoreLengthsSize; int32_t maxTokensPerStep{1}; RepetitionPenaltyTestCase& setBatchSize(int32_t bs) @@ -423,6 +425,12 @@ struct RepetitionPenaltyTestCase return *this; } + RepetitionPenaltyTestCase& setPromptIgnoreLengths(TensorPtr pil) + { + promptIgnoreLengths = pil; + return *this; + } + RepetitionPenaltyTestCase& setRepetitionPenaltiesSize(int32_t rps) { repetitionPenaltiesSize = rps; @@ -441,6 +449,12 @@ struct RepetitionPenaltyTestCase return *this; } + RepetitionPenaltyTestCase& setPromptIgnoreLengthsSize(int32_t pils) + { + promptIgnoreLengthsSize = pils; + return *this; + } + RepetitionPenaltyTestCase& setMaxTokensPerStep(int32_t ts) { maxTokensPerStep = ts; @@ -451,11 +465,12 @@ struct RepetitionPenaltyTestCase { return tc::fmtstr( "RepetitionPenaltyTestCase[batch=%d, vocab=%d, maxInputLength=%d, " - "repetitionPenalties=%s, presencePenalties=%s, frequencyPenalties=%s]", + "repetitionPenalties=%s, presencePenalties=%s, frequencyPenalties=%s, promptIgnoreLengths=%s]", batchSize, vocabSize, maxInputLength, tc::arr2str(bufferCast(*repetitionPenalties), repetitionPenaltiesSize).c_str(), tc::arr2str(bufferCast(*presencePenalties), presencePenaltiesSize).c_str(), - tc::arr2str(bufferCast(*frequencyPenalties), frequencyPenaltiesSize).c_str()); + tc::arr2str(bufferCast(*frequencyPenalties), frequencyPenaltiesSize).c_str(), + tc::arr2str(bufferCast(*promptIgnoreLengths), promptIgnoreLengthsSize).c_str()); } }; @@ -499,6 +514,7 @@ class RepetitionPenaltyTest : public SamplingKernelTest TensorPtr mRepetitionPenaltiesDevice; TensorPtr mPresencePenaltiesDevice; TensorPtr mFrequencyPenaltiesDevice; + TensorPtr mPromptIgnoreLengthsDevice; TensorPtr mBatchSlots; void subsetup(RepetitionPenaltyTestCase param) @@ -525,7 +541,7 @@ class RepetitionPenaltyTest : public SamplingKernelTest mLogitsPtrs = BufferManager::pinned(ITensor::makeShape({mBatchSize}), ptrType); mPenaltyWorkspaceDevice = mBufferManager->gpu( - ITensor::makeShape({mBatchSize, mMaxTokensPerStep, mVocabSize}), nvinfer1::DataType::kINT32); + ITensor::makeShape({mBatchSize, mMaxTokensPerStep, mVocabSize * 2}), nvinfer1::DataType::kINT32); mTokensPerStep = BufferManager::pinned(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32); @@ -588,24 +604,29 @@ class RepetitionPenaltyTest : public SamplingKernelTest ASSERT_EQ(param.repetitionPenaltiesSize, mMaxBatchSize) << "Invalid test configuration."; ASSERT_EQ(param.presencePenaltiesSize, mMaxBatchSize) << "Invalid test configuration."; ASSERT_EQ(param.frequencyPenaltiesSize, mMaxBatchSize) << "Invalid test configuration."; + ASSERT_EQ(param.promptIgnoreLengthsSize, mMaxBatchSize) << "Invalid test configuration."; mRepetitionPenaltiesDevice = mBufferManager->gpu(ITensor::makeShape({param.repetitionPenaltiesSize}), nvinfer1::DataType::kFLOAT); mPresencePenaltiesDevice = mBufferManager->gpu(ITensor::makeShape({param.presencePenaltiesSize}), nvinfer1::DataType::kFLOAT); mFrequencyPenaltiesDevice = mBufferManager->gpu(ITensor::makeShape({param.frequencyPenaltiesSize}), nvinfer1::DataType::kFLOAT); + mPromptIgnoreLengthsDevice + = mBufferManager->gpu(ITensor::makeShape({param.promptIgnoreLengthsSize}), nvinfer1::DataType::kINT32); mBufferManager->copy(*param.repetitionPenalties, *mRepetitionPenaltiesDevice); mBufferManager->copy(*param.presencePenalties, *mPresencePenaltiesDevice); mBufferManager->copy(*param.frequencyPenalties, *mFrequencyPenaltiesDevice); + mBufferManager->copy(*param.promptIgnoreLengths, *mPromptIgnoreLengthsDevice); } void computeReference(T const* const inLogits, T* const outLogits, int32_t const* const outputIds, int32_t const* const sequenceLengths, float const* const repetitionPenalties, float const* const presencePenalties, float const* const frequencyPenalties, - int32_t const repetitionPenaltiesSize, int32_t const presencePenaltiesSize, - int32_t const frequencyPenaltiesSize) + int32_t const* const promptIgnoreLengths, int32_t const repetitionPenaltiesSize, + int32_t const presencePenaltiesSize, int32_t const frequencyPenaltiesSize, + int32_t const promptIgnoreLengthsSize) { - std::vector penalized(mVocabSize); + std::vector repetitionPenalized(mVocabSize), presencePenalized(mVocabSize); auto const batchSlotsPtr = bufferCast(*mBatchSlots); auto const tokensPerStepPtr = bufferCast(*mTokensPerStep); @@ -633,21 +654,47 @@ class RepetitionPenaltyTest : public SamplingKernelTest float presencePenalty = presencePenaltiesSize > 1 ? presencePenalties[batchSlot] : presencePenalties[0]; float frequencyPenalty = frequencyPenaltiesSize > 1 ? frequencyPenalties[batchSlot] : frequencyPenalties[0]; + int32_t promptIgnoreLength + = promptIgnoreLengthsSize > 1 ? promptIgnoreLengths[batchSlot] : promptIgnoreLengths[0]; - std::fill(penalized.begin(), penalized.end(), false); + std::fill(repetitionPenalized.begin(), repetitionPenalized.end(), false); + std::fill(presencePenalized.begin(), presencePenalized.end(), false); size_t offset = (bi * mMaxTokensPerStep + ti) * mVocabSizePadded; auto const step = sequenceLengths[batchSlot]; + + // clamping to the inputLength (set to same as sequenceLength) + promptIgnoreLength = std::min(promptIgnoreLength, step); + + std::vector numOccurences(mVocabSize, 0); for (int32_t t = 0; t < step; ++t) { auto tokenId = outputIds[batchSlot * mSequenceLength + t]; - if (!penalized[tokenId]) + + if (!repetitionPenalized[tokenId]) { auto logit = static_cast(outLogits[offset + tokenId]); - outLogits[offset + tokenId] = static_cast( - (logit < 0.0f ? logit * repetitionPenalty : logit / repetitionPenalty) - presencePenalty); - penalized[tokenId] = true; + outLogits[offset + tokenId] + = static_cast((logit < 0.0f ? logit * repetitionPenalty : logit / repetitionPenalty)); + repetitionPenalized[tokenId] = true; + } + + if (!(t < promptIgnoreLength)) + { + presencePenalized[tokenId] = true; + numOccurences[tokenId] += 1; + } + } + + for (int32_t vi = 0; vi < mVocabSize; ++vi) + { + if (presencePenalized[vi]) + { + outLogits[offset + vi] -= presencePenalty; + } + if (numOccurences[vi] > 0) + { + outLogits[offset + vi] -= numOccurences[vi] * frequencyPenalty; } - outLogits[offset + tokenId] -= frequencyPenalty; } } } @@ -661,7 +708,8 @@ class RepetitionPenaltyTest : public SamplingKernelTest InvokeBatchApplyPenaltyParams penaltyParams{reinterpret_cast(bufferCast(*mLogitsPtrs)), bufferCast(*mOutLogitsDevice), nullptr, bufferCast(*mPenaltyWorkspaceDevice), nullptr, nullptr, bufferCast(*mRepetitionPenaltiesDevice), bufferCast(*mPresencePenaltiesDevice), - bufferCast(*mFrequencyPenaltiesDevice), mBatchSize, 1, mSequenceLength, mVocabSize, mVocabSizePadded, + bufferCast(*mFrequencyPenaltiesDevice), bufferCast(*mPromptIgnoreLengthsDevice), mBatchSize, + 1, mSequenceLength, mVocabSize, mVocabSizePadded, reinterpret_cast(bufferCast(*mIdsPtrDevice)), nullptr, bufferCast(*mContextLengthDevice), bufferCast(*mSeqLengthDevice), nullptr, nullptr, bufferCast(*mBatchSlots), mMaxTokensPerStep, bufferCast(*mTokensPerStep), nullptr, @@ -673,8 +721,9 @@ class RepetitionPenaltyTest : public SamplingKernelTest computeReference(bufferCast(*mLogitsHost), bufferCast(*mLogitsRefHost), bufferCast(*mOutputIdsHost), bufferCast(*mSeqLengthHost), bufferCast(*param.repetitionPenalties), bufferCast(*param.presencePenalties), - bufferCast(*param.frequencyPenalties), param.repetitionPenaltiesSize, param.presencePenaltiesSize, - param.frequencyPenaltiesSize); + bufferCast(*param.frequencyPenalties), bufferCast(*param.promptIgnoreLengths), + param.repetitionPenaltiesSize, param.presencePenaltiesSize, param.frequencyPenaltiesSize, + param.promptIgnoreLengthsSize); mStream->synchronize(); @@ -696,11 +745,14 @@ TYPED_TEST(RepetitionPenaltyTest, BatchNoPenalty) = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); TensorPtr frequencyPenaltyHost = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + TensorPtr promptIgnoreLengthsHost + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32); for (int32_t i = 0; i < maxBatchSize; ++i) { bufferCast(*repetitionPenaltyHost)[i] = 1.0f; bufferCast(*presencePenaltyHost)[i] = 0.0f; bufferCast(*frequencyPenaltyHost)[i] = 0.0f; + bufferCast(*promptIgnoreLengthsHost)[i] = 0; } this->runTest(RepetitionPenaltyTestCase() .setBatchSize(batchSize) @@ -709,9 +761,11 @@ TYPED_TEST(RepetitionPenaltyTest, BatchNoPenalty) .setRepetitionPenalties(repetitionPenaltyHost) .setPresencePenalties(presencePenaltyHost) .setFrequencyPenalties(frequencyPenaltyHost) + .setPromptIgnoreLengths(promptIgnoreLengthsHost) .setRepetitionPenaltiesSize(maxBatchSize) .setPresencePenaltiesSize(maxBatchSize) - .setFrequencyPenaltiesSize(maxBatchSize)); + .setFrequencyPenaltiesSize(maxBatchSize) + .setPromptIgnoreLengthsSize(maxBatchSize)); } TYPED_TEST(RepetitionPenaltyTest, BatchRepetitionLessThanOne) @@ -724,11 +778,14 @@ TYPED_TEST(RepetitionPenaltyTest, BatchRepetitionLessThanOne) = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); TensorPtr frequencyPenaltyHost = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + TensorPtr promptIgnoreLengthsHost + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32); for (int32_t i = 0; i < maxBatchSize; ++i) { bufferCast(*repetitionPenaltyHost)[i] = 0.53f; bufferCast(*presencePenaltyHost)[i] = 0.0f; bufferCast(*frequencyPenaltyHost)[i] = 0.0f; + bufferCast(*promptIgnoreLengthsHost)[i] = 0; } this->runTest(RepetitionPenaltyTestCase() .setBatchSize(batchSize) @@ -737,9 +794,11 @@ TYPED_TEST(RepetitionPenaltyTest, BatchRepetitionLessThanOne) .setRepetitionPenalties(repetitionPenaltyHost) .setPresencePenalties(presencePenaltyHost) .setFrequencyPenalties(frequencyPenaltyHost) + .setPromptIgnoreLengths(promptIgnoreLengthsHost) .setRepetitionPenaltiesSize(maxBatchSize) .setPresencePenaltiesSize(maxBatchSize) - .setFrequencyPenaltiesSize(maxBatchSize)); + .setFrequencyPenaltiesSize(maxBatchSize) + .setPromptIgnoreLengthsSize(maxBatchSize)); } TYPED_TEST(RepetitionPenaltyTest, BatchRepetitionGreaterThaneOne) @@ -752,11 +811,14 @@ TYPED_TEST(RepetitionPenaltyTest, BatchRepetitionGreaterThaneOne) = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); TensorPtr frequencyPenaltyHost = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + TensorPtr promptIgnoreLengthsHost + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32); for (int32_t i = 0; i < maxBatchSize; ++i) { bufferCast(*repetitionPenaltyHost)[i] = 2.01f; bufferCast(*presencePenaltyHost)[i] = 0.0f; bufferCast(*frequencyPenaltyHost)[i] = 0.0f; + bufferCast(*promptIgnoreLengthsHost)[i] = 0; } this->runTest(RepetitionPenaltyTestCase() .setBatchSize(batchSize) @@ -765,9 +827,11 @@ TYPED_TEST(RepetitionPenaltyTest, BatchRepetitionGreaterThaneOne) .setRepetitionPenalties(repetitionPenaltyHost) .setPresencePenalties(presencePenaltyHost) .setFrequencyPenalties(frequencyPenaltyHost) + .setPromptIgnoreLengths(promptIgnoreLengthsHost) .setRepetitionPenaltiesSize(maxBatchSize) .setPresencePenaltiesSize(maxBatchSize) - .setFrequencyPenaltiesSize(maxBatchSize)); + .setFrequencyPenaltiesSize(maxBatchSize) + .setPromptIgnoreLengthsSize(maxBatchSize)); } TYPED_TEST(RepetitionPenaltyTest, BatchRepetitionMixed) @@ -780,11 +844,14 @@ TYPED_TEST(RepetitionPenaltyTest, BatchRepetitionMixed) = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); TensorPtr frequencyPenaltyHost = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + TensorPtr promptIgnoreLengthsHost + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32); for (int32_t i = 0; i < maxBatchSize; ++i) { bufferCast(*repetitionPenaltyHost)[i] = 0.53 + i * 0.2f; bufferCast(*presencePenaltyHost)[i] = 0.0f; bufferCast(*frequencyPenaltyHost)[i] = 0.0f; + bufferCast(*promptIgnoreLengthsHost)[i] = 0; } this->runTest(RepetitionPenaltyTestCase() .setBatchSize(batchSize) @@ -793,9 +860,11 @@ TYPED_TEST(RepetitionPenaltyTest, BatchRepetitionMixed) .setRepetitionPenalties(repetitionPenaltyHost) .setPresencePenalties(presencePenaltyHost) .setFrequencyPenalties(frequencyPenaltyHost) + .setPromptIgnoreLengths(promptIgnoreLengthsHost) .setRepetitionPenaltiesSize(maxBatchSize) .setPresencePenaltiesSize(maxBatchSize) - .setFrequencyPenaltiesSize(maxBatchSize)); + .setFrequencyPenaltiesSize(maxBatchSize) + .setPromptIgnoreLengthsSize(maxBatchSize)); } TYPED_TEST(RepetitionPenaltyTest, BatchPresenceMixed) @@ -808,11 +877,14 @@ TYPED_TEST(RepetitionPenaltyTest, BatchPresenceMixed) = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); TensorPtr frequencyPenaltyHost = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + TensorPtr promptIgnoreLengthsHost + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32); for (int32_t i = 0; i < maxBatchSize; ++i) { bufferCast(*repetitionPenaltyHost)[i] = 1.0f; bufferCast(*presencePenaltyHost)[i] = 0.53 + i * 0.2f; bufferCast(*frequencyPenaltyHost)[i] = 0.0f; + bufferCast(*promptIgnoreLengthsHost)[i] = 0; } this->runTest(RepetitionPenaltyTestCase() .setBatchSize(batchSize) @@ -821,9 +893,11 @@ TYPED_TEST(RepetitionPenaltyTest, BatchPresenceMixed) .setRepetitionPenalties(repetitionPenaltyHost) .setPresencePenalties(presencePenaltyHost) .setFrequencyPenalties(frequencyPenaltyHost) + .setPromptIgnoreLengths(promptIgnoreLengthsHost) .setRepetitionPenaltiesSize(maxBatchSize) .setPresencePenaltiesSize(maxBatchSize) - .setFrequencyPenaltiesSize(maxBatchSize)); + .setFrequencyPenaltiesSize(maxBatchSize) + .setPromptIgnoreLengthsSize(maxBatchSize)); } TYPED_TEST(RepetitionPenaltyTest, BatchPresenceHasDefaultValueZero2) @@ -836,11 +910,14 @@ TYPED_TEST(RepetitionPenaltyTest, BatchPresenceHasDefaultValueZero2) = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); TensorPtr frequencyPenaltyHost = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + TensorPtr promptIgnoreLengthsHost + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32); for (int32_t i = 0; i < maxBatchSize; ++i) { bufferCast(*repetitionPenaltyHost)[i] = 1.0f; bufferCast(*presencePenaltyHost)[i] = i % 2 == 0 ? 1.0f : 0.0f; bufferCast(*frequencyPenaltyHost)[i] = 0.0f; + bufferCast(*promptIgnoreLengthsHost)[i] = 0; } this->runTest(RepetitionPenaltyTestCase() .setBatchSize(batchSize) @@ -849,9 +926,11 @@ TYPED_TEST(RepetitionPenaltyTest, BatchPresenceHasDefaultValueZero2) .setRepetitionPenalties(repetitionPenaltyHost) .setPresencePenalties(presencePenaltyHost) .setFrequencyPenalties(frequencyPenaltyHost) + .setPromptIgnoreLengths(promptIgnoreLengthsHost) .setRepetitionPenaltiesSize(maxBatchSize) .setPresencePenaltiesSize(maxBatchSize) - .setFrequencyPenaltiesSize(maxBatchSize)); + .setFrequencyPenaltiesSize(maxBatchSize) + .setPromptIgnoreLengthsSize(maxBatchSize)); } TYPED_TEST(RepetitionPenaltyTest, BatchFrequencyMixed) @@ -864,11 +943,14 @@ TYPED_TEST(RepetitionPenaltyTest, BatchFrequencyMixed) = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); TensorPtr frequencyPenaltyHost = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + TensorPtr promptIgnoreLengthsHost + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32); for (int32_t i = 0; i < maxBatchSize; ++i) { bufferCast(*repetitionPenaltyHost)[i] = 1.0f; bufferCast(*presencePenaltyHost)[i] = 0.0f; bufferCast(*frequencyPenaltyHost)[i] = 0.53 + i * 0.2f; + bufferCast(*promptIgnoreLengthsHost)[i] = 0; } this->runTest(RepetitionPenaltyTestCase() .setBatchSize(batchSize) @@ -877,9 +959,11 @@ TYPED_TEST(RepetitionPenaltyTest, BatchFrequencyMixed) .setRepetitionPenalties(repetitionPenaltyHost) .setPresencePenalties(presencePenaltyHost) .setFrequencyPenalties(frequencyPenaltyHost) + .setPromptIgnoreLengths(promptIgnoreLengthsHost) .setRepetitionPenaltiesSize(maxBatchSize) .setPresencePenaltiesSize(maxBatchSize) - .setFrequencyPenaltiesSize(maxBatchSize)); + .setFrequencyPenaltiesSize(maxBatchSize) + .setPromptIgnoreLengthsSize(maxBatchSize)); } TYPED_TEST(RepetitionPenaltyTest, BatchFrequencyHasDefaultValueZero2) @@ -892,11 +976,14 @@ TYPED_TEST(RepetitionPenaltyTest, BatchFrequencyHasDefaultValueZero2) = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); TensorPtr frequencyPenaltyHost = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + TensorPtr promptIgnoreLengthsHost + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32); for (int32_t i = 0; i < maxBatchSize; ++i) { bufferCast(*repetitionPenaltyHost)[i] = 1.0f; bufferCast(*presencePenaltyHost)[i] = 0.0f; bufferCast(*frequencyPenaltyHost)[i] = i % 2 == 0 ? 1.0f : 0.0f; + bufferCast(*promptIgnoreLengthsHost)[i] = 0; } this->runTest(RepetitionPenaltyTestCase() .setBatchSize(batchSize) @@ -905,9 +992,11 @@ TYPED_TEST(RepetitionPenaltyTest, BatchFrequencyHasDefaultValueZero2) .setRepetitionPenalties(repetitionPenaltyHost) .setPresencePenalties(presencePenaltyHost) .setFrequencyPenalties(frequencyPenaltyHost) + .setPromptIgnoreLengths(promptIgnoreLengthsHost) .setRepetitionPenaltiesSize(maxBatchSize) .setPresencePenaltiesSize(maxBatchSize) - .setFrequencyPenaltiesSize(maxBatchSize)); + .setFrequencyPenaltiesSize(maxBatchSize) + .setPromptIgnoreLengthsSize(maxBatchSize)); } TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeRepetitionPresence) @@ -920,11 +1009,14 @@ TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeRepetitionPresence) = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); TensorPtr frequencyPenaltyHost = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + TensorPtr promptIgnoreLengthsHost + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32); for (int32_t i = 0; i < maxBatchSize; ++i) { bufferCast(*repetitionPenaltyHost)[i] = 0.53 + i * 0.2f; bufferCast(*presencePenaltyHost)[i] = 0.53 + i * 0.2f; bufferCast(*frequencyPenaltyHost)[i] = 0.0f; + bufferCast(*promptIgnoreLengthsHost)[i] = 0; } this->runTest(RepetitionPenaltyTestCase() .setBatchSize(batchSize) @@ -933,9 +1025,11 @@ TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeRepetitionPresence) .setRepetitionPenalties(repetitionPenaltyHost) .setPresencePenalties(presencePenaltyHost) .setFrequencyPenalties(frequencyPenaltyHost) + .setPromptIgnoreLengths(promptIgnoreLengthsHost) .setRepetitionPenaltiesSize(maxBatchSize) .setPresencePenaltiesSize(maxBatchSize) - .setFrequencyPenaltiesSize(maxBatchSize)); + .setFrequencyPenaltiesSize(maxBatchSize) + .setPromptIgnoreLengthsSize(maxBatchSize)); } TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeRepetitionFrequency) @@ -948,11 +1042,14 @@ TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeRepetitionFrequency) = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); TensorPtr frequencyPenaltyHost = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + TensorPtr promptIgnoreLengthsHost + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32); for (int32_t i = 0; i < maxBatchSize; ++i) { bufferCast(*repetitionPenaltyHost)[i] = 0.53 + i * 0.2f; bufferCast(*presencePenaltyHost)[i] = 0.0f; bufferCast(*frequencyPenaltyHost)[i] = 0.53 + i * 0.2f; + bufferCast(*promptIgnoreLengthsHost)[i] = 0; } this->runTest(RepetitionPenaltyTestCase() .setBatchSize(batchSize) @@ -961,9 +1058,11 @@ TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeRepetitionFrequency) .setRepetitionPenalties(repetitionPenaltyHost) .setPresencePenalties(presencePenaltyHost) .setFrequencyPenalties(frequencyPenaltyHost) + .setPromptIgnoreLengths(promptIgnoreLengthsHost) .setRepetitionPenaltiesSize(maxBatchSize) .setPresencePenaltiesSize(maxBatchSize) - .setFrequencyPenaltiesSize(maxBatchSize)); + .setFrequencyPenaltiesSize(maxBatchSize) + .setPromptIgnoreLengthsSize(maxBatchSize)); } TYPED_TEST(RepetitionPenaltyTest, PenaltyTypePresenceFrequency) @@ -976,11 +1075,14 @@ TYPED_TEST(RepetitionPenaltyTest, PenaltyTypePresenceFrequency) = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); TensorPtr frequencyPenaltyHost = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + TensorPtr promptIgnoreLengthsHost + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32); for (int32_t i = 0; i < maxBatchSize; ++i) { bufferCast(*repetitionPenaltyHost)[i] = 1.0f; bufferCast(*presencePenaltyHost)[i] = 0.53 + i * 0.2f; bufferCast(*frequencyPenaltyHost)[i] = 0.53 + i * 0.2f; + bufferCast(*promptIgnoreLengthsHost)[i] = 0; } this->runTest(RepetitionPenaltyTestCase() .setBatchSize(batchSize) @@ -989,9 +1091,11 @@ TYPED_TEST(RepetitionPenaltyTest, PenaltyTypePresenceFrequency) .setRepetitionPenalties(repetitionPenaltyHost) .setPresencePenalties(presencePenaltyHost) .setFrequencyPenalties(frequencyPenaltyHost) + .setPromptIgnoreLengths(promptIgnoreLengthsHost) .setRepetitionPenaltiesSize(maxBatchSize) .setPresencePenaltiesSize(maxBatchSize) - .setFrequencyPenaltiesSize(maxBatchSize)); + .setFrequencyPenaltiesSize(maxBatchSize) + .setPromptIgnoreLengthsSize(maxBatchSize)); } TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeFull) @@ -1004,11 +1108,14 @@ TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeFull) = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); TensorPtr frequencyPenaltyHost = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + TensorPtr promptIgnoreLengthsHost + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32); for (int32_t i = 0; i < maxBatchSize; ++i) { bufferCast(*repetitionPenaltyHost)[i] = 0.53 + i * 0.2f; bufferCast(*presencePenaltyHost)[i] = 0.53 + i * 0.2f; bufferCast(*frequencyPenaltyHost)[i] = 0.53 + i * 0.2f; + bufferCast(*promptIgnoreLengthsHost)[i] = 0; } this->runTest(RepetitionPenaltyTestCase() .setBatchSize(batchSize) @@ -1017,9 +1124,11 @@ TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeFull) .setRepetitionPenalties(repetitionPenaltyHost) .setPresencePenalties(presencePenaltyHost) .setFrequencyPenalties(frequencyPenaltyHost) + .setPromptIgnoreLengths(promptIgnoreLengthsHost) .setRepetitionPenaltiesSize(maxBatchSize) .setPresencePenaltiesSize(maxBatchSize) - .setFrequencyPenaltiesSize(maxBatchSize)); + .setFrequencyPenaltiesSize(maxBatchSize) + .setPromptIgnoreLengthsSize(maxBatchSize)); } TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeFullTokensPerStep) @@ -1032,11 +1141,81 @@ TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeFullTokensPerStep) = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); TensorPtr frequencyPenaltyHost = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + TensorPtr promptIgnoreLengthsHost + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32); + for (int32_t i = 0; i < maxBatchSize; ++i) + { + bufferCast(*repetitionPenaltyHost)[i] = 0.53 + i * 0.2f; + bufferCast(*presencePenaltyHost)[i] = 0.53 + i * 0.2f; + bufferCast(*frequencyPenaltyHost)[i] = 0.53 + i * 0.2f; + bufferCast(*promptIgnoreLengthsHost)[i] = 0; + } + this->runTest(RepetitionPenaltyTestCase() + .setBatchSize(batchSize) + .setVocabSize(4) + .setMaxInputLength(5) + .setRepetitionPenalties(repetitionPenaltyHost) + .setPresencePenalties(presencePenaltyHost) + .setFrequencyPenalties(frequencyPenaltyHost) + .setPromptIgnoreLengths(promptIgnoreLengthsHost) + .setRepetitionPenaltiesSize(maxBatchSize) + .setPresencePenaltiesSize(maxBatchSize) + .setFrequencyPenaltiesSize(maxBatchSize) + .setPromptIgnoreLengthsSize(maxBatchSize) + .setMaxTokensPerStep(4)); +} + +TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeFullWithPartialPromptIgnore) +{ + int32_t batchSize = 6; + int32_t maxBatchSize = 2 * batchSize; + TensorPtr repetitionPenaltyHost + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + TensorPtr presencePenaltyHost + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + TensorPtr frequencyPenaltyHost + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + TensorPtr promptIgnoreLengthsHost + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32); + for (int32_t i = 0; i < maxBatchSize; ++i) + { + bufferCast(*repetitionPenaltyHost)[i] = 0.53 + i * 0.2f; + bufferCast(*presencePenaltyHost)[i] = 0.53 + i * 0.2f; + bufferCast(*frequencyPenaltyHost)[i] = 0.53 + i * 0.2f; + bufferCast(*promptIgnoreLengthsHost)[i] = 1; // set to 1 to ignore first prompt token + } + this->runTest(RepetitionPenaltyTestCase() + .setBatchSize(batchSize) + .setVocabSize(4) + .setMaxInputLength(5) + .setRepetitionPenalties(repetitionPenaltyHost) + .setPresencePenalties(presencePenaltyHost) + .setFrequencyPenalties(frequencyPenaltyHost) + .setPromptIgnoreLengths(promptIgnoreLengthsHost) + .setRepetitionPenaltiesSize(maxBatchSize) + .setPresencePenaltiesSize(maxBatchSize) + .setFrequencyPenaltiesSize(maxBatchSize) + .setPromptIgnoreLengthsSize(maxBatchSize)); +} + +TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeFullTokensPerStepWithFullPromptIgnore) +{ + int32_t batchSize = 6; + int32_t maxBatchSize = 2 * batchSize; + TensorPtr repetitionPenaltyHost + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + TensorPtr presencePenaltyHost + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + TensorPtr frequencyPenaltyHost + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + TensorPtr promptIgnoreLengthsHost + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32); for (int32_t i = 0; i < maxBatchSize; ++i) { bufferCast(*repetitionPenaltyHost)[i] = 0.53 + i * 0.2f; bufferCast(*presencePenaltyHost)[i] = 0.53 + i * 0.2f; bufferCast(*frequencyPenaltyHost)[i] = 0.53 + i * 0.2f; + bufferCast(*promptIgnoreLengthsHost)[i] = 5; // set to max input length to ignore full prompt } this->runTest(RepetitionPenaltyTestCase() .setBatchSize(batchSize) @@ -1045,9 +1224,11 @@ TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeFullTokensPerStep) .setRepetitionPenalties(repetitionPenaltyHost) .setPresencePenalties(presencePenaltyHost) .setFrequencyPenalties(frequencyPenaltyHost) + .setPromptIgnoreLengths(promptIgnoreLengthsHost) .setRepetitionPenaltiesSize(maxBatchSize) .setPresencePenaltiesSize(maxBatchSize) .setFrequencyPenaltiesSize(maxBatchSize) + .setPromptIgnoreLengthsSize(maxBatchSize) .setMaxTokensPerStep(4)); } @@ -1257,8 +1438,8 @@ class MinLengthPenaltyTest : public SamplingKernelTest InvokeBatchApplyPenaltyParams penaltyParams{reinterpret_cast(bufferCast(*mLogitsPtrs)), bufferCast(*mOutLogitsDevice), nullptr, bufferCast(*mPenaltyWorkspaceDevice), nullptr, nullptr, - nullptr, nullptr, nullptr, mBatchSize, 1, mSequenceLength, mVocabSize, mVocabSizePadded, nullptr, nullptr, - bufferCast(*mContextLengthDevice), bufferCast(*mSeqLengthDevice), + nullptr, nullptr, nullptr, nullptr, mBatchSize, 1, mSequenceLength, mVocabSize, mVocabSizePadded, nullptr, + nullptr, bufferCast(*mContextLengthDevice), bufferCast(*mSeqLengthDevice), bufferCast(*mMinLengthDevice), bufferCast(*mEndIdsDevice), bufferCast(*mBatchSlots), mMaxTokensPerStep, bufferCast(*mTokensPerStep), nullptr, mStream->get()}; @@ -1415,6 +1596,7 @@ class MinLengthPenaltyOOBSafetyTest : public SamplingKernelTest /*repetitionPenalties=*/nullptr, /*presencePenalties=*/nullptr, /*frequencyPenalties=*/nullptr, + /*promptIgnoreLengths=*/nullptr, /*batchSize=*/mBatchSize, /*beamWidth=*/1, /*maxSeqLen=*/mSequenceLength, diff --git a/cpp/tests/unit_tests/layers/dynamicDecodeLayerTest.h b/cpp/tests/unit_tests/layers/dynamicDecodeLayerTest.h index 24ae9463ce3..ca059c2266d 100644 --- a/cpp/tests/unit_tests/layers/dynamicDecodeLayerTest.h +++ b/cpp/tests/unit_tests/layers/dynamicDecodeLayerTest.h @@ -40,6 +40,7 @@ struct TestSamplingParams std::vector repetitionPenalties; std::vector presencePenalties; std::vector frequencyPenalties; + std::vector promptIgnoreLengths; std::vector minLengths; std::vector decay; std::vector minTopP; diff --git a/cpp/tests/unit_tests/runtime/samplingConfigTest.cpp b/cpp/tests/unit_tests/runtime/samplingConfigTest.cpp index 34dab538dec..c22346e56b8 100644 --- a/cpp/tests/unit_tests/runtime/samplingConfigTest.cpp +++ b/cpp/tests/unit_tests/runtime/samplingConfigTest.cpp @@ -37,17 +37,18 @@ void test(bool const useExternalDraftTokensConfig, SizeType32 beamWidth = 1, std std::optional randomSeed = no, std::optional temperature = no, std::optional minLength = no, std::optional beamSearchDiversityRate = no, std::optional repetitionPenalty = no, std::optional presencePenalty = no, - std::optional frequencyPenalty = no, std::optional lengthPenalty = no, - std::optional earlyStopping = no, std::optional noRepeatNgramSize = no, - std::optional numReturnSequences = no, std::optional minP = no, - std::optional> beamWidthArray = no) + std::optional frequencyPenalty = no, std::optional promptIgnoreLength = no, + std::optional lengthPenalty = no, std::optional earlyStopping = no, + std::optional noRepeatNgramSize = no, std::optional numReturnSequences = no, + std::optional minP = no, std::optional> beamWidthArray = no) { - // 19 parameters for SamplingConfig, from `beamWidth` to `beamWidthArray` + // 20 parameters for SamplingConfig, from `beamWidth` to `beamWidthArray` try { te::SamplingConfig execSamplingCfg(beamWidth, topK, topP, topPMin, topPResetIds, topPDecay, randomSeed, temperature, minLength, beamSearchDiversityRate, repetitionPenalty, presencePenalty, frequencyPenalty, - lengthPenalty, earlyStopping, noRepeatNgramSize, numReturnSequences, minP, beamWidthArray); + promptIgnoreLength, lengthPenalty, earlyStopping, noRepeatNgramSize, numReturnSequences, minP, + beamWidthArray); std::optional specCfg = std::nullopt; if (useExternalDraftTokensConfig) { @@ -110,18 +111,20 @@ TEST(samplingConfigTest, validInputs) test(false, 1, no, no, no, no, no, no, no, no, no, no, 1.f); // Frequency penalty test(false, 1, no, no, no, no, no, no, no, no, no, no, no, 1.f); + // Prompt ignore length + test(false, 1, no, no, no, no, no, no, no, no, no, no, no, no, 1); // Length penalty - test(false, 1, no, no, no, no, no, no, no, no, no, no, no, no, 1.f); - // Early stopping test(false, 1, no, no, no, no, no, no, no, no, no, no, no, no, no, 1.f); + // Early stopping + test(false, 1, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 1.f); // No repeat ngram size - test(false, 1, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 2); + test(false, 1, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 2); // NumReturnSequences - test(false, 4, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 2); - // MinP, 18 arguments - test(false, 1, no, 0.9, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 0.5f); + test(false, 4, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 2); + // MinP, 19 arguments + test(false, 1, no, 0.9, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, 0.5f); // BeamWidthArray - test(false, 5, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, + test(false, 5, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, std::vector{2, 3, 4, 5}); // All parameters @@ -139,6 +142,7 @@ TEST(samplingConfigTest, validInputs) te::FloatType repetitionPenalty{0.5f}; te::FloatType presencePenalty{0.5f}; te::FloatType frequencyPenalty{0.5f}; + te::SizeType32 promptIgnoreLength{1}; te::FloatType lengthPenalty{0.5f}; te::SizeType32 earlyStopping{1}; te::SizeType32 noRepeatNgramSize{5}; @@ -148,7 +152,8 @@ TEST(samplingConfigTest, validInputs) te::SamplingConfig execSamplingCfg(beamWidth, topK, topP, topPMin, topPResetIds, topPDecay, randomSeed, temperature, minLength, beamSearchDiversityRate, repetitionPenalty, presencePenalty, frequencyPenalty, - lengthPenalty, earlyStopping, noRepeatNgramSize, numReturnSequences, minP, beamWidthArray); + promptIgnoreLength, lengthPenalty, earlyStopping, noRepeatNgramSize, numReturnSequences, minP, + beamWidthArray); te::ExternalDraftTokensConfig specCfg({1}, no, 0.5f); tr::SamplingConfig samplingCfg(execSamplingCfg, specCfg); EXPECT_EQ(samplingCfg.beamWidth, execSamplingCfg.getBeamWidth()); @@ -166,6 +171,7 @@ TEST(samplingConfigTest, validInputs) EXPECT_THAT(samplingCfg.repetitionPenalty.value(), testing::ElementsAre(repetitionPenalty)); EXPECT_THAT(samplingCfg.presencePenalty.value(), testing::ElementsAre(presencePenalty)); EXPECT_THAT(samplingCfg.frequencyPenalty.value(), testing::ElementsAre(frequencyPenalty)); + EXPECT_THAT(samplingCfg.promptIgnoreLength.value(), testing::ElementsAre(promptIgnoreLength)); EXPECT_THAT(samplingCfg.lengthPenalty.value(), testing::ElementsAre(lengthPenalty)); EXPECT_THAT(samplingCfg.earlyStopping.value(), testing::ElementsAre(earlyStopping)); EXPECT_THAT(samplingCfg.noRepeatNgramSize.value(), testing::ElementsAre(noRepeatNgramSize)); diff --git a/examples/eval_long_context.py b/examples/eval_long_context.py index 337b4bc3494..90b7ef2dd27 100644 --- a/examples/eval_long_context.py +++ b/examples/eval_long_context.py @@ -281,6 +281,7 @@ def main(args): repetition_penalty=args.repetition_penalty, presence_penalty=args.presence_penalty, frequency_penalty=args.frequency_penalty, + prompt_ignore_length=args.prompt_ignore_length, # stop_words_list=stop_words_list, # bad_words_list=bad_words_list, output_cum_log_probs=(args.output_cum_log_probs_npy != None), diff --git a/examples/run.py b/examples/run.py index 0f19b56d768..7ce36bbe984 100755 --- a/examples/run.py +++ b/examples/run.py @@ -540,6 +540,7 @@ def main(args): repetition_penalty=args.repetition_penalty, presence_penalty=args.presence_penalty, frequency_penalty=args.frequency_penalty, + prompt_ignore_length=args.prompt_ignore_length, min_p=args.min_p, stop_words_list=stop_words_list, bad_words_list=bad_words_list, @@ -639,6 +640,7 @@ def main(args): repetition_penalty=args.repetition_penalty, presence_penalty=args.presence_penalty, frequency_penalty=args.frequency_penalty, + prompt_ignore_length=args.prompt_ignore_length, min_p=args.min_p, stop_words_list=stop_words_list, bad_words_list=bad_words_list, @@ -677,6 +679,7 @@ def main(args): repetition_penalty=args.repetition_penalty, presence_penalty=args.presence_penalty, frequency_penalty=args.frequency_penalty, + prompt_ignore_length=args.prompt_ignore_length, stop_words_list=stop_words_list, bad_words_list=bad_words_list, output_cum_log_probs=(args.output_cum_log_probs_npy diff --git a/examples/summarize.py b/examples/summarize.py index 77406a45c46..1f6f8979bb7 100644 --- a/examples/summarize.py +++ b/examples/summarize.py @@ -211,6 +211,7 @@ def main(args): repetition_penalty = args.repetition_penalty presence_penalty = args.presence_penalty frequency_penalty = args.frequency_penalty + prompt_ignore_length = args.prompt_ignore_length random_seed = args.random_seed torch.manual_seed(random_seed) @@ -353,6 +354,7 @@ def eval_trt_llm(datapoint, repetition_penalty=repetition_penalty, presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, + prompt_ignore_length=prompt_ignore_length, lora_uids=args.lora_task_uids, lookahead_config=args.lookahead_config, output_sequence_lengths=True, diff --git a/examples/utils.py b/examples/utils.py index 8956e4979e0..9b0aaf735d4 100644 --- a/examples/utils.py +++ b/examples/utils.py @@ -304,6 +304,7 @@ def add_common_args(parser): parser.add_argument('--repetition_penalty', type=float, default=1.0) parser.add_argument('--presence_penalty', type=float, default=0.0) parser.add_argument('--frequency_penalty', type=float, default=0.0) + parser.add_argument('--prompt_ignore_length', type=int, default=0) parser.add_argument('--min_p', type=float, default=0.0) parser.add_argument('--beam_search_diversity_rate', type=float, default=0.0) parser.add_argument('--random_seed', type=int, default=0) diff --git a/tensorrt_llm/runtime/generation.py b/tensorrt_llm/runtime/generation.py index 36cdbf0aca5..cb633656777 100755 --- a/tensorrt_llm/runtime/generation.py +++ b/tensorrt_llm/runtime/generation.py @@ -721,6 +721,7 @@ class SamplingConfig: min_length: Union[int, torch.Tensor] = field(default=1) presence_penalty: Union[float, torch.Tensor] = field(default=0.0) frequency_penalty: Union[float, torch.Tensor] = field(default=0.0) + prompt_ignore_length: Union[int, torch.Tensor] = field(default=0) use_beam_hyps: bool = field(default=True) # None here means user didn't set it, and dynamicDecodeOp.cpp take optional value @@ -1474,6 +1475,16 @@ def __setup_decoder(self, input_ids: torch.Tensor, scfg.frequency_penalty, dtype=torch.float32) + if isinstance(scfg.prompt_ignore_length, torch.Tensor): + assert scfg.prompt_ignore_length.dtype == torch.int32, f"scfg.prompt_ignore_length.dtype ({scfg.prompt_ignore_length.dtype}) must be torch.int32" + assert scfg.prompt_ignore_length.shape[ + 0] == batch_size, f"scfg.prompt_ignore_length.shape[0] ({scfg.prompt_ignore_length.shape[0]}) must equal to batch_size ({batch_size})" + self.prompt_ignore_length = scfg.prompt_ignore_length + else: + self.prompt_ignore_length = torch.full([batch_size], + scfg.prompt_ignore_length, + dtype=torch.int32) + if isinstance(scfg.min_length, torch.Tensor): assert scfg.min_length.dtype == torch.int32, f"scfg.min_length.dtype ({scfg.min_length.dtype}) must be torch.int32" assert scfg.min_length.shape[ @@ -1543,6 +1554,7 @@ def __setup_decoder(self, input_ids: torch.Tensor, self.repetition_penalty, self.presence_penalty, self.frequency_penalty, + self.prompt_ignore_length, self.min_length, self.host_length_penalty, self.host_early_stopping, diff --git a/tensorrt_llm/runtime/model_runner_cpp.py b/tensorrt_llm/runtime/model_runner_cpp.py index 96895268074..e6a5d52a82d 100644 --- a/tensorrt_llm/runtime/model_runner_cpp.py +++ b/tensorrt_llm/runtime/model_runner_cpp.py @@ -648,6 +648,7 @@ def generate( "repetition_penalty", "presence_penalty", "frequency_penalty", + "prompt_ignore_length", "length_penalty", "early_stopping", "no_repeat_ngram_size", diff --git a/tensorrt_llm/sampling_params.py b/tensorrt_llm/sampling_params.py index 686eab1bbf7..b7ad63821ad 100644 --- a/tensorrt_llm/sampling_params.py +++ b/tensorrt_llm/sampling_params.py @@ -165,6 +165,7 @@ class SamplingParams: repetition_penalty (float, optional): Used to penalize tokens based on how often they appear in the sequence. It can have any value > 0.f. Values < 1.f encourages repetition, values > 1.f discourages it. None means using C++ runtime default 1.f. Defaults to None. presence_penalty (float, optional): Used to penalize tokens already present in the sequence (irrespective of the number of appearances). It can have any values. Values < 0.f encourage repetition, values > 0.f discourage it. None means using C++ runtime default 0.f. Defaults to None. frequency_penalty (float, optional): Used to penalize tokens already present in the sequence (dependent on the number of appearances). It can have any values. Values < 0.f encourage repetition, values > 0.f discourage it. None means using C++ runtime default 0.f. Defaults to None. + prompt_ignore_length (int, optional): Controls how many tokens to ignore from the prompt for presence and frequency penalties. Values <= 0 have no effect. Values > input (prompt) length will be clamped. None means using C++ runtime default 0. Defaults to None. length_penalty (float, optional): Controls how to penalize longer sequences in beam search. None means using C++ runtime default 0.f. Defaults to None. early_stopping (int, optional): Controls whether the generation process finishes once beamWidth sentences are generated (ends with end_token). None means using C++ runtime default 1. Defaults to None. no_repeat_ngram_size (int, optional): Controls how many repeat ngram size are acceptable. None means using C++ runtime default 1 << 30. Defaults to None. @@ -232,6 +233,7 @@ class SamplingParams: repetition_penalty: Optional[float] = None presence_penalty: Optional[float] = None frequency_penalty: Optional[float] = None + prompt_ignore_length: Optional[int] = None length_penalty: Optional[float] = None early_stopping: Optional[int] = None no_repeat_ngram_size: Optional[int] = None diff --git a/tensorrt_llm/serve/openai_protocol.py b/tensorrt_llm/serve/openai_protocol.py index 2303b89089f..af8111d1f07 100644 --- a/tensorrt_llm/serve/openai_protocol.py +++ b/tensorrt_llm/serve/openai_protocol.py @@ -229,6 +229,7 @@ class CompletionRequest(OpenAIBaseModel): top_p: Optional[float] = None user: Optional[str] = None lora_request: Optional[LoRARequest] = None + prompt_ignore_length: Optional[int] = 0 # doc: begin-completion-sampling-params use_beam_search: bool = False @@ -283,6 +284,7 @@ def to_sampling_params(self, vocab_size: int = 32000) -> SamplingParams: temperature=(self.temperature if self.temperature is not None else 1.0), top_p=(self.top_p if self.top_p is not None else 1.0), + prompt_ignore_length=self.prompt_ignore_length, # completion-sampling-params use_beam_search=self.use_beam_search, @@ -530,6 +532,7 @@ class ChatCompletionRequest(OpenAIBaseModel): "reasoning is shown in the model's response. Options: " "'low', 'medium', 'high'."), ) + prompt_ignore_length: Optional[int] = 0 # doc: begin-chat-completion-sampling-params best_of: Optional[int] = None @@ -622,6 +625,7 @@ def to_sampling_params(self, stop=self.stop, temperature=(self.temperature if self.temperature is not None else 1.0), + prompt_ignore_length=self.prompt_ignore_length, # chat-completion-sampling-params best_of=self.best_of, diff --git a/tests/unittest/api_stability/references/sampling_params.yaml b/tests/unittest/api_stability/references/sampling_params.yaml index e48f9fae493..d6b3e6156e3 100644 --- a/tests/unittest/api_stability/references/sampling_params.yaml +++ b/tests/unittest/api_stability/references/sampling_params.yaml @@ -12,5 +12,8 @@ methods: beam_width_array: annotation: Optional[List[int]] default: null + prompt_ignore_length: + annotation: Optional[int] + default: null return_annotation: None properties: {} diff --git a/tests/unittest/api_stability/test_llm_api.py b/tests/unittest/api_stability/test_llm_api.py index 6960f993286..3edd14ecfab 100644 --- a/tests/unittest/api_stability/test_llm_api.py +++ b/tests/unittest/api_stability/test_llm_api.py @@ -35,6 +35,7 @@ def test_get_sampling_config(self): "repetition_penalty", "presence_penalty", "frequency_penalty", + "prompt_ignore_length", "length_penalty", "early_stopping", "no_repeat_ngram_size",