From 5647a37b6e3a9a0aab6db980c4c2c612b595d64e Mon Sep 17 00:00:00 2001 From: Rinne Date: Thu, 14 Dec 2023 18:15:52 +0800 Subject: [PATCH 1/3] feat: support custom generation control of executors. --- .../Examples/CustomGenerationControl.cs | 63 +++++++++++++++++++ LLama.Examples/Examples/Runner.cs | 1 + LLama.Web/Common/InferenceOptions.cs | 4 ++ LLama/Abstractions/IInferenceParams.cs | 8 ++- LLama/Common/InferenceParams.cs | 4 ++ LLama/{ => Control}/AntipromptProcessor.cs | 2 +- LLama/Control/DefaultGenerationControl.cs | 42 +++++++++++++ LLama/Control/IGenerationControl.cs | 31 +++++++++ LLama/LLamaExecutorBase.cs | 9 ++- LLama/LLamaInteractExecutor.cs | 14 ++++- LLama/LLamaStatelessExecutor.cs | 8 ++- 11 files changed, 180 insertions(+), 6 deletions(-) create mode 100644 LLama.Examples/Examples/CustomGenerationControl.cs rename LLama/{ => Control}/AntipromptProcessor.cs (99%) create mode 100644 LLama/Control/DefaultGenerationControl.cs create mode 100644 LLama/Control/IGenerationControl.cs diff --git a/LLama.Examples/Examples/CustomGenerationControl.cs b/LLama.Examples/Examples/CustomGenerationControl.cs new file mode 100644 index 000000000..9e3fb439d --- /dev/null +++ b/LLama.Examples/Examples/CustomGenerationControl.cs @@ -0,0 +1,63 @@ +using LLama.Abstractions; +using LLama.Common; +using LLama.Control; +using LLama.Examples.Extensions; + +namespace LLama.Examples.Examples +{ + public class CustomGenerationControl + { + public class NumberGenerationControl: IGenerationControl + { + public bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, string lastOutputText) + { + if (lastOutputText.Any(char.IsDigit)) + { + return true; + } + return false; + } + + public bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, IEnumerable lastOutputIds) + { + return false; + } + } + public static async Task Run() + { + Console.Write("Please input your model path: "); + var modelPath = Console.ReadLine(); + + var parameters = new ModelParams(modelPath) + { + ContextSize = 1024, + Seed = 1337, + GpuLayerCount = 5 + }; + using var model = LLamaWeights.LoadFromFile(parameters); + var ex = new StatelessExecutor(model, parameters); + + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine("This is an example to show how to customize the generation control of the executors. Here we implement a control mode in which" + + " the generation will stop once there's a number generated. Please try different questions to lead the model to generate answers with and without numbers." + + " No anti-prompt is used in this example."); + Console.ForegroundColor = ConsoleColor.White; + + var inferenceParams = new InferenceParams() { Temperature = 0.6f, MaxTokens = 60, GenerationControl = new NumberGenerationControl() }; + + while (true) + { + Console.Write("\nQuestion: "); + Console.ForegroundColor = ConsoleColor.Green; + var prompt = Console.ReadLine(); + Console.ForegroundColor = ConsoleColor.White; + Console.Write("Answer: "); + prompt = $"Question: {prompt?.Trim()} Answer: "; + await foreach (var text in ex.InferAsync(prompt, inferenceParams).Spinner()) + { + Console.Write(text); + } + } + } + } +} diff --git a/LLama.Examples/Examples/Runner.cs b/LLama.Examples/Examples/Runner.cs index 3d9858e1d..0bc5202b4 100644 --- a/LLama.Examples/Examples/Runner.cs +++ b/LLama.Examples/Examples/Runner.cs @@ -13,6 +13,7 @@ public class Runner { "Interactive mode chat by using executor.", InteractiveModeExecute.Run }, { "Instruct mode chat by using executor.", InstructModeExecute.Run }, { "Stateless mode chat by using executor.", StatelessModeExecute.Run }, + { "Customize the generation control of executor.", CustomGenerationControl.Run }, { "Load and save chat session.", SaveAndLoadSession.Run }, { "Load and save state of model and executor.", LoadAndSaveState.Run }, { "Get embeddings from LLama model.", () => Task.Run(GetEmbeddings.Run) }, diff --git a/LLama.Web/Common/InferenceOptions.cs b/LLama.Web/Common/InferenceOptions.cs index c604dc0d1..9418acba4 100644 --- a/LLama.Web/Common/InferenceOptions.cs +++ b/LLama.Web/Common/InferenceOptions.cs @@ -4,6 +4,7 @@ using LLama.Abstractions; using LLama.Native; using LLama.Sampling; +using LLama.Control; namespace LLama.Web.Common { @@ -71,5 +72,8 @@ public class InferenceOptions /// public ISamplingPipeline? SamplingPipeline { get; set; } + + /// + public IGenerationControl? GenerationControl { get; set; } } } diff --git a/LLama/Abstractions/IInferenceParams.cs b/LLama/Abstractions/IInferenceParams.cs index e1e894143..06d391bf4 100644 --- a/LLama/Abstractions/IInferenceParams.cs +++ b/LLama/Abstractions/IInferenceParams.cs @@ -1,5 +1,6 @@ using System.Collections.Generic; using LLama.Common; +using LLama.Control; using LLama.Native; using LLama.Sampling; @@ -114,5 +115,10 @@ public interface IInferenceParams /// Set a custom sampling pipeline to use. If this is set All other sampling parameters are ignored! /// ISamplingPipeline? SamplingPipeline { get; set; } - } + + /// + /// Set a custom generation control to use. If this is set antiprompt will be ignored! + /// + IGenerationControl? GenerationControl { get; set; } + } } \ No newline at end of file diff --git a/LLama/Common/InferenceParams.cs b/LLama/Common/InferenceParams.cs index c1f395505..e5f54bd3f 100644 --- a/LLama/Common/InferenceParams.cs +++ b/LLama/Common/InferenceParams.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using LLama.Native; using LLama.Sampling; +using LLama.Control; namespace LLama.Common { @@ -80,6 +81,9 @@ public record InferenceParams /// public ISamplingPipeline? SamplingPipeline { get; set; } + + /// + public IGenerationControl? GenerationControl { get; set; } } /// diff --git a/LLama/AntipromptProcessor.cs b/LLama/Control/AntipromptProcessor.cs similarity index 99% rename from LLama/AntipromptProcessor.cs rename to LLama/Control/AntipromptProcessor.cs index c18c0915d..9b5ff9877 100644 --- a/LLama/AntipromptProcessor.cs +++ b/LLama/Control/AntipromptProcessor.cs @@ -1,7 +1,7 @@ using System; using System.Collections.Generic; -namespace LLama +namespace LLama.Control { /// /// AntipromptProcessor keeps track of past tokens looking for any set Anti-Prompts diff --git a/LLama/Control/DefaultGenerationControl.cs b/LLama/Control/DefaultGenerationControl.cs new file mode 100644 index 000000000..b48f71722 --- /dev/null +++ b/LLama/Control/DefaultGenerationControl.cs @@ -0,0 +1,42 @@ +using LLama.Abstractions; +using System; +using System.Collections.Generic; +using System.Text; + +namespace LLama.Control +{ + /// + /// The default generation control in LLamaSharp, using antiprompts. This class should not be inherited. + /// Note that this class has state. The previous outputs feeded to it will affect its control. + /// If you use it in a session, please don't reuse it for another session unless you intend to do so. + /// + public sealed class DefaultGenerationControl : IGenerationControl + { + private AntipromptProcessor _antipromptProcessor; + + /// + /// + /// + public DefaultGenerationControl() + { + _antipromptProcessor = new AntipromptProcessor(); + } + + /// + /// + /// + public bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, string lastOutputText) + { + _antipromptProcessor.SetAntiprompts(inferenceParams.AntiPrompts); + return _antipromptProcessor.Add(lastOutputText); + } + + /// + /// + /// + public bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, IEnumerable lastOutputIds) + { + return false; + } + } +} diff --git a/LLama/Control/IGenerationControl.cs b/LLama/Control/IGenerationControl.cs new file mode 100644 index 000000000..3e01d284a --- /dev/null +++ b/LLama/Control/IGenerationControl.cs @@ -0,0 +1,31 @@ +using LLama.Abstractions; +using System; +using System.Collections.Generic; +using System.Text; + +namespace LLama.Control +{ + /// + /// Control the text generation of LLama Executors. + /// + public interface IGenerationControl + { + /// + /// Use the last output text to determine if the generation should stop. + /// + /// + /// + /// + /// + bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, string lastOutputText); + + /// + /// Use the last output ids to determine if the generation should stop. + /// + /// + /// + /// + /// + bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, IEnumerable lastOutputIds); + } +} diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index e0fde1edb..163318100 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -60,6 +60,11 @@ public abstract class StatefulExecutorBase : ILLamaExecutor /// The last tokens generated by the model. /// protected FixedSizeQueue _last_n_tokens; + + /// + /// The last output text generated by the model. + /// + protected string _lastOutputText = string.Empty; /// /// The context used by the executor. /// @@ -299,9 +304,11 @@ public virtual async IAsyncEnumerable InferAsync(string text, IInference if (args.ReturnValue) { _decoder.AddRange(_embeds); - yield return _decoder.Read(); + _lastOutputText = _decoder.Read(); + yield return _lastOutputText; } + // TODO(Rinne): Refactor the logic here. var (breakGeneration, extraOutputs) = await PostProcess(inferenceParams, args); if (extraOutputs is { Count: > 0 }) { diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index 9cecf4378..8c4073ad8 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -10,6 +10,7 @@ using System.Threading.Tasks; using LLama.Extensions; using Microsoft.Extensions.Logging; +using LLama.Control; namespace LLama { @@ -21,6 +22,7 @@ public class InteractiveExecutor : StatefulExecutorBase { private bool _is_prompt_run = true; private readonly llama_token _llama_token_newline; + private IGenerationControl _control; /// /// @@ -31,6 +33,7 @@ public InteractiveExecutor(LLamaContext context, ILogger? logger = null) : base(context, logger) { _llama_token_newline = NativeApi.llama_token_nl(Context.NativeHandle.ModelHandle); + _control = new DefaultGenerationControl(); } /// @@ -134,8 +137,17 @@ protected override Task PreprocessInputs(string text, InferStateArgs args) { if (_embed_inps.Count <= _consumedTokensCount) { - if (_last_n_tokens.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding)) + var control = inferenceParams.GenerationControl ?? _control; + // Get stop signal by ids + if(control.ShouldStopGeneration(Context, inferenceParams, _embeds)) + { args.WaitForInput = true; + } + // Get stop signal by text + else if (control.ShouldStopGeneration(Context, inferenceParams, _lastOutputText)) + { + args.WaitForInput = true; + } if (_pastTokensCount > 0 && args.WaitForInput) return (true, Array.Empty()); diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index 831aceb26..8f97de0d1 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -8,6 +8,7 @@ using System.Threading.Tasks; using LLama.Native; using LLama.Sampling; +using LLama.Control; using Microsoft.Extensions.Logging; namespace LLama @@ -63,8 +64,8 @@ public async IAsyncEnumerable InferAsync(string prompt, IInferenceParams throw new ArgumentOutOfRangeException(nameof(inferenceParams), $"TokensKeep ({inferenceParams.TokensKeep}) cannot be larger than ContextSize ({Context.ContextSize})"); // Create decoders for the token stream + IGenerationControl control = inferenceParams.GenerationControl ?? new DefaultGenerationControl(); var decoder = new StreamingTokenDecoder(Context); - var antiprocessor = new AntipromptProcessor(inferenceParams.AntiPrompts); // Keep track of the last N tokens emitted var repeat_last_n = Math.Max(0, inferenceParams.RepeatLastTokensCount <0 ? _weights.ContextSize : inferenceParams.RepeatLastTokensCount); @@ -105,13 +106,16 @@ public async IAsyncEnumerable InferAsync(string prompt, IInferenceParams ); } + if(control.ShouldStopGeneration(context, inferenceParams, new int[] { id })) + break; + // Decode this token into text decoder.Add(id); var decoded = decoder.Read(); yield return decoded; // Check if any of the antiprompts have been generated - if (antiprocessor.Add(decoded)) + if(control.ShouldStopGeneration(Context, inferenceParams, decoded)) break; lastTokens.Add(id); From 601122ffe229ad087e234fd96b1dd0a7f3ad7c3a Mon Sep 17 00:00:00 2001 From: Rinne Date: Thu, 14 Dec 2023 18:26:13 +0800 Subject: [PATCH 2/3] fix: unit test error. --- LLama.Unittest/TokenTests.cs | 1 + 1 file changed, 1 insertion(+) diff --git a/LLama.Unittest/TokenTests.cs b/LLama.Unittest/TokenTests.cs index e39df5f47..e9f3a7e39 100644 --- a/LLama.Unittest/TokenTests.cs +++ b/LLama.Unittest/TokenTests.cs @@ -1,5 +1,6 @@ using System.Text; using LLama.Common; +using LLama.Control; using LLama.Extensions; namespace LLama.Unittest; From 5c949fb955c50bbd96078b76beb04f8e6d6bdd63 Mon Sep 17 00:00:00 2001 From: Rinne Date: Fri, 15 Dec 2023 08:52:04 +0800 Subject: [PATCH 3/3] fix: resolve comments. --- .../Examples/CustomGenerationControl.cs | 10 +++---- LLama/Control/DefaultGenerationControl.cs | 4 +-- LLama/Control/IGenerationControl.cs | 20 ++++++++------ .../Extensions/GenerationControlExtensions.cs | 27 +++++++++++++++++++ LLama/LLamaContext.cs | 10 +++++++ LLama/LLamaStatelessExecutor.cs | 16 ++++++----- 6 files changed, 63 insertions(+), 24 deletions(-) create mode 100644 LLama/Extensions/GenerationControlExtensions.cs diff --git a/LLama.Examples/Examples/CustomGenerationControl.cs b/LLama.Examples/Examples/CustomGenerationControl.cs index 9e3fb439d..1ad8943c2 100644 --- a/LLama.Examples/Examples/CustomGenerationControl.cs +++ b/LLama.Examples/Examples/CustomGenerationControl.cs @@ -11,14 +11,10 @@ public class NumberGenerationControl: IGenerationControl { public bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, string lastOutputText) { - if (lastOutputText.Any(char.IsDigit)) - { - return true; - } - return false; + return lastOutputText.Any(x => char.IsDigit(x) && (x == '4' || x == '5')); } - public bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, IEnumerable lastOutputIds) + public bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, int lastOutputId) { return false; } @@ -39,7 +35,7 @@ public static async Task Run() Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("This is an example to show how to customize the generation control of the executors. Here we implement a control mode in which" + - " the generation will stop once there's a number generated. Please try different questions to lead the model to generate answers with and without numbers." + + " the generation will stop once there's a number 4 or 5 is generated. Please try different questions to lead the model to generate answers with and without numbers." + " No anti-prompt is used in this example."); Console.ForegroundColor = ConsoleColor.White; diff --git a/LLama/Control/DefaultGenerationControl.cs b/LLama/Control/DefaultGenerationControl.cs index b48f71722..e8a619682 100644 --- a/LLama/Control/DefaultGenerationControl.cs +++ b/LLama/Control/DefaultGenerationControl.cs @@ -34,9 +34,9 @@ public bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenc /// /// /// - public bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, IEnumerable lastOutputIds) + public bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, int lastOutputId) { - return false; + return context.IsEOS(lastOutputId); } } } diff --git a/LLama/Control/IGenerationControl.cs b/LLama/Control/IGenerationControl.cs index 3e01d284a..ab5061ca3 100644 --- a/LLama/Control/IGenerationControl.cs +++ b/LLama/Control/IGenerationControl.cs @@ -12,20 +12,24 @@ public interface IGenerationControl { /// /// Use the last output text to determine if the generation should stop. + /// This method will be called after the overload with output id. + /// The text will be returned even if this returns true but the generation will be stopped. /// - /// - /// - /// + /// The LLamaContext used in the current generation. + /// The inference params used in the current generation. + /// The last output text generated by the model. /// bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, string lastOutputText); /// - /// Use the last output ids to determine if the generation should stop. + /// Use the last output token to determine if the generation should stop. + /// This method will be called before the overload with output text. + /// The token will be returned even if this returns true but the generation will be stopped. /// - /// - /// - /// + /// The LLamaContext used in the current generation. + /// The inference params used in the current generation. + /// The last output token generated by the model. /// - bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, IEnumerable lastOutputIds); + bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, int lastOutputId); } } diff --git a/LLama/Extensions/GenerationControlExtensions.cs b/LLama/Extensions/GenerationControlExtensions.cs new file mode 100644 index 000000000..81ddb626c --- /dev/null +++ b/LLama/Extensions/GenerationControlExtensions.cs @@ -0,0 +1,27 @@ +using LLama.Abstractions; +using LLama.Control; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace LLama.Extensions +{ + /// + /// Extension methods for generation control + /// + public static class GenerationControlExtensions + { + public static bool ShouldStopGeneration(this IGenerationControl control, LLamaContext context, IInferenceParams inferenceParams, IEnumerable lastOutputIds) + { + foreach (var id in lastOutputIds) + { + if(control.ShouldStopGeneration(context, inferenceParams, id)) + { + return true; + } + } + return false; + } + } +} diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index 2902dc8f9..3e9c0f446 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -86,6 +86,16 @@ public LLamaContext(LLamaWeights model, IContextParams @params, ILogger? logger NativeHandle = SafeLLamaContextHandle.Create(model.NativeHandle, lparams); } + /// + /// Return if a token marks the end of a sentence. + /// + /// + /// + public bool IsEOS(int token) + { + return NativeApi.llama_token_eos(this.NativeHandle.ModelHandle) == token; + } + /// /// Tokenize a string. /// diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index 8f97de0d1..a836716c8 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -106,20 +106,22 @@ public async IAsyncEnumerable InferAsync(string prompt, IInferenceParams ); } - if(control.ShouldStopGeneration(context, inferenceParams, new int[] { id })) - break; - // Decode this token into text decoder.Add(id); var decoded = decoder.Read(); yield return decoded; - // Check if any of the antiprompts have been generated - if(control.ShouldStopGeneration(Context, inferenceParams, decoded)) - break; - lastTokens.Add(id); tokens.Clear(); + + // Check if we should steop generation by ids + if (control.ShouldStopGeneration(context, inferenceParams, id)) + break; + // Check if we should steop generation by text + if (control.ShouldStopGeneration(Context, inferenceParams, decoded)) + break; + + // prepare for the next loop tokens.Add(id); // when run out of context