diff --git a/LLama/StreamingTokenDecoder.cs b/LLama/StreamingTokenDecoder.cs index 9252e532d..60de20769 100644 --- a/LLama/StreamingTokenDecoder.cs +++ b/LLama/StreamingTokenDecoder.cs @@ -1,9 +1,8 @@ -using System.Buffers; +using System.Buffers; using System.Diagnostics; using System; using System.Collections.Generic; using System.Text; -using LLama.Extensions; using LLama.Native; namespace LLama @@ -23,6 +22,11 @@ public sealed class StreamingTokenDecoder /// public int AvailableCharacters => _characters.Count; + /// + /// If true, special characters will be converted to text. If false they will be invisible. + /// + public bool DecodeSpecialTokens { get; set; } + #region constructors /// /// Create a new decoder @@ -76,7 +80,7 @@ public void Add(LLamaToken token) try { // Convert this token into bytes - var bytesAvailable = TokenToBytes(ref bytesArr, token, _weights).Length; + var bytesAvailable = TokenToBytes(ref bytesArr, token, _weights, DecodeSpecialTokens).Length; // Convert those bytes into characters var bytesOffset = 0; @@ -108,10 +112,10 @@ public void Add(LLamaToken token) // Converts a single token into bytes, using the `bytes` array as temporary storage. // If the `bytes` array is too small it will get a larger one from the ArrayPool. - static Span TokenToBytes(ref byte[] bytes, LLamaToken token, SafeLlamaModelHandle model) + static Span TokenToBytes(ref byte[] bytes, LLamaToken token, SafeLlamaModelHandle model, bool special) { // Try to get bytes - var l = model.TokenToSpan(token, bytes); + var l = model.TokenToSpan(token, bytes, special); // Check if the length was larger than the buffer. If so expand the buffer and try again if (l > bytes.Length)