Skip to content

Commit 5b6408b

Browse files
authored
Merge pull request #205 from martindevans/roundtrip_tokenization_investigation
RoundTrip Tokenization Errors
2 parents 4a63197 + a03fe00 commit 5b6408b

File tree

11 files changed

+425
-195
lines changed

11 files changed

+425
-195
lines changed

LLama.Unittest/BeamTests.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,14 @@ public void BasicBeam()
3232
{
3333
const int num_beams = 2;
3434
const int n_predict = 3;
35+
const string prompt = "The cat sat on";
3536

3637
var context = _model.CreateContext(_params);
3738

3839
var result = new StringBuilder();
3940

40-
var initial_tokens = context.Tokenize("The cat sat on");
41-
result.Append(context.DeTokenize(initial_tokens.ToArray()));
41+
var initial_tokens = context.Tokenize(prompt);
42+
result.Append(prompt);
4243
context.Eval(initial_tokens, 0);
4344

4445
NativeApi.llama_beam_search(context.NativeHandle, (data, state) =>

LLama.Unittest/TokenTests.cs

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,4 +72,129 @@ public void TokensNotEndWithNothing()
7272
var result = tokens.TokensEndsWithAnyString((IList<string>)Array.Empty<string>(), _model.NativeHandle, Encoding.UTF8);
7373
Assert.False(result);
7474
}
75+
76+
[Fact]
77+
public void TokensEndWith2()
78+
{
79+
var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, true, Encoding.UTF8);
80+
81+
var decoder = new StreamingTokenDecoder(Encoding.UTF8, _model);
82+
decoder.AddRange(tokens);
83+
84+
var processor = new AntipromptProcessor(new[]
85+
{
86+
"a fish",
87+
"the mat",
88+
"this is an improbably long query to be using for this method"
89+
});
90+
var result = processor.Add(decoder.Read());
91+
92+
Assert.True(result);
93+
}
94+
95+
[Fact]
96+
public void TokensEndSubstring2()
97+
{
98+
var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, true, Encoding.UTF8);
99+
100+
var decoder = new StreamingTokenDecoder(Encoding.UTF8, _model);
101+
decoder.AddRange(tokens);
102+
103+
var processor = new AntipromptProcessor(new[] { "at" });
104+
var result = processor.Add(decoder.Read());
105+
106+
Assert.True(result);
107+
}
108+
109+
[Fact]
110+
public void TokensNotEndWith2()
111+
{
112+
var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, true, Encoding.UTF8);
113+
114+
var decoder = new StreamingTokenDecoder(Encoding.UTF8, _model);
115+
decoder.AddRange(tokens);
116+
117+
var processor = new AntipromptProcessor(new[]
118+
{
119+
"a fish",
120+
"The cat sat on the edge of the ma",
121+
"this is an improbably long query to be using for this method"
122+
});
123+
var result = processor.Add(decoder.Read());
124+
125+
Assert.False(result);
126+
}
127+
128+
[Fact]
129+
public void TokensNotEndWithNothing2()
130+
{
131+
var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, true, Encoding.UTF8);
132+
133+
var decoder = new StreamingTokenDecoder(Encoding.UTF8, _model);
134+
decoder.AddRange(tokens);
135+
136+
var processor = new AntipromptProcessor();
137+
var result = processor.Add(decoder.Read());
138+
139+
Assert.False(result);
140+
}
141+
142+
[Fact]
143+
public void RoundTrip()
144+
{
145+
var strings = new[]
146+
{
147+
"Hello world",
148+
"철수",
149+
"😀 😃 😄 😁 😆철수😅 😂 😊 😇 🙂 ",
150+
};
151+
152+
var charsArr = new char[1024];
153+
154+
foreach (var input in strings)
155+
{
156+
// Convert into llama tokens
157+
var tokens = _model.NativeHandle.Tokenize(input, false, false, Encoding.UTF8);
158+
159+
// Convert tokens back into characters
160+
var chars = _model.NativeHandle.TokensToSpan(tokens, charsArr.AsSpan(), Encoding.UTF8);
161+
162+
// llama.cpp adds a space to the start of strings, remove that
163+
var output = new string(chars).TrimStart(' ');
164+
165+
// Check that the input equals the output
166+
Assert.Equal(input, output);
167+
}
168+
}
169+
170+
[Fact]
171+
public void StreamingDecoderRoundTrip()
172+
{
173+
var decoder = new StreamingTokenDecoder(Encoding.UTF8, _model);
174+
175+
var strings = new[]
176+
{
177+
"Hello world",
178+
"철수",
179+
"😀 😃 😄 😁 😆철수😅 😂 😊 😇 🙂 ",
180+
};
181+
182+
foreach (var input in strings)
183+
{
184+
decoder.Reset();
185+
186+
// Convert into llama tokens
187+
var tokens = _model.NativeHandle.Tokenize(input, false, false, Encoding.UTF8);
188+
189+
// Add tokens to decoder
190+
foreach (var token in tokens)
191+
decoder.Add(token);
192+
193+
// llama.cpp adds a space to the start of strings, remove that
194+
var output = decoder.Read().TrimStart(' ');
195+
196+
// Check that the input equals the output
197+
Assert.Equal(input, output);
198+
}
199+
}
75200
}

