diff --git a/LLama.Unittest/BeamTests.cs b/LLama.Unittest/BeamTests.cs index f8d5cf01c..3014894ec 100644 --- a/LLama.Unittest/BeamTests.cs +++ b/LLama.Unittest/BeamTests.cs @@ -32,13 +32,14 @@ public void BasicBeam() { const int num_beams = 2; const int n_predict = 3; + const string prompt = "The cat sat on"; var context = _model.CreateContext(_params); var result = new StringBuilder(); - var initial_tokens = context.Tokenize("The cat sat on"); - result.Append(context.DeTokenize(initial_tokens.ToArray())); + var initial_tokens = context.Tokenize(prompt); + result.Append(prompt); context.Eval(initial_tokens, 0); NativeApi.llama_beam_search(context.NativeHandle, (data, state) => diff --git a/LLama.Unittest/TokenTests.cs b/LLama.Unittest/TokenTests.cs index a699b9b85..e39df5f47 100644 --- a/LLama.Unittest/TokenTests.cs +++ b/LLama.Unittest/TokenTests.cs @@ -72,4 +72,129 @@ public void TokensNotEndWithNothing() var result = tokens.TokensEndsWithAnyString((IList)Array.Empty(), _model.NativeHandle, Encoding.UTF8); Assert.False(result); } + + [Fact] + public void TokensEndWith2() + { + var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, true, Encoding.UTF8); + + var decoder = new StreamingTokenDecoder(Encoding.UTF8, _model); + decoder.AddRange(tokens); + + var processor = new AntipromptProcessor(new[] + { + "a fish", + "the mat", + "this is an improbably long query to be using for this method" + }); + var result = processor.Add(decoder.Read()); + + Assert.True(result); + } + + [Fact] + public void TokensEndSubstring2() + { + var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, true, Encoding.UTF8); + + var decoder = new StreamingTokenDecoder(Encoding.UTF8, _model); + decoder.AddRange(tokens); + + var processor = new AntipromptProcessor(new[] { "at" }); + var result = processor.Add(decoder.Read()); + + Assert.True(result); + } + + [Fact] + public void TokensNotEndWith2() + { + var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, true, Encoding.UTF8); + + var decoder = new StreamingTokenDecoder(Encoding.UTF8, _model); + decoder.AddRange(tokens); + + var processor = new AntipromptProcessor(new[] + { + "a fish", + "The cat sat on the edge of the ma", + "this is an improbably long query to be using for this method" + }); + var result = processor.Add(decoder.Read()); + + Assert.False(result); + } + + [Fact] + public void TokensNotEndWithNothing2() + { + var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, true, Encoding.UTF8); + + var decoder = new StreamingTokenDecoder(Encoding.UTF8, _model); + decoder.AddRange(tokens); + + var processor = new AntipromptProcessor(); + var result = processor.Add(decoder.Read()); + + Assert.False(result); + } + + [Fact] + public void RoundTrip() + { + var strings = new[] + { + "Hello world", + "철수", + "πŸ˜€ πŸ˜ƒ πŸ˜„ 😁 πŸ˜†μ² μˆ˜πŸ˜… πŸ˜‚ 😊 πŸ˜‡ πŸ™‚ ", + }; + + var charsArr = new char[1024]; + + foreach (var input in strings) + { + // Convert into llama tokens + var tokens = _model.NativeHandle.Tokenize(input, false, false, Encoding.UTF8); + + // Convert tokens back into characters + var chars = _model.NativeHandle.TokensToSpan(tokens, charsArr.AsSpan(), Encoding.UTF8); + + // llama.cpp adds a space to the start of strings, remove that + var output = new string(chars).TrimStart(' '); + + // Check that the input equals the output + Assert.Equal(input, output); + } + } + + [Fact] + public void StreamingDecoderRoundTrip() + { + var decoder = new StreamingTokenDecoder(Encoding.UTF8, _model); + + var strings = new[] + { + "Hello world", + "철수", + "πŸ˜€ πŸ˜ƒ πŸ˜„ 😁 πŸ˜†μ² μˆ˜πŸ˜… πŸ˜‚ 😊 πŸ˜‡ πŸ™‚ ", + }; + + foreach (var input in strings) + { + decoder.Reset(); + + // Convert into llama tokens + var tokens = _model.NativeHandle.Tokenize(input, false, false, Encoding.UTF8); + + // Add tokens to decoder + foreach (var token in tokens) + decoder.Add(token); + + // llama.cpp adds a space to the start of strings, remove that + var output = decoder.Read().TrimStart(' '); + + // Check that the input equals the output + Assert.Equal(input, output); + } + } } \ No newline at end of file diff --git a/LLama/AntipromptProcessor.cs b/LLama/AntipromptProcessor.cs new file mode 100644 index 000000000..4d969cea2 --- /dev/null +++ b/LLama/AntipromptProcessor.cs @@ -0,0 +1,66 @@ +ο»Ώusing System; +using System.Collections.Generic; + +namespace LLama; + +internal sealed class AntipromptProcessor +{ + private int _longestAntiprompt; + private readonly List _antiprompts = new(); + + private string? _string; + + public AntipromptProcessor(IEnumerable? antiprompts = null) + { + if (antiprompts != null) + SetAntiprompts(antiprompts); + } + + /// + /// Add an antiprompt to the collection + /// + /// + public void AddAntiprompt(string antiprompt) + { + _antiprompts.Add(antiprompt); + _longestAntiprompt = Math.Max(_longestAntiprompt, antiprompt.Length); + } + + /// + /// Overwrite all current antiprompts with a new set + /// + /// + public void SetAntiprompts(IEnumerable antiprompts) + { + _antiprompts.Clear(); + _antiprompts.AddRange(antiprompts); + + _longestAntiprompt = 0; + foreach (var antiprompt in _antiprompts) + _longestAntiprompt = Math.Max(_longestAntiprompt, antiprompt.Length); + } + + /// + /// Add some text and check if the buffer now ends with any antiprompt + /// + /// + /// true if the text buffer ends with any antiprompt + public bool Add(string text) + { + _string += text; + + // When the string gets very long (4x antiprompt length) trim it down (to 2x antiprompt length). + // This trimming leaves a lot of extra characters because two sequences can be considered "equal" in unicode + // even with different numbers of characters. Hopefully there are enough characters here to handle all those weird circumstances! + var maxLength = Math.Max(32, _longestAntiprompt * 4); + var trimLength = Math.Max(16, _longestAntiprompt * 2); + if (_string.Length > maxLength) + _string = _string.Substring(_string.Length - trimLength); + + foreach (var antiprompt in _antiprompts) + if (_string.EndsWith(antiprompt, StringComparison.CurrentCulture)) + return true; + + return false; + } +} \ No newline at end of file diff --git a/LLama/Extensions/IReadOnlyListExtensions.cs b/LLama/Extensions/IReadOnlyListExtensions.cs index 4d1c6f093..7a3473b71 100644 --- a/LLama/Extensions/IReadOnlyListExtensions.cs +++ b/LLama/Extensions/IReadOnlyListExtensions.cs @@ -36,6 +36,7 @@ internal static class IReadOnlyListExtensions /// Model to use to convert tokens into bytes /// Encoding to use to convert bytes into characters /// + [Obsolete("Use an Antiprompt processor instead")] internal static bool TokensEndsWithAnyString(this TTokens tokens, TQueries? queries, SafeLlamaModelHandle model, Encoding encoding) where TTokens : IReadOnlyList where TQueries : IReadOnlyList @@ -68,13 +69,6 @@ internal static bool TokensEndsWithAnyString(this TTokens tok } } - internal static bool TokensEndsWithAnyString(this TTokens tokens, TQueries? queries, LLamaContext context) - where TTokens : IReadOnlyList - where TQueries : IReadOnlyList - { - return TokensEndsWithAnyString(tokens, queries, context.NativeHandle.ModelHandle, context.Encoding); - } - /// /// Check if the given set of tokens ends with any of the given strings /// @@ -83,6 +77,7 @@ internal static bool TokensEndsWithAnyString(this TTokens tok /// Model to use to convert tokens into bytes /// Encoding to use to convert bytes into characters /// + [Obsolete("Use an Antiprompt processor instead")] internal static bool TokensEndsWithAnyString(this TTokens tokens, IList? queries, SafeLlamaModelHandle model, Encoding encoding) where TTokens : IReadOnlyList { diff --git a/LLama/Extensions/ListExtensions.cs b/LLama/Extensions/ListExtensions.cs new file mode 100644 index 000000000..11a1d4f00 --- /dev/null +++ b/LLama/Extensions/ListExtensions.cs @@ -0,0 +1,24 @@ +ο»Ώusing System; +using System.Collections.Generic; + +namespace LLama.Extensions +{ + internal static class ListExtensions + { +#if NETSTANDARD2_0 + public static void EnsureCapacity(this List list, int capacity) + { + if (list.Capacity < capacity) + list.Capacity = capacity; + } +#endif + + public static void AddSpan(this List list, ReadOnlySpan items) + { + list.EnsureCapacity(list.Count + items.Length); + + for (var i = 0; i < items.Length; i++) + list.Add(items[i]); + } + } +} diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index a190c075f..46b0ae3f8 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -102,13 +102,15 @@ public llama_token[] Tokenize(string text, bool addBos = true, bool special = fa /// /// /// - public string DeTokenize(IEnumerable tokens) + [Obsolete("Use a `StreamingTokenDecoder` instead")] + public string DeTokenize(IReadOnlyList tokens) { - var sb = new StringBuilder(); - foreach (var token in tokens) - NativeHandle.TokenToString(token, Encoding, sb); + // Do **not** use this method as an example of how to correctly use the StreamingTokenDecoder! + // It should be kept around for the entire time you are decoding one stream of tokens. - return sb.ToString(); + var decoder = new StreamingTokenDecoder(this); + decoder.AddRange(tokens); + return decoder.ToString(); } /// @@ -418,26 +420,6 @@ public int Eval(ReadOnlySpan tokens, int pastTokensCount) } #endregion - /// - /// Convert a token into a string - /// - /// - /// - public string TokenToString(llama_token token) - { - return NativeHandle.TokenToString(token, Encoding); - } - - /// - /// Append a single token to a string builder - /// - /// Token to decode - /// string builder to append the result to - public void TokenToString(llama_token token, StringBuilder dest) - { - NativeHandle.TokenToString(token, Encoding, dest); - } - /// public void Dispose() { diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index 1a12c6b2d..578bd4d84 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -294,10 +294,7 @@ public virtual async IAsyncEnumerable InferAsync(string text, IInference await InferInternal(inferenceParams, args); if (args.ReturnValue) - { - foreach (var id in _embeds) - yield return Context.TokenToString(id); - } + yield return Context.DeTokenize(_embeds); var (breakGeneration, extraOutputs) = await PostProcess(inferenceParams, args); if (extraOutputs is { Count: > 0 }) diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index 80488b712..ab1e9bbc0 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -56,6 +56,9 @@ public async IAsyncEnumerable InferAsync(string text, IInferenceParams? Context.Dispose(); Context = _weights.CreateContext(Context.Params, _logger); + var decoder = new StreamingTokenDecoder(Context); + var antiprocessor = new AntipromptProcessor(inferenceParams?.AntiPrompts ?? Array.Empty()); + if (inferenceParams != null) { if (inferenceParams.TokensKeep > Context.ContextSize) @@ -64,7 +67,6 @@ public async IAsyncEnumerable InferAsync(string text, IInferenceParams? cancellationToken.ThrowIfCancellationRequested(); - var antiprompts = inferenceParams?.AntiPrompts.ToArray() ?? Array.Empty(); inferenceParams ??= new InferenceParams(); var lastTokens = new List(inferenceParams.RepeatLastTokensCount); @@ -95,13 +97,16 @@ public async IAsyncEnumerable InferAsync(string text, IInferenceParams? inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar); lastTokens.Add(id); - yield return Context.TokenToString(id); + + decoder.Add(id); + var decoded = decoder.Read(); + yield return decoded; tokens.Clear(); tokens.Add(id); // Check if any of the antiprompts have been generated - if (lastTokens.TokensEndsWithAnyString(antiprompts, Context)) + if (antiprocessor.Add(decoded)) break; // when run out of context diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index c411385c9..7fb5edf74 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -1,5 +1,6 @@ ο»Ώusing System; using System.Buffers; +using System.Collections.Generic; using System.Text; using LLama.Exceptions; @@ -158,28 +159,6 @@ public int[] Tokenize(string text, bool add_bos, bool special, Encoding encoding } } - /// - /// Convert a token into a string - /// - /// Token to decode into a string - /// - /// - public string TokenToString(int token, Encoding encoding) - { - return ThrowIfDisposed().TokenToString(token, encoding); - } - - /// - /// Append a single llama token to a string builder - /// - /// Token to decode - /// - /// string builder to append the result to - public void TokenToString(int token, Encoding encoding, StringBuilder dest) - { - ThrowIfDisposed().TokenToString(token, encoding, dest); - } - /// /// Convert a single llama token into bytes /// @@ -190,7 +169,7 @@ public int TokenToSpan(int token, Span dest) { return ThrowIfDisposed().TokenToSpan(token, dest); } - #endregion +#endregion /// /// Run the llama inference to obtain the logits and probabilities for the next token. diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs index 5f3900e97..b93c2b89c 100644 --- a/LLama/Native/SafeLlamaModelHandle.cs +++ b/LLama/Native/SafeLlamaModelHandle.cs @@ -1,10 +1,7 @@ ο»Ώusing System; -using System.Buffers; using System.Collections.Generic; -using System.Diagnostics; using System.Text; using LLama.Exceptions; -using LLama.Extensions; namespace LLama.Native { @@ -119,67 +116,7 @@ public int TokenToSpan(int llama_token, Span dest) } /// - /// Convert a single llama token into a string - /// - /// - /// Encoding to use to decode the bytes into a string - /// - public string TokenToString(int llama_token, Encoding encoding) - { - unsafe - { - var length = NativeApi.llama_token_to_piece(this, llama_token, null, 0); - if (length == 0) - return ""; - - Span bytes = stackalloc byte[-length]; - - fixed (byte* bytePtr = bytes) - { - var written = NativeApi.llama_token_to_piece(this, llama_token, bytePtr, bytes.Length); - Debug.Assert(written == bytes.Length); - - return encoding.GetString(bytePtr, bytes.Length); - } - } - } - - /// - /// Append a single llama token to a string builder - /// - /// Token to decode - /// - /// string builder to append the result to - public void TokenToString(int llama_token, Encoding encoding, StringBuilder dest) - { - unsafe - { - var length = NativeApi.llama_token_to_piece(this, llama_token, null, 0); - if (length == 0) - return; - - Span bytes = stackalloc byte[-length]; - fixed (byte* bytePtr = bytes) - { - // Decode into bytes - var written = NativeApi.llama_token_to_piece(this, llama_token, bytePtr, bytes.Length); - Debug.Assert(written == bytes.Length); - - // Decode into chars - var charCount = encoding.GetCharCount(bytePtr, bytes.Length); - Span chars = stackalloc char[charCount]; - fixed (char* charPtr = chars) - encoding.GetChars(bytePtr, bytes.Length, charPtr, chars.Length); - - // Write it to the output - for (var i = 0; i < chars.Length; i++) - dest.Append(chars[i]); - } - } - } - - /// - /// Convert a sequence of tokens into characters. If there + /// Convert a sequence of tokens into characters. /// /// /// @@ -188,80 +125,25 @@ public void TokenToString(int llama_token, Encoding encoding, StringBuilder dest /// If there was insufficient space in the output span this will be /// filled with as many characters as possible, starting from the _last_ token. /// + [Obsolete("Use a StreamingTokenDecoder instead")] internal Span TokensToSpan(IReadOnlyList tokens, Span dest, Encoding encoding) { - // Rent an array to detokenize into - var tokenBytesArr = ArrayPool.Shared.Rent(16); - var tokenCharsArr = ArrayPool.Shared.Rent(16); - try - { - var totalCharacters = 0; - var unused = dest; - - for (var i = tokens.Count - 1; i >= 0; i--) - { - var token = tokens[i]; + var decoder = new StreamingTokenDecoder(encoding, this); - // Get bytes for this token - var tokenBytes = TokenToBytes(ref tokenBytesArr, token, this); + foreach (var token in tokens) + decoder.Add(token); - // Get chars for this token - var tokenChars = BytesToChars(ref tokenCharsArr, tokenBytes, encoding); + var str = decoder.Read(); - // Trim down number of characters if there are too many - if (tokenChars.Length > unused.Length) - tokenChars = tokenChars.Slice(tokenChars.Length - unused.Length, unused.Length); - - // Copy characters - tokenChars.CopyTo(unused.Slice(unused.Length - tokenChars.Length, tokenChars.Length)); - unused = unused.Slice(0, unused.Length - tokenChars.Length); - totalCharacters += tokenChars.Length; - - // Break out if we've run out of space - if (unused.Length == 0) - break; - } - - return dest.Slice(dest.Length - totalCharacters, totalCharacters); - } - finally + if (str.Length < dest.Length) { - ArrayPool.Shared.Return(tokenBytesArr); - ArrayPool.Shared.Return(tokenCharsArr); + str.AsSpan().CopyTo(dest); + return dest.Slice(0, str.Length); } - - // vvv Local Functions vvv - - static Span TokenToBytes(ref byte[] bytes, int token, SafeLlamaModelHandle model) + else { - // Try to get bytes, if that fails we known the length - var l = model.TokenToSpan(token, bytes); - - // Array was too small, get a bigger one - if (l < 0) - { - ArrayPool.Shared.Return(bytes); - bytes = ArrayPool.Shared.Rent(-l * 2); - - // Get bytes, this time it can't fail - l = model.TokenToSpan(token, bytes); - } - - Debug.Assert(l >= 0); - return new Span(bytes, 0, l); - } - - static Span BytesToChars(ref char[] chars, ReadOnlySpan bytes, Encoding encoding) - { - var count = encoding.GetCharCount(bytes); - if (count > chars.Length) - { - ArrayPool.Shared.Return(chars); - chars = ArrayPool.Shared.Rent(count * 2); - } - - encoding.GetChars(bytes, chars); - return chars.AsSpan(0, count); + str.AsSpan().Slice(str.Length - dest.Length).CopyTo(dest); + return dest; } } @@ -304,7 +186,7 @@ public int[] Tokenize(string text, bool add_bos, bool special, Encoding encoding } } } - #endregion +#endregion #region context /// diff --git a/LLama/StreamingTokenDecoder.cs b/LLama/StreamingTokenDecoder.cs new file mode 100644 index 000000000..c5d9683e9 --- /dev/null +++ b/LLama/StreamingTokenDecoder.cs @@ -0,0 +1,174 @@ +ο»Ώusing System.Buffers; +using System.Diagnostics; +using System; +using System.Collections.Generic; +using System.Text; +using LLama.Extensions; +using LLama.Native; + +namespace LLama; + +/// +/// Decodes a stream of tokens into a stream of characters +/// +public sealed class StreamingTokenDecoder +{ + private readonly SafeLlamaModelHandle _weights; + private readonly Decoder _decoder; + + private readonly List _characters = new(); + + /// + /// The number of decoded characters waiting to be read + /// + public int AvailableCharacters => _characters.Count; + + #region constructors + /// + /// Create a new decoder + /// + /// Text encoding to use + /// Model weights + public StreamingTokenDecoder(Encoding encoding, LLamaWeights weights) + : this(encoding, weights.NativeHandle) + { + } + + /// + /// Create a new decoder + /// + /// Context to retrieve encoding and model weights from + public StreamingTokenDecoder(LLamaContext context) + : this(context.Encoding, context.NativeHandle) + { + } + + /// + /// Create a new decoder + /// + /// Text encoding to use + /// Context to retrieve model weights from + public StreamingTokenDecoder(Encoding encoding, SafeLLamaContextHandle context) + : this(encoding, context.ModelHandle) + { + } + + /// + /// Create a new decoder + /// + /// Text encoding to use + /// Models weights to use + public StreamingTokenDecoder(Encoding encoding, SafeLlamaModelHandle weights) + { + _weights = weights; + _decoder = encoding.GetDecoder(); + } + #endregion + + /// + /// Add a single token to the decoder + /// + /// + public void Add(int token) + { + var charsArr = ArrayPool.Shared.Rent(16); + var bytesArr = ArrayPool.Shared.Rent(16); + try + { + // Convert this token into bytes + var bytesAvailable = TokenToBytes(ref bytesArr, token, _weights).Length; + + // Convert those bytes into characters + var bytesOffset = 0; + var completed = false; + while (!completed) + { + // Decode some of the bytes into the temp char buffer. Keep doing this + // until all bytes have been consumed + _decoder.Convert( + bytesArr, bytesOffset, bytesAvailable, + charsArr, 0, charsArr.Length, + false, + out var bytesUsed, out var charsUsed, out completed + ); + bytesOffset += bytesUsed; + bytesAvailable -= bytesUsed; + + // Add the decoded characters to the output buffer + _characters.AddSpan(charsArr.AsSpan(0, charsUsed)); + } + } + finally + { + ArrayPool.Shared.Return(charsArr); + ArrayPool.Shared.Return(bytesArr); + } + + return; + + // Converts a single token into bytes, using the `bytes` array as temporary storage. + // If the `bytes` array is too small it will get a larger one from the ArrayPool. + static Span TokenToBytes(ref byte[] bytes, int token, SafeLlamaModelHandle model) + { + // Try to get bytes + var l = model.TokenToSpan(token, bytes); + + // Negative length indicates that the output was too small. Expand it to twice that size and try again. + if (l < 0) + { + // Return the old array to the pool and get a new one + ArrayPool.Shared.Return(bytes); + bytes = ArrayPool.Shared.Rent(-l * 2); + + // Get bytes, this time it can't fail + l = model.TokenToSpan(token, bytes); + } + + Debug.Assert(l >= 0); + return new Span(bytes, 0, l); + } + } + + /// + /// Add all tokens in the given enumerable + /// + /// + public void AddRange(IEnumerable tokens) + { + foreach (var item in tokens) + Add(item); + } + + /// + /// Read all decoded characters and clear the buffer + /// + /// + public void Read(List dest) + { + dest.AddRange(_characters); + _characters.Clear(); + } + + /// + /// Read all decoded characters as a string and clear the buffer + /// + /// + public string Read() + { + if (_characters.Count == 0) + return ""; + + var str = string.Join("", _characters); + _characters.Clear(); + return str; + } + + /// + /// Set the decoder back to its initial state + /// + public void Reset() + { + _decoder.Reset(); + _characters.Clear(); + } +} \ No newline at end of file