Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 80 additions & 21 deletions LLama/Batched/BatchedExecutor.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using LLama.Abstractions;
Expand All @@ -13,13 +15,15 @@ public sealed class BatchedExecutor
: IDisposable
{
private int _nextSequenceId;
private readonly List<LLamaBatch> _batchQueue = [ ];

private LLamaBatch _promptingBatch = new();
private LLamaBatch _nextBatch = new();
internal LLamaBatch Batch => _promptingBatch;
/// <summary>
/// Held while inference is running
/// </summary>
private readonly object _inferenceLock = new();

/// <summary>
/// Epoch is incremented every time Infer is called. Conversations can use this to keep track of
/// Epoch is incremented twice every time Infer is called. Conversations can use this to keep track of
/// whether they're waiting for inference, or can be sampled.
/// </summary>
internal ulong Epoch { get; private set; }
Expand All @@ -33,11 +37,11 @@ public sealed class BatchedExecutor
/// The <see cref="LLamaWeights"/> this executor is using
/// </summary>
public LLamaWeights Model { get; }

/// <summary>
/// Get the number of tokens in the batch, waiting for <see cref="Infer"/> to be called
/// </summary>
public int BatchedTokenCount => Batch.TokenCount;
public int BatchedTokenCount => _batchQueue.Sum(a => a.TokenCount);

/// <summary>
/// Check if this executor has been disposed.
Expand Down Expand Up @@ -112,26 +116,53 @@ public async Task<DecodeResult> Infer(CancellationToken cancellation = default)
if (IsDisposed)
throw new ObjectDisposedException(nameof(BatchedExecutor));

// Swap over batches. This means the next batch can be filled with
// tokens while inference is still running for the previous one.
var batch = _promptingBatch;
(_promptingBatch, _nextBatch) = (_nextBatch, _promptingBatch);

var status = await Context.DecodeAsync(batch, cancellation);
// If there's no work to do then we successfully completed all available work! immediately exit.
var next = GetNextBatch();
if (next == null)
return DecodeResult.Ok;

// If there was an error swap the previous batch back into place. This allows infer to be called again
// after the issue has been fixed (e.g. some KV cache space has been freed) to "retry" this operation.
if (status != DecodeResult.Ok)
// Take the inference lock, if this fails it's because inference is already running.
if (!Monitor.TryEnter(_inferenceLock))
throw new InvalidOperationException("Cannot start inference while it is already running");
try
{
(_promptingBatch, _nextBatch) = (_nextBatch, _promptingBatch);
// Advance epoch by one. This ensures that _nothing_ can be sampled while inference is running.
// Only do this if the epoch is odd. If it's even that means it was previously advanced by another
// inference run, and this run is a retry.
if ((Epoch & 1) == 1)
Epoch++;

// Run the actual inference. This is the slow bit!
var status = await Context.DecodeAsync(next, cancellation);

// If there was an error then early exit without incrementing the epoch. This allows infer to be called
// again after the issue has been fixed (e.g. some KV cache space has been freed) to retry this operation.
if (status != DecodeResult.Ok)
{
_batchQueue.Insert(0, next);
return status;
}

// Everything was ok, advance the epoch and clear the batch we just ran inference for.
Epoch++;
next.Clear();

return status;
}
finally
{
Monitor.Exit(_inferenceLock);
}

// Everything was ok, advance the epoch and clear the batch we just ran inference for.
Epoch++;
batch.Clear();

return status;
LLamaBatch? GetNextBatch()
{
if (_batchQueue.Count == 0)
return null;

var nextBatch = _batchQueue[0];
_batchQueue.RemoveAt(0);
return nextBatch;
}
}

/// <inheritdoc />
Expand All @@ -148,4 +179,32 @@ internal LLamaSeqId GetNextSequenceId()
{
return checked((LLamaSeqId)_nextSequenceId++);
}

/// <summary>
/// Get a reference to a batch that tokens can be added to.
/// </summary>
/// <param name="minCapacity"></param>
/// <returns></returns>
/// <exception cref="ArgumentOutOfRangeException"></exception>
internal (LLamaBatch batch, ulong epoch) GetTokenBatch(int minCapacity = 1)
{
if (minCapacity > Context.BatchSize)
throw new ArgumentOutOfRangeException(nameof(minCapacity), $"Request batch capacity must be less than or equal to BatchSize ({Context.BatchSize})");

// Find a batch with space for at least minCapacity tokens
for (var i = 0; i < _batchQueue.Count; i++)
{
var capacity = Context.BatchSize - _batchQueue[i].TokenCount;
if (capacity < minCapacity)
continue;

if (_batchQueue[i].TokenCount < Context.BatchSize)
return (_batchQueue[i], Epoch + (uint)(i + 1) * 2);
}

// Add a new batch to the end of the queue
var end = new LLamaBatch();
_batchQueue.Add(end);
return (end, Epoch + (uint)_batchQueue.Count * 2);
}
}
38 changes: 23 additions & 15 deletions LLama/Batched/Conversation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -232,22 +232,33 @@ public void Prompt(ReadOnlySpan<LLamaToken> tokens, bool allLogits = false)

_batchSampleCount = tokens.Length;

