Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions cpp/include/tensorrt_llm/executor/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class SamplingConfig
std::optional<FloatType> const& repetitionPenalty = std::nullopt,
std::optional<FloatType> const& presencePenalty = std::nullopt,
std::optional<FloatType> const& frequencyPenalty = std::nullopt,
std::optional<SizeType32> const& promptIgnoreLength = std::nullopt,
std::optional<FloatType> const& lengthPenalty = std::nullopt,
std::optional<SizeType32> const& earlyStopping = std::nullopt,
std::optional<SizeType32> const& noRepeatNgramSize = std::nullopt,
Expand All @@ -94,6 +95,7 @@ class SamplingConfig
[[nodiscard]] std::optional<FloatType> getRepetitionPenalty() const;
[[nodiscard]] std::optional<FloatType> getPresencePenalty() const;
[[nodiscard]] std::optional<FloatType> getFrequencyPenalty() const;
[[nodiscard]] std::optional<SizeType32> getPromptIgnoreLength() const;
[[nodiscard]] std::optional<FloatType> getLengthPenalty() const;
[[nodiscard]] std::optional<SizeType32> getEarlyStopping() const;
[[nodiscard]] std::optional<SizeType32> getNoRepeatNgramSize() const;
Expand All @@ -114,6 +116,7 @@ class SamplingConfig
void setRepetitionPenalty(std::optional<FloatType> const& repetitionPenalty);
void setPresencePenalty(std::optional<FloatType> const& presencePenalty);
void setFrequencyPenalty(std::optional<FloatType> const& frequencyPenalty);
void setPromptIgnoreLength(std::optional<SizeType32> const& promptIgnoreLength);
void setLengthPenalty(std::optional<FloatType> const& lengthPenalty);
void setEarlyStopping(std::optional<SizeType32> const& earlyStopping);
void setNoRepeatNgramSize(std::optional<SizeType32> const& noRepeatNgramSize);
Expand All @@ -133,6 +136,8 @@ class SamplingConfig
static std::optional<FloatType> const& checkBeamSearchDiversityRate(
std::optional<FloatType> const& beamSearchDiversityRate);
static std::optional<FloatType> const& checkRepetitionPenalty(std::optional<FloatType> const& repetitionpenalty);
static std::optional<SizeType32> const& checkPromptIgnoreLength(
std::optional<SizeType32> const& promptIgnoreLength);
static std::optional<FloatType> const& checkLengthPenalty(std::optional<FloatType> const& lengthPenalty);
static std::optional<SizeType32> const& checkEarlyStopping(std::optional<SizeType32> const& earlyStopping);
static std::optional<SizeType32> const& checkNoRepeatNgramSize(std::optional<SizeType32> const& noRepeatNgramSize);
Expand Down Expand Up @@ -174,6 +179,9 @@ class SamplingConfig
/// @brief Used to penalize tokens already present in the sequence (dependent on the number of appearances). It can
/// have any values. Values < 0.f encourage repetition, values > 0.f discourage it. Default is 0.f
std::optional<FloatType> mFrequencyPenalty;
/// @brief Controls how many tokens to ignore from the prompt for presence and frequency penalties. Values <= 0 have
/// no effect. Values > input (prompt) length will be clamped. Default is 0.
std::optional<SizeType32> mPromptIgnoreLength;
/// @brief Controls how to penalize longer sequences in beam search. Default is 0.f
std::optional<FloatType> mLengthPenalty;
/// @brief Controls whether the generation process finishes once beamWidth sentences are generated (ends with
Expand Down
5 changes: 5 additions & 0 deletions cpp/include/tensorrt_llm/layers/defaultDecodingParams.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ class DefaultDecodingParams
return 1;
}

[[nodiscard]] __host__ __device__ static constexpr runtime::SizeType32 getPromptIgnoreLength()
{
return 0;
}

