Skip to content

Commit e2cfed6

Browse files
authored
Merge pull request #1109 from martindevans/grammar_resampling
Grammar Resampling
2 parents 2d1f639 + 7bffe1c commit e2cfed6

File tree

4 files changed

+223
-26
lines changed

4 files changed

+223
-26
lines changed

LLama/Native/LLamaTokenDataArray.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,9 @@ public int Compare(LLamaTokenData x, LLamaTokenData y)
134134
}
135135

136136
/// <summary>
137-
/// Contains a pointer to an array of LLamaTokenData which is pinned in memory.
138-
/// </summary>
139-
/// <remarks>C# equivalent of llama_token_data_array</remarks>
137+
/// Contains a pointer to an array of LLamaTokenData which is pinned in memory.
138+
/// </summary>
139+
/// <remarks>C# equivalent of llama_token_data_array</remarks>
140140
[StructLayout(LayoutKind.Sequential)]
141141
public struct LLamaTokenDataArrayNative
142142
{

LLama/Sampling/BaseSamplingPipeline.cs

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ public BaseSamplingPipeline()
2323
protected abstract SafeLLamaSamplerChainHandle CreateChain(SafeLLamaContextHandle context);
2424

2525
/// <inheritdoc />
26-
public void Dispose()
26+
public virtual void Dispose()
2727
{
2828
_chain?.Dispose();
2929
_chain = null;
@@ -32,21 +32,40 @@ public void Dispose()
3232
}
3333

3434
/// <inheritdoc />
35-
public LLamaToken Sample(SafeLLamaContextHandle ctx, int index)
35+
public virtual LLamaToken Sample(SafeLLamaContextHandle ctx, int index)
3636
{
3737
_chain ??= CreateChain(ctx);
3838

3939
return _chain.Sample(ctx, index);
4040
}
4141

4242
/// <inheritdoc />
43-
public void Reset()
43+
public virtual void Apply(SafeLLamaContextHandle ctx, LLamaTokenDataArray data)
44+
{
45+
_chain ??= CreateChain(ctx);
46+
using (LLamaTokenDataArrayNative.Create(data, out var native))
47+
_chain.Apply(ref native);
48+
}
49+
50+
/// <summary>
51+
/// Apply this sampling chain to a LLamaTokenDataArrayNative
52+
/// </summary>
53+
/// <param name="ctx"></param>
54+
/// <param name="data"></param>
55+
public virtual void Apply(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative data)
56+
{
57+
_chain ??= CreateChain(ctx);
58+
_chain.Apply(ref data);
59+
}
60+
61+
/// <inheritdoc />
62+
public virtual void Reset()
4463
{
4564
_chain?.Reset();
4665
}
4766

4867
/// <inheritdoc />
49-
public void Accept(LLamaToken token)
68+
public virtual void Accept(LLamaToken token)
5069
{
5170
_chain?.Accept(token);
5271
}

LLama/Sampling/DefaultSamplingPipeline.cs

Lines changed: 190 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,16 @@ public float PresencePenalty
112112
/// Seed to use for random sampling
113113
/// </summary>
114114
public uint Seed { get; set; } = GetRandomSeed();
115+
116+
/// <summary>
117+
/// Selected grammar optimization mode
118+
/// </summary>
119+
public GrammarOptimizationMode GrammarOptimization { get; init; } = GrammarOptimizationMode.Extended;
120+
121+
/// <summary>
122+
/// A chain with just the grammar
123+
/// </summary>
124+
private SafeLLamaSamplerChainHandle? _grammarChain;
115125

116126

117127
private static readonly Random RandomSeedGenerator = new();
@@ -121,37 +131,71 @@ private static uint GetRandomSeed()
121131
return (uint) RandomSeedGenerator.Next(0, int.MaxValue) + (uint) RandomSeedGenerator.Next(0, int.MaxValue);
122132
}
123133

134+
/// <inheritdoc />
135+
public override void Dispose()
136+
{
137+
base.Dispose();
138+
139+
_grammarChain?.Dispose();
140+
_grammarChain = null;
141+
}
142+
143+
/// <inheritdoc />
144+
public override void Reset()
145+
{
146+
base.Reset();
147+
148+
_grammarChain?.Reset();
149+
}
150+
151+
/// <inheritdoc />
152+
public override void Accept(LLamaToken token)
153+
{
154+
base.Accept(token);
155+
156+
_grammarChain?.Accept(token);
157+
}
158+
159+
private SafeLLamaSamplerChainHandle CreateGrammarChain(SafeLLamaContextHandle context)
160+
{
161+
if (Grammar == null)
162+
throw new InvalidOperationException(nameof(Grammar) + " is null");
163+
164+
var chain = SafeLLamaSamplerChainHandle.Create(LLamaSamplerChainParams.Default());
165+
chain.AddGrammar(context.ModelHandle, Grammar.Gbnf, Grammar.Root);
166+
return chain;
167+
}
124168

