Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions LLama.Unittest/BeamTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,14 @@ public void BasicBeam()
{
const int num_beams = 2;
const int n_predict = 3;
const string prompt = "The cat sat on";

var context = _model.CreateContext(_params);

var result = new StringBuilder();

var initial_tokens = context.Tokenize("The cat sat on");
result.Append(context.DeTokenize(initial_tokens.ToArray()));
var initial_tokens = context.Tokenize(prompt);
result.Append(prompt);
context.Eval(initial_tokens, 0);

NativeApi.llama_beam_search(context.NativeHandle, (data, state) =>
Expand Down
125 changes: 125 additions & 0 deletions LLama.Unittest/TokenTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,129 @@ public void TokensNotEndWithNothing()
var result = tokens.TokensEndsWithAnyString((IList<string>)Array.Empty<string>(), _model.NativeHandle, Encoding.UTF8);
Assert.False(result);
}

[Fact]
public void TokensEndWith2()
{
var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, true, Encoding.UTF8);

var decoder = new StreamingTokenDecoder(Encoding.UTF8, _model);
decoder.AddRange(tokens);

var processor = new AntipromptProcessor(new[]
{
"a fish",
"the mat",
"this is an improbably long query to be using for this method"
});
var result = processor.Add(decoder.Read());

Assert.True(result);
}

[Fact]
public void TokensEndSubstring2()
{
var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, true, Encoding.UTF8);

var decoder = new StreamingTokenDecoder(Encoding.UTF8, _model);
decoder.AddRange(tokens);

var processor = new AntipromptProcessor(new[] { "at" });
var result = processor.Add(decoder.Read());

Assert.True(result);
}

[Fact]
public void TokensNotEndWith2()
{
var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, true, Encoding.UTF8);

var decoder = new StreamingTokenDecoder(Encoding.UTF8, _model);
decoder.AddRange(tokens);

var processor = new AntipromptProcessor(new[]
{
"a fish",
"The cat sat on the edge of the ma",
"this is an improbably long query to be using for this method"
});
var result = processor.Add(decoder.Read());

Assert.False(result);
}

[Fact]
public void TokensNotEndWithNothing2()
{
var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, true, Encoding.UTF8);

var decoder = new StreamingTokenDecoder(Encoding.UTF8, _model);
decoder.AddRange(tokens);

var processor = new AntipromptProcessor();
var result = processor.Add(decoder.Read());

Assert.False(result);
}

[Fact]
public void RoundTrip()
{
var strings = new[]
{
"Hello world",
"철수",
"😀 😃 😄 😁 😆철수😅 😂 😊 😇 🙂 ",
};

var charsArr = new char[1024];

foreach (var input in strings)
{
// Convert into llama tokens
var tokens = _model.NativeHandle.Tokenize(input, false, false, Encoding.UTF8);

// Convert tokens back into characters
var chars = _model.NativeHandle.TokensToSpan(tokens, charsArr.AsSpan(), Encoding.UTF8);

// llama.cpp adds a space to the start of strings, remove that
var output = new string(chars).TrimStart(' ');

// Check that the input equals the output
Assert.Equal(input, output);
}
}

[Fact]
public void StreamingDecoderRoundTrip()
{
var decoder = new StreamingTokenDecoder(Encoding.UTF8, _model);

var strings = new[]
{
"Hello world",
"철수",
"😀 😃 😄 😁 😆철수😅 😂 😊 😇 🙂 ",
};

foreach (var input in strings)
{
decoder.Reset();

// Convert into llama tokens
var tokens = _model.NativeHandle.Tokenize(input, false, false, Encoding.UTF8);

// Add tokens to decoder
foreach (var token in tokens)
decoder.Add(token);

// llama.cpp adds a space to the start of strings, remove that
var output = decoder.Read().TrimStart(' ');

// Check that the input equals the output
Assert.Equal(input, output);
}
}
}
66 changes: 66 additions & 0 deletions LLama/AntipromptProcessor.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
using System;
using System.Collections.Generic;