[[nodiscard]] __host__ __device__ static constexpr uint64_t getSeed()
{
return 0;
Expand Down
20 changes: 13 additions & 7 deletions cpp/include/tensorrt_llm/runtime/samplingConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ class SamplingConfig
frequencyPenalty = fuseValues<FloatType>(
configs, [&configs](size_t ci) { return configs[ci].frequencyPenalty; },
layers::DefaultDecodingParams::getFrequencyPenalty());
promptIgnoreLength = fuseValues<SizeType32>(
configs, [&configs](size_t ci) { return configs[ci].promptIgnoreLength; },
layers::DefaultDecodingParams::getPromptIgnoreLength());
noRepeatNgramSize = fuseValues<SizeType32>(
configs, [&configs](size_t ci) { return configs[ci].noRepeatNgramSize; },
layers::DefaultDecodingParams::getNoRepeatNgramSize());
Expand Down Expand Up @@ -224,6 +227,7 @@ class SamplingConfig
SET_FROM_OPTIONAL(repetitionPenalty, RepetitionPenalty, FloatType)
SET_FROM_OPTIONAL(presencePenalty, PresencePenalty, FloatType)
SET_FROM_OPTIONAL(frequencyPenalty, FrequencyPenalty, FloatType)
SET_FROM_OPTIONAL(promptIgnoreLength, PromptIgnoreLength, SizeType32)
SET_FROM_OPTIONAL(lengthPenalty, LengthPenalty, FloatType)
SET_FROM_OPTIONAL(earlyStopping, EarlyStopping, SizeType32)
SET_FROM_OPTIONAL(noRepeatNgramSize, NoRepeatNgramSize, SizeType32)
Expand Down Expand Up @@ -342,6 +346,7 @@ class SamplingConfig
OptVec<FloatType> repetitionPenalty; // [1] or [batchSize]
OptVec<FloatType> presencePenalty; // [1] or [batchSize]
OptVec<FloatType> frequencyPenalty; // [1] or [batchSize]
OptVec<SizeType32> promptIgnoreLength; // [1] or [batchSize]
OptVec<SizeType32> noRepeatNgramSize; // [1] or [batchSize]

// probs
Expand Down Expand Up @@ -377,13 +382,14 @@ class SamplingConfig
&& temperature == other.temperature && originalTemperature == other.originalTemperature
&& minLength == other.minLength && repetitionPenalty == other.repetitionPenalty
&& presencePenalty == other.presencePenalty && frequencyPenalty == other.frequencyPenalty
&& noRepeatNgramSize == other.noRepeatNgramSize && topK == other.topK && topP == other.topP
&& randomSeed == other.randomSeed && topPDecay == other.topPDecay && topPMin == other.topPMin
&& topPResetIds == other.topPResetIds && beamSearchDiversityRate == other.beamSearchDiversityRate
&& lengthPenalty == other.lengthPenalty && earlyStopping == other.earlyStopping
&& draftAcceptanceThreshold == other.draftAcceptanceThreshold && topKMedusaHeads == other.topKMedusaHeads
&& normalizeLogProbs == other.normalizeLogProbs && outputLogProbs == other.outputLogProbs
&& cumLogProbs == other.cumLogProbs && minP == other.minP && beamWidthArray == other.beamWidthArray;
&& promptIgnoreLength == other.promptIgnoreLength && noRepeatNgramSize == other.noRepeatNgramSize
&& topK == other.topK && topP == other.topP && randomSeed == other.randomSeed
&& topPDecay == other.topPDecay && topPMin == other.topPMin && topPResetIds == other.topPResetIds
&& beamSearchDiversityRate == other.beamSearchDiversityRate && lengthPenalty == other.lengthPenalty
&& earlyStopping == other.earlyStopping && draftAcceptanceThreshold == other.draftAcceptanceThreshold
&& topKMedusaHeads == other.topKMedusaHeads && normalizeLogProbs == other.normalizeLogProbs
&& outputLogProbs == other.outputLogProbs && cumLogProbs == other.cumLogProbs && minP == other.minP
&& beamWidthArray == other.beamWidthArray;
}