125169
/// <inheritdoc />
126170
protected override SafeLLamaSamplerChainHandle CreateChain(SafeLLamaContextHandle context)
127171
{
128172
var chain = SafeLLamaSamplerChainHandle.Create(LLamaSamplerChainParams.Default());
129173

130-
// Rent a temporary array and copy the biases into it
131-
var biases = ArrayPool<LLamaLogitBias>.Shared.Rent(LogitBias.Count);
132-
try
174+
if (LogitBias.Count > 0)
133175
{
134-
var index = 0;
135-
foreach (var bias in LogitBias)
176+
// Rent a temporary array and copy the biases into it
177+
var biases = ArrayPool<LLamaLogitBias>.Shared.Rent(LogitBias.Count);
178+
try
136179
{
137-
biases[index++] = new LLamaLogitBias
180+
var index = 0;
181+
foreach (var bias in LogitBias)
138182
{
139-
Token = bias.Key,
140-
Bias = bias.Value
141-
};
142-
}
183+
biases[index++] = new LLamaLogitBias
184+
{
185+
Token = bias.Key,
186+
Bias = bias.Value
187+
};
188+
}
143189

144-
// Add the biases to the sampler
145-
chain.AddLogitBias(context.Vocab.Count, biases.AsSpan(0, LogitBias.Count));
146-
}
147-
finally
148-
{
149-
ArrayPool<LLamaLogitBias>.Shared.Return(biases);
190+
// Add the biases to the sampler
191+
chain.AddLogitBias(context.Vocab.Count, biases.AsSpan(0, LogitBias.Count));
192+
}
193+
finally
194+
{
195+
ArrayPool<LLamaLogitBias>.Shared.Return(biases);
196+
}
150197
}
151198

152-
if (Grammar != null)
153-
chain.AddGrammar(context.ModelHandle, Grammar.Gbnf, Grammar.Root);
154-
155199
chain.AddPenalties(PenaltyCount, RepeatPenalty, FrequencyPenalty, PresencePenalty);
156200

