Skip to content

Commit 1e1a131

Browse files
authored
Merge pull request #986 from martindevans/logit_bias
Implemented `LogitBias` for `DefaultSamplingPipeline`
2 parents f68c1f1 + 07ec3fc commit 1e1a131

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

LLama/Sampling/DefaultSamplingPipeline.cs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,28 @@ protected override SafeLLamaSamplerChainHandle CreateChain(SafeLLamaContextHandl
118118
{
119119
var chain = SafeLLamaSamplerChainHandle.Create(LLamaSamplerChainParams.Default());
120120

121+
// Rent a temporary array and copy the biases into it
122+
var biases = ArrayPool<LLamaLogitBias>.Shared.Rent(LogitBias.Count);
123+
try
124+
{
125+
var index = 0;
126+
foreach (var bias in LogitBias)
127+
{
128+
biases[index++] = new LLamaLogitBias
129+
{
130+
Token = bias.Key,
131+
Bias = bias.Value
132+
};
133+
}
134+
135+
// Add the biases to the sampler
136+
chain.AddLogitBias(context.ModelHandle.VocabCount, biases.AsSpan(0, LogitBias.Count));
137+
}
138+
finally
139+
{
140+
ArrayPool<LLamaLogitBias>.Shared.Return(biases);
141+
}
142+
121143
if (Grammar != null)
122144
chain.AddGrammar(context.ModelHandle, Grammar.Gbnf, Grammar.Root);
123145

0 commit comments

Comments
 (0)