@@ -112,6 +112,16 @@ public float PresencePenalty
112112 /// Seed to use for random sampling
113113 /// </summary>
114114 public uint Seed { get ; set ; } = GetRandomSeed ( ) ;
115+
116+ /// <summary>
117+ /// Selected grammar optimization mode
118+ /// </summary>
119+ public GrammarOptimizationMode GrammarOptimization { get ; init ; } = GrammarOptimizationMode . Extended ;
120+
121+ /// <summary>
122+ /// A chain with just the grammar
123+ /// </summary>
124+ private SafeLLamaSamplerChainHandle ? _grammarChain ;
115125
116126
117127 private static readonly Random RandomSeedGenerator = new ( ) ;
@@ -121,37 +131,71 @@ private static uint GetRandomSeed()
121131 return ( uint ) RandomSeedGenerator . Next ( 0 , int . MaxValue ) + ( uint ) RandomSeedGenerator . Next ( 0 , int . MaxValue ) ;
122132 }
123133
134+ /// <inheritdoc />
135+ public override void Dispose ( )
136+ {
137+ base . Dispose ( ) ;
138+
139+ _grammarChain ? . Dispose ( ) ;
140+ _grammarChain = null ;
141+ }
142+
143+ /// <inheritdoc />
144+ public override void Reset ( )
145+ {
146+ base . Reset ( ) ;
147+
148+ _grammarChain ? . Reset ( ) ;
149+ }
150+
151+ /// <inheritdoc />
152+ public override void Accept ( LLamaToken token )
153+ {
154+ base . Accept ( token ) ;
155+
156+ _grammarChain ? . Accept ( token ) ;
157+ }
158+
159+ private SafeLLamaSamplerChainHandle CreateGrammarChain ( SafeLLamaContextHandle context )
160+ {
161+ if ( Grammar == null )
162+ throw new InvalidOperationException ( nameof ( Grammar ) + " is null" ) ;
163+
164+ var chain = SafeLLamaSamplerChainHandle . Create ( LLamaSamplerChainParams . Default ( ) ) ;
165+ chain . AddGrammar ( context . ModelHandle , Grammar . Gbnf , Grammar . Root ) ;
166+ return chain ;
167+ }
124168
125169 /// <inheritdoc />
126170 protected override SafeLLamaSamplerChainHandle CreateChain ( SafeLLamaContextHandle context )
127171 {
128172 var chain = SafeLLamaSamplerChainHandle . Create ( LLamaSamplerChainParams . Default ( ) ) ;
129173
130- // Rent a temporary array and copy the biases into it
131- var biases = ArrayPool < LLamaLogitBias > . Shared . Rent ( LogitBias . Count ) ;
132- try
174+ if ( LogitBias . Count > 0 )
133175 {
134- var index = 0 ;
135- foreach ( var bias in LogitBias )
176+ // Rent a temporary array and copy the biases into it
177+ var biases = ArrayPool < LLamaLogitBias > . Shared . Rent ( LogitBias . Count ) ;
178+ try
136179 {
137- biases [ index ++ ] = new LLamaLogitBias
180+ var index = 0 ;
181+ foreach ( var bias in LogitBias )
138182 {
139- Token = bias . Key ,
140- Bias = bias . Value
141- } ;
142- }
183+ biases [ index ++ ] = new LLamaLogitBias
184+ {
185+ Token = bias . Key ,
186+ Bias = bias . Value
187+ } ;
188+ }
143189
144- // Add the biases to the sampler
145- chain . AddLogitBias ( context . Vocab . Count , biases . AsSpan ( 0 , LogitBias . Count ) ) ;
146- }
147- finally
148- {
149- ArrayPool < LLamaLogitBias > . Shared . Return ( biases ) ;
190+ // Add the biases to the sampler
191+ chain . AddLogitBias ( context . Vocab . Count , biases . AsSpan ( 0 , LogitBias . Count ) ) ;
192+ }
193+ finally
194+ {
195+ ArrayPool < LLamaLogitBias > . Shared . Return ( biases ) ;
196+ }
150197 }
151198
152- if ( Grammar != null )
153- chain . AddGrammar ( context . ModelHandle , Grammar . Gbnf , Grammar . Root ) ;
154-
155199 chain . AddPenalties ( PenaltyCount , RepeatPenalty , FrequencyPenalty , PresencePenalty ) ;
156200
157201 chain . AddTopK ( TopK ) ;
@@ -164,4 +208,131 @@ protected override SafeLLamaSamplerChainHandle CreateChain(SafeLLamaContextHandl
164208
165209 return chain ;
166210 }
211+
212+ /// <inheritdoc />
213+ public override LLamaToken Sample ( SafeLLamaContextHandle ctx , int index )
214+ {
215+ if ( Grammar == null )
216+ return base . Sample ( ctx , index ) ;
217+
218+ // Create a chain with the grammar
219+ _grammarChain ??= CreateGrammarChain ( ctx ) ;
220+
221+ // Rent some buffers to use later
222+ var rentedBufferVocabSizeArr = ArrayPool < LLamaTokenData > . Shared . Rent ( ctx . ModelHandle . Vocab . Count ) ;
223+ var rentedBufferVocabSize = rentedBufferVocabSizeArr . AsMemory ( 0 , ctx . ModelHandle . Vocab . Count ) ;
224+ var rentedBufferSingleItemArr = ArrayPool < LLamaTokenData > . Shared . Rent ( 1 ) ;
225+ var rentedBufferSingleItem = rentedBufferSingleItemArr . AsMemory ( 0 , 1 ) ;
226+
227+ try
228+ {
229+ // Handle grammar optimization modes
230+ if ( GrammarOptimization != GrammarOptimizationMode . None )
231+ {
232+ // Basic optimization : Apply the grammar to the selected token and check if it's valid
233+ using ( LLamaTokenDataArrayNative . Create ( LLamaTokenDataArray . Create ( ctx . GetLogitsIth ( index ) , rentedBufferVocabSize ) , out var nativeAll ) )
234+ {
235+ // Apply the chain without the grammar to select one token which may or may not be valid
236+ Apply ( ctx , ref nativeAll ) ;
237+
238+ // Select the candidate token
239+ var candidateToken = nativeAll . Data [ checked ( ( int ) nativeAll . Selected ) ] . ID ;
240+
241+ // Now create another token data array with just that one token
242+ rentedBufferSingleItem . Span [ 0 ] = new LLamaTokenData ( candidateToken , 1 , 0 ) ;
243+ using ( LLamaTokenDataArrayNative . Create ( new LLamaTokenDataArray ( rentedBufferSingleItem , true ) , out var nativeSingleCandidate ) )
244+ {
245+ // Apply the grammar chain to the single candidate
246+ _grammarChain . Apply ( ref nativeSingleCandidate ) ;
247+
248+ // Check if the token passes the grammar
249+ if ( ! float . IsNegativeInfinity ( nativeSingleCandidate . Data [ 0 ] . Logit ) )
250+ {
251+ Accept ( candidateToken ) ;
252+ return candidateToken ;
253+ }
254+ }
255+
256+ // Extended optimization : Apply the grammar to the TopK tokens and check if the selected token is valid
257+ if ( GrammarOptimization == GrammarOptimizationMode . Extended )
258+ {
259+ // Calculate a safe TopK value
260+ var safeTopK = Math . Min ( TopK , nativeAll . Data . Length ) ;
261+
262+ // Rent a buffer for the TopK candidates
263+ var rentedBufferTopKArr = ArrayPool < LLamaTokenData > . Shared . Rent ( safeTopK ) ;
264+ var rentedBufferTopK = rentedBufferTopKArr . AsMemory ( 0 , safeTopK ) ;
265+ try
266+ {
267+ // Copy only the TopK tokens from the existing candidate pool to the new buffer
268+ nativeAll . Data . Slice ( 0 , safeTopK ) . CopyTo ( rentedBufferTopK . Span ) ;
269+
270+ // Create a native array with the TopK tokens
271+ using ( LLamaTokenDataArrayNative . Create ( new LLamaTokenDataArray ( rentedBufferTopK , true ) , out var nativeTopK ) )
272+ {
273+ // Apply the grammar chain to the TopK candidates
274+ _grammarChain . Apply ( ref nativeTopK ) ;
275+
276+ // Select the candidate token
277+ var candidateTokenTopK = nativeTopK . Data [ checked ( ( int ) nativeTopK . Selected ) ] ;
278+
279+ // Check if the token passes the grammar
280+ if ( ! float . IsNegativeInfinity ( candidateTokenTopK . Logit ) )
281+ {
282+ // Accept and return the token
283+ Accept ( candidateTokenTopK . ID ) ;
284+ return candidateTokenTopK . ID ;
285+ }
286+ }
287+ }
288+ finally
289+ {
290+ ArrayPool < LLamaTokenData > . Shared . Return ( rentedBufferTopKArr ) ;
291+ }
292+ }
293+ }
294+ }
295+
296+ // If we get here the grammar rejected the token
297+ using ( LLamaTokenDataArrayNative . Create ( LLamaTokenDataArray . Create ( ctx . GetLogitsIth ( index ) , rentedBufferVocabSize ) , out var nativeAll ) )
298+ {
299+ // Apply the grammar _first_. This is slower (since it has to work on the entire vocab), but guaranteed to work
300+ _grammarChain . Apply ( ref nativeAll ) ;
301+
302+ // Now apply the rest of the pipeline
303+ Apply ( ctx , ref nativeAll ) ;
304+
305+ // Take the selected token
306+ var token = nativeAll . Data [ checked ( ( int ) nativeAll . Selected ) ] . ID ;
307+ Accept ( token ) ;
308+ return token ;
309+ }
310+ }
311+ finally
312+ {
313+ ArrayPool < LLamaTokenData > . Shared . Return ( rentedBufferVocabSizeArr ) ;
314+ ArrayPool < LLamaTokenData > . Shared . Return ( rentedBufferSingleItemArr ) ;
315+ }
316+ }
317+
318+ /// <summary>
319+ /// Grammar Optimization Mode
320+ /// </summary>
321+ public enum GrammarOptimizationMode
322+ {
323+ /// <summary>
324+ /// No grammar optimization, slow because it has to apply the grammar to the entire vocab.
325+ /// </summary>
326+ None ,
327+
328+ /// <summary>
329+ /// Attempts to return early by only applying the grammar to the selected token and checking if it's valid.
330+ /// </summary>
331+ Basic ,
332+
333+ /// <summary>
334+ /// Attempts to return early by applying the grammar to the top K tokens and checking if the selected token is valid.
335+ /// </summary>
336+ Extended
337+ }
167338}
0 commit comments