55using System . Linq ;
66using System . Runtime . CompilerServices ;
77using System . Threading ;
8- using System . Threading . Tasks ;
8+ using LLama . Exceptions ;
99using LLama . Native ;
1010using LLama . Sampling ;
1111using Microsoft . Extensions . Logging ;
@@ -22,6 +22,7 @@ public class StatelessExecutor
2222 private readonly LLamaWeights _weights ;
2323 private readonly IContextParams _params ;
2424 private readonly ILogger ? _logger ;
25+ private readonly LLamaBatch _batch ;
2526
2627 /// <summary>
2728 /// The context used by the executor when running the inference.
@@ -39,6 +40,7 @@ public StatelessExecutor(LLamaWeights weights, IContextParams @params, ILogger?
3940 _weights = weights ;
4041 _params = @params ;
4142 _logger = logger ;
43+ _batch = new LLamaBatch ( 1 ) ;
4244
4345 Context = _weights . CreateContext ( _params , logger ) ;
4446 Context . Dispose ( ) ;
@@ -71,16 +73,29 @@ public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams
7173 var repeat_last_n = Math . Max ( 0 , inferenceParams . RepeatLastTokensCount < 0 ? _weights . ContextSize : inferenceParams . RepeatLastTokensCount ) ;
7274 var lastTokens = new List < LLamaToken > ( repeat_last_n ) ;
7375 for ( var i = 0 ; i < repeat_last_n ; i ++ )
74- lastTokens . Add ( ( LLamaToken ) 0 ) ;
76+ lastTokens . Add ( 0 ) ;
7577
7678 // Tokenize the prompt
7779 var tokens = Context . Tokenize ( prompt ) . ToList ( ) ;
7880 lastTokens . AddRange ( tokens ) ;
79- var n_past = 1 + tokens . Count ;
8081
81- // Evaluate the prompt
82- await Task . Run ( ( ) => { Context . Eval ( tokens , 1 ) ; } , cancellationToken )
83- . ConfigureAwait ( false ) ;
82+ // Evaluate the prompt, in chunks smaller than the max batch size
83+ var n_past = 0 ;
84+ var batchSize = ( int ) Context . Params . BatchSize ;
85+ for ( var i = 0 ; i < tokens . Count ; i += batchSize )
86+ {
87+ var n_eval = tokens . Count - i ;
88+ if ( n_eval > batchSize )
89+ n_eval = batchSize ;
90+
91+ _batch . Clear ( ) ;
92+ for ( var j = 0 ; j < n_eval ; j ++ )
93+ _batch . Add ( tokens [ i + j ] , n_past ++ , LLamaSeqId . Zero , ( i + j ) == tokens . Count - 1 ) ;
94+
95+ var returnCode = await Context . DecodeAsync ( _batch , cancellationToken ) ;
96+ if ( returnCode != 0 )
97+ throw new LLamaDecodeError ( returnCode ) ;
98+ }
8499
85100 // Begin loop, evaluating one token at a time
86101 var mu = ( float ? ) null ;
@@ -90,12 +105,12 @@ public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams
90105 LLamaToken id ;
91106 if ( inferenceParams . SamplingPipeline is not null )
92107 {
93- id = inferenceParams . SamplingPipeline . Sample ( Context . NativeHandle , Context . NativeHandle . GetLogits ( ) , lastTokens ) ;
108+ id = inferenceParams . SamplingPipeline . Sample ( Context . NativeHandle , Context . NativeHandle . GetLogitsIth ( _batch . TokenCount - 1 ) , lastTokens ) ;
94109 }
95110 else
96111 {
97112 // Penalize the generated tokens by various penalties
98- var tokenDataArray = Context . ApplyPenalty ( lastTokens , inferenceParams . LogitBias , repeat_last_n ,
113+ var tokenDataArray = Context . ApplyPenalty ( _batch . TokenCount - 1 , lastTokens , inferenceParams . LogitBias , repeat_last_n ,
99114 inferenceParams . RepeatPenalty , inferenceParams . FrequencyPenalty , inferenceParams . PresencePenalty , inferenceParams . PenalizeNL ) ;
100115
101116 // Sample a single token
@@ -136,9 +151,12 @@ public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams
136151 n_past -= n_discard ;
137152 }
138153
139- // ReSharper disable once AccessToModifiedClosure (Justification: n_past is modified inside and outside the capture, but not concurrently)
140- n_past = await Task . Run ( ( ) => Context . Eval ( tokens , n_past ) , cancellationToken )
141- . ConfigureAwait ( false ) ;
154+ // Evaluate with this new token
155+ _batch . Clear ( ) ;
156+ _batch . Add ( id , n_past ++ , LLamaSeqId . Zero , true ) ;
157+ var returnCode = await context . DecodeAsync ( _batch , cancellationToken ) ;
158+ if ( returnCode != 0 )
159+ throw new LLamaDecodeError ( returnCode ) ;
142160 }
143161 }
144162 }
0 commit comments