SizeType32 getNumReturnBeams() const
Expand Down
33 changes: 27 additions & 6 deletions cpp/tensorrt_llm/executor/samplingConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ SamplingConfig::SamplingConfig(SizeType32 beamWidth, OptSize32 const& topK, OptF
OptFloat const& topPMin, std::optional<TokenIdType> const& topPResetIds, OptFloat const& topPDecay,
std::optional<RandomSeedType> const& seed, OptFloat const& temperature, OptSize32 const& minTokens,
OptFloat const& beamSearchDiversityRate, OptFloat const& repetitionPenalty, OptFloat const& presencePenalty,
OptFloat const& frequencyPenalty, OptFloat const& lengthPenalty, OptSize32 const& earlyStopping,
OptSize32 const& noRepeatNgramSize, OptSize32 const& numReturnSequences, OptFloat const& minP,
OptVec<SizeType32> const& beamWidthArray)
OptFloat const& frequencyPenalty, OptSize32 const& promptIgnoreLength, OptFloat const& lengthPenalty,
OptSize32 const& earlyStopping, OptSize32 const& noRepeatNgramSize, OptSize32 const& numReturnSequences,
OptFloat const& minP, OptVec<SizeType32> const& beamWidthArray)
: mBeamWidth(checkBeamWidth(beamWidth))
, mTopK(checkTopK(topK))
, mTopP(checkTopP(topP))
Expand All @@ -50,6 +50,7 @@ SamplingConfig::SamplingConfig(SizeType32 beamWidth, OptSize32 const& topK, OptF
, mRepetitionPenalty(checkRepetitionPenalty(repetitionPenalty))
, mPresencePenalty(presencePenalty)
, mFrequencyPenalty(frequencyPenalty)
, mPromptIgnoreLength(checkPromptIgnoreLength(promptIgnoreLength))
, mLengthPenalty(checkLengthPenalty(lengthPenalty))
, mEarlyStopping(checkEarlyStopping(earlyStopping))
, mNoRepeatNgramSize(checkNoRepeatNgramSize(noRepeatNgramSize))
Expand All @@ -67,9 +68,10 @@ bool SamplingConfig::operator==(SamplingConfig const& other) const
&& mTemperature == other.mTemperature && mMinTokens == other.mMinTokens
&& mBeamSearchDiversityRate == other.mBeamSearchDiversityRate && mRepetitionPenalty == other.mRepetitionPenalty
&& mPresencePenalty == other.mPresencePenalty && mFrequencyPenalty == other.mFrequencyPenalty
&& mLengthPenalty == other.mLengthPenalty && mEarlyStopping == other.mEarlyStopping
&& mNoRepeatNgramSize == other.mNoRepeatNgramSize && mNumReturnSequences == other.mNumReturnSequences
&& mMinP == other.mMinP && mBeamWidthArray == other.mBeamWidthArray;
&& mPromptIgnoreLength == other.mPromptIgnoreLength && mLengthPenalty == other.mLengthPenalty
&& mEarlyStopping == other.mEarlyStopping && mNoRepeatNgramSize == other.mNoRepeatNgramSize
&& mNumReturnSequences == other.mNumReturnSequences && mMinP == other.mMinP
&& mBeamWidthArray == other.mBeamWidthArray;
}

// Getters
Expand Down Expand Up @@ -143,6 +145,11 @@ OptFloat SamplingConfig::getFrequencyPenalty() const
return mFrequencyPenalty;
}

OptSize32 SamplingConfig::getPromptIgnoreLength() const
{
return mPromptIgnoreLength;
}

OptFloat SamplingConfig::getLengthPenalty() const
{
return mLengthPenalty;
Expand Down Expand Up @@ -240,6 +247,11 @@ void SamplingConfig::setFrequencyPenalty(OptFloat const& frequencyPenalty)
mFrequencyPenalty = frequencyPenalty;
}

