From 0db5c646fc1649ac7e1bf81fe1b5bee9a9fa1346 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Fri, 21 Feb 2025 22:23:25 +0000 Subject: [PATCH 1/8] First sketch of grammar resampling --- LLama/Native/LLamaTokenDataArray.cs | 6 +- LLama/Sampling/BaseSamplingPipeline.cs | 19 ++++-- LLama/Sampling/DefaultSamplingPipeline.cs | 75 ++++++++++++++++++++++- LLama/Sampling/ISamplingPipeline.cs | 7 +++ 4 files changed, 97 insertions(+), 10 deletions(-) diff --git a/LLama/Native/LLamaTokenDataArray.cs b/LLama/Native/LLamaTokenDataArray.cs index e1a19555d..32d2cccf5 100644 --- a/LLama/Native/LLamaTokenDataArray.cs +++ b/LLama/Native/LLamaTokenDataArray.cs @@ -134,9 +134,9 @@ public int Compare(LLamaTokenData x, LLamaTokenData y) } /// - /// Contains a pointer to an array of LLamaTokenData which is pinned in memory. - /// - /// C# equivalent of llama_token_data_array + /// Contains a pointer to an array of LLamaTokenData which is pinned in memory. + /// + /// C# equivalent of llama_token_data_array [StructLayout(LayoutKind.Sequential)] public struct LLamaTokenDataArrayNative { diff --git a/LLama/Sampling/BaseSamplingPipeline.cs b/LLama/Sampling/BaseSamplingPipeline.cs index d0c3348a2..ffe9b8d30 100644 --- a/LLama/Sampling/BaseSamplingPipeline.cs +++ b/LLama/Sampling/BaseSamplingPipeline.cs @@ -23,7 +23,7 @@ public BaseSamplingPipeline() protected abstract SafeLLamaSamplerChainHandle CreateChain(SafeLLamaContextHandle context); /// - public void Dispose() + public virtual void Dispose() { _chain?.Dispose(); _chain = null; @@ -32,21 +32,32 @@ public void Dispose() } /// - public LLamaToken Sample(SafeLLamaContextHandle ctx, int index) + public virtual LLamaToken Sample(SafeLLamaContextHandle ctx, int index) { _chain ??= CreateChain(ctx); return _chain.Sample(ctx, index); } + /// + /// Apply this sampling chain to a LLamaTokenDataArrayNative + /// + /// + /// + public virtual void Apply(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative data) + { + _chain ??= CreateChain(ctx); + _chain.Apply(ref data); + } + /// - public void Reset() + public virtual void Reset() { _chain?.Reset(); } /// - public void Accept(LLamaToken token) + public virtual void Accept(LLamaToken token) { _chain?.Accept(token); } diff --git a/LLama/Sampling/DefaultSamplingPipeline.cs b/LLama/Sampling/DefaultSamplingPipeline.cs index 8bc074062..f2c0790f8 100644 --- a/LLama/Sampling/DefaultSamplingPipeline.cs +++ b/LLama/Sampling/DefaultSamplingPipeline.cs @@ -113,6 +113,11 @@ public float PresencePenalty /// public uint Seed { get; set; } = GetRandomSeed(); + /// + /// A chain with just the grammar + /// + private SafeLLamaSamplerChainHandle? _grammarChain; + private static readonly Random RandomSeedGenerator = new(); private static uint GetRandomSeed() @@ -121,6 +126,41 @@ private static uint GetRandomSeed() return (uint) RandomSeedGenerator.Next(0, int.MaxValue) + (uint) RandomSeedGenerator.Next(0, int.MaxValue); } + /// + public override void Dispose() + { + base.Dispose(); + + _grammarChain?.Dispose(); + _grammarChain = null; + } + + /// + public override void Reset() + { + base.Reset(); + + _grammarChain?.Reset(); + } + + /// + public override void Accept(LLamaToken token) + { + base.Accept(token); + + _grammarChain?.Accept(token); + } + + private SafeLLamaSamplerChainHandle CreateGrammarChain(SafeLLamaContextHandle context) + { + if (Grammar == null) + throw new InvalidOperationException(nameof(Grammar) + " is null"); + + var chain = SafeLLamaSamplerChainHandle.Create(LLamaSamplerChainParams.Default()); + chain.AddGrammar(context.ModelHandle, Grammar.Gbnf, Grammar.Root); + chain.AddDistributionSampler(Seed); + return chain; + } /// protected override SafeLLamaSamplerChainHandle CreateChain(SafeLLamaContextHandle context) @@ -149,9 +189,6 @@ protected override SafeLLamaSamplerChainHandle CreateChain(SafeLLamaContextHandl ArrayPool.Shared.Return(biases); } - if (Grammar != null) - chain.AddGrammar(context.ModelHandle, Grammar.Gbnf, Grammar.Root); - chain.AddPenalties(PenaltyCount, RepeatPenalty, FrequencyPenalty, PresencePenalty); chain.AddTopK(TopK); @@ -164,4 +201,36 @@ protected override SafeLLamaSamplerChainHandle CreateChain(SafeLLamaContextHandl return chain; } + + public override LLamaToken Sample(SafeLLamaContextHandle ctx, int index) + { + if (Grammar == null) + return base.Sample(ctx, index); + + // Create a chain with the grammar + _grammarChain ??= CreateGrammarChain(ctx); + + //todo: pass in rented temporary to LLamaTokenDataArray.Create (x2) + + using (LLamaTokenDataArrayNative.Create(LLamaTokenDataArray.Create(ctx.GetLogitsIth(index)), out var nativeAll)) + { + // Apply the chain without the grammar to select one token which may or may not be valid + Apply(ctx, ref nativeAll); + var candidateToken = nativeAll.Data[checked((int)nativeAll.Selected)].ID; + + // Now create another token data array with just that one token + using (LLamaTokenDataArrayNative.Create(new LLamaTokenDataArray(new[] { new LLamaTokenData(candidateToken, 1, 0) }, true), out var nativeSingleCandidate)) + { + // Apply the grammar to this single candidate. + _grammarChain.Apply(ref nativeSingleCandidate); + + // Test if that single token was rejected by the grammar + if (!float.IsNegativeInfinity(nativeSingleCandidate.Data[0].Logit)) + return candidateToken; + } + } + + // If we get here the grammar rejected the token, fallback to applying the grammar first and then the entire pipeline + throw new NotImplementedException(); + } } \ No newline at end of file diff --git a/LLama/Sampling/ISamplingPipeline.cs b/LLama/Sampling/ISamplingPipeline.cs index d98ad342f..245108701 100644 --- a/LLama/Sampling/ISamplingPipeline.cs +++ b/LLama/Sampling/ISamplingPipeline.cs @@ -17,6 +17,13 @@ public interface ISamplingPipeline /// LLamaToken Sample(SafeLLamaContextHandle ctx, int index); + /// + /// Apply this pipeline to a set of token data + /// + /// + /// + public void Apply(SafeLLamaContextHandle ctx, LLamaTokenDataArray data); + /// /// Reset all internal state of the sampling pipeline /// From b3b6397f265c5e94c7b79979cd0fed31398886f4 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Sun, 23 Feb 2025 00:35:23 +0000 Subject: [PATCH 2/8] Completed resampling system --- LLama/Sampling/BaseSamplingPipeline.cs | 8 ++ LLama/Sampling/DefaultSamplingPipeline.cs | 93 +++++++++++++++-------- 2 files changed, 69 insertions(+), 32 deletions(-) diff --git a/LLama/Sampling/BaseSamplingPipeline.cs b/LLama/Sampling/BaseSamplingPipeline.cs index ffe9b8d30..303412b8c 100644 --- a/LLama/Sampling/BaseSamplingPipeline.cs +++ b/LLama/Sampling/BaseSamplingPipeline.cs @@ -39,6 +39,14 @@ public virtual LLamaToken Sample(SafeLLamaContextHandle ctx, int index) return _chain.Sample(ctx, index); } + /// + public virtual void Apply(SafeLLamaContextHandle ctx, LLamaTokenDataArray data) + { + _chain ??= CreateChain(ctx); + using (LLamaTokenDataArrayNative.Create(data, out var native)) + _chain.Apply(ref native); + } + /// /// Apply this sampling chain to a LLamaTokenDataArrayNative /// diff --git a/LLama/Sampling/DefaultSamplingPipeline.cs b/LLama/Sampling/DefaultSamplingPipeline.cs index f2c0790f8..08fbf05d8 100644 --- a/LLama/Sampling/DefaultSamplingPipeline.cs +++ b/LLama/Sampling/DefaultSamplingPipeline.cs @@ -167,26 +167,29 @@ protected override SafeLLamaSamplerChainHandle CreateChain(SafeLLamaContextHandl { var chain = SafeLLamaSamplerChainHandle.Create(LLamaSamplerChainParams.Default()); - // Rent a temporary array and copy the biases into it - var biases = ArrayPool.Shared.Rent(LogitBias.Count); - try + if (LogitBias.Count > 0) { - var index = 0; - foreach (var bias in LogitBias) + // Rent a temporary array and copy the biases into it + var biases = ArrayPool.Shared.Rent(LogitBias.Count); + try { - biases[index++] = new LLamaLogitBias + var index = 0; + foreach (var bias in LogitBias) { - Token = bias.Key, - Bias = bias.Value - }; + biases[index++] = new LLamaLogitBias + { + Token = bias.Key, + Bias = bias.Value + }; + } + + // Add the biases to the sampler + chain.AddLogitBias(context.Vocab.Count, biases.AsSpan(0, LogitBias.Count)); + } + finally + { + ArrayPool.Shared.Return(biases); } - - // Add the biases to the sampler - chain.AddLogitBias(context.Vocab.Count, biases.AsSpan(0, LogitBias.Count)); - } - finally - { - ArrayPool.Shared.Return(biases); } chain.AddPenalties(PenaltyCount, RepeatPenalty, FrequencyPenalty, PresencePenalty); @@ -202,6 +205,7 @@ protected override SafeLLamaSamplerChainHandle CreateChain(SafeLLamaContextHandl return chain; } + /// public override LLamaToken Sample(SafeLLamaContextHandle ctx, int index) { if (Grammar == null) @@ -210,27 +214,52 @@ public override LLamaToken Sample(SafeLLamaContextHandle ctx, int index) // Create a chain with the grammar _grammarChain ??= CreateGrammarChain(ctx); - //todo: pass in rented temporary to LLamaTokenDataArray.Create (x2) - - using (LLamaTokenDataArrayNative.Create(LLamaTokenDataArray.Create(ctx.GetLogitsIth(index)), out var nativeAll)) + // Rent some buffers to use later + var rentedBufferVocabSize = ArrayPool.Shared.Rent(ctx.ModelHandle.Vocab.Count); + var rentedBufferSingleItem = ArrayPool.Shared.Rent(1); + try { - // Apply the chain without the grammar to select one token which may or may not be valid - Apply(ctx, ref nativeAll); - var candidateToken = nativeAll.Data[checked((int)nativeAll.Selected)].ID; + using (LLamaTokenDataArrayNative.Create(LLamaTokenDataArray.Create(ctx.GetLogitsIth(index), rentedBufferVocabSize), out var nativeAll)) + { + // Apply the chain without the grammar to select one token which may or may not be valid + Apply(ctx, ref nativeAll); + var candidateToken = nativeAll.Data[checked((int)nativeAll.Selected)].ID; - // Now create another token data array with just that one token - using (LLamaTokenDataArrayNative.Create(new LLamaTokenDataArray(new[] { new LLamaTokenData(candidateToken, 1, 0) }, true), out var nativeSingleCandidate)) + // Now create another token data array with just that one token + rentedBufferSingleItem[0] = new LLamaTokenData(candidateToken, 1, 0); + using (LLamaTokenDataArrayNative.Create(new LLamaTokenDataArray(rentedBufferSingleItem, true), out var nativeSingleCandidate)) + { + // Apply the grammar to this single candidate. + _grammarChain.Apply(ref nativeSingleCandidate); + + // Test if that single token was rejected by the grammar + if (!float.IsNegativeInfinity(nativeSingleCandidate.Data[0].Logit)) + { + Accept(candidateToken); + return candidateToken; + } + } + } + + // If we get here the grammar rejected the token + using (LLamaTokenDataArrayNative.Create(LLamaTokenDataArray.Create(ctx.GetLogitsIth(index), rentedBufferVocabSize), out var nativeAll)) { - // Apply the grammar to this single candidate. - _grammarChain.Apply(ref nativeSingleCandidate); + // Apply the grammar _first_. This is slower (since it has to work on the entire vocab), but guaranteed to work + _grammarChain.Apply(ref nativeAll); + + // Now apply the rest of the pipeline + Apply(ctx, ref nativeAll); - // Test if that single token was rejected by the grammar - if (!float.IsNegativeInfinity(nativeSingleCandidate.Data[0].Logit)) - return candidateToken; + // Take the selected token + var token = nativeAll.Data[checked((int)nativeAll.Selected)].ID; + Accept(token); + return token; } } - - // If we get here the grammar rejected the token, fallback to applying the grammar first and then the entire pipeline - throw new NotImplementedException(); + finally + { + ArrayPool.Shared.Return(rentedBufferVocabSize); + ArrayPool.Shared.Return(rentedBufferSingleItem); + } } } \ No newline at end of file From 80eef40ed03e333871b47b0e3dfeea995b2d239d Mon Sep 17 00:00:00 2001 From: m0nsky Date: Sun, 23 Feb 2025 20:35:35 +0100 Subject: [PATCH 3/8] Add grammar optimization mode --- LLama/Sampling/DefaultSamplingPipeline.cs | 26 +++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/LLama/Sampling/DefaultSamplingPipeline.cs b/LLama/Sampling/DefaultSamplingPipeline.cs index 08fbf05d8..0783c9942 100644 --- a/LLama/Sampling/DefaultSamplingPipeline.cs +++ b/LLama/Sampling/DefaultSamplingPipeline.cs @@ -112,6 +112,11 @@ public float PresencePenalty /// Seed to use for random sampling /// public uint Seed { get; set; } = GetRandomSeed(); + + /// + /// Selected grammar optimization mode for processing + /// + public GrammarOptimizationMode GrammarOptimization { get; init; } = GrammarOptimizationMode.None; /// /// A chain with just the grammar @@ -262,4 +267,25 @@ public override LLamaToken Sample(SafeLLamaContextHandle ctx, int index) ArrayPool.Shared.Return(rentedBufferSingleItem); } } + + /// + /// Grammar Optimization Mode + /// + public enum GrammarOptimizationMode + { + /// + /// No grammar optimization, slow because it has to apply the grammar to the entire vocab. + /// + None, + + /// + /// Attempts to return early by only applying the grammar to the selected token and checking if it's valid. + /// + Basic, + + /// + /// Attempts to return early by applying the grammar to the top K tokens and checking if the selected token is valid. + /// + Extended + } } \ No newline at end of file From ba71fbded4f87cf885e9714a8de2ee4bf80fc0b2 Mon Sep 17 00:00:00 2001 From: m0nsky Date: Sun, 23 Feb 2025 20:36:33 +0100 Subject: [PATCH 4/8] Remove the distribution sampler from the grammar chain --- LLama/Sampling/DefaultSamplingPipeline.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/LLama/Sampling/DefaultSamplingPipeline.cs b/LLama/Sampling/DefaultSamplingPipeline.cs index 0783c9942..bf47984e4 100644 --- a/LLama/Sampling/DefaultSamplingPipeline.cs +++ b/LLama/Sampling/DefaultSamplingPipeline.cs @@ -163,7 +163,6 @@ private SafeLLamaSamplerChainHandle CreateGrammarChain(SafeLLamaContextHandle co var chain = SafeLLamaSamplerChainHandle.Create(LLamaSamplerChainParams.Default()); chain.AddGrammar(context.ModelHandle, Grammar.Gbnf, Grammar.Root); - chain.AddDistributionSampler(Seed); return chain; } From 6ae495d1707447c5f5db947a460f531f50a1def1 Mon Sep 17 00:00:00 2001 From: m0nsky Date: Sun, 23 Feb 2025 20:46:36 +0100 Subject: [PATCH 5/8] Add the extended grammar optimization logic --- LLama/Sampling/DefaultSamplingPipeline.cs | 77 ++++++++++++++++++----- 1 file changed, 61 insertions(+), 16 deletions(-) diff --git a/LLama/Sampling/DefaultSamplingPipeline.cs b/LLama/Sampling/DefaultSamplingPipeline.cs index bf47984e4..3db64a7b3 100644 --- a/LLama/Sampling/DefaultSamplingPipeline.cs +++ b/LLama/Sampling/DefaultSamplingPipeline.cs @@ -221,30 +221,75 @@ public override LLamaToken Sample(SafeLLamaContextHandle ctx, int index) // Rent some buffers to use later var rentedBufferVocabSize = ArrayPool.Shared.Rent(ctx.ModelHandle.Vocab.Count); var rentedBufferSingleItem = ArrayPool.Shared.Rent(1); + try { - using (LLamaTokenDataArrayNative.Create(LLamaTokenDataArray.Create(ctx.GetLogitsIth(index), rentedBufferVocabSize), out var nativeAll)) + // Handle grammar optimization modes + if (GrammarOptimization != GrammarOptimizationMode.None) { - // Apply the chain without the grammar to select one token which may or may not be valid - Apply(ctx, ref nativeAll); - var candidateToken = nativeAll.Data[checked((int)nativeAll.Selected)].ID; - - // Now create another token data array with just that one token - rentedBufferSingleItem[0] = new LLamaTokenData(candidateToken, 1, 0); - using (LLamaTokenDataArrayNative.Create(new LLamaTokenDataArray(rentedBufferSingleItem, true), out var nativeSingleCandidate)) + // Basic optimization : Apply the grammar to the selected token and check if it's valid + using (LLamaTokenDataArrayNative.Create(LLamaTokenDataArray.Create(ctx.GetLogitsIth(index), rentedBufferVocabSize), out var nativeAll)) { - // Apply the grammar to this single candidate. - _grammarChain.Apply(ref nativeSingleCandidate); - - // Test if that single token was rejected by the grammar - if (!float.IsNegativeInfinity(nativeSingleCandidate.Data[0].Logit)) + // Apply the chain without the grammar to select one token which may or may not be valid + Apply(ctx, ref nativeAll); + + // Select the candidate token + var candidateToken = nativeAll.Data[checked((int)nativeAll.Selected)].ID; + + // Now create another token data array with just that one token + rentedBufferSingleItem[0] = new LLamaTokenData(candidateToken, 1, 0); + using (LLamaTokenDataArrayNative.Create(new LLamaTokenDataArray(rentedBufferSingleItem, true), out var nativeSingleCandidate)) { - Accept(candidateToken); - return candidateToken; + // Apply the grammar chain to the single candidate + _grammarChain.Apply(ref nativeSingleCandidate); + + // Check if the token passes the grammar + if (!float.IsNegativeInfinity(nativeSingleCandidate.Data[0].Logit)) + { + Accept(candidateToken); + return candidateToken; + } + } + + // Extended optimization : Apply the grammar to the TopK tokens and check if the selected token is valid + if (GrammarOptimization == GrammarOptimizationMode.Extended) + { + // Calculate a safe TopK value + int safeTopK = Math.Min(TopK, nativeAll.Data.Length); + + // Rent a buffer for the TopK candidates + var rentedBufferTopK = ArrayPool.Shared.Rent(safeTopK); + try + { + // Copy only the TopK tokens from the existing candidate pool to the new buffer + nativeAll.Data.Slice(0, safeTopK).CopyTo(rentedBufferTopK.AsSpan(0, safeTopK)); + + // Create a native array with the TopK tokens + using (LLamaTokenDataArrayNative.Create(new LLamaTokenDataArray(rentedBufferTopK, true), out var nativeTopK)) + { + // Apply the grammar chain to the TopK candidates + _grammarChain.Apply(ref nativeTopK); + + // Select the candidate token + var candidateTokenTopK = nativeTopK.Data[checked((int)nativeTopK.Selected)]; + + // Check if the token passes the grammar + if (!float.IsNegativeInfinity(candidateTokenTopK.Logit)) + { + // Accept and return the token + Accept(candidateTokenTopK.ID); + return candidateTokenTopK.ID; + } + } + } + finally + { + ArrayPool.Shared.Return(rentedBufferTopK); + } } } } - + // If we get here the grammar rejected the token using (LLamaTokenDataArrayNative.Create(LLamaTokenDataArray.Create(ctx.GetLogitsIth(index), rentedBufferVocabSize), out var nativeAll)) { From 0b7f50892584b6aaaddf0ac856e7c46f92327021 Mon Sep 17 00:00:00 2001 From: m0nsky Date: Sun, 23 Feb 2025 21:22:26 +0100 Subject: [PATCH 6/8] Fix comment --- LLama/Sampling/DefaultSamplingPipeline.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/LLama/Sampling/DefaultSamplingPipeline.cs b/LLama/Sampling/DefaultSamplingPipeline.cs index 3db64a7b3..b8da759af 100644 --- a/LLama/Sampling/DefaultSamplingPipeline.cs +++ b/LLama/Sampling/DefaultSamplingPipeline.cs @@ -114,7 +114,7 @@ public float PresencePenalty public uint Seed { get; set; } = GetRandomSeed(); /// - /// Selected grammar optimization mode for processing + /// Selected grammar optimization mode /// public GrammarOptimizationMode GrammarOptimization { get; init; } = GrammarOptimizationMode.None; From ea9437d9b12d67ca62bc4dd7647fcc071d1dbba6 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Mon, 24 Feb 2025 16:04:46 +0000 Subject: [PATCH 7/8] Set `Extended` as default grammar mode --- LLama/Sampling/DefaultSamplingPipeline.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/LLama/Sampling/DefaultSamplingPipeline.cs b/LLama/Sampling/DefaultSamplingPipeline.cs index b8da759af..cca39c091 100644 --- a/LLama/Sampling/DefaultSamplingPipeline.cs +++ b/LLama/Sampling/DefaultSamplingPipeline.cs @@ -116,7 +116,7 @@ public float PresencePenalty /// /// Selected grammar optimization mode /// - public GrammarOptimizationMode GrammarOptimization { get; init; } = GrammarOptimizationMode.None; + public GrammarOptimizationMode GrammarOptimization { get; init; } = GrammarOptimizationMode.Extended; /// /// A chain with just the grammar From 7bffe1cbe0a3be9e80955919824833e84806c8bf Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Tue, 25 Feb 2025 20:36:29 +0000 Subject: [PATCH 8/8] Fixed usage of rented arrays, slicing them down to the right size. Before this there was always some junk data at the end. --- LLama/Sampling/DefaultSamplingPipeline.cs | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/LLama/Sampling/DefaultSamplingPipeline.cs b/LLama/Sampling/DefaultSamplingPipeline.cs index cca39c091..cd8f57f27 100644 --- a/LLama/Sampling/DefaultSamplingPipeline.cs +++ b/LLama/Sampling/DefaultSamplingPipeline.cs @@ -219,9 +219,11 @@ public override LLamaToken Sample(SafeLLamaContextHandle ctx, int index) _grammarChain ??= CreateGrammarChain(ctx); // Rent some buffers to use later - var rentedBufferVocabSize = ArrayPool.Shared.Rent(ctx.ModelHandle.Vocab.Count); - var rentedBufferSingleItem = ArrayPool.Shared.Rent(1); - + var rentedBufferVocabSizeArr = ArrayPool.Shared.Rent(ctx.ModelHandle.Vocab.Count); + var rentedBufferVocabSize = rentedBufferVocabSizeArr.AsMemory(0, ctx.ModelHandle.Vocab.Count); + var rentedBufferSingleItemArr = ArrayPool.Shared.Rent(1); + var rentedBufferSingleItem = rentedBufferSingleItemArr.AsMemory(0, 1); + try { // Handle grammar optimization modes @@ -237,7 +239,7 @@ public override LLamaToken Sample(SafeLLamaContextHandle ctx, int index) var candidateToken = nativeAll.Data[checked((int)nativeAll.Selected)].ID; // Now create another token data array with just that one token - rentedBufferSingleItem[0] = new LLamaTokenData(candidateToken, 1, 0); + rentedBufferSingleItem.Span[0] = new LLamaTokenData(candidateToken, 1, 0); using (LLamaTokenDataArrayNative.Create(new LLamaTokenDataArray(rentedBufferSingleItem, true), out var nativeSingleCandidate)) { // Apply the grammar chain to the single candidate @@ -255,14 +257,15 @@ public override LLamaToken Sample(SafeLLamaContextHandle ctx, int index) if (GrammarOptimization == GrammarOptimizationMode.Extended) { // Calculate a safe TopK value - int safeTopK = Math.Min(TopK, nativeAll.Data.Length); + var safeTopK = Math.Min(TopK, nativeAll.Data.Length); // Rent a buffer for the TopK candidates - var rentedBufferTopK = ArrayPool.Shared.Rent(safeTopK); + var rentedBufferTopKArr = ArrayPool.Shared.Rent(safeTopK); + var rentedBufferTopK = rentedBufferTopKArr.AsMemory(0, safeTopK); try { // Copy only the TopK tokens from the existing candidate pool to the new buffer - nativeAll.Data.Slice(0, safeTopK).CopyTo(rentedBufferTopK.AsSpan(0, safeTopK)); + nativeAll.Data.Slice(0, safeTopK).CopyTo(rentedBufferTopK.Span); // Create a native array with the TopK tokens using (LLamaTokenDataArrayNative.Create(new LLamaTokenDataArray(rentedBufferTopK, true), out var nativeTopK)) @@ -284,7 +287,7 @@ public override LLamaToken Sample(SafeLLamaContextHandle ctx, int index) } finally { - ArrayPool.Shared.Return(rentedBufferTopK); + ArrayPool.Shared.Return(rentedBufferTopKArr); } } } @@ -307,8 +310,8 @@ public override LLamaToken Sample(SafeLLamaContextHandle ctx, int index) } finally { - ArrayPool.Shared.Return(rentedBufferVocabSize); - ArrayPool.Shared.Return(rentedBufferSingleItem); + ArrayPool.Shared.Return(rentedBufferVocabSizeArr); + ArrayPool.Shared.Return(rentedBufferSingleItemArr); } }