diff --git a/LLama/Native/LLamaTokenDataArray.cs b/LLama/Native/LLamaTokenDataArray.cs index e1a19555d..32d2cccf5 100644 --- a/LLama/Native/LLamaTokenDataArray.cs +++ b/LLama/Native/LLamaTokenDataArray.cs @@ -134,9 +134,9 @@ public int Compare(LLamaTokenData x, LLamaTokenData y) } /// - /// Contains a pointer to an array of LLamaTokenData which is pinned in memory. - /// - /// C# equivalent of llama_token_data_array + /// Contains a pointer to an array of LLamaTokenData which is pinned in memory. + /// + /// C# equivalent of llama_token_data_array [StructLayout(LayoutKind.Sequential)] public struct LLamaTokenDataArrayNative { diff --git a/LLama/Sampling/BaseSamplingPipeline.cs b/LLama/Sampling/BaseSamplingPipeline.cs index d0c3348a2..303412b8c 100644 --- a/LLama/Sampling/BaseSamplingPipeline.cs +++ b/LLama/Sampling/BaseSamplingPipeline.cs @@ -23,7 +23,7 @@ public BaseSamplingPipeline() protected abstract SafeLLamaSamplerChainHandle CreateChain(SafeLLamaContextHandle context); /// - public void Dispose() + public virtual void Dispose() { _chain?.Dispose(); _chain = null; @@ -32,7 +32,7 @@ public void Dispose() } /// - public LLamaToken Sample(SafeLLamaContextHandle ctx, int index) + public virtual LLamaToken Sample(SafeLLamaContextHandle ctx, int index) { _chain ??= CreateChain(ctx); @@ -40,13 +40,32 @@ public LLamaToken Sample(SafeLLamaContextHandle ctx, int index) } /// - public void Reset() + public virtual void Apply(SafeLLamaContextHandle ctx, LLamaTokenDataArray data) + { + _chain ??= CreateChain(ctx); + using (LLamaTokenDataArrayNative.Create(data, out var native)) + _chain.Apply(ref native); + } + + /// + /// Apply this sampling chain to a LLamaTokenDataArrayNative + /// + /// + /// + public virtual void Apply(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative data) + { + _chain ??= CreateChain(ctx); + _chain.Apply(ref data); + } + + /// + public virtual void Reset() { _chain?.Reset(); } /// - public void Accept(LLamaToken token) + public virtual void Accept(LLamaToken token) { _chain?.Accept(token); } diff --git a/LLama/Sampling/DefaultSamplingPipeline.cs b/LLama/Sampling/DefaultSamplingPipeline.cs index 8bc074062..cd8f57f27 100644 --- a/LLama/Sampling/DefaultSamplingPipeline.cs +++ b/LLama/Sampling/DefaultSamplingPipeline.cs @@ -112,6 +112,16 @@ public float PresencePenalty /// Seed to use for random sampling /// public uint Seed { get; set; } = GetRandomSeed(); + + /// + /// Selected grammar optimization mode + /// + public GrammarOptimizationMode GrammarOptimization { get; init; } = GrammarOptimizationMode.Extended; + + /// + /// A chain with just the grammar + /// + private SafeLLamaSamplerChainHandle? _grammarChain; private static readonly Random RandomSeedGenerator = new(); @@ -121,37 +131,71 @@ private static uint GetRandomSeed() return (uint) RandomSeedGenerator.Next(0, int.MaxValue) + (uint) RandomSeedGenerator.Next(0, int.MaxValue); } + /// + public override void Dispose() + { + base.Dispose(); + + _grammarChain?.Dispose(); + _grammarChain = null; + } + + /// + public override void Reset() + { + base.Reset(); + + _grammarChain?.Reset(); + } + + /// + public override void Accept(LLamaToken token) + { + base.Accept(token); + + _grammarChain?.Accept(token); + } + + private SafeLLamaSamplerChainHandle CreateGrammarChain(SafeLLamaContextHandle context) + { + if (Grammar == null) + throw new InvalidOperationException(nameof(Grammar) + " is null"); + + var chain = SafeLLamaSamplerChainHandle.Create(LLamaSamplerChainParams.Default()); + chain.AddGrammar(context.ModelHandle, Grammar.Gbnf, Grammar.Root); + return chain; + } /// protected override SafeLLamaSamplerChainHandle CreateChain(SafeLLamaContextHandle context) { var chain = SafeLLamaSamplerChainHandle.Create(LLamaSamplerChainParams.Default()); - // Rent a temporary array and copy the biases into it - var biases = ArrayPool.Shared.Rent(LogitBias.Count); - try + if (LogitBias.Count > 0) { - var index = 0; - foreach (var bias in LogitBias) + // Rent a temporary array and copy the biases into it + var biases = ArrayPool.Shared.Rent(LogitBias.Count); + try { - biases[index++] = new LLamaLogitBias + var index = 0; + foreach (var bias in LogitBias) { - Token = bias.Key, - Bias = bias.Value - }; - } + biases[index++] = new LLamaLogitBias + { + Token = bias.Key, + Bias = bias.Value + }; + } - // Add the biases to the sampler - chain.AddLogitBias(context.Vocab.Count, biases.AsSpan(0, LogitBias.Count)); - } - finally - { - ArrayPool.Shared.Return(biases); + // Add the biases to the sampler + chain.AddLogitBias(context.Vocab.Count, biases.AsSpan(0, LogitBias.Count)); + } + finally + { + ArrayPool.Shared.Return(biases); + } } - if (Grammar != null) - chain.AddGrammar(context.ModelHandle, Grammar.Gbnf, Grammar.Root); - chain.AddPenalties(PenaltyCount, RepeatPenalty, FrequencyPenalty, PresencePenalty); chain.AddTopK(TopK); @@ -164,4 +208,131 @@ protected override SafeLLamaSamplerChainHandle CreateChain(SafeLLamaContextHandl return chain; } + + /// + public override LLamaToken Sample(SafeLLamaContextHandle ctx, int index) + { + if (Grammar == null) + return base.Sample(ctx, index); + + // Create a chain with the grammar + _grammarChain ??= CreateGrammarChain(ctx); + + // Rent some buffers to use later + var rentedBufferVocabSizeArr = ArrayPool.Shared.Rent(ctx.ModelHandle.Vocab.Count); + var rentedBufferVocabSize = rentedBufferVocabSizeArr.AsMemory(0, ctx.ModelHandle.Vocab.Count); + var rentedBufferSingleItemArr = ArrayPool.Shared.Rent(1); + var rentedBufferSingleItem = rentedBufferSingleItemArr.AsMemory(0, 1); + + try + { + // Handle grammar optimization modes + if (GrammarOptimization != GrammarOptimizationMode.None) + { + // Basic optimization : Apply the grammar to the selected token and check if it's valid + using (LLamaTokenDataArrayNative.Create(LLamaTokenDataArray.Create(ctx.GetLogitsIth(index), rentedBufferVocabSize), out var nativeAll)) + { + // Apply the chain without the grammar to select one token which may or may not be valid + Apply(ctx, ref nativeAll); + + // Select the candidate token + var candidateToken = nativeAll.Data[checked((int)nativeAll.Selected)].ID; + + // Now create another token data array with just that one token + rentedBufferSingleItem.Span[0] = new LLamaTokenData(candidateToken, 1, 0); + using (LLamaTokenDataArrayNative.Create(new LLamaTokenDataArray(rentedBufferSingleItem, true), out var nativeSingleCandidate)) + { + // Apply the grammar chain to the single candidate + _grammarChain.Apply(ref nativeSingleCandidate); + + // Check if the token passes the grammar + if (!float.IsNegativeInfinity(nativeSingleCandidate.Data[0].Logit)) + { + Accept(candidateToken); + return candidateToken; + } + } + + // Extended optimization : Apply the grammar to the TopK tokens and check if the selected token is valid + if (GrammarOptimization == GrammarOptimizationMode.Extended) + { + // Calculate a safe TopK value + var safeTopK = Math.Min(TopK, nativeAll.Data.Length); + + // Rent a buffer for the TopK candidates + var rentedBufferTopKArr = ArrayPool.Shared.Rent(safeTopK); + var rentedBufferTopK = rentedBufferTopKArr.AsMemory(0, safeTopK); + try + { + // Copy only the TopK tokens from the existing candidate pool to the new buffer + nativeAll.Data.Slice(0, safeTopK).CopyTo(rentedBufferTopK.Span); + + // Create a native array with the TopK tokens + using (LLamaTokenDataArrayNative.Create(new LLamaTokenDataArray(rentedBufferTopK, true), out var nativeTopK)) + { + // Apply the grammar chain to the TopK candidates + _grammarChain.Apply(ref nativeTopK); + + // Select the candidate token + var candidateTokenTopK = nativeTopK.Data[checked((int)nativeTopK.Selected)]; + + // Check if the token passes the grammar + if (!float.IsNegativeInfinity(candidateTokenTopK.Logit)) + { + // Accept and return the token + Accept(candidateTokenTopK.ID); + return candidateTokenTopK.ID; + } + } + } + finally + { + ArrayPool.Shared.Return(rentedBufferTopKArr); + } + } + } + } + + // If we get here the grammar rejected the token + using (LLamaTokenDataArrayNative.Create(LLamaTokenDataArray.Create(ctx.GetLogitsIth(index), rentedBufferVocabSize), out var nativeAll)) + { + // Apply the grammar _first_. This is slower (since it has to work on the entire vocab), but guaranteed to work + _grammarChain.Apply(ref nativeAll); + + // Now apply the rest of the pipeline + Apply(ctx, ref nativeAll); + + // Take the selected token + var token = nativeAll.Data[checked((int)nativeAll.Selected)].ID; + Accept(token); + return token; + } + } + finally + { + ArrayPool.Shared.Return(rentedBufferVocabSizeArr); + ArrayPool.Shared.Return(rentedBufferSingleItemArr); + } + } + + /// + /// Grammar Optimization Mode + /// + public enum GrammarOptimizationMode + { + /// + /// No grammar optimization, slow because it has to apply the grammar to the entire vocab. + /// + None, + + /// + /// Attempts to return early by only applying the grammar to the selected token and checking if it's valid. + /// + Basic, + + /// + /// Attempts to return early by applying the grammar to the top K tokens and checking if the selected token is valid. + /// + Extended + } } \ No newline at end of file diff --git a/LLama/Sampling/ISamplingPipeline.cs b/LLama/Sampling/ISamplingPipeline.cs index d98ad342f..245108701 100644 --- a/LLama/Sampling/ISamplingPipeline.cs +++ b/LLama/Sampling/ISamplingPipeline.cs @@ -17,6 +17,13 @@ public interface ISamplingPipeline /// LLamaToken Sample(SafeLLamaContextHandle ctx, int index); + /// + /// Apply this pipeline to a set of token data + /// + /// + /// + public void Apply(SafeLLamaContextHandle ctx, LLamaTokenDataArray data); + /// /// Reset all internal state of the sampling pipeline ///