LLama/AntipromptProcessor.cs

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
using System;
2+
using System.Collections.Generic;
3+
4+
namespace LLama;
5+
6+
internal sealed class AntipromptProcessor
7+
{
8+
private int _longestAntiprompt;
9+
private readonly List<string> _antiprompts = new();
10+
11+
private string? _string;
12+
13+
public AntipromptProcessor(IEnumerable<string>? antiprompts = null)
14+
{
15+
if (antiprompts != null)
16+
SetAntiprompts(antiprompts);
17+
}
18+
19+
/// <summary>
20+
/// Add an antiprompt to the collection
21+
/// </summary>
22+
/// <param name="antiprompt"></param>
23+
public void AddAntiprompt(string antiprompt)
24+
{
25+
_antiprompts.Add(antiprompt);
26+
_longestAntiprompt = Math.Max(_longestAntiprompt, antiprompt.Length);
27+
}
28+
29+
/// <summary>
30+
/// Overwrite all current antiprompts with a new set
31+
/// </summary>
32+
/// <param name="antiprompts"></param>
33+
public void SetAntiprompts(IEnumerable<string> antiprompts)
34+
{
35+
_antiprompts.Clear();
36+
_antiprompts.AddRange(antiprompts);
37+
38+
_longestAntiprompt = 0;
39+
foreach (var antiprompt in _antiprompts)
40+
_longestAntiprompt = Math.Max(_longestAntiprompt, antiprompt.Length);
41+
}
42+
43+
/// <summary>
44+
/// Add some text and check if the buffer now ends with any antiprompt
45+
/// </summary>
46+
/// <param name="text"></param>
47+
/// <returns>true if the text buffer ends with any antiprompt</returns>
48+
public bool Add(string text)
49+
{
50+
_string += text;
51+
52+
// When the string gets very long (4x antiprompt length) trim it down (to 2x antiprompt length).
53+
// This trimming leaves a lot of extra characters because two sequences can be considered "equal" in unicode
54+
// even with different numbers of characters. Hopefully there are enough characters here to handle all those weird circumstances!
55+
var maxLength = Math.Max(32, _longestAntiprompt * 4);
56+
var trimLength = Math.Max(16, _longestAntiprompt * 2);
57+
if (_string.Length > maxLength)
58+
_string = _string.Substring(_string.Length - trimLength);
59+
60+
foreach (var antiprompt in _antiprompts)
61+
if (_string.EndsWith(antiprompt, StringComparison.CurrentCulture))
62+
return true;
63+
64+
return false;
65+
}
66+
}

LLama/Extensions/IReadOnlyListExtensions.cs

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ internal static class IReadOnlyListExtensions
3636
/// <param name="model">Model to use to convert tokens into bytes</param>
3737
/// <param name="encoding">Encoding to use to convert bytes into characters</param>
3838
/// <returns></returns>
39+
[Obsolete("Use an Antiprompt processor instead")]
3940
internal static bool TokensEndsWithAnyString<TTokens, TQueries>(this TTokens tokens, TQueries? queries, SafeLlamaModelHandle model, Encoding encoding)
4041
where TTokens : IReadOnlyList<int>
4142
where TQueries : IReadOnlyList<string>
@@ -68,13 +69,6 @@ internal static bool TokensEndsWithAnyString<TTokens, TQueries>(this TTokens tok
6869
}
6970
}
7071

