@@ -46,14 +46,41 @@ public static LLamaTokenDataArray Create(ReadOnlySpan<float> logits)
4646 return new LLamaTokenDataArray ( candidates ) ;
4747 }
4848
49+ /// <summary>
50+ /// Overwrite the logit values for all given tokens
51+ /// </summary>
52+ /// <param name="values">tuples of token and logit value to overwrite</param>
53+ public void OverwriteLogits ( ReadOnlySpan < ( llama_token token , float logit ) > values )
54+ {
55+ if ( values . Length == 0 )
56+ return ;
57+
58+ var dataSpan = data . Span ;
59+ foreach ( var ( token , value ) in values )
60+ {
61+ for ( var i = 0 ; i < data . Length ; i ++ )
62+ {
63+ if ( dataSpan [ i ] . id == token )
64+ {
65+ dataSpan [ i ] . logit = value ;
66+ break ;
67+ }
68+ }
69+ }
70+ sorted = false ;
71+ }
72+
4973 #region sampling
5074 /// <summary>
5175 /// Apply grammar rules to candidate tokens
5276 /// </summary>
5377 /// <param name="ctx"></param>
5478 /// <param name="grammar"></param>
55- public void ApplyGrammar ( SafeLLamaContextHandle ctx , SafeLLamaGrammarHandle grammar )
79+ public void ApplyGrammar ( SafeLLamaContextHandle ctx , SafeLLamaGrammarHandle ? grammar )
5680 {
81+ if ( grammar == null )
82+ return ;
83+
5784 using ( LLamaTokenDataArrayNative . Create ( this , out var st ) )
5885 {
5986 NativeApi . llama_sample_grammar ( ctx , ref st , grammar ) ;
@@ -145,15 +172,17 @@ public void LocallyTypical(SafeLLamaContextHandle context, float p, ulong min_ke
145172 /// <param name="penalty_repeat"></param>
146173 /// <param name="penalty_freq"></param>
147174 /// <param name="penalty_present"></param>
148- public void RepetitionPenalty ( SafeLLamaContextHandle context , Memory < llama_token > last_tokens , float penalty_repeat , float penalty_freq , float penalty_present )
175+ public void RepetitionPenalty ( SafeLLamaContextHandle context , ReadOnlySpan < llama_token > last_tokens , float penalty_repeat , float penalty_freq , float penalty_present )
149176 {
150177 unsafe
151178 {
152179 using ( LLamaTokenDataArrayNative . Create ( this , out var st ) )
153- using ( var last_tokens_handle = last_tokens . Pin ( ) )
154180 {
155- NativeApi . llama_sample_repetition_penalties ( context , ref st , ( int * ) last_tokens_handle . Pointer , ( ulong ) last_tokens . Length , penalty_repeat , penalty_freq , penalty_present ) ;
156- sorted = st . sorted ;
181+ fixed ( int * last_tokens_handle = last_tokens )
182+ {
183+ NativeApi . llama_sample_repetition_penalties ( context , ref st , last_tokens_handle , ( ulong ) last_tokens . Length , penalty_repeat , penalty_freq , penalty_present ) ;
184+ sorted = st . sorted ;
185+ }
157186 }
158187 }
159188 }
0 commit comments