@@ -23,6 +23,16 @@ public class LLamaBatch
2323    /// </summary> 
2424    private  readonly  Dictionary < ( LLamaToken ,  LLamaPos ) ,  int >  _index  =  new ( ) ; 
2525
26+     /// <summary> 
27+     /// Keep a list of where logits can be sampled from 
28+     /// </summary> 
29+     private  readonly  List < ( LLamaSeqId ,  int ) >  _logitPositions  =  new ( ) ; 
30+ 
31+     /// <summary> 
32+     /// Get the number of logit positions that will be generated from this batch 
33+     /// </summary> 
34+     internal  int  LogitPositionCount  =>  _logitPositions . Count ; 
35+ 
2636    /// <summary> 
2737    /// The number of tokens in this batch 
2838    /// </summary> 
@@ -175,8 +185,16 @@ public int Add(LLamaToken token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequence
175185        for  ( var  i  =  0 ;  i  <  sequences . Length ;  i ++ ) 
176186            _sequenceIds [ TokenCount ] [ i ]  =  sequences [ i ] ; 
177187        _logits [ TokenCount ]  =  Convert . ToByte ( logits ) ; 
188+         TokenCount ++ ; 
178189
179-         return  TokenCount ++ ; 
190+         // Store this position in the logits lookup if necessary 
191+         if  ( logits ) 
192+         { 
193+             foreach  ( var  sequence  in  sequences ) 
194+                 _logitPositions . Add ( ( sequence ,  TokenCount ) ) ; 
195+         } 
196+ 
197+         return  TokenCount ; 
180198    } 
181199
182200    /// <summary> 
@@ -257,6 +275,20 @@ public int AddRange(ReadOnlySpan<LLamaToken> tokens, LLamaPos start, LLamaSeqId
257275    public  void  Clear ( ) 
258276    { 
259277        TokenCount  =  0 ; 
278+ 
260279        _index . Clear ( ) ; 
280+         _logitPositions . Clear ( ) ; 
281+     } 
282+ 
283+     /// <summary> 
284+     /// Get the positions where logits can be sampled from 
285+     /// </summary> 
286+     /// <returns></returns> 
287+     internal  Span < ( LLamaSeqId ,  int ) >  GetLogitPositions ( Span < ( LLamaSeqId ,  int ) >  dest ) 
288+     { 
289+         for  ( var  i  =  0 ;  i  <  _logitPositions . Count ;  i ++ ) 
290+             dest [ i ]  =  _logitPositions [ i ] ; 
291+ 
292+         return  dest . Slice ( 0 ,  _logitPositions . Count ) ; 
261293    } 
262294} 
0 commit comments