Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions LLama.Unittest/FixedSizeQueueTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
using LLama.Common;

namespace LLama.Unittest;

public class FixedSizeQueueTests
{
[Fact]
public void Create()
{
var q = new FixedSizeQueue<int>(7);

Assert.Equal(7, q.Capacity);
Assert.Empty(q);
}

[Fact]
public void CreateFromItems()
{
var q = new FixedSizeQueue<int>(7, new [] { 1, 2, 3 });

Assert.Equal(7, q.Capacity);
Assert.Equal(3, q.Count);
Assert.True(q.ToArray().SequenceEqual(new[] { 1, 2, 3 }));
}

[Fact]
public void Indexing()
{
var q = new FixedSizeQueue<int>(7, new[] { 1, 2, 3 });

Assert.Equal(1, q[0]);
Assert.Equal(2, q[1]);
Assert.Equal(3, q[2]);

Assert.Throws<ArgumentOutOfRangeException>(() => q[3]);
}

[Fact]
public void CreateFromFullItems()
{
var q = new FixedSizeQueue<int>(3, new[] { 1, 2, 3 });

Assert.Equal(3, q.Capacity);
Assert.Equal(3, q.Count);
Assert.True(q.ToArray().SequenceEqual(new[] { 1, 2, 3 }));
}

[Fact]
public void CreateFromTooManyItems()
{
Assert.Throws<ArgumentException>(() => new FixedSizeQueue<int>(2, new[] { 1, 2, 3 }));
}

[Fact]
public void CreateFromTooManyItemsNonCountable()
{
Assert.Throws<ArgumentException>(() => new FixedSizeQueue<int>(2, Items()));
return;

static IEnumerable<int> Items()
{
yield return 1;
yield return 2;
yield return 3;
}
}

[Fact]
public void Enqueue()
{
var q = new FixedSizeQueue<int>(7, new[] { 1, 2, 3 });

q.Enqueue(4);
q.Enqueue(5);

Assert.Equal(7, q.Capacity);
Assert.Equal(5, q.Count);
Assert.True(q.ToArray().SequenceEqual(new[] { 1, 2, 3, 4, 5 }));
}

[Fact]
public void EnqueueOverflow()
{
var q = new FixedSizeQueue<int>(5, new[] { 1, 2, 3 });

q.Enqueue(4);
q.Enqueue(5);
q.Enqueue(6);
q.Enqueue(7);

Assert.Equal(5, q.Capacity);
Assert.Equal(5, q.Count);
Assert.True(q.ToArray().SequenceEqual(new[] { 3, 4, 5, 6, 7 }));
}
}
21 changes: 19 additions & 2 deletions LLama.Unittest/GrammarTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,22 @@ public void CreateBasicGrammar()
using var handle = SafeLLamaGrammarHandle.Create(rules, 0);
}

[Fact]
public void CreateGrammar_StartIndexOutOfRange()
{
var rules = new List<GrammarRule>
{
new GrammarRule("alpha", new[]
{
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'a'),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 'z'),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
}),
};

Assert.Throws<ArgumentOutOfRangeException>(() => new Grammar(rules, 3));
}

[Fact]
public async Task SampleWithTrivialGrammar()
{
Expand All @@ -56,14 +72,15 @@ public async Task SampleWithTrivialGrammar()
}),
};

using var grammar = SafeLLamaGrammarHandle.Create(rules, 0);
var grammar = new Grammar(rules, 0);
using var grammarInstance = grammar.CreateInstance();

var executor = new StatelessExecutor(_model, _params);
var inferenceParams = new InferenceParams
{
MaxTokens = 3,
AntiPrompts = new [] { ".", "Input:", "\n" },
Grammar = grammar,
Grammar = grammarInstance,
};

var result = await executor.InferAsync("Q. 7 + 12\nA. ", inferenceParams).ToListAsync();
Expand Down
22 changes: 1 addition & 21 deletions LLama.Unittest/LLamaEmbedderTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

namespace LLama.Unittest;

public class LLamaEmbedderTests
public sealed class LLamaEmbedderTests
: IDisposable
{
private readonly LLamaEmbedder _embedder;
Expand Down Expand Up @@ -37,26 +37,6 @@ private static float Dot(float[] a, float[] b)
return a.Zip(b, (x, y) => x * y).Sum();
}

private static void AssertApproxStartsWith(float[] expected, float[] actual, float epsilon = 0.08f)
{
for (int i = 0; i < expected.Length; i++)
Assert.Equal(expected[i], actual[i], epsilon);
}

// todo: enable this one llama2 7B gguf is available
//[Fact]
//public void EmbedBasic()
//{
// var cat = _embedder.GetEmbeddings("cat");

