diff --git a/LLama/Extensions/IReadOnlyListExtensions.cs b/LLama/Extensions/IReadOnlyListExtensions.cs index b07d90cfa..131a88525 100644 --- a/LLama/Extensions/IReadOnlyListExtensions.cs +++ b/LLama/Extensions/IReadOnlyListExtensions.cs @@ -9,6 +9,13 @@ namespace LLama.Extensions { internal static class IReadOnlyListExtensions { + /// + /// Find the index of `item` in `list` + /// + /// + /// list to search + /// item to search for + /// public static int? IndexOf(this IReadOnlyList list, T item) where T : IEquatable { @@ -61,6 +68,14 @@ internal static bool TokensEndsWithAnyString(this TTokens tok } } + /// + /// Check if the given set of tokens ends with any of the given strings + /// + /// Tokens to check + /// Strings to search for + /// Model to use to convert tokens into bytes + /// Encoding to use to convert bytes into characters + /// internal static bool TokensEndsWithAnyString(this TTokens tokens, IList? queries, SafeLlamaModelHandle model, Encoding encoding) where TTokens : IReadOnlyList { diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index 5b3853fc6..7b7647e71 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -484,6 +484,16 @@ public string TokenToString(llama_token token) return NativeHandle.TokenToString(token, Encoding); } + /// + /// Append a single token to a string builder + /// + /// Token to decode + /// string builder to append the result to + public void TokenToString(llama_token token, StringBuilder dest) + { + NativeHandle.TokenToString(token, Encoding, dest); + } + /// public void Dispose() { diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index 712c2c239..2d46728fd 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -8,6 +8,7 @@ using System.Text; using System.Text.Json; using System.Text.Json.Serialization; +using LLama.Extensions; namespace LLama { @@ -139,21 +140,10 @@ protected override bool PostProcess(IInferenceParams inferenceParams, InferState extraOutputs = null; if (_embed_inps.Count <= _consumedTokensCount) { - if (args.Antiprompts is not null && args.Antiprompts.Count > 0) + if (_last_n_tokens.Items.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding)) { - var last_output_builder = new StringBuilder(); - foreach (var token in _last_n_tokens) - Context.NativeHandle.TokenToString(token, Context.Encoding, last_output_builder); - var last_output = last_output_builder.ToString(); - - foreach (var antiprompt in args.Antiprompts) - { - if (last_output.EndsWith(antiprompt)) - { - args.WaitForInput = true; - return true; - } - } + args.WaitForInput = true; + return true; } if (_pastTokensCount > 0 && args.WaitForInput)