namespace LLama;

internal sealed class AntipromptProcessor
{
private int _longestAntiprompt;
private readonly List<string> _antiprompts = new();

private string? _string;

public AntipromptProcessor(IEnumerable<string>? antiprompts = null)
{
if (antiprompts != null)
SetAntiprompts(antiprompts);
}

/// <summary>
/// Add an antiprompt to the collection
/// </summary>
/// <param name="antiprompt"></param>
public void AddAntiprompt(string antiprompt)
{
_antiprompts.Add(antiprompt);
_longestAntiprompt = Math.Max(_longestAntiprompt, antiprompt.Length);
}

/// <summary>
/// Overwrite all current antiprompts with a new set
/// </summary>
/// <param name="antiprompts"></param>
public void SetAntiprompts(IEnumerable<string> antiprompts)
{
_antiprompts.Clear();
_antiprompts.AddRange(antiprompts);

_longestAntiprompt = 0;
foreach (var antiprompt in _antiprompts)
_longestAntiprompt = Math.Max(_longestAntiprompt, antiprompt.Length);
}

/// <summary>
/// Add some text and check if the buffer now ends with any antiprompt
/// </summary>
/// <param name="text"></param>
/// <returns>true if the text buffer ends with any antiprompt</returns>
public bool Add(string text)
{
_string += text;

// When the string gets very long (4x antiprompt length) trim it down (to 2x antiprompt length).
// This trimming leaves a lot of extra characters because two sequences can be considered "equal" in unicode
// even with different numbers of characters. Hopefully there are enough characters here to handle all those weird circumstances!
var maxLength = Math.Max(32, _longestAntiprompt * 4);
var trimLength = Math.Max(16, _longestAntiprompt * 2);
if (_string.Length > maxLength)
_string = _string.Substring(_string.Length - trimLength);

foreach (var antiprompt in _antiprompts)
if (_string.EndsWith(antiprompt, StringComparison.CurrentCulture))
return true;

return false;
}
}
9 changes: 2 additions & 7 deletions LLama/Extensions/IReadOnlyListExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ internal static class IReadOnlyListExtensions
/// <param name="model">Model to use to convert tokens into bytes</param>
/// <param name="encoding">Encoding to use to convert bytes into characters</param>
/// <returns></returns>
[Obsolete("Use an Antiprompt processor instead")]
internal static bool TokensEndsWithAnyString<TTokens, TQueries>(this TTokens tokens, TQueries? queries, SafeLlamaModelHandle model, Encoding encoding)
where TTokens : IReadOnlyList<int>
where TQueries : IReadOnlyList<string>
Expand Down Expand Up @@ -68,13 +69,6 @@ internal static bool TokensEndsWithAnyString<TTokens, TQueries>(this TTokens tok
}
}

internal static bool TokensEndsWithAnyString<TTokens, TQueries>(this TTokens tokens, TQueries? queries, LLamaContext context)
where TTokens : IReadOnlyList<int>
where TQueries : IReadOnlyList<string>
{
return TokensEndsWithAnyString(tokens, queries, context.NativeHandle.ModelHandle, context.Encoding);
}

