11using System ;
2+ using System . Buffers ;
23using System . Collections . Generic ;
4+ using System . Runtime . InteropServices ;
35using LLama . Native ;
46
57namespace LLama . Batched ;
@@ -14,6 +16,7 @@ public sealed class Conversation
1416 private LLamaPos _end ;
1517 private int _batchIndex ;
1618 private bool _disposed ;
19+ private bool _forked ;
1720
1821 /// <summary>
1922 /// The executor which this conversation belongs to
@@ -46,12 +49,10 @@ public sealed class Conversation
4649 public bool RequiresSampling => _requiredEpoch == Executor . Epoch ;
4750
4851 #region construction/destruction
49- internal Conversation ( BatchedExecutor batch , LLamaSeqId id , LLamaPos end )
52+ internal Conversation ( BatchedExecutor batch , LLamaSeqId id )
5053 {
5154 ConversationId = id ;
5255 Executor = batch ;
53-
54- _end = end ;
5556 }
5657
5758 /// <summary>
@@ -98,16 +99,24 @@ public Conversation Fork()
9899 {
99100 AssertNotDisposed ( ) ;
100101
101- if ( RequiresInference )
102- throw new CannotForkWhileRequiresInferenceException ( ) ;
103-
104102 // Create a new conversation which references the current position in this one
105- var c = new Conversation ( Executor , Executor . GetNextSequenceId ( ) , _end )
103+ var c = new Conversation ( Executor , Executor . GetNextSequenceId ( ) )
106104 {
107- _batchIndex = _batchIndex ,
105+ // Because these values are copied to the forked conversation it means that it will share the exact same output
106+ // logits next time sampling is done. This is a problem, because the sampling process is allowed to modify those
107+ // logits, so sampling one conversation may mess up the fork! Setting the "forked" flag on both sequences ensures
108+ // they both copy the logits before the next sampling run, to fix this issue.
108109 _requiredEpoch = _requiredEpoch ,
110+ _batchIndex = _batchIndex ,
111+ _forked = true ,
112+
113+ _end = _end ,
109114 } ;
110115
116+ // Setting this flag means that logits will be copied next time sampling is called, ensuring that the forked
117+ // conversation doesn't share logits with this one.
118+ _forked = true ;
119+
111120 // Assign tokens to the new sequence
112121 NativeApi . llama_kv_cache_seq_cp ( Executor . Context . NativeHandle , ConversationId , c . ConversationId , 0 , _end ) ;
113122
@@ -131,7 +140,14 @@ public Span<float> Sample()
131140 if ( _requiredEpoch > Executor . Epoch )
132141 throw new CannotSampleRequiresInferenceException ( ) ;
133142
134- return Executor . Context . NativeHandle . GetLogitsIth ( _batchIndex ) ;
143+ var span = Executor . Context . NativeHandle . GetLogitsIth ( _batchIndex ) ;
144+
145+ // If necessary copy the span, to protect it from modification. This is only done when
146+ // this conversation has been forked in this epoch.
147+ if ( _forked )
148+ span = span . ToArray ( ) ;
149+
150+ return span ;
135151 }
136152 #endregion
137153
@@ -162,20 +178,56 @@ public void Prompt(string input)
162178 /// <param name="tokens"></param>
163179 /// <returns></returns>
164180 /// <exception cref="ObjectDisposedException"></exception>
165- public void Prompt ( IReadOnlyList < LLamaToken > tokens )
181+ /// <exception cref="AlreadyPromptedConversationException"></exception>
182+ public void Prompt ( List < LLamaToken > tokens )
183+ {
184+ AssertCanBePrompted ( ) ;
185+
186+ #if NET6_0_OR_GREATER
187+ var span = CollectionsMarshal . AsSpan ( tokens ) ;
188+ Prompt ( span ) ;
189+ #else
190+ // Borrow an array and copy tokens into it
191+ var arr = ArrayPool < LLamaToken > . Shared . Rent ( tokens . Count ) ;
192+ try
193+ {
194+ for ( var i = 0 ; i < tokens . Count ; i ++ )
195+ arr [ i ] = tokens [ i ] ;
196+
197+ Prompt ( arr . AsSpan ( ) ) ;
198+ }
199+ finally
200+ {
201+ ArrayPool < LLamaToken > . Shared . Return ( arr ) ;
202+ }
203+ #endif
204+ }
205+
206+ /// <summary>
207+ /// Add tokens to this conversation
208+ /// </summary>
209+ /// <param name="tokens"></param>
210+ /// <returns></returns>
211+ /// <exception cref="ObjectDisposedException"></exception>
212+ /// <exception cref="AlreadyPromptedConversationException"></exception>
213+ public void Prompt ( ReadOnlySpan < LLamaToken > tokens )
166214 {
167215 AssertCanBePrompted ( ) ;
168216
169217 // No point doing anything if there is no actual prompt!
170- if ( tokens . Count == 0 )
218+ if ( tokens . Length == 0 )
171219 return ;
172220
173221 // Add the prompt to the batch
174- for ( var i = 0 ; i < tokens . Count ; i ++ )
175- _batchIndex = Executor . Batch . Add ( tokens [ i ] , _end ++ , ConversationId , i == tokens . Count - 1 ) ;
222+ for ( var i = 0 ; i < tokens . Length ; i ++ )
223+ _batchIndex = Executor . Batch . Add ( tokens [ i ] , _end ++ , ConversationId , i == tokens . Length - 1 ) ;
176224
177225 // Mark this conversation as needing inference/sampling
178226 _requiredEpoch = Executor . Epoch + 1 ;
227+
228+ // Unset the forked flag. Since this conversation has just been prompted it's no longer
229+ // sharing anything with any other conversations.
230+ _forked = false ;
179231 }
180232
181233 /// <summary>
@@ -184,16 +236,16 @@ public void Prompt(IReadOnlyList<LLamaToken> tokens)
184236 /// <param name="token"></param>
185237 /// <returns></returns>
186238 /// <exception cref="ObjectDisposedException"></exception>
187- /// <exception cref="InvalidOperationException "></exception>
239+ /// <exception cref="AlreadyPromptedConversationException "></exception>
188240 public void Prompt ( LLamaToken token )
189241 {
190242 AssertCanBePrompted ( ) ;
191243
192- // Add this token as input
193- _batchIndex = Executor . Batch . Add ( token , _end ++ , ConversationId , true ) ;
194-
195- // Mark this conversation as needing inference/sampling
196- _requiredEpoch = Executor . Epoch + 1 ;
244+ unsafe
245+ {
246+ Span < LLamaToken > span = stackalloc LLamaToken [ 1 ] { token } ;
247+ Prompt ( span ) ;
248+ }
197249 }
198250 #endregion
199251
0 commit comments