-
Notifications
You must be signed in to change notification settings - Fork 470
feat: support custom generation control of executors. #364
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
using LLama.Abstractions; | ||
using LLama.Common; | ||
using LLama.Control; | ||
using LLama.Examples.Extensions; | ||
|
||
namespace LLama.Examples.Examples | ||
{ | ||
public class CustomGenerationControl | ||
{ | ||
public class NumberGenerationControl: IGenerationControl | ||
{ | ||
public bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, string lastOutputText) | ||
{ | ||
return lastOutputText.Any(x => char.IsDigit(x) && (x == '4' || x == '5')); | ||
} | ||
|
||
public bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, int lastOutputId) | ||
{ | ||
return false; | ||
} | ||
} | ||
public static async Task Run() | ||
{ | ||
Console.Write("Please input your model path: "); | ||
var modelPath = Console.ReadLine(); | ||
|
||
var parameters = new ModelParams(modelPath) | ||
{ | ||
ContextSize = 1024, | ||
Seed = 1337, | ||
GpuLayerCount = 5 | ||
}; | ||
using var model = LLamaWeights.LoadFromFile(parameters); | ||
var ex = new StatelessExecutor(model, parameters); | ||
|
||
Console.ForegroundColor = ConsoleColor.Yellow; | ||
Console.WriteLine("This is an example to show how to customize the generation control of the executors. Here we implement a control mode in which" + | ||
" the generation will stop once there's a number 4 or 5 is generated. Please try different questions to lead the model to generate answers with and without numbers." + | ||
" No anti-prompt is used in this example."); | ||
Console.ForegroundColor = ConsoleColor.White; | ||
|
||
var inferenceParams = new InferenceParams() { Temperature = 0.6f, MaxTokens = 60, GenerationControl = new NumberGenerationControl() }; | ||
|
||
while (true) | ||
{ | ||
Console.Write("\nQuestion: "); | ||
Console.ForegroundColor = ConsoleColor.Green; | ||
var prompt = Console.ReadLine(); | ||
Console.ForegroundColor = ConsoleColor.White; | ||
Console.Write("Answer: "); | ||
prompt = $"Question: {prompt?.Trim()} Answer: "; | ||
await foreach (var text in ex.InferAsync(prompt, inferenceParams).Spinner()) | ||
{ | ||
Console.Write(text); | ||
} | ||
} | ||
} | ||
} | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
using System.Text; | ||
using LLama.Common; | ||
using LLama.Control; | ||
using LLama.Extensions; | ||
|
||
namespace LLama.Unittest; | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
2 changes: 1 addition & 1 deletion
2
LLama/AntipromptProcessor.cs → LLama/Control/AntipromptProcessor.cs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
using LLama.Abstractions; | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Text; | ||
|
||
namespace LLama.Control | ||
{ | ||
/// <summary> | ||
/// The default generation control in LLamaSharp, using antiprompts. This class should not be inherited. | ||
/// <b>Note that this class has state. The previous outputs feeded to it will affect its control.</b> | ||
/// If you use it in a session, please don't reuse it for another session unless you intend to do so. | ||
/// </summary> | ||
public sealed class DefaultGenerationControl : IGenerationControl | ||
{ | ||
private AntipromptProcessor _antipromptProcessor; | ||
|
||
/// <summary> | ||
/// <inheritdoc/> | ||
/// </summary> | ||
public DefaultGenerationControl() | ||
{ | ||
_antipromptProcessor = new AntipromptProcessor(); | ||
} | ||
|
||
/// <summary> | ||
/// <inheritdoc/> | ||
/// </summary> | ||
public bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, string lastOutputText) | ||
{ | ||
_antipromptProcessor.SetAntiprompts(inferenceParams.AntiPrompts); | ||
return _antipromptProcessor.Add(lastOutputText); | ||
} | ||
|
||
/// <summary> | ||
/// <inheritdoc/> | ||
/// </summary> | ||
public bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, int lastOutputId) | ||
{ | ||
return context.IsEOS(lastOutputId); | ||
} | ||
} | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
using LLama.Abstractions; | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Text; | ||
|
||
namespace LLama.Control | ||
{ | ||
/// <summary> | ||
/// Control the text generation of LLama Executors. | ||
/// </summary> | ||
public interface IGenerationControl | ||
{ | ||
/// <summary> | ||
/// Use the last output text to determine if the generation should stop. | ||
/// This method will be called after the overload with output id. | ||
/// The text will be returned even if this returns true but the generation will be stopped. | ||
/// </summary> | ||
/// <param name="context">The LLamaContext used in the current generation.</param> | ||
/// <param name="inferenceParams">The inference params used in the current generation.</param> | ||
/// <param name="lastOutputText">The last output text generated by the model.</param> | ||
/// <returns></returns> | ||
bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, string lastOutputText); | ||
|
||
/// <summary> | ||
/// Use the last output token to determine if the generation should stop. | ||
/// This method will be called before the overload with output text. | ||
/// The token will be returned even if this returns true but the generation will be stopped. | ||
/// </summary> | ||
SanftMonster marked this conversation as resolved.
Show resolved
Hide resolved
|
||
/// <param name="context">The LLamaContext used in the current generation.</param> | ||
/// <param name="inferenceParams">The inference params used in the current generation.</param> | ||
/// <param name="lastOutputId">The last output token generated by the model.</param> | ||
/// <returns></returns> | ||
bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, int lastOutputId); | ||
} | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
using LLama.Abstractions; | ||
using LLama.Control; | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Linq; | ||
using System.Text; | ||
|
||
namespace LLama.Extensions | ||
{ | ||
/// <summary> | ||
/// Extension methods for generation control | ||
/// </summary> | ||
public static class GenerationControlExtensions | ||
{ | ||
public static bool ShouldStopGeneration(this IGenerationControl control, LLamaContext context, IInferenceParams inferenceParams, IEnumerable<int> lastOutputIds) | ||
{ | ||
foreach (var id in lastOutputIds) | ||
{ | ||
if(control.ShouldStopGeneration(context, inferenceParams, id)) | ||
{ | ||
return true; | ||
} | ||
} | ||
return false; | ||
} | ||
} | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a thought - is it possible to decouple this interface from the
IInferenceParams
somehow? It feels a bit odd thatGenerationControl
is overriding parts ofIInferenceParams
and is also configured by it.The obvious way to do that would be to remove the
inferenceParams
parameter here, that would introduce problems inDefaultGenerationControl
though.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I also thought of that and I could come up with the following alternatives:
ILLamaExecutor
toInferAsync(string text, IInferenceParams? inferenceParams = null, IGenerationControl? control = null, CancellationToken token = default);
or similar formats. It will introduce some break changes though.ILLamaExecutor
toILLamaExecutor<TControl>
, or change the methodInferAsync
.var executor = new StatelessExecutor().WithControl()
.Each of them has advantages and disadvantages. The most important one I think is that whether we should make it class level or method level. Actually I can't think up a case that the user must use different control strategies in different calls of an executor, though it does have such flexibility. If to get a compromise of it, I'd like to suggest the following proposal:
ILLamaExecutor
, takingWithGenerationControl
as an extension method.BTW there's one thing related with it, the sampling pipeline. Currently
InferenceParams
contains many params of sampling, while specifyingSamplingPipeline
will totally override them. It's completely okay in the master branch because we don't want to introduce substantial break changes now. However I'm wondering if we should refactor it before v1.0.0.For example, we could add many kinds of sampling pipelines as we discussed here, along with a class named
CustomSamplingPipeline
, which accepts all the parameters related with samping inInferenceParams
now.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking more along the lines of removing the
IInferenceParams inferenceParams
from theShouldStopGeneration
methods.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this is my thinking. It feels like we're moving in a direction where we can completely get rid of the inference params in the future (which is part of why I want to avoid depending on it in this new interface, if possible).