diff --git a/LLama.Unittest/DictionaryExtensionsTests.cs b/LLama.Unittest/DictionaryExtensionsTests.cs new file mode 100644 index 000000000..a3f0c9b1d --- /dev/null +++ b/LLama.Unittest/DictionaryExtensionsTests.cs @@ -0,0 +1,38 @@ +using LLama.Extensions; + +namespace LLama.Unittest +{ + public class DictionaryExtensionsTests + { + [Fact] + public void GetDefaultValueEmptyDict() + { + var dict = new Dictionary(); + + Assert.Equal(42, DictionaryExtensions.GetValueOrDefaultImpl(dict, 0, 42)); + } + + [Fact] + public void GetDefaultValueMissingKey() + { + var dict = new Dictionary() + { + { 3, 4 } + }; + + Assert.Equal(43, DictionaryExtensions.GetValueOrDefaultImpl(dict, 0, 43)); + } + + [Fact] + public void GetValue() + { + var dict = new Dictionary() + { + { 3, 4 }, + { 4, 5 }, + }; + + Assert.Equal(4, DictionaryExtensions.GetValueOrDefaultImpl(dict, 3, 42)); + } + } +} diff --git a/LLama.Unittest/EncodingExtensionsTests.cs b/LLama.Unittest/EncodingExtensionsTests.cs new file mode 100644 index 000000000..705980a03 --- /dev/null +++ b/LLama.Unittest/EncodingExtensionsTests.cs @@ -0,0 +1,37 @@ +using System.Text; +using EncodingExtensions = LLama.Extensions.EncodingExtensions; + +namespace LLama.Unittest +{ + public class EncodingExtensionsTests + { + private static void GetCharsTest(string str) + { + var bytes = Encoding.UTF8.GetBytes(str); + + var chars = new char[128]; + var count = EncodingExtensions.GetCharsImpl(Encoding.UTF8, bytes, chars); + + Assert.Equal(str.Length, count); + Assert.True(chars[..count].SequenceEqual(str)); + } + + [Fact] + public void GetCharsEmptyString() + { + GetCharsTest(""); + } + + [Fact] + public void GetCharsString() + { + GetCharsTest("Hello World"); + } + + [Fact] + public void GetCharsChineseString() + { + GetCharsTest("猫坐在垫子上"); + } + } +} diff --git a/LLama.Unittest/IEnumerableExtensionsTests.cs b/LLama.Unittest/IEnumerableExtensionsTests.cs new file mode 100644 index 000000000..18bc45f7c --- /dev/null +++ b/LLama.Unittest/IEnumerableExtensionsTests.cs @@ -0,0 +1,36 @@ +using LLama.Extensions; + +namespace LLama.Unittest; + +public class IEnumerableExtensionsTests +{ + [Fact] + public void TakeLastEmpty() + { + var arr = Array.Empty(); + + var last = IEnumerableExtensions.TakeLastImpl(arr, 5).ToList(); + + Assert.Empty(last); + } + + [Fact] + public void TakeLastAll() + { + var arr = new[] { 1, 2, 3, 4, 5 }; + + var last = IEnumerableExtensions.TakeLastImpl(arr, 5).ToList(); + + Assert.True(last.SequenceEqual(arr)); + } + + [Fact] + public void TakeLast() + { + var arr = new[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }; + + var last = IEnumerableExtensions.TakeLastImpl(arr, 5).ToList(); + + Assert.True(last.SequenceEqual(arr[5..])); + } +} \ No newline at end of file diff --git a/LLama.Unittest/IReadOnlyListExtensionsTests.cs b/LLama.Unittest/IReadOnlyListExtensionsTests.cs new file mode 100644 index 000000000..a5a653759 --- /dev/null +++ b/LLama.Unittest/IReadOnlyListExtensionsTests.cs @@ -0,0 +1,22 @@ +using LLama.Extensions; + +namespace LLama.Unittest; + +public class IReadOnlyListExtensionsTests +{ + [Fact] + public void IndexOfItem() + { + var items = (IReadOnlyList)new List { 1, 2, 3, 4, }; + + Assert.Equal(2, items.IndexOf(3)); + } + + [Fact] + public void IndexOfItemNotFound() + { + var items = (IReadOnlyList)new List { 1, 2, 3, 4, }; + + Assert.Null(items.IndexOf(42)); + } +} \ No newline at end of file diff --git a/LLama.Unittest/KeyValuePairExtensionsTests.cs b/LLama.Unittest/KeyValuePairExtensionsTests.cs new file mode 100644 index 000000000..70ffc086e --- /dev/null +++ b/LLama.Unittest/KeyValuePairExtensionsTests.cs @@ -0,0 +1,15 @@ +namespace LLama.Unittest; + +public class KeyValuePairExtensionsTests +{ + [Fact] + public void Deconstruct() + { + var kvp = new KeyValuePair(1, "2"); + + var (a, b) = kvp; + + Assert.Equal(1, a); + Assert.Equal("2", b); + } +} \ No newline at end of file diff --git a/LLama.Unittest/TokenTests.cs b/LLama.Unittest/TokenTests.cs index fd17727b3..7ec484b81 100644 --- a/LLama.Unittest/TokenTests.cs +++ b/LLama.Unittest/TokenTests.cs @@ -43,7 +43,7 @@ public void TokensEndSubstring() { var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, Encoding.UTF8); - var result = tokens.TokensEndsWithAnyString(new[] + var result = tokens.TokensEndsWithAnyString((IList)new[] { "at", }, _model.NativeHandle, Encoding.UTF8); @@ -55,7 +55,7 @@ public void TokensNotEndWith() { var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, Encoding.UTF8); - var result = tokens.TokensEndsWithAnyString(new[] + var result = tokens.TokensEndsWithAnyString((IList)new[] { "a fish", "The cat sat on the edge of the ma", @@ -69,7 +69,7 @@ public void TokensNotEndWithNothing() { var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, Encoding.UTF8); - var result = tokens.TokensEndsWithAnyString(Array.Empty(), _model.NativeHandle, Encoding.UTF8); + var result = tokens.TokensEndsWithAnyString((IList)Array.Empty(), _model.NativeHandle, Encoding.UTF8); Assert.False(result); } } \ No newline at end of file diff --git a/LLama/Extensions/DictionaryExtensions.cs b/LLama/Extensions/DictionaryExtensions.cs index e5a27d6d5..a39ed7e8b 100644 --- a/LLama/Extensions/DictionaryExtensions.cs +++ b/LLama/Extensions/DictionaryExtensions.cs @@ -7,8 +7,13 @@ internal static class DictionaryExtensions #if NETSTANDARD2_0 public static TValue GetValueOrDefault(this IReadOnlyDictionary dictionary, TKey key, TValue defaultValue) { - return dictionary.TryGetValue(key, out var value) ? value : defaultValue; + return GetValueOrDefaultImpl(dictionary, key, defaultValue); } #endif + + internal static TValue GetValueOrDefaultImpl(IReadOnlyDictionary dictionary, TKey key, TValue defaultValue) + { + return dictionary.TryGetValue(key, out var value) ? value : defaultValue; + } } } diff --git a/LLama/Extensions/EncodingExtensions.cs b/LLama/Extensions/EncodingExtensions.cs index 29073fea5..28f5c2c00 100644 --- a/LLama/Extensions/EncodingExtensions.cs +++ b/LLama/Extensions/EncodingExtensions.cs @@ -9,6 +9,20 @@ internal static class EncodingExtensions #if NETSTANDARD2_0 public static int GetChars(this Encoding encoding, ReadOnlySpan bytes, Span output) { + return GetCharsImpl(encoding, bytes, output); + } + + public static int GetCharCount(this Encoding encoding, ReadOnlySpan bytes) + { + return GetCharCountImpl(encoding, bytes); + } +#endif + + internal static int GetCharsImpl(Encoding encoding, ReadOnlySpan bytes, Span output) + { + if (bytes.Length == 0) + return 0; + unsafe { fixed (byte* bytePtr = bytes) @@ -19,7 +33,7 @@ public static int GetChars(this Encoding encoding, ReadOnlySpan bytes, Spa } } - public static int GetCharCount(this Encoding encoding, ReadOnlySpan bytes) + internal static int GetCharCountImpl(Encoding encoding, ReadOnlySpan bytes) { unsafe { @@ -29,5 +43,4 @@ public static int GetCharCount(this Encoding encoding, ReadOnlySpan bytes) } } } -#endif } \ No newline at end of file diff --git a/LLama/Extensions/IEnumerableExtensions.cs b/LLama/Extensions/IEnumerableExtensions.cs index ebc234be0..9e01feb85 100644 --- a/LLama/Extensions/IEnumerableExtensions.cs +++ b/LLama/Extensions/IEnumerableExtensions.cs @@ -7,15 +7,20 @@ internal static class IEnumerableExtensions { #if NETSTANDARD2_0 public static IEnumerable TakeLast(this IEnumerable source, int count) + { + return TakeLastImpl(source, count); + } +#endif + + internal static IEnumerable TakeLastImpl(IEnumerable source, int count) { var list = source.ToList(); - + if (count >= list.Count) return list; list.RemoveRange(0, list.Count - count); return list; } -#endif } } diff --git a/LLama/Extensions/ListExtensions.cs b/LLama/Extensions/ListExtensions.cs deleted file mode 100644 index c78d311ce..000000000 --- a/LLama/Extensions/ListExtensions.cs +++ /dev/null @@ -1,14 +0,0 @@ -using System; -using System.Collections.Generic; - -namespace LLama.Extensions -{ - internal static class ListExtensions - { - public static void AddRangeSpan(this List list, ReadOnlySpan span) - { - for (var i = 0; i < span.Length; i++) - list.Add(span[i]); - } - } -}