Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions LLama.Unittest/EncodingExtensionsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@ private static void GetCharsTest(string str)
Assert.True(chars[..count].SequenceEqual(str));
}

private static void GetCharCountTest(string str)
{
var bytes = Encoding.UTF8.GetBytes(str);

var count = EncodingExtensions.GetCharCountImpl(Encoding.UTF8, bytes);

Assert.Equal(str.Length, count);
}

[Fact]
public void GetCharsEmptyString()
{
Expand All @@ -33,5 +42,23 @@ public void GetCharsChineseString()
{
GetCharsTest("猫坐在垫子上");
}

[Fact]
public void GetCharCountEmptyString()
{
GetCharCountTest("");
}

[Fact]
public void GetCharCountString()
{
GetCharCountTest("Hello World");
}

[Fact]
public void GetCharCountChineseString()
{
GetCharCountTest("猫坐在垫子上");
}
}
}
89 changes: 88 additions & 1 deletion LLama.Unittest/GrammarParserTest.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using LLama.Exceptions;
using System.Text;
using LLama.Exceptions;
using LLama.Native;
using LLama.Grammars;

Expand Down Expand Up @@ -211,6 +212,61 @@ public void ParseExtraComplexGrammar()
CheckGrammar(grammarBytes, "root", expected, expectedRules);
}

[Fact]
public void ParseGrammarNotSequence()
{
var grammarBytes = @"root ::= [^a]";

var expected = new List<KeyValuePair<string, uint>>
{
new KeyValuePair<string, uint>("root", 0),
};

var expectedRules = new List<LLamaGrammarElement>
{
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_NOT, 97),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
};

CheckGrammar(grammarBytes, "root", expected, expectedRules);
}

[Fact]
public void ParseGrammarWithMultibyteCharacter()
{
var grammarBytes = @"root ::= [罗]*";

var expected = new List<KeyValuePair<string, uint>>
{
new KeyValuePair<string, uint>("root", 0),
new KeyValuePair<string, uint>("root_1", 1),
};

var expectedRules = new List<LLamaGrammarElement>
{
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 1),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 32599),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 1),
new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
};

CheckGrammar(grammarBytes, "root", expected, expectedRules);
}


[Fact]
public void InvalidGrammarMissingRuleDefinition()
{
var parsedGrammar = new GBNFGrammarParser();
var grammarBytes = @"root := [^a]";

Assert.Throws<GrammarExpectedNext>(() =>
{
parsedGrammar.Parse(grammarBytes, "root");
});
}

[Fact]
public void InvalidGrammarNoClosingBracket()
Expand Down Expand Up @@ -269,6 +325,37 @@ public void InvalidGrammarBadHex()
});
}

[Fact]
public void InvalidGrammarBadEscapeCharacter()
{
var parsedGrammar = new GBNFGrammarParser();
var grammarBytes = @"
root ::= (expr ""="" ws term ""\z"")+ ## <--- `\z` is not a valid escape character
expr ::= term ([-+*/] term)*
term ::= ident | num | ""("" ws expr "")"" ws
ident ::= [a-z] [a-z0-9_]* ws
num ::= [0-9]+ ws
ws ::= [ \t\n]*
";

Assert.Throws<GrammarUnknownEscapeCharacter>(() =>
{
parsedGrammar.Parse(grammarBytes, "root");
});
}

[Fact]
public void InvalidGrammarUnexpectedEndOfInput()
{
var parsedGrammar = new GBNFGrammarParser();
var grammarBytes = @"root ::= (expr ""="" ws term ""\";

Assert.Throws<GrammarUnexpectedEndOfInput>(() =>
{
parsedGrammar.Parse(grammarBytes, "root");
});
}


[Fact]
public void InvalidRuleNoElements()
Expand Down
2 changes: 1 addition & 1 deletion LLama.Unittest/StatelessExecutorTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public async Task Stateless()
Assert.Equal(result1, result2);
}

[Fact]
[Fact(Skip = "Very very slow in CI")]
public async Task OutOfContext()
{
var executor = new StatelessExecutor(_weights, _params);
Expand Down
27 changes: 27 additions & 0 deletions LLama.Unittest/TextTransformTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
namespace LLama.Unittest
{
public sealed class TextTransformTests
{
[Fact]
public void NaiveTextInputTransformTrimsText()
{
var transform = new LLamaTransforms.NaiveTextInputTransform();

Assert.Equal("hello", transform.Transform("hello"));
Assert.Equal("hello", transform.Transform(" hello"));
Assert.Equal("hello", transform.Transform("hello "));
Assert.Equal("hello", transform.Transform(" hello "));
Assert.Equal("hello world", transform.Transform(" hello world "));
}

[Fact]
public async Task EmptyTextOutputStreamTransformDoesNothing()
{
var input = new[] { "Hello", "world" };

var transform = new LLamaTransforms.EmptyTextOutputStreamTransform();

Assert.Equal(input, await transform.TransformAsync(input.ToAsyncEnumerable()).ToArrayAsync());
}
}
}
3 changes: 3 additions & 0 deletions LLama/Extensions/EncodingExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ internal static int GetCharsImpl(Encoding encoding, ReadOnlySpan<byte> bytes, Sp

internal static int GetCharCountImpl(Encoding encoding, ReadOnlySpan<byte> bytes)
{
if (bytes.Length == 0)
return 0;

unsafe
{
fixed (byte* bytePtr = bytes)
Expand Down
4 changes: 3 additions & 1 deletion LLama/Extensions/KeyValuePairExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ namespace LLama.Extensions
/// <summary>
/// Extensions to the KeyValuePair struct
/// </summary>
public static class KeyValuePairExtensions
internal static class KeyValuePairExtensions
{
#if NETSTANDARD2_0
/// <summary>
/// Deconstruct a KeyValuePair into it's constituent parts.
/// </summary>
Expand All @@ -20,5 +21,6 @@ public static void Deconstruct<TKey, TValue>(this KeyValuePair<TKey, TValue> pai
first = pair.Key;
second = pair.Value;
}
#endif
}
}
40 changes: 16 additions & 24 deletions LLama/Grammars/GBNFGrammarParser.cs
Original file line number Diff line number Diff line change
Expand Up @@ -155,35 +155,27 @@ private uint ParseChar(ref ReadOnlySpan<byte> src)
{
if (src[0] == '\\')
{
if (src.Length < 2)
throw new GrammarUnexpectedEndOfInput();

var chr = src[1];
src = src.Slice(2);
switch (chr)

return (char)chr switch
{
case (byte)'x':
return ParseHex(ref src, 2);
case (byte)'u':
return ParseHex(ref src, 4);
case (byte)'U':
return ParseHex(ref src, 8);
case (byte)'t':
return '\t';
case (byte)'r':
return '\r';
case (byte)'n':
return '\n';
case (byte)'\\':
case (byte)'"':
case (byte)'[':
case (byte)']':
return chr;
default:
throw new GrammarUnknownEscapeCharacter(Encoding.UTF8.GetString(src.ToArray()));
}
'x' => ParseHex(ref src, 2),
'u' => ParseHex(ref src, 4),
'U' => ParseHex(ref src, 8),
't' => '\t',
'r' => '\r',
'n' => '\n',
'\\' or '"' or '[' or ']' => chr,
_ => throw new GrammarUnknownEscapeCharacter(Encoding.UTF8.GetString(src.ToArray())),
};
}
else if (!src.IsEmpty)
{

if (!src.IsEmpty)
return DecodeUTF8(ref src);
}

throw new GrammarUnexpectedEndOfInput();
}
Expand Down
Loading