Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions LLama/Native/LLamaTokenDataArray.cs
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,9 @@ public int Compare(LLamaTokenData x, LLamaTokenData y)
}

/// <summary>
/// Contains a pointer to an array of LLamaTokenData which is pinned in memory.
/// </summary>
/// <remarks>C# equivalent of llama_token_data_array</remarks>
/// Contains a pointer to an array of LLamaTokenData which is pinned in memory.
/// </summary>
/// <remarks>C# equivalent of llama_token_data_array</remarks>
[StructLayout(LayoutKind.Sequential)]
public struct LLamaTokenDataArrayNative
{
Expand Down
27 changes: 23 additions & 4 deletions LLama/Sampling/BaseSamplingPipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public BaseSamplingPipeline()
protected abstract SafeLLamaSamplerChainHandle CreateChain(SafeLLamaContextHandle context);

/// <inheritdoc />
public void Dispose()
public virtual void Dispose()
{
_chain?.Dispose();
_chain = null;
Expand All @@ -32,21 +32,40 @@ public void Dispose()
}

/// <inheritdoc />
public LLamaToken Sample(SafeLLamaContextHandle ctx, int index)
public virtual LLamaToken Sample(SafeLLamaContextHandle ctx, int index)
{
_chain ??= CreateChain(ctx);

return _chain.Sample(ctx, index);
}

/// <inheritdoc />
public void Reset()
public virtual void Apply(SafeLLamaContextHandle ctx, LLamaTokenDataArray data)
{
_chain ??= CreateChain(ctx);
using (LLamaTokenDataArrayNative.Create(data, out var native))
_chain.Apply(ref native);
}

/// <summary>
/// Apply this sampling chain to a LLamaTokenDataArrayNative
/// </summary>
/// <param name="ctx"></param>
/// <param name="data"></param>
public virtual void Apply(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative data)
{
_chain ??= CreateChain(ctx);
_chain.Apply(ref data);
}

/// <inheritdoc />
public virtual void Reset()
{
_chain?.Reset();
}

/// <inheritdoc />
public void Accept(LLamaToken token)
public virtual void Accept(LLamaToken token)
{
_chain?.Accept(token);
}
Expand Down
209 changes: 190 additions & 19 deletions LLama/Sampling/DefaultSamplingPipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,16 @@ public float PresencePenalty
/// Seed to use for random sampling
/// </summary>
public uint Seed { get; set; } = GetRandomSeed();

/// <summary>
/// Selected grammar optimization mode
/// </summary>
public GrammarOptimizationMode GrammarOptimization { get; init; } = GrammarOptimizationMode.Extended;

/// <summary>
/// A chain with just the grammar
/// </summary>
private SafeLLamaSamplerChainHandle? _grammarChain;


private static readonly Random RandomSeedGenerator = new();
Expand All @@ -121,37 +131,71 @@ private static uint GetRandomSeed()
return (uint) RandomSeedGenerator.Next(0, int.MaxValue) + (uint) RandomSeedGenerator.Next(0, int.MaxValue);
}

/// <inheritdoc />
public override void Dispose()
{
base.Dispose();

_grammarChain?.Dispose();
_grammarChain = null;
}

/// <inheritdoc />
public override void Reset()
{
base.Reset();

_grammarChain?.Reset();
}

/// <inheritdoc />
public override void Accept(LLamaToken token)
{
base.Accept(token);

_grammarChain?.Accept(token);
}

private SafeLLamaSamplerChainHandle CreateGrammarChain(SafeLLamaContextHandle context)
{
if (Grammar == null)
throw new InvalidOperationException(nameof(Grammar) + " is null");

var chain = SafeLLamaSamplerChainHandle.Create(LLamaSamplerChainParams.Default());
chain.AddGrammar(context.ModelHandle, Grammar.Gbnf, Grammar.Root);
return chain;
}

/// <inheritdoc />
protected override SafeLLamaSamplerChainHandle CreateChain(SafeLLamaContextHandle context)
{
var chain = SafeLLamaSamplerChainHandle.Create(LLamaSamplerChainParams.Default());

// Rent a temporary array and copy the biases into it
var biases = ArrayPool<LLamaLogitBias>.Shared.Rent(LogitBias.Count);
try
if (LogitBias.Count > 0)
{
var index = 0;
foreach (var bias in LogitBias)
// Rent a temporary array and copy the biases into it
var biases = ArrayPool<LLamaLogitBias>.Shared.Rent(LogitBias.Count);
try
{
biases[index++] = new LLamaLogitBias
var index = 0;
foreach (var bias in LogitBias)
{
Token = bias.Key,
Bias = bias.Value
};
}
biases[index++] = new LLamaLogitBias
{
Token = bias.Key,
Bias = bias.Value
};
}

// Add the biases to the sampler
chain.AddLogitBias(context.Vocab.Count, biases.AsSpan(0, LogitBias.Count));
}
finally
{
ArrayPool<LLamaLogitBias>.Shared.Return(biases);
// Add the biases to the sampler
chain.AddLogitBias(context.Vocab.Count, biases.AsSpan(0, LogitBias.Count));
}
finally
{
ArrayPool<LLamaLogitBias>.Shared.Return(biases);
}
}

if (Grammar != null)
chain.AddGrammar(context.ModelHandle, Grammar.Gbnf, Grammar.Root);

chain.AddPenalties(PenaltyCount, RepeatPenalty, FrequencyPenalty, PresencePenalty);

