Skip to content

Commit 268f3a6

Browse files
authored
BatchedExecutor Fixed Forking (#621)
* Previously when a conversation was forked this would result in both the parent and the child sharing exactly the same logits. Since sampling is allowed to modify logits this could lead to issues in sampling (e.g. one conversation is sampled and overwrites logits to be all zero, second conversation is sampled and generates nonsense). Fixed this by setting a "forked" flag, logits are copied if this flag is set. Flag is cleared next time the conversation is prompted so this extra copying only happens once after a fork occurs. * Removed finalizer from `BatchedExecutor`. This class does not directly own any unmanaged resources so it is not necessary.
1 parent ad682fb commit 268f3a6

File tree

2 files changed

+72
-30
lines changed

2 files changed

+72
-30
lines changed

LLama/Batched/BatchedExecutor.cs

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,6 @@ public BatchedExecutor(LLamaWeights model, IContextParams contextParams)
5555
Epoch = 1;
5656
}
5757

58-
/// <summary>
59-
/// Finalizer for BatchedExecutor
60-
/// </summary>
61-
~BatchedExecutor()
62-
{
63-
Dispose();
64-
}
65-
6658
/// <summary>
6759
/// Start a new <see cref="Conversation"/> with the given prompt
6860
/// </summary>
@@ -89,7 +81,7 @@ public Conversation Create()
8981
if (IsDisposed)
9082
throw new ObjectDisposedException(nameof(BatchedExecutor));
9183

92-
return new Conversation(this, GetNextSequenceId(), 0);
84+
return new Conversation(this, GetNextSequenceId());
9385
}
9486

9587
/// <summary>
@@ -123,8 +115,6 @@ public void Dispose()
123115
return;
124116
IsDisposed = true;
125117

126-
GC.SuppressFinalize(this);
127-
128118
Context.Dispose();
129119
}
130120

LLama/Batched/Conversation.cs

Lines changed: 71 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
using System;
2+
using System.Buffers;
23
using System.Collections.Generic;
4+
using System.Runtime.InteropServices;
35
using LLama.Native;
46

57
namespace LLama.Batched;
@@ -14,6 +16,7 @@ public sealed class Conversation
1416
private LLamaPos _end;
1517
private int _batchIndex;
1618
private bool _disposed;
19+
private bool _forked;
1720

1821
/// <summary>
1922
/// The executor which this conversation belongs to
@@ -46,12 +49,10 @@ public sealed class Conversation
4649
public bool RequiresSampling => _requiredEpoch == Executor.Epoch;
4750

4851
#region construction/destruction
49-
internal Conversation(BatchedExecutor batch, LLamaSeqId id, LLamaPos end)
52+
internal Conversation(BatchedExecutor batch, LLamaSeqId id)
5053
{
5154
ConversationId = id;
5255
Executor = batch;
53-
54-
_end = end;
5556
}
5657

5758
/// <summary>
@@ -98,16 +99,24 @@ public Conversation Fork()
9899
{
99100
AssertNotDisposed();
100101

101-
if (RequiresInference)
102-
throw new CannotForkWhileRequiresInferenceException();
103-
104102
// Create a new conversation which references the current position in this one
105-
var c = new Conversation(Executor, Executor.GetNextSequenceId(), _end)
103+
var c = new Conversation(Executor, Executor.GetNextSequenceId())
106104
{
107-
_batchIndex = _batchIndex,
105+
// Because these values are copied to the forked conversation it means that it will share the exact same output
106+
// logits next time sampling is done. This is a problem, because the sampling process is allowed to modify those
107+
// logits, so sampling one conversation may mess up the fork! Setting the "forked" flag on both sequences ensures
108+
// they both copy the logits before the next sampling run, to fix this issue.
108109
_requiredEpoch = _requiredEpoch,
110+
_batchIndex = _batchIndex,
111+
_forked = true,
112+
113+
_end = _end,
109114
};
110115

116+
// Setting this flag means that logits will be copied next time sampling is called, ensuring that the forked
117+
// conversation doesn't share logits with this one.
118+
_forked = true;
119+
111120
// Assign tokens to the new sequence
112121
NativeApi.llama_kv_cache_seq_cp(Executor.Context.NativeHandle, ConversationId, c.ConversationId, 0, _end);
113122

@@ -131,7 +140,14 @@ public Span<float> Sample()
131140
if (_requiredEpoch > Executor.Epoch)
132141
throw new CannotSampleRequiresInferenceException();
133142

134-
return Executor.Context.NativeHandle.GetLogitsIth(_batchIndex);
143+
var span = Executor.Context.NativeHandle.GetLogitsIth(_batchIndex);
144+
145+
// If necessary copy the span, to protect it from modification. This is only done when
146+
// this conversation has been forked in this epoch.
147+
if (_forked)
148+
span = span.ToArray();
149+
150+
return span;
135151
}
136152
#endregion
137153

