Skip to content

Commit 91d72e7

Browse files
authored
Keeping track of positions where logits will be generated in a batch and what sequence those logits are associated with. (#624)
1 parent 268f3a6 commit 91d72e7

File tree

1 file changed

+33
-1
lines changed

1 file changed

+33
-1
lines changed

LLama/Native/LLamaBatch.cs

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,16 @@ public class LLamaBatch
2323
/// </summary>
2424
private readonly Dictionary<(LLamaToken, LLamaPos), int> _index = new();
2525

26+
/// <summary>
27+
/// Keep a list of where logits can be sampled from
28+
/// </summary>
29+
private readonly List<(LLamaSeqId, int)> _logitPositions = new();
30+
31+
/// <summary>
32+
/// Get the number of logit positions that will be generated from this batch
33+
/// </summary>
34+
internal int LogitPositionCount => _logitPositions.Count;
35+
2636
/// <summary>
2737
/// The number of tokens in this batch
2838
/// </summary>
@@ -175,8 +185,16 @@ public int Add(LLamaToken token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequence
175185
for (var i = 0; i < sequences.Length; i++)
176186
_sequenceIds[TokenCount][i] = sequences[i];
177187
_logits[TokenCount] = Convert.ToByte(logits);
188+
TokenCount++;
178189

179-
return TokenCount++;
190+
// Store this position in the logits lookup if necessary
191+
if (logits)
192+
{
193+
foreach (var sequence in sequences)
194+
_logitPositions.Add((sequence, TokenCount));
195+
}
196+
197+
return TokenCount;
180198
}
181199

182200
/// <summary>
@@ -257,6 +275,20 @@ public int AddRange(ReadOnlySpan<LLamaToken> tokens, LLamaPos start, LLamaSeqId
257275
public void Clear()
258276
{
259277
TokenCount = 0;
278+
260279
_index.Clear();
280+
_logitPositions.Clear();
281+
}
282+
283+
/// <summary>
284+
/// Get the positions where logits can be sampled from
285+
/// </summary>
286+
/// <returns></returns>
287+
internal Span<(LLamaSeqId, int)> GetLogitPositions(Span<(LLamaSeqId, int)> dest)
288+
{
289+
for (var i = 0; i < _logitPositions.Count; i++)
290+
dest[i] = _logitPositions[i];
291+
292+
return dest.Slice(0, _logitPositions.Count);
261293
}
262294
}

0 commit comments

Comments
 (0)