chain.AddTopK(TopK);
Expand All @@ -164,4 +208,131 @@ protected override SafeLLamaSamplerChainHandle CreateChain(SafeLLamaContextHandl

return chain;
}

/// <inheritdoc />
public override LLamaToken Sample(SafeLLamaContextHandle ctx, int index)
{
if (Grammar == null)
return base.Sample(ctx, index);

// Create a chain with the grammar
_grammarChain ??= CreateGrammarChain(ctx);

// Rent some buffers to use later
var rentedBufferVocabSizeArr = ArrayPool<LLamaTokenData>.Shared.Rent(ctx.ModelHandle.Vocab.Count);
var rentedBufferVocabSize = rentedBufferVocabSizeArr.AsMemory(0, ctx.ModelHandle.Vocab.Count);
var rentedBufferSingleItemArr = ArrayPool<LLamaTokenData>.Shared.Rent(1);
var rentedBufferSingleItem = rentedBufferSingleItemArr.AsMemory(0, 1);

try
{
// Handle grammar optimization modes
if (GrammarOptimization != GrammarOptimizationMode.None)
{
// Basic optimization : Apply the grammar to the selected token and check if it's valid
using (LLamaTokenDataArrayNative.Create(LLamaTokenDataArray.Create(ctx.GetLogitsIth(index), rentedBufferVocabSize), out var nativeAll))
{
// Apply the chain without the grammar to select one token which may or may not be valid
Apply(ctx, ref nativeAll);

// Select the candidate token
var candidateToken = nativeAll.Data[checked((int)nativeAll.Selected)].ID;

// Now create another token data array with just that one token
rentedBufferSingleItem.Span[0] = new LLamaTokenData(candidateToken, 1, 0);
using (LLamaTokenDataArrayNative.Create(new LLamaTokenDataArray(rentedBufferSingleItem, true), out var nativeSingleCandidate))
{
// Apply the grammar chain to the single candidate
_grammarChain.Apply(ref nativeSingleCandidate);

// Check if the token passes the grammar
if (!float.IsNegativeInfinity(nativeSingleCandidate.Data[0].Logit))
{
Accept(candidateToken);
return candidateToken;
}
}

// Extended optimization : Apply the grammar to the TopK tokens and check if the selected token is valid
if (GrammarOptimization == GrammarOptimizationMode.Extended)
{
// Calculate a safe TopK value
var safeTopK = Math.Min(TopK, nativeAll.Data.Length);

// Rent a buffer for the TopK candidates
var rentedBufferTopKArr = ArrayPool<LLamaTokenData>.Shared.Rent(safeTopK);
var rentedBufferTopK = rentedBufferTopKArr.AsMemory(0, safeTopK);
try
{
// Copy only the TopK tokens from the existing candidate pool to the new buffer
nativeAll.Data.Slice(0, safeTopK).CopyTo(rentedBufferTopK.Span);

// Create a native array with the TopK tokens
using (LLamaTokenDataArrayNative.Create(new LLamaTokenDataArray(rentedBufferTopK, true), out var nativeTopK))
{
// Apply the grammar chain to the TopK candidates
_grammarChain.Apply(ref nativeTopK);

// Select the candidate token
var candidateTokenTopK = nativeTopK.Data[checked((int)nativeTopK.Selected)];

// Check if the token passes the grammar
if (!float.IsNegativeInfinity(candidateTokenTopK.Logit))
{
// Accept and return the token
Accept(candidateTokenTopK.ID);
return candidateTokenTopK.ID;
}
}
}
finally
{
ArrayPool<LLamaTokenData>.Shared.Return(rentedBufferTopKArr);
}
}
}
}

// If we get here the grammar rejected the token
using (LLamaTokenDataArrayNative.Create(LLamaTokenDataArray.Create(ctx.GetLogitsIth(index), rentedBufferVocabSize), out var nativeAll))
{
// Apply the grammar _first_. This is slower (since it has to work on the entire vocab), but guaranteed to work
_grammarChain.Apply(ref nativeAll);

// Now apply the rest of the pipeline
Apply(ctx, ref nativeAll);

// Take the selected token
var token = nativeAll.Data[checked((int)nativeAll.Selected)].ID;
Accept(token);
return token;
}
}
finally
{
ArrayPool<LLamaTokenData>.Shared.Return(rentedBufferVocabSizeArr);
ArrayPool<LLamaTokenData>.Shared.Return(rentedBufferSingleItemArr);
}
}

/// <summary>
/// Grammar Optimization Mode
/// </summary>
public enum GrammarOptimizationMode
{
/// <summary>
/// No grammar optimization, slow because it has to apply the grammar to the entire vocab.
/// </summary>
None,

/// <summary>
/// Attempts to return early by only applying the grammar to the selected token and checking if it's valid.
/// </summary>
Basic,

/// <summary>
/// Attempts to return early by applying the grammar to the top K tokens and checking if the selected token is valid.
/// </summary>
Extended
}
}
7 changes: 7 additions & 0 deletions LLama/Sampling/ISamplingPipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ public interface ISamplingPipeline
/// <returns></returns>
LLamaToken Sample(SafeLLamaContextHandle ctx, int index);

/// <summary>
/// Apply this pipeline to a set of token data
/// </summary>
/// <param name="ctx"></param>
/// <param name="data"></param>
public void Apply(SafeLLamaContextHandle ctx, LLamaTokenDataArray data);

/// <summary>
/// Reset all internal state of the sampling pipeline
/// </summary>
Expand Down
Loading