void SamplingConfig::setPromptIgnoreLength(OptSize32 const& promptIgnoreLength)
{
mPromptIgnoreLength = checkPromptIgnoreLength(promptIgnoreLength);
}

void SamplingConfig::setLengthPenalty(OptFloat const& lengthPenalty)
{
mLengthPenalty = lengthPenalty; // TODO: re-enable `checkLengthPenalty` later
Expand Down Expand Up @@ -362,6 +374,15 @@ OptFloat const& SamplingConfig::checkRepetitionPenalty(OptFloat const& repetitio
return repetitionpenalty;
}

OptSize32 const& SamplingConfig::checkPromptIgnoreLength(OptSize32 const& promptIgnoreLength)
{
if (promptIgnoreLength.has_value())
{
TLLM_CHECK(promptIgnoreLength.value() >= 0);
}
return promptIgnoreLength;
}

OptFloat const& SamplingConfig::checkLengthPenalty(OptFloat const& lengthPenalty)
{
if (lengthPenalty.has_value())
Expand Down
7 changes: 5 additions & 2 deletions cpp/tensorrt_llm/executor/serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ SamplingConfig Serialization::deserializeSamplingConfig(std::istream& is)
auto repetitionPenalty = su::deserialize<std::optional<FloatType>>(is);
auto presencePenalty = su::deserialize<std::optional<FloatType>>(is);
auto frequencyPenalty = su::deserialize<std::optional<FloatType>>(is);
auto promptIgnoreLength = su::deserialize<std::optional<SizeType32>>(is);
auto lengthPenalty = su::deserialize<std::optional<FloatType>>(is);
auto earlyStopping = su::deserialize<std::optional<SizeType32>>(is);
auto noRepeatNgramSize = su::deserialize<std::optional<SizeType32>>(is);
Expand All @@ -167,8 +168,8 @@ SamplingConfig Serialization::deserializeSamplingConfig(std::istream& is)
auto beamWidthArray = su::deserialize<std::optional<std::vector<SizeType32>>>(is);

return SamplingConfig{beamWidth, topK, topP, topPMin, topPResetIds, topPDecay, randomSeed, temperature, minLength,
beamSearchDiversityRate, repetitionPenalty, presencePenalty, frequencyPenalty, lengthPenalty, earlyStopping,
noRepeatNgramSize, numReturnSequences, minP, beamWidthArray};
beamSearchDiversityRate, repetitionPenalty, presencePenalty, frequencyPenalty, promptIgnoreLength,
lengthPenalty, earlyStopping, noRepeatNgramSize, numReturnSequences, minP, beamWidthArray};
}

void Serialization::serialize(SamplingConfig const& config, std::ostream& os)
Expand All @@ -186,6 +187,7 @@ void Serialization::serialize(SamplingConfig const& config, std::ostream& os)
su::serialize(config.mRepetitionPenalty, os);
su::serialize(config.mPresencePenalty, os);
su::serialize(config.mFrequencyPenalty, os);
su::serialize(config.mPromptIgnoreLength, os);
su::serialize(config.mLengthPenalty, os);
su::serialize(config.mEarlyStopping, os);
su::serialize(config.mNoRepeatNgramSize, os);
Expand All @@ -210,6 +212,7 @@ size_t Serialization::serializedSize(SamplingConfig const& config)
totalSize += su::serializedSize(config.mRepetitionPenalty);
totalSize += su::serializedSize(config.mPresencePenalty);
totalSize += su::serializedSize(config.mFrequencyPenalty);
totalSize += su::serializedSize(config.mPromptIgnoreLength);
totalSize += su::serializedSize(config.mLengthPenalty);
totalSize += su::serializedSize(config.mEarlyStopping);
totalSize += su::serializedSize(config.mNoRepeatNgramSize);
Expand Down
Loading