71-
internal static bool TokensEndsWithAnyString<TTokens, TQueries>(this TTokens tokens, TQueries? queries, LLamaContext context)
72-
where TTokens : IReadOnlyList<int>
73-
where TQueries : IReadOnlyList<string>
74-
{
75-
return TokensEndsWithAnyString(tokens, queries, context.NativeHandle.ModelHandle, context.Encoding);
76-
}
77-
7872
/// <summary>
7973
/// Check if the given set of tokens ends with any of the given strings
8074
/// </summary>
@@ -83,6 +77,7 @@ internal static bool TokensEndsWithAnyString<TTokens, TQueries>(this TTokens tok
8377
/// <param name="model">Model to use to convert tokens into bytes</param>
8478
/// <param name="encoding">Encoding to use to convert bytes into characters</param>
8579
/// <returns></returns>
80+
[Obsolete("Use an Antiprompt processor instead")]
8681
internal static bool TokensEndsWithAnyString<TTokens>(this TTokens tokens, IList<string>? queries, SafeLlamaModelHandle model, Encoding encoding)
8782
where TTokens : IReadOnlyList<int>
8883
{

LLama/Extensions/ListExtensions.cs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
using System;
2+
using System.Collections.Generic;
3+
4+
namespace LLama.Extensions
5+
{
6+
internal static class ListExtensions
7+
{
8+
#if NETSTANDARD2_0
9+
public static void EnsureCapacity<T>(this List<T> list, int capacity)
10+
{
11+
if (list.Capacity < capacity)
12+
list.Capacity = capacity;
13+
}
14+
#endif
15+
16+
public static void AddSpan<T>(this List<T> list, ReadOnlySpan<T> items)
17+
{
18+
list.EnsureCapacity(list.Count + items.Length);
19+
20+
for (var i = 0; i < items.Length; i++)
21+
list.Add(items[i]);
22+
}
23+
}
24+
}

LLama/LLamaContext.cs

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -102,13 +102,15 @@ public llama_token[] Tokenize(string text, bool addBos = true, bool special = fa
102102
/// </summary>
103103
/// <param name="tokens"></param>
104104
/// <returns></returns>
105-
public string DeTokenize(IEnumerable<llama_token> tokens)
105+
[Obsolete("Use a `StreamingTokenDecoder` instead")]
106+
public string DeTokenize(IReadOnlyList<llama_token> tokens)
106107
{
107-
var sb = new StringBuilder();
108-
foreach (var token in tokens)
109-
NativeHandle.TokenToString(token, Encoding, sb);
108+
// Do **not** use this method as an example of how to correctly use the StreamingTokenDecoder!
109+
// It should be kept around for the entire time you are decoding one stream of tokens.
110110

111-
return sb.ToString();
111+
var decoder = new StreamingTokenDecoder(this);
112+
decoder.AddRange(tokens);
113+
return decoder.ToString();
112114
}
113115

114116
/// <summary>
@@ -418,26 +420,6 @@ public int Eval(ReadOnlySpan<llama_token> tokens, int pastTokensCount)
418420
}
419421
#endregion
420422

421-
/// <summary>
422-
/// Convert a token into a string
423-
/// </summary>
424-
/// <param name="token"></param>
425-
/// <returns></returns>
426-
public string TokenToString(llama_token token)
427-
{
428-
return NativeHandle.TokenToString(token, Encoding);
429-
}
430-
431-
/// <summary>
432-
/// Append a single token to a string builder
433-
/// </summary>
434-
/// <param name="token">Token to decode</param>
435-
/// <param name="dest">string builder to append the result to</param>
436-
public void TokenToString(llama_token token, StringBuilder dest)
437-
{
438-
NativeHandle.TokenToString(token, Encoding, dest);
439-
}
440-
441423
/// <inheritdoc />
442424
public void Dispose()
443425
{

LLama/LLamaExecutorBase.cs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -294,10 +294,7 @@ public virtual async IAsyncEnumerable<string> InferAsync(string text, IInference
294294
await InferInternal(inferenceParams, args);
295295

296296
if (args.ReturnValue)
297-
{
298-
foreach (var id in _embeds)
299-
yield return Context.TokenToString(id);
300-
}
297+
yield return Context.DeTokenize(_embeds);
301298

302299
var (breakGeneration, extraOutputs) = await PostProcess(inferenceParams, args);
303300
if (extraOutputs is { Count: > 0 })

LLama/LLamaStatelessExecutor.cs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ public async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams?
5656
Context.Dispose();
5757
Context = _weights.CreateContext(Context.Params, _logger);
5858

59+
var decoder = new StreamingTokenDecoder(Context);
60+
var antiprocessor = new AntipromptProcessor(inferenceParams?.AntiPrompts ?? Array.Empty<string>());
61+
5962
if (inferenceParams != null)
6063
{
6164
if (inferenceParams.TokensKeep > Context.ContextSize)
@@ -64,7 +67,6 @@ public async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams?
6467

6568
cancellationToken.ThrowIfCancellationRequested();
6669

67-
var antiprompts = inferenceParams?.AntiPrompts.ToArray() ?? Array.Empty<string>();
6870
inferenceParams ??= new InferenceParams();
6971

7072
var lastTokens = new List<llama_token>(inferenceParams.RepeatLastTokensCount);
@@ -95,13 +97,16 @@ public async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams?
9597
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar);
9698

9799
lastTokens.Add(id);
98-
yield return Context.TokenToString(id);
100+
101+
decoder.Add(id);
102+
var decoded = decoder.Read();
103+
yield return decoded;
99104

100105
tokens.Clear();
101106
tokens.Add(id);
102107

103108
// Check if any of the antiprompts have been generated
104-
if (lastTokens.TokensEndsWithAnyString(antiprompts, Context))
109+
if (antiprocessor.Add(decoded))
105110
break;
106111

107112
// when run out of context

0 commit comments

Comments
 (0)