// We need to add all tokens to a single batch, so they can all be sampled at once.
// Request a batch with sufficient space.
(var batch, _requiredEpoch) = Executor.GetTokenBatch(tokens.Length);

// Add everything to that batch
for (var i = 0; i < tokens.Length; i++)
_batchSampleIndices[i] = Executor.Batch.Add(tokens[i], _end++, ConversationId, true);
_batchSampleIndices[i] = batch.Add(tokens[i], _end++, ConversationId, true);
}
else
{
_batchSampleCount = 1;

for (var i = 0; i < tokens.Length; i++)
_batchSampleIndices[0] = Executor.Batch.Add(tokens[i], _end++, ConversationId, i == tokens.Length - 1);

while (tokens.Length > 0)
{
// Get a batch with capacity for at least 1 token
(var batch, _requiredEpoch) = Executor.GetTokenBatch();

// Add as many tokens as possible
var count = Math.Min(tokens.Length, checked((int)Executor.Context.BatchSize) - batch.TokenCount);
for (var i = 0; i < count; i++)
_batchSampleIndices[0] = batch.Add(tokens[i], _end++, ConversationId, i == tokens.Length - 1);

// Slice the array to remove tokens we've already added to a batch
tokens = tokens.Slice(count);
}
}



// Mark this conversation as needing inference/sampling
_requiredEpoch = Executor.Epoch + 1;

// Unset the forked flag. Since this conversation has just been prompted it's no longer
// sharing anything with any other conversations.
_forked = false;
Expand All @@ -263,12 +274,9 @@ public void Prompt(ReadOnlySpan<LLamaToken> tokens, bool allLogits = false)
public void Prompt(LLamaToken token)
{
AssertCanBePrompted();

unsafe
{
Span<LLamaToken> span = stackalloc LLamaToken[1] { token };
Prompt(span);
}

Span<LLamaToken> span = [ token ];
Prompt(span);
}
#endregion

Expand Down
47 changes: 23 additions & 24 deletions LLama/Native/LLamaBatch.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using System;
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;

Expand Down Expand Up @@ -54,51 +54,50 @@ public class LLamaBatch
public LLamaBatch()
{
// These can both be grown later, start off with reasonable numbers.
const int n_tokens = 128;
const int n_seq_max = 1;
const int tokensCapacity = 128;
const int seqCapacity = 1;

SequenceCapacity = n_seq_max;
TokenCapacity = n_tokens;
SequenceCapacity = seqCapacity;
TokenCapacity = tokensCapacity;

_logits = new byte[n_tokens];
_tokens = new LLamaToken[n_tokens];
_positions = new LLamaPos[n_tokens];
_logits = new byte[tokensCapacity];
_tokens = new LLamaToken[tokensCapacity];
_positions = new LLamaPos[tokensCapacity];

_sequenceIdCount = new int[n_tokens];
_sequenceIdCount = new int[tokensCapacity];
_sequenceIdsPtrs = new IntPtr[_sequenceIdCount.Length];

_sequenceIds = new LLamaSeqId[n_tokens][];
_sequenceIds = new LLamaSeqId[tokensCapacity][];
for (var i = 0; i < _sequenceIds.Length; i++)
_sequenceIds[i] = new LLamaSeqId[SequenceCapacity];
}

#region grow
private void GrowTokenCapacity()
{
var n_tokens = TokenCount * 2;
TokenCapacity = n_tokens;
var tokenCapacity = TokenCount * 2;
TokenCapacity = tokenCapacity;

Array.Resize(ref _logits, n_tokens);
Array.Resize(ref _tokens, n_tokens);
Array.Resize(ref _positions, n_tokens);
Array.Resize(ref _logits, tokenCapacity);
Array.Resize(ref _tokens, tokenCapacity);
Array.Resize(ref _positions, tokenCapacity);

Array.Resize(ref _sequenceIdCount, n_tokens);
Array.Resize(ref _sequenceIdsPtrs, n_tokens);
Array.Resize(ref _sequenceIdCount, tokenCapacity);
Array.Resize(ref _sequenceIdsPtrs, tokenCapacity);

Array.Resize(ref _sequenceIds, n_tokens);
for (int i = 0; i < _sequenceIds.Length; i++)
Array.Resize(ref _sequenceIds, tokenCapacity);
for (var i = 0; i < _sequenceIds.Length; i++)
{
// Growing the array filled elements with null, temporarily violating the nullability contract!
// ReSharper disable once ConditionIsAlwaysTrueOrFalseAccordingToNullableAPIContract
if (_sequenceIds[i] == null)
_sequenceIds[i] = new LLamaSeqId[SequenceCapacity];
// ReSharper disable once NullCoalescingConditionIsAlwaysNotNullAccordingToAPIContract
_sequenceIds[i] ??= new LLamaSeqId[SequenceCapacity];
}
}

private void GrowMaxSequences(int atLeast)
{
var n_seq = Math.Max(SequenceCapacity * 2, atLeast);
SequenceCapacity = n_seq;
var seqCapacity = Math.Max(SequenceCapacity * 2, atLeast);
SequenceCapacity = seqCapacity;

for (var i = 0; i < _sequenceIds.Length; i++)
Array.Resize(ref _sequenceIds[i], SequenceCapacity);
Expand Down