diff --git a/LLama/Native/LLamaTokenDataArray.cs b/LLama/Native/LLamaTokenDataArray.cs
index e1a19555d..32d2cccf5 100644
--- a/LLama/Native/LLamaTokenDataArray.cs
+++ b/LLama/Native/LLamaTokenDataArray.cs
@@ -134,9 +134,9 @@ public int Compare(LLamaTokenData x, LLamaTokenData y)
}
///
- /// Contains a pointer to an array of LLamaTokenData which is pinned in memory.
- ///
- /// C# equivalent of llama_token_data_array
+ /// Contains a pointer to an array of LLamaTokenData which is pinned in memory.
+ ///
+ /// C# equivalent of llama_token_data_array
[StructLayout(LayoutKind.Sequential)]
public struct LLamaTokenDataArrayNative
{
diff --git a/LLama/Sampling/BaseSamplingPipeline.cs b/LLama/Sampling/BaseSamplingPipeline.cs
index d0c3348a2..303412b8c 100644
--- a/LLama/Sampling/BaseSamplingPipeline.cs
+++ b/LLama/Sampling/BaseSamplingPipeline.cs
@@ -23,7 +23,7 @@ public BaseSamplingPipeline()
protected abstract SafeLLamaSamplerChainHandle CreateChain(SafeLLamaContextHandle context);
///
- public void Dispose()
+ public virtual void Dispose()
{
_chain?.Dispose();
_chain = null;
@@ -32,7 +32,7 @@ public void Dispose()
}
///
- public LLamaToken Sample(SafeLLamaContextHandle ctx, int index)
+ public virtual LLamaToken Sample(SafeLLamaContextHandle ctx, int index)
{
_chain ??= CreateChain(ctx);
@@ -40,13 +40,32 @@ public LLamaToken Sample(SafeLLamaContextHandle ctx, int index)
}
///
- public void Reset()
+ public virtual void Apply(SafeLLamaContextHandle ctx, LLamaTokenDataArray data)
+ {
+ _chain ??= CreateChain(ctx);
+ using (LLamaTokenDataArrayNative.Create(data, out var native))
+ _chain.Apply(ref native);
+ }
+
+ ///
+ /// Apply this sampling chain to a LLamaTokenDataArrayNative
+ ///
+ ///
+ ///
+ public virtual void Apply(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative data)
+ {
+ _chain ??= CreateChain(ctx);
+ _chain.Apply(ref data);
+ }
+
+ ///
+ public virtual void Reset()
{
_chain?.Reset();
}
///
- public void Accept(LLamaToken token)
+ public virtual void Accept(LLamaToken token)
{
_chain?.Accept(token);
}
diff --git a/LLama/Sampling/DefaultSamplingPipeline.cs b/LLama/Sampling/DefaultSamplingPipeline.cs
index 8bc074062..cd8f57f27 100644
--- a/LLama/Sampling/DefaultSamplingPipeline.cs
+++ b/LLama/Sampling/DefaultSamplingPipeline.cs
@@ -112,6 +112,16 @@ public float PresencePenalty
/// Seed to use for random sampling
///
public uint Seed { get; set; } = GetRandomSeed();
+
+ ///
+ /// Selected grammar optimization mode
+ ///
+ public GrammarOptimizationMode GrammarOptimization { get; init; } = GrammarOptimizationMode.Extended;
+
+ ///
+ /// A chain with just the grammar
+ ///
+ private SafeLLamaSamplerChainHandle? _grammarChain;
private static readonly Random RandomSeedGenerator = new();
@@ -121,37 +131,71 @@ private static uint GetRandomSeed()
return (uint) RandomSeedGenerator.Next(0, int.MaxValue) + (uint) RandomSeedGenerator.Next(0, int.MaxValue);
}
+ ///
+ public override void Dispose()
+ {
+ base.Dispose();
+
+ _grammarChain?.Dispose();
+ _grammarChain = null;
+ }
+
+ ///
+ public override void Reset()
+ {
+ base.Reset();
+
+ _grammarChain?.Reset();
+ }
+
+ ///
+ public override void Accept(LLamaToken token)
+ {
+ base.Accept(token);
+
+ _grammarChain?.Accept(token);
+ }
+
+ private SafeLLamaSamplerChainHandle CreateGrammarChain(SafeLLamaContextHandle context)
+ {
+ if (Grammar == null)
+ throw new InvalidOperationException(nameof(Grammar) + " is null");
+
+ var chain = SafeLLamaSamplerChainHandle.Create(LLamaSamplerChainParams.Default());
+ chain.AddGrammar(context.ModelHandle, Grammar.Gbnf, Grammar.Root);
+ return chain;
+ }
///
protected override SafeLLamaSamplerChainHandle CreateChain(SafeLLamaContextHandle context)
{
var chain = SafeLLamaSamplerChainHandle.Create(LLamaSamplerChainParams.Default());
- // Rent a temporary array and copy the biases into it
- var biases = ArrayPool.Shared.Rent(LogitBias.Count);
- try
+ if (LogitBias.Count > 0)
{
- var index = 0;
- foreach (var bias in LogitBias)
+ // Rent a temporary array and copy the biases into it
+ var biases = ArrayPool.Shared.Rent(LogitBias.Count);
+ try
{
- biases[index++] = new LLamaLogitBias
+ var index = 0;
+ foreach (var bias in LogitBias)
{
- Token = bias.Key,
- Bias = bias.Value
- };
- }
+ biases[index++] = new LLamaLogitBias
+ {
+ Token = bias.Key,
+ Bias = bias.Value
+ };
+ }
- // Add the biases to the sampler
- chain.AddLogitBias(context.Vocab.Count, biases.AsSpan(0, LogitBias.Count));
- }
- finally
- {
- ArrayPool.Shared.Return(biases);
+ // Add the biases to the sampler
+ chain.AddLogitBias(context.Vocab.Count, biases.AsSpan(0, LogitBias.Count));
+ }
+ finally
+ {
+ ArrayPool.Shared.Return(biases);
+ }
}
- if (Grammar != null)
- chain.AddGrammar(context.ModelHandle, Grammar.Gbnf, Grammar.Root);
-
chain.AddPenalties(PenaltyCount, RepeatPenalty, FrequencyPenalty, PresencePenalty);
chain.AddTopK(TopK);
@@ -164,4 +208,131 @@ protected override SafeLLamaSamplerChainHandle CreateChain(SafeLLamaContextHandl
return chain;
}
+
+ ///
+ public override LLamaToken Sample(SafeLLamaContextHandle ctx, int index)
+ {
+ if (Grammar == null)
+ return base.Sample(ctx, index);
+
+ // Create a chain with the grammar
+ _grammarChain ??= CreateGrammarChain(ctx);
+
+ // Rent some buffers to use later
+ var rentedBufferVocabSizeArr = ArrayPool.Shared.Rent(ctx.ModelHandle.Vocab.Count);
+ var rentedBufferVocabSize = rentedBufferVocabSizeArr.AsMemory(0, ctx.ModelHandle.Vocab.Count);
+ var rentedBufferSingleItemArr = ArrayPool.Shared.Rent(1);
+ var rentedBufferSingleItem = rentedBufferSingleItemArr.AsMemory(0, 1);
+
+ try
+ {
+ // Handle grammar optimization modes
+ if (GrammarOptimization != GrammarOptimizationMode.None)
+ {
+ // Basic optimization : Apply the grammar to the selected token and check if it's valid
+ using (LLamaTokenDataArrayNative.Create(LLamaTokenDataArray.Create(ctx.GetLogitsIth(index), rentedBufferVocabSize), out var nativeAll))
+ {
+ // Apply the chain without the grammar to select one token which may or may not be valid
+ Apply(ctx, ref nativeAll);
+
+ // Select the candidate token
+ var candidateToken = nativeAll.Data[checked((int)nativeAll.Selected)].ID;
+
+ // Now create another token data array with just that one token
+ rentedBufferSingleItem.Span[0] = new LLamaTokenData(candidateToken, 1, 0);
+ using (LLamaTokenDataArrayNative.Create(new LLamaTokenDataArray(rentedBufferSingleItem, true), out var nativeSingleCandidate))
+ {
+ // Apply the grammar chain to the single candidate
+ _grammarChain.Apply(ref nativeSingleCandidate);
+
+ // Check if the token passes the grammar
+ if (!float.IsNegativeInfinity(nativeSingleCandidate.Data[0].Logit))
+ {
+ Accept(candidateToken);
+ return candidateToken;
+ }
+ }
+
+ // Extended optimization : Apply the grammar to the TopK tokens and check if the selected token is valid
+ if (GrammarOptimization == GrammarOptimizationMode.Extended)
+ {
+ // Calculate a safe TopK value
+ var safeTopK = Math.Min(TopK, nativeAll.Data.Length);
+
+ // Rent a buffer for the TopK candidates
+ var rentedBufferTopKArr = ArrayPool.Shared.Rent(safeTopK);
+ var rentedBufferTopK = rentedBufferTopKArr.AsMemory(0, safeTopK);
+ try
+ {
+ // Copy only the TopK tokens from the existing candidate pool to the new buffer
+ nativeAll.Data.Slice(0, safeTopK).CopyTo(rentedBufferTopK.Span);
+
+ // Create a native array with the TopK tokens
+ using (LLamaTokenDataArrayNative.Create(new LLamaTokenDataArray(rentedBufferTopK, true), out var nativeTopK))
+ {
+ // Apply the grammar chain to the TopK candidates
+ _grammarChain.Apply(ref nativeTopK);
+
+ // Select the candidate token
+ var candidateTokenTopK = nativeTopK.Data[checked((int)nativeTopK.Selected)];
+
+ // Check if the token passes the grammar
+ if (!float.IsNegativeInfinity(candidateTokenTopK.Logit))
+ {
+ // Accept and return the token
+ Accept(candidateTokenTopK.ID);
+ return candidateTokenTopK.ID;
+ }
+ }
+ }
+ finally
+ {
+ ArrayPool.Shared.Return(rentedBufferTopKArr);
+ }
+ }
+ }
+ }
+
+ // If we get here the grammar rejected the token
+ using (LLamaTokenDataArrayNative.Create(LLamaTokenDataArray.Create(ctx.GetLogitsIth(index), rentedBufferVocabSize), out var nativeAll))
+ {
+ // Apply the grammar _first_. This is slower (since it has to work on the entire vocab), but guaranteed to work
+ _grammarChain.Apply(ref nativeAll);
+
+ // Now apply the rest of the pipeline
+ Apply(ctx, ref nativeAll);
+
+ // Take the selected token
+ var token = nativeAll.Data[checked((int)nativeAll.Selected)].ID;
+ Accept(token);
+ return token;
+ }
+ }
+ finally
+ {
+ ArrayPool.Shared.Return(rentedBufferVocabSizeArr);
+ ArrayPool.Shared.Return(rentedBufferSingleItemArr);
+ }
+ }
+
+ ///
+ /// Grammar Optimization Mode
+ ///
+ public enum GrammarOptimizationMode
+ {
+ ///
+ /// No grammar optimization, slow because it has to apply the grammar to the entire vocab.
+ ///
+ None,
+
+ ///
+ /// Attempts to return early by only applying the grammar to the selected token and checking if it's valid.
+ ///
+ Basic,
+
+ ///
+ /// Attempts to return early by applying the grammar to the top K tokens and checking if the selected token is valid.
+ ///
+ Extended
+ }
}
\ No newline at end of file
diff --git a/LLama/Sampling/ISamplingPipeline.cs b/LLama/Sampling/ISamplingPipeline.cs
index d98ad342f..245108701 100644
--- a/LLama/Sampling/ISamplingPipeline.cs
+++ b/LLama/Sampling/ISamplingPipeline.cs
@@ -17,6 +17,13 @@ public interface ISamplingPipeline
///
LLamaToken Sample(SafeLLamaContextHandle ctx, int index);
+ ///
+ /// Apply this pipeline to a set of token data
+ ///
+ ///
+ ///
+ public void Apply(SafeLLamaContextHandle ctx, LLamaTokenDataArray data);
+
///
/// Reset all internal state of the sampling pipeline
///