@@ -162,20 +178,56 @@ public void Prompt(string input)
162178
/// <param name="tokens"></param>
163179
/// <returns></returns>
164180
/// <exception cref="ObjectDisposedException"></exception>
165-
public void Prompt(IReadOnlyList<LLamaToken> tokens)
181+
/// <exception cref="AlreadyPromptedConversationException"></exception>
182+
public void Prompt(List<LLamaToken> tokens)
183+
{
184+
AssertCanBePrompted();
185+
186+
#if NET6_0_OR_GREATER
187+
var span = CollectionsMarshal.AsSpan(tokens);
188+
Prompt(span);
189+
#else
190+
// Borrow an array and copy tokens into it
191+
var arr = ArrayPool<LLamaToken>.Shared.Rent(tokens.Count);
192+
try
193+
{
194+
for (var i = 0; i < tokens.Count; i++)
195+
arr[i] = tokens[i];
196+
197+
Prompt(arr.AsSpan());
198+
}
199+
finally
200+
{
201+
ArrayPool<LLamaToken>.Shared.Return(arr);
202+
}
203+
#endif
204+
}
205+
206+
/// <summary>
207+
/// Add tokens to this conversation
208+
/// </summary>
209+
/// <param name="tokens"></param>
210+
/// <returns></returns>
211+
/// <exception cref="ObjectDisposedException"></exception>
212+
/// <exception cref="AlreadyPromptedConversationException"></exception>
213+
public void Prompt(ReadOnlySpan<LLamaToken> tokens)
166214
{
167215
AssertCanBePrompted();
168216

169217
// No point doing anything if there is no actual prompt!
170-
if (tokens.Count == 0)
218+
if (tokens.Length == 0)
171219
return;
172220

173221
// Add the prompt to the batch
174-
for (var i = 0; i < tokens.Count; i++)
175-
_batchIndex = Executor.Batch.Add(tokens[i], _end++, ConversationId, i == tokens.Count - 1);
222+
for (var i = 0; i < tokens.Length; i++)
223+
_batchIndex = Executor.Batch.Add(tokens[i], _end++, ConversationId, i == tokens.Length - 1);
176224

177225
// Mark this conversation as needing inference/sampling
178226
_requiredEpoch = Executor.Epoch + 1;
227+
228+
// Unset the forked flag. Since this conversation has just been prompted it's no longer
229+
// sharing anything with any other conversations.
230+
_forked = false;
179231
}
180232

181233
/// <summary>
@@ -184,16 +236,16 @@ public void Prompt(IReadOnlyList<LLamaToken> tokens)
184236
/// <param name="token"></param>
185237
/// <returns></returns>
186238
/// <exception cref="ObjectDisposedException"></exception>
187-
/// <exception cref="InvalidOperationException"></exception>
239+
/// <exception cref="AlreadyPromptedConversationException"></exception>
188240
public void Prompt(LLamaToken token)
189241
{
190242
AssertCanBePrompted();
191243

192-
// Add this token as input
193-
_batchIndex = Executor.Batch.Add(token, _end++, ConversationId, true);
194-
195-
// Mark this conversation as needing inference/sampling
196-
_requiredEpoch = Executor.Epoch + 1;
244+
unsafe
245+
{
246+
Span<LLamaToken> span = stackalloc LLamaToken[1] { token };
247+
Prompt(span);
248+
}
197249
}
198250
#endregion
199251

0 commit comments

Comments
 (0)