// Assert.NotNull(cat);
// Assert.NotEmpty(cat);

// // Expected value generate with llama.cpp embedding.exe
// var expected = new float[] { -0.127304f, -0.678057f, -0.085244f, -0.956915f, -0.638633f };
// AssertApproxStartsWith(expected, cat);
//}

[Fact]
public void EmbedCompare()
{
Expand Down
24 changes: 5 additions & 19 deletions LLama/Common/FixedSizeQueue.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using LLama.Extensions;

namespace LLama.Common
{
Expand All @@ -10,11 +11,12 @@ namespace LLama.Common
/// Currently it's only a naive implementation and needs to be further optimized in the future.
/// </summary>
public class FixedSizeQueue<T>
: IEnumerable<T>
: IReadOnlyList<T>
{
private readonly List<T> _storage;

internal IReadOnlyList<T> Items => _storage;
/// <inheritdoc />
public T this[int index] => _storage[index];

/// <summary>
/// Number of items in this queue
Expand Down Expand Up @@ -59,31 +61,15 @@ public FixedSizeQueue(int size, IEnumerable<T> data)
throw new ArgumentException($"The max size set for the quene is {size}, but got {_storage.Count} initial values.");
}

/// <summary>
/// Replace every item in the queue with the given value
/// </summary>
/// <param name="value">The value to replace all items with</param>
/// <returns>returns this</returns>
public FixedSizeQueue<T> FillWith(T value)
{
for(var i = 0; i < Count; i++)
{
_storage[i] = value;
}
return this;
}

/// <summary>
/// Enquene an element.
/// </summary>
/// <returns></returns>
public void Enqueue(T item)
{
_storage.Add(item);
if(_storage.Count >= Capacity)
{
if (_storage.Count > Capacity)
_storage.RemoveAt(0);
}
}

/// <inheritdoc />
Expand Down
2 changes: 1 addition & 1 deletion LLama/LLamaExecutorBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ protected StatefulExecutorBase(LLamaContext context, ILogger? logger = null)
_pastTokensCount = 0;
_consumedTokensCount = 0;
_n_session_consumed = 0;
_last_n_tokens = new FixedSizeQueue<llama_token>(Context.ContextSize).FillWith(0);
_last_n_tokens = new FixedSizeQueue<llama_token>(Context.ContextSize);
_decoder = new StreamingTokenDecoder(context);
}

Expand Down
2 changes: 1 addition & 1 deletion LLama/LLamaInstructExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ protected override Task PreprocessInputs(string text, InferStateArgs args)
{
if (_embed_inps.Count <= _consumedTokensCount)
{
if (_last_n_tokens.Items.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding))
if (_last_n_tokens.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding))
{
args.WaitForInput = true;
return (true, Array.Empty<string>());
Expand Down
2 changes: 1 addition & 1 deletion LLama/LLamaInteractExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ protected override Task PreprocessInputs(string text, InferStateArgs args)
{
if (_embed_inps.Count <= _consumedTokensCount)
{
if (_last_n_tokens.Items.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding))
if (_last_n_tokens.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding))
args.WaitForInput = true;

if (_pastTokensCount > 0 && args.WaitForInput)
Expand Down
37 changes: 2 additions & 35 deletions LLama/Native/LLamaGrammarElement.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using System;
using System.Diagnostics;
using System.Diagnostics;
using System.Runtime.InteropServices;

namespace LLama.Native
Expand Down Expand Up @@ -52,8 +51,7 @@ public enum LLamaGrammarElementType
/// </summary>
[StructLayout(LayoutKind.Sequential)]
[DebuggerDisplay("{Type} {Value}")]
public struct LLamaGrammarElement
: IEquatable<LLamaGrammarElement>
public record struct LLamaGrammarElement
{
/// <summary>
/// The type of this element
Expand All @@ -76,37 +74,6 @@ public LLamaGrammarElement(LLamaGrammarElementType type, uint value)
Value = value;
}

/// <inheritdoc />
public bool Equals(LLamaGrammarElement other)
{
if (Type != other.Type)
return false;

// No need to compare values for the END rule
if (Type == LLamaGrammarElementType.END)
return true;

return Value == other.Value;
}

/// <inheritdoc />
public override bool Equals(object? obj)
{
return obj is LLamaGrammarElement other && Equals(other);
}

/// <inheritdoc />
public override int GetHashCode()
{
unchecked
{
var hash = 2999;
hash = hash * 7723 + (int)Type;
hash = hash * 7723 + (int)Value;
return hash;
}
}

internal bool IsCharElement()
{
switch (Type)
Expand Down