From 48c5039054c9e5d59c39531b2b1d184784022608 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Sat, 18 Nov 2023 02:40:36 +0000 Subject: [PATCH] Improved test coverage. Discovered some issues: FixedSizeQueue: - Enqueue would always stop one short of filling the capacity - Fill would only _replace_ existing items. It was only used in a place where there were not existing items! Removed the method entirely. LLamaGrammarElement: - Converted into a `record` struct, removed all of the (now unnecessary) equality stuff. --- LLama.Unittest/FixedSizeQueueTests.cs | 95 +++++++++++++++++++++++++++ LLama.Unittest/GrammarTest.cs | 21 +++++- LLama.Unittest/LLamaEmbedderTests.cs | 22 +------ LLama/Common/FixedSizeQueue.cs | 24 ++----- LLama/LLamaExecutorBase.cs | 2 +- LLama/LLamaInstructExecutor.cs | 2 +- LLama/LLamaInteractExecutor.cs | 2 +- LLama/Native/LLamaGrammarElement.cs | 37 +---------- 8 files changed, 125 insertions(+), 80 deletions(-) create mode 100644 LLama.Unittest/FixedSizeQueueTests.cs diff --git a/LLama.Unittest/FixedSizeQueueTests.cs b/LLama.Unittest/FixedSizeQueueTests.cs new file mode 100644 index 000000000..db75579f8 --- /dev/null +++ b/LLama.Unittest/FixedSizeQueueTests.cs @@ -0,0 +1,95 @@ +using LLama.Common; + +namespace LLama.Unittest; + +public class FixedSizeQueueTests +{ + [Fact] + public void Create() + { + var q = new FixedSizeQueue(7); + + Assert.Equal(7, q.Capacity); + Assert.Empty(q); + } + + [Fact] + public void CreateFromItems() + { + var q = new FixedSizeQueue(7, new [] { 1, 2, 3 }); + + Assert.Equal(7, q.Capacity); + Assert.Equal(3, q.Count); + Assert.True(q.ToArray().SequenceEqual(new[] { 1, 2, 3 })); + } + + [Fact] + public void Indexing() + { + var q = new FixedSizeQueue(7, new[] { 1, 2, 3 }); + + Assert.Equal(1, q[0]); + Assert.Equal(2, q[1]); + Assert.Equal(3, q[2]); + + Assert.Throws(() => q[3]); + } + + [Fact] + public void CreateFromFullItems() + { + var q = new FixedSizeQueue(3, new[] { 1, 2, 3 }); + + Assert.Equal(3, q.Capacity); + Assert.Equal(3, q.Count); + Assert.True(q.ToArray().SequenceEqual(new[] { 1, 2, 3 })); + } + + [Fact] + public void CreateFromTooManyItems() + { + Assert.Throws(() => new FixedSizeQueue(2, new[] { 1, 2, 3 })); + } + + [Fact] + public void CreateFromTooManyItemsNonCountable() + { + Assert.Throws(() => new FixedSizeQueue(2, Items())); + return; + + static IEnumerable Items() + { + yield return 1; + yield return 2; + yield return 3; + } + } + + [Fact] + public void Enqueue() + { + var q = new FixedSizeQueue(7, new[] { 1, 2, 3 }); + + q.Enqueue(4); + q.Enqueue(5); + + Assert.Equal(7, q.Capacity); + Assert.Equal(5, q.Count); + Assert.True(q.ToArray().SequenceEqual(new[] { 1, 2, 3, 4, 5 })); + } + + [Fact] + public void EnqueueOverflow() + { + var q = new FixedSizeQueue(5, new[] { 1, 2, 3 }); + + q.Enqueue(4); + q.Enqueue(5); + q.Enqueue(6); + q.Enqueue(7); + + Assert.Equal(5, q.Capacity); + Assert.Equal(5, q.Count); + Assert.True(q.ToArray().SequenceEqual(new[] { 3, 4, 5, 6, 7 })); + } +} \ No newline at end of file diff --git a/LLama.Unittest/GrammarTest.cs b/LLama.Unittest/GrammarTest.cs index 7bd012b80..3d7d1dada 100644 --- a/LLama.Unittest/GrammarTest.cs +++ b/LLama.Unittest/GrammarTest.cs @@ -40,6 +40,22 @@ public void CreateBasicGrammar() using var handle = SafeLLamaGrammarHandle.Create(rules, 0); } + [Fact] + public void CreateGrammar_StartIndexOutOfRange() + { + var rules = new List + { + new GrammarRule("alpha", new[] + { + new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'a'), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 'z'), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + }), + }; + + Assert.Throws(() => new Grammar(rules, 3)); + } + [Fact] public async Task SampleWithTrivialGrammar() { @@ -56,14 +72,15 @@ public async Task SampleWithTrivialGrammar() }), }; - using var grammar = SafeLLamaGrammarHandle.Create(rules, 0); + var grammar = new Grammar(rules, 0); + using var grammarInstance = grammar.CreateInstance(); var executor = new StatelessExecutor(_model, _params); var inferenceParams = new InferenceParams { MaxTokens = 3, AntiPrompts = new [] { ".", "Input:", "\n" }, - Grammar = grammar, + Grammar = grammarInstance, }; var result = await executor.InferAsync("Q. 7 + 12\nA. ", inferenceParams).ToListAsync(); diff --git a/LLama.Unittest/LLamaEmbedderTests.cs b/LLama.Unittest/LLamaEmbedderTests.cs index a4bd5867a..b8fede8f5 100644 --- a/LLama.Unittest/LLamaEmbedderTests.cs +++ b/LLama.Unittest/LLamaEmbedderTests.cs @@ -2,7 +2,7 @@ namespace LLama.Unittest; -public class LLamaEmbedderTests +public sealed class LLamaEmbedderTests : IDisposable { private readonly LLamaEmbedder _embedder; @@ -37,26 +37,6 @@ private static float Dot(float[] a, float[] b) return a.Zip(b, (x, y) => x * y).Sum(); } - private static void AssertApproxStartsWith(float[] expected, float[] actual, float epsilon = 0.08f) - { - for (int i = 0; i < expected.Length; i++) - Assert.Equal(expected[i], actual[i], epsilon); - } - - // todo: enable this one llama2 7B gguf is available - //[Fact] - //public void EmbedBasic() - //{ - // var cat = _embedder.GetEmbeddings("cat"); - - // Assert.NotNull(cat); - // Assert.NotEmpty(cat); - - // // Expected value generate with llama.cpp embedding.exe - // var expected = new float[] { -0.127304f, -0.678057f, -0.085244f, -0.956915f, -0.638633f }; - // AssertApproxStartsWith(expected, cat); - //} - [Fact] public void EmbedCompare() { diff --git a/LLama/Common/FixedSizeQueue.cs b/LLama/Common/FixedSizeQueue.cs index 37fb1cf51..6d272f23f 100644 --- a/LLama/Common/FixedSizeQueue.cs +++ b/LLama/Common/FixedSizeQueue.cs @@ -2,6 +2,7 @@ using System.Collections; using System.Collections.Generic; using System.Linq; +using LLama.Extensions; namespace LLama.Common { @@ -10,11 +11,12 @@ namespace LLama.Common /// Currently it's only a naive implementation and needs to be further optimized in the future. /// public class FixedSizeQueue - : IEnumerable + : IReadOnlyList { private readonly List _storage; - internal IReadOnlyList Items => _storage; + /// + public T this[int index] => _storage[index]; /// /// Number of items in this queue @@ -59,20 +61,6 @@ public FixedSizeQueue(int size, IEnumerable data) throw new ArgumentException($"The max size set for the quene is {size}, but got {_storage.Count} initial values."); } - /// - /// Replace every item in the queue with the given value - /// - /// The value to replace all items with - /// returns this - public FixedSizeQueue FillWith(T value) - { - for(var i = 0; i < Count; i++) - { - _storage[i] = value; - } - return this; - } - /// /// Enquene an element. /// @@ -80,10 +68,8 @@ public FixedSizeQueue FillWith(T value) public void Enqueue(T item) { _storage.Add(item); - if(_storage.Count >= Capacity) - { + if (_storage.Count > Capacity) _storage.RemoveAt(0); - } } /// diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index f047ab892..e0fde1edb 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -84,7 +84,7 @@ protected StatefulExecutorBase(LLamaContext context, ILogger? logger = null) _pastTokensCount = 0; _consumedTokensCount = 0; _n_session_consumed = 0; - _last_n_tokens = new FixedSizeQueue(Context.ContextSize).FillWith(0); + _last_n_tokens = new FixedSizeQueue(Context.ContextSize); _decoder = new StreamingTokenDecoder(context); } diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index 33cbd23e9..d81630aa9 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -151,7 +151,7 @@ protected override Task PreprocessInputs(string text, InferStateArgs args) { if (_embed_inps.Count <= _consumedTokensCount) { - if (_last_n_tokens.Items.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding)) + if (_last_n_tokens.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding)) { args.WaitForInput = true; return (true, Array.Empty()); diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index 98b45814c..4d28274b4 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -134,7 +134,7 @@ protected override Task PreprocessInputs(string text, InferStateArgs args) { if (_embed_inps.Count <= _consumedTokensCount) { - if (_last_n_tokens.Items.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding)) + if (_last_n_tokens.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding)) args.WaitForInput = true; if (_pastTokensCount > 0 && args.WaitForInput) diff --git a/LLama/Native/LLamaGrammarElement.cs b/LLama/Native/LLamaGrammarElement.cs index 96313f239..c41d93b70 100644 --- a/LLama/Native/LLamaGrammarElement.cs +++ b/LLama/Native/LLamaGrammarElement.cs @@ -1,5 +1,4 @@ -using System; -using System.Diagnostics; +using System.Diagnostics; using System.Runtime.InteropServices; namespace LLama.Native @@ -52,8 +51,7 @@ public enum LLamaGrammarElementType /// [StructLayout(LayoutKind.Sequential)] [DebuggerDisplay("{Type} {Value}")] - public struct LLamaGrammarElement - : IEquatable + public record struct LLamaGrammarElement { /// /// The type of this element @@ -76,37 +74,6 @@ public LLamaGrammarElement(LLamaGrammarElementType type, uint value) Value = value; } - /// - public bool Equals(LLamaGrammarElement other) - { - if (Type != other.Type) - return false; - - // No need to compare values for the END rule - if (Type == LLamaGrammarElementType.END) - return true; - - return Value == other.Value; - } - - /// - public override bool Equals(object? obj) - { - return obj is LLamaGrammarElement other && Equals(other); - } - - /// - public override int GetHashCode() - { - unchecked - { - var hash = 2999; - hash = hash * 7723 + (int)Type; - hash = hash * 7723 + (int)Value; - return hash; - } - } - internal bool IsCharElement() { switch (Type)