Skip to content

Commit 5025ad9

Browse files
authored
Merge pull request #852 from dpmm99/feat/continuation
Allow continuation in Instruct and Interact executors; fix a minor leak
2 parents e66b375 + 317a7b0 commit 5025ad9

File tree

3 files changed

+40
-27
lines changed

3 files changed

+40
-27
lines changed

LLama/LLamaExecutorBase.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ protected virtual void TryReuseMatchingPrefix()
251251
/// </summary>
252252
/// <param name="text"></param>
253253
/// <param name="args"></param>
254-
protected abstract Task PreprocessInputs(string text, InferStateArgs args);
254+
protected abstract Task PreprocessInputs(string? text, InferStateArgs args);
255255

256256
/// <summary>
257257
/// Do some post processing after the inference.
@@ -296,11 +296,11 @@ protected virtual void TryReuseMatchingPrefix()
296296
/// <summary>
297297
/// Execute the inference.
298298
/// </summary>
299-
/// <param name="text"></param>
299+
/// <param name="text">The prompt. If null, generation will continue where it left off previously.</param>
300300
/// <param name="inferenceParams"></param>
301301
/// <param name="cancellationToken"></param>
302302
/// <returns></returns>
303-
public virtual async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
303+
public virtual async IAsyncEnumerable<string> InferAsync(string? text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
304304
{
305305
cancellationToken.ThrowIfCancellationRequested();
306306
inferenceParams ??= new InferenceParams();

LLama/LLamaInstructExecutor.cs

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -116,30 +116,38 @@ protected override Task<bool> GetLoopCondition(InferStateArgs args)
116116
}
117117

118118
/// <inheritdoc />
119-
protected override Task PreprocessInputs(string text, InferStateArgs args)
119+
protected override Task PreprocessInputs(string? text, InferStateArgs args)
120120
{
121121
args.Antiprompts ??= [ ];
122-
args.Antiprompts.Add(_instructionPrefix);
122+
if (!args.Antiprompts.Contains(_instructionPrefix))
123+
args.Antiprompts.Add(_instructionPrefix);
124+
123125
if (_is_prompt_run)
124126
{
125127
// When running the first input (prompt) in inteactive mode, we should specially process it.
128+
if (text == null) throw new ArgumentException("Prompt cannot be null to trigger continuation if a prompt has not been provided previously.");
126129
_embed_inps = Context.Tokenize(text, true, true).ToList();
127130
}
128131
else
129132
{
130-
if (!text.EndsWith("\n"))
131-
{
132-
text += "\n";
133-
}
134133
_consumedTokensCount = _embed_inps.Count;
135-
_embed_inps.AddRange(_inp_pfx);
136134

137-
var line_inp = Context.Tokenize(text, false, true);
138-
_embed_inps.AddRange(line_inp);
135+
// Don't append the template tokens if continuation is requested (by providing a null prompt)
136+
if (text != null)
137+
{
138+
if (!text.EndsWith("\n"))
139+
{
140+
text += "\n";
141+
}
142+
_embed_inps.AddRange(_inp_pfx);
143+
144+
var line_inp = Context.Tokenize(text, false, true);
145+
_embed_inps.AddRange(line_inp);
139146

140-
_embed_inps.AddRange(_inp_sfx);
147+
_embed_inps.AddRange(_inp_sfx);
141148

142-
args.RemainedTokens -= line_inp.Length;
149+
args.RemainedTokens -= line_inp.Length;
150+
}
143151
}
144152

145153
return Task.CompletedTask;

LLama/LLamaInteractExecutor.cs

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,12 @@ protected override Task<bool> GetLoopCondition(InferStateArgs args)
111111
}
112112

113113
/// <inheritdoc />
114-
protected override Task PreprocessInputs(string text, InferStateArgs args)
114+
protected override Task PreprocessInputs(string? text, InferStateArgs args)
115115
{
116116
if (_is_prompt_run)
117117
{
118118
// When running the first input (prompt) in interactive mode, we should specially process it.
119+
if (text == null) throw new ArgumentException("Prompt cannot be null to trigger continuation if a prompt has not been provided previously.");
119120
if (!this.IsMultiModal)
120121
{
121122
_embed_inps = Context.Tokenize(text, true, true).ToList();
@@ -127,20 +128,24 @@ protected override Task PreprocessInputs(string text, InferStateArgs args)
127128
}
128129
else
129130
{
130-
if (!text.EndsWith("\n"))
131+
// Don't add any tokens if continuation is requested (by providing a null prompt)
132+
if (text != null)
131133
{
132-
text += "\n";
133-
}
134+
if (!text.EndsWith("\n"))
135+
{
136+
text += "\n";
137+
}
134138

135-
if (!this.IsMultiModal)
136-
{
137-
var line_inp = Context.Tokenize(text, false, true);
138-
_embed_inps.AddRange(line_inp);
139-
args.RemainedTokens -= line_inp.Length;
140-
}
141-
else
142-
{
143-
PreprocessLlava(text, args, false);
139+
if (!this.IsMultiModal)
140+
{
141+
var line_inp = Context.Tokenize(text, false, true);
142+
_embed_inps.AddRange(line_inp);
143+
args.RemainedTokens -= line_inp.Length;
144+
}
145+
else
146+
{
147+
PreprocessLlava(text, args, false);
148+
}
144149
}
145150
}
146151

0 commit comments

Comments
 (0)