diff --git a/LLama/Extensions/IReadOnlyListExtensions.cs b/LLama/Extensions/IReadOnlyListExtensions.cs
index b07d90cfa..131a88525 100644
--- a/LLama/Extensions/IReadOnlyListExtensions.cs
+++ b/LLama/Extensions/IReadOnlyListExtensions.cs
@@ -9,6 +9,13 @@ namespace LLama.Extensions
{
internal static class IReadOnlyListExtensions
{
+ ///
+ /// Find the index of `item` in `list`
+ ///
+ ///
+ /// list to search
+ /// item to search for
+ ///
public static int? IndexOf(this IReadOnlyList list, T item)
where T : IEquatable
{
@@ -61,6 +68,14 @@ internal static bool TokensEndsWithAnyString(this TTokens tok
}
}
+ ///
+ /// Check if the given set of tokens ends with any of the given strings
+ ///
+ /// Tokens to check
+ /// Strings to search for
+ /// Model to use to convert tokens into bytes
+ /// Encoding to use to convert bytes into characters
+ ///
internal static bool TokensEndsWithAnyString(this TTokens tokens, IList? queries, SafeLlamaModelHandle model, Encoding encoding)
where TTokens : IReadOnlyList
{
diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs
index 5b3853fc6..7b7647e71 100644
--- a/LLama/LLamaContext.cs
+++ b/LLama/LLamaContext.cs
@@ -484,6 +484,16 @@ public string TokenToString(llama_token token)
return NativeHandle.TokenToString(token, Encoding);
}
+ ///
+ /// Append a single token to a string builder
+ ///
+ /// Token to decode
+ /// string builder to append the result to
+ public void TokenToString(llama_token token, StringBuilder dest)
+ {
+ NativeHandle.TokenToString(token, Encoding, dest);
+ }
+
///
public void Dispose()
{
diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs
index 712c2c239..2d46728fd 100644
--- a/LLama/LLamaInstructExecutor.cs
+++ b/LLama/LLamaInstructExecutor.cs
@@ -8,6 +8,7 @@
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
+using LLama.Extensions;
namespace LLama
{
@@ -139,21 +140,10 @@ protected override bool PostProcess(IInferenceParams inferenceParams, InferState
extraOutputs = null;
if (_embed_inps.Count <= _consumedTokensCount)
{
- if (args.Antiprompts is not null && args.Antiprompts.Count > 0)
+ if (_last_n_tokens.Items.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding))
{
- var last_output_builder = new StringBuilder();
- foreach (var token in _last_n_tokens)
- Context.NativeHandle.TokenToString(token, Context.Encoding, last_output_builder);
- var last_output = last_output_builder.ToString();
-
- foreach (var antiprompt in args.Antiprompts)
- {
- if (last_output.EndsWith(antiprompt))
- {
- args.WaitForInput = true;
- return true;
- }
- }
+ args.WaitForInput = true;
+ return true;
}
if (_pastTokensCount > 0 && args.WaitForInput)