/// <summary>
/// Check if the given set of tokens ends with any of the given strings
/// </summary>
Expand All @@ -83,6 +77,7 @@ internal static bool TokensEndsWithAnyString<TTokens, TQueries>(this TTokens tok
/// <param name="model">Model to use to convert tokens into bytes</param>
/// <param name="encoding">Encoding to use to convert bytes into characters</param>
/// <returns></returns>
[Obsolete("Use an Antiprompt processor instead")]
internal static bool TokensEndsWithAnyString<TTokens>(this TTokens tokens, IList<string>? queries, SafeLlamaModelHandle model, Encoding encoding)
where TTokens : IReadOnlyList<int>
{
Expand Down
24 changes: 24 additions & 0 deletions LLama/Extensions/ListExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
using System;
using System.Collections.Generic;

namespace LLama.Extensions
{
internal static class ListExtensions
{
#if NETSTANDARD2_0
public static void EnsureCapacity<T>(this List<T> list, int capacity)
{
if (list.Capacity < capacity)
list.Capacity = capacity;
}
#endif

public static void AddSpan<T>(this List<T> list, ReadOnlySpan<T> items)
{
list.EnsureCapacity(list.Count + items.Length);

for (var i = 0; i < items.Length; i++)
list.Add(items[i]);
}
}
}
32 changes: 7 additions & 25 deletions LLama/LLamaContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,15 @@ public llama_token[] Tokenize(string text, bool addBos = true, bool special = fa
/// </summary>
/// <param name="tokens"></param>
/// <returns></returns>
public string DeTokenize(IEnumerable<llama_token> tokens)
[Obsolete("Use a `StreamingTokenDecoder` instead")]
public string DeTokenize(IReadOnlyList<llama_token> tokens)
{
var sb = new StringBuilder();
foreach (var token in tokens)
NativeHandle.TokenToString(token, Encoding, sb);
// Do **not** use this method as an example of how to correctly use the StreamingTokenDecoder!
// It should be kept around for the entire time you are decoding one stream of tokens.

return sb.ToString();
var decoder = new StreamingTokenDecoder(this);
decoder.AddRange(tokens);
return decoder.ToString();
}

/// <summary>
Expand Down Expand Up @@ -418,26 +420,6 @@ public int Eval(ReadOnlySpan<llama_token> tokens, int pastTokensCount)
}
#endregion

/// <summary>
/// Convert a token into a string
/// </summary>
/// <param name="token"></param>
/// <returns></returns>
public string TokenToString(llama_token token)
{
return NativeHandle.TokenToString(token, Encoding);
}

/// <summary>
/// Append a single token to a string builder
/// </summary>
/// <param name="token">Token to decode</param>
/// <param name="dest">string builder to append the result to</param>
public void TokenToString(llama_token token, StringBuilder dest)
{
NativeHandle.TokenToString(token, Encoding, dest);
}

/// <inheritdoc />
public void Dispose()
{
Expand Down
5 changes: 1 addition & 4 deletions LLama/LLamaExecutorBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -294,10 +294,7 @@ public virtual async IAsyncEnumerable<string> InferAsync(string text, IInference
await InferInternal(inferenceParams, args);

if (args.ReturnValue)
{
foreach (var id in _embeds)
yield return Context.TokenToString(id);
}
yield return Context.DeTokenize(_embeds);

var (breakGeneration, extraOutputs) = await PostProcess(inferenceParams, args);
if (extraOutputs is { Count: > 0 })
Expand Down
11 changes: 8 additions & 3 deletions LLama/LLamaStatelessExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ public async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams?
Context.Dispose();
Context = _weights.CreateContext(Context.Params, _logger);

var decoder = new StreamingTokenDecoder(Context);
var antiprocessor = new AntipromptProcessor(inferenceParams?.AntiPrompts ?? Array.Empty<string>());

if (inferenceParams != null)
{
if (inferenceParams.TokensKeep > Context.ContextSize)
Expand All @@ -64,7 +67,6 @@ public async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams?

cancellationToken.ThrowIfCancellationRequested();

var antiprompts = inferenceParams?.AntiPrompts.ToArray() ?? Array.Empty<string>();
inferenceParams ??= new InferenceParams();

var lastTokens = new List<llama_token>(inferenceParams.RepeatLastTokensCount);
Expand Down Expand Up @@ -95,13 +97,16 @@ public async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams?
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar);

lastTokens.Add(id);
yield return Context.TokenToString(id);

decoder.Add(id);
var decoded = decoder.Read();
yield return decoded;

tokens.Clear();
tokens.Add(id);

// Check if any of the antiprompts have been generated
if (lastTokens.TokensEndsWithAnyString(antiprompts, Context))
if (antiprocessor.Add(decoded))
break;

// when run out of context
Expand Down
Loading