157201
chain.AddTopK(TopK);
@@ -164,4 +208,131 @@ protected override SafeLLamaSamplerChainHandle CreateChain(SafeLLamaContextHandl
164208

165209
return chain;
166210
}
211+
212+
/// <inheritdoc />
213+
public override LLamaToken Sample(SafeLLamaContextHandle ctx, int index)
214+
{
215+
if (Grammar == null)
216+
return base.Sample(ctx, index);
217+
218+
// Create a chain with the grammar
219+
_grammarChain ??= CreateGrammarChain(ctx);
220+
221+
// Rent some buffers to use later
222+
var rentedBufferVocabSizeArr = ArrayPool<LLamaTokenData>.Shared.Rent(ctx.ModelHandle.Vocab.Count);
223+
var rentedBufferVocabSize = rentedBufferVocabSizeArr.AsMemory(0, ctx.ModelHandle.Vocab.Count);
224+
var rentedBufferSingleItemArr = ArrayPool<LLamaTokenData>.Shared.Rent(1);
225+
var rentedBufferSingleItem = rentedBufferSingleItemArr.AsMemory(0, 1);
226+
227+
try
228+
{
229+
// Handle grammar optimization modes
230+
if (GrammarOptimization != GrammarOptimizationMode.None)
231+
{
232+
// Basic optimization : Apply the grammar to the selected token and check if it's valid
233+
using (LLamaTokenDataArrayNative.Create(LLamaTokenDataArray.Create(ctx.GetLogitsIth(index), rentedBufferVocabSize), out var nativeAll))
234+
{
235+
// Apply the chain without the grammar to select one token which may or may not be valid
236+
Apply(ctx, ref nativeAll);
237+
238+
// Select the candidate token
239+
var candidateToken = nativeAll.Data[checked((int)nativeAll.Selected)].ID;
240+
241+
// Now create another token data array with just that one token
242+
rentedBufferSingleItem.Span[0] = new LLamaTokenData(candidateToken, 1, 0);
243+
using (LLamaTokenDataArrayNative.Create(new LLamaTokenDataArray(rentedBufferSingleItem, true), out var nativeSingleCandidate))
244+
{
245+
// Apply the grammar chain to the single candidate
246+
_grammarChain.Apply(ref nativeSingleCandidate);
247+
248+
// Check if the token passes the grammar
249+
if (!float.IsNegativeInfinity(nativeSingleCandidate.Data[0].Logit))
250+
{
251+
Accept(candidateToken);
252+
return candidateToken;
253+
}
254+
}
255+
256+
// Extended optimization : Apply the grammar to the TopK tokens and check if the selected token is valid
257+
if (GrammarOptimization == GrammarOptimizationMode.Extended)
258+
{
259+
// Calculate a safe TopK value
260+
var safeTopK = Math.Min(TopK, nativeAll.Data.Length);
261+
262+
// Rent a buffer for the TopK candidates
263+
var rentedBufferTopKArr = ArrayPool<LLamaTokenData>.Shared.Rent(safeTopK);
264+
var rentedBufferTopK = rentedBufferTopKArr.AsMemory(0, safeTopK);
265+
try
266+
{
267+
// Copy only the TopK tokens from the existing candidate pool to the new buffer
268+
nativeAll.Data.Slice(0, safeTopK).CopyTo(rentedBufferTopK.Span);
269+
270+
// Create a native array with the TopK tokens
271+
using (LLamaTokenDataArrayNative.Create(new LLamaTokenDataArray(rentedBufferTopK, true), out var nativeTopK))
272+
{
273+
// Apply the grammar chain to the TopK candidates
274+
_grammarChain.Apply(ref nativeTopK);
275+
276+
// Select the candidate token
277+
var candidateTokenTopK = nativeTopK.Data[checked((int)nativeTopK.Selected)];
278+
279+
// Check if the token passes the grammar
280+
if (!float.IsNegativeInfinity(candidateTokenTopK.Logit))
281+
{
282+
// Accept and return the token
283+
Accept(candidateTokenTopK.ID);
284+
return candidateTokenTopK.ID;
285+
}
286+
}
287+
}
288+
finally
289+
{
290+
ArrayPool<LLamaTokenData>.Shared.Return(rentedBufferTopKArr);
291+
}
292+
}
293+
}
294+
}
295+
296+
// If we get here the grammar rejected the token
297+
using (LLamaTokenDataArrayNative.Create(LLamaTokenDataArray.Create(ctx.GetLogitsIth(index), rentedBufferVocabSize), out var nativeAll))
298+
{
299+
// Apply the grammar _first_. This is slower (since it has to work on the entire vocab), but guaranteed to work
300+
_grammarChain.Apply(ref nativeAll);
301+
302+
// Now apply the rest of the pipeline
303+
Apply(ctx, ref nativeAll);
304+
305+
// Take the selected token
306+
var token = nativeAll.Data[checked((int)nativeAll.Selected)].ID;
307+
Accept(token);
308+
return token;
309+
}
310+
}
311+
finally
312+
{
313+
ArrayPool<LLamaTokenData>.Shared.Return(rentedBufferVocabSizeArr);
314+
ArrayPool<LLamaTokenData>.Shared.Return(rentedBufferSingleItemArr);
315+
}
316+
}
317+
318+
/// <summary>
319+
/// Grammar Optimization Mode
320+
/// </summary>
321+
public enum GrammarOptimizationMode
322+
{
323+
/// <summary>
324+
/// No grammar optimization, slow because it has to apply the grammar to the entire vocab.
325+
/// </summary>
326+
None,
327+
328+
/// <summary>
329+
/// Attempts to return early by only applying the grammar to the selected token and checking if it's valid.
330+
/// </summary>
331+
Basic,
332+
333+
/// <summary>
334+
/// Attempts to return early by applying the grammar to the top K tokens and checking if the selected token is valid.
335+
/// </summary>
336+
Extended
337+
}
167338
}

LLama/Sampling/ISamplingPipeline.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@ public interface ISamplingPipeline
1717
/// <returns></returns>
1818
LLamaToken Sample(SafeLLamaContextHandle ctx, int index);
1919

20+
/// <summary>
21+
/// Apply this pipeline to a set of token data
22+
/// </summary>
23+
/// <param name="ctx"></param>
24+
/// <param name="data"></param>
25+
public void Apply(SafeLLamaContextHandle ctx, LLamaTokenDataArray data);
26+
2027
/// <summary>
2128
/// Reset all internal state of the sampling pipeline
2229
/// </summary>

0 commit comments

Comments
 (0)