Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions LLama.Examples/NewVersion/ChatSessionStripRoleName.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ namespace LLama.Examples.NewVersion
{
public class ChatSessionStripRoleName
{
public static void Run()
public static async Task Run()
{
Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine();
Expand All @@ -30,7 +30,7 @@ public static void Run()
Console.Write(prompt);
while (true)
{
foreach (var text in session.Chat(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } }))
await foreach (var text in session.ChatAsync(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } }))
{
Console.Write(text);
}
Expand Down
4 changes: 2 additions & 2 deletions LLama.Examples/NewVersion/ChatSessionWithRoleName.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ namespace LLama.Examples.NewVersion
{
public class ChatSessionWithRoleName
{
public static void Run()
public static async Task Run()
{
Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine();
Expand All @@ -30,7 +30,7 @@ public static void Run()
Console.Write(prompt);
while (true)
{
foreach (var text in session.Chat(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } }))
await foreach (var text in session.ChatAsync(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } }))
{
Console.Write(text);
}
Expand Down
6 changes: 3 additions & 3 deletions LLama.Examples/NewVersion/GrammarJsonResponse.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ namespace LLama.Examples.NewVersion
{
public class GrammarJsonResponse
{
public static void Run()
public static async Task Run()
{
var gbnf = File.ReadAllText("Assets/json.gbnf").Trim();
var gbnf = (await File.ReadAllTextAsync("Assets/json.gbnf")).Trim();
var grammar = Grammar.Parse(gbnf, "root");

Console.Write("Please input your model path: ");
Expand Down Expand Up @@ -43,7 +43,7 @@ public static void Run()
Console.ForegroundColor = ConsoleColor.White;
Console.Write("Answer: ");
prompt = $"Question: {prompt?.Trim()} Answer: ";
foreach (var text in ex.Infer(prompt, inferenceParams))
await foreach (var text in ex.InferAsync(prompt, inferenceParams))
{
Console.Write(text);
}
Expand Down
4 changes: 2 additions & 2 deletions LLama.Examples/NewVersion/InstructModeExecute.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ namespace LLama.Examples.NewVersion
{
public class InstructModeExecute
{
public static void Run()
public static async Task Run()
{
Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine();
Expand All @@ -29,7 +29,7 @@ public static void Run()

while (true)
{
foreach (var text in executor.Infer(prompt, inferenceParams))
await foreach (var text in executor.InferAsync(prompt, inferenceParams))
{
Console.Write(text);
}
Expand Down
4 changes: 2 additions & 2 deletions LLama.Examples/NewVersion/LoadAndSaveSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ namespace LLama.Examples.NewVersion
{
public class SaveAndLoadSession
{
public static void Run()
public static async Task Run()
{
Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine();
Expand All @@ -30,7 +30,7 @@ public static void Run()
Console.Write(prompt);
while (true)
{
foreach (var text in session.Chat(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } }))
await foreach (var text in session.ChatAsync(prompt, new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" } }))
{
Console.Write(text);
}
Expand Down
4 changes: 2 additions & 2 deletions LLama.Examples/NewVersion/LoadAndSaveState.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ namespace LLama.Examples.NewVersion
{
public class LoadAndSaveState
{
public static void Run()
public static async Task Run()
{
Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine();
Expand All @@ -30,7 +30,7 @@ public static void Run()

while (true)
{
foreach (var text in ex.Infer(prompt, inferenceParams))
await foreach (var text in ex.InferAsync(prompt, inferenceParams))
{
Console.Write(text);
}
Expand Down
4 changes: 2 additions & 2 deletions LLama.Examples/NewVersion/StatelessModeExecute.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ namespace LLama.Examples.NewVersion
{
public class StatelessModeExecute
{
public static void Run()
public static async Task Run()
{
Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine();
Expand Down Expand Up @@ -35,7 +35,7 @@ public static void Run()
Console.ForegroundColor = ConsoleColor.White;
Console.Write("Answer: ");
prompt = $"Question: {prompt?.Trim()} Answer: ";
foreach (var text in ex.Infer(prompt, inferenceParams))
await foreach (var text in ex.InferAsync(prompt, inferenceParams))
{
Console.Write(text);
}
Expand Down
14 changes: 7 additions & 7 deletions LLama.Examples/NewVersion/TestRunner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,31 +29,31 @@ public static async Task Run()

if (choice == 0)
{
ChatSessionWithRoleName.Run();
await ChatSessionWithRoleName.Run();
}
else if (choice == 1)
{
ChatSessionStripRoleName.Run();
await ChatSessionStripRoleName.Run();
}
else if(choice == 2)
{
await InteractiveModeExecute.Run();
}
else if(choice == 3)
{
InstructModeExecute.Run();
await InstructModeExecute.Run();
}
else if(choice == 4)
{
StatelessModeExecute.Run();
await StatelessModeExecute.Run();
}
else if(choice == 5)
{
SaveAndLoadSession.Run();
await SaveAndLoadSession.Run();
}
else if(choice == 6)
{
LoadAndSaveState.Run();
await LoadAndSaveState.Run();
}
else if(choice == 7)
{
Expand All @@ -69,7 +69,7 @@ public static async Task Run()
}
else if (choice == 10)
{
GrammarJsonResponse.Run();
await GrammarJsonResponse.Run();
}
else if (choice == 11)
{
Expand Down
4 changes: 2 additions & 2 deletions LLama.Unittest/GrammarTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public void CreateBasicGrammar()
}

[Fact]
public void SampleWithTrivialGrammar()
public async Task SampleWithTrivialGrammar()
{
// Create a grammar that constrains the output to be "cat" and nothing else. This is a nonsense answer, so
// we can be confident it's not what the LLM would say if not constrained by the grammar!
Expand All @@ -66,7 +66,7 @@ public void SampleWithTrivialGrammar()
Grammar = grammar,
};

var result = executor.Infer("Q. 7 + 12\nA. ", inferenceParams).ToList();
var result = await executor.InferAsync("Q. 7 + 12\nA. ", inferenceParams).ToListAsync();

Assert.Equal("cat", result[0]);
}
Expand Down
1 change: 1 addition & 0 deletions LLama.Unittest/LLama.Unittest.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

<ItemGroup>
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.7.1" />
<PackageReference Include="System.Linq.Async" Version="6.0.1" />
<PackageReference Include="xunit" Version="2.5.0" />
<PackageReference Include="xunit.runner.visualstudio" Version="2.5.0">
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
Expand Down
12 changes: 6 additions & 6 deletions LLama.Unittest/StatelessExecutorTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ public void Dispose()
}

[Fact]
public void Stateless()
public async Task Stateless()
{
var executor = new StatelessExecutor(_weights, _params);

const string question = "Question. what is a cat?\nAnswer: ";
var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." } };

var result1 = string.Join("", executor.Infer(question, @params));
var result2 = string.Join("", executor.Infer(question, @params));
var result1 = string.Join("", await executor.InferAsync(question, @params).ToListAsync());
var result2 = string.Join("", await executor.InferAsync(question, @params).ToListAsync());

_testOutputHelper.WriteLine(result1);

Expand All @@ -44,7 +44,7 @@ public void Stateless()
}

[Fact]
public void OutOfContext()
public async Task OutOfContext()
{
var executor = new StatelessExecutor(_weights, _params);

Expand All @@ -58,8 +58,8 @@ public void OutOfContext()
TokensKeep = question.Length,
};

var result1 = string.Join("", executor.Infer(question, @params));
var result2 = string.Join("", executor.Infer(question, @params));
var result1 = string.Join("", await executor.InferAsync(question, @params).ToListAsync());
var result2 = string.Join("", await executor.InferAsync(question, @params).ToListAsync());

_testOutputHelper.WriteLine(result1);

Expand Down
2 changes: 1 addition & 1 deletion LLama.WebAPI/Controllers/ChatController.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public ChatController(ILogger<ChatController> logger)
}

[HttpPost("Send")]
public string SendMessage([FromBody] SendMessageInput input, [FromServices] StatefulChatService _service)
public Task<string> SendMessage([FromBody] SendMessageInput input, [FromServices] StatefulChatService _service)
{
return _service.Send(input);
}
Expand Down
6 changes: 3 additions & 3 deletions LLama.WebAPI/Services/StatefulChatService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public void Dispose()
_context?.Dispose();
}

public string Send(SendMessageInput input)
public async Task<string> Send(SendMessageInput input)
{
var userInput = input.Text;
if (!_continue)
Expand All @@ -42,13 +42,13 @@ public string Send(SendMessageInput input)
Console.Write(input.Text);

Console.ForegroundColor = ConsoleColor.White;
var outputs = _session.Chat(userInput, new Common.InferenceParams()
var outputs = _session.ChatAsync(userInput, new Common.InferenceParams()
{
RepeatPenalty = 1.0f,
AntiPrompts = new string[] { "User:" },
});
var result = "";
foreach (var output in outputs)
await foreach (var output in outputs)
{
Console.Write(output);
result += output;
Expand Down
9 changes: 0 additions & 9 deletions LLama/Abstractions/ILLamaExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,6 @@ public interface ILLamaExecutor
/// </summary>
public LLamaContext Context { get; }

/// <summary>
/// Infers a response from the model.
/// </summary>
/// <param name="text">Your prompt</param>
/// <param name="inferenceParams">Any additional parameters</param>
/// <param name="token">A cancellation token.</param>
/// <returns></returns>
IEnumerable<string> Infer(string text, IInferenceParams? inferenceParams = null, CancellationToken token = default);

/// <summary>
/// Asynchronously infers a response from the model.
/// </summary>
Expand Down
7 changes: 0 additions & 7 deletions LLama/Abstractions/ITextStreamTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,6 @@ namespace LLama.Abstractions
/// </summary>
public interface ITextStreamTransform
{
/// <summary>
/// Takes a stream of tokens and transforms them, returning a new stream of tokens.
/// </summary>
/// <param name="tokens"></param>
/// <returns></returns>
IEnumerable<string> Transform(IEnumerable<string> tokens);

/// <summary>
/// Takes a stream of tokens and transforms them, returning a new stream of tokens asynchronously.
/// </summary>
Expand Down
56 changes: 3 additions & 53 deletions LLama/ChatSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -134,26 +134,6 @@ public virtual void LoadSession(string path)
}
}

/// <summary>
/// Get the response from the LLama model with chat histories.
/// </summary>
/// <param name="history"></param>
/// <param name="inferenceParams"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
public IEnumerable<string> Chat(ChatHistory history, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
{
var prompt = HistoryTransform.HistoryToText(history);
History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages);
StringBuilder sb = new();
foreach (var result in ChatInternal(prompt, inferenceParams, cancellationToken))
{
yield return result;
sb.Append(result);
}
History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages);
}

/// <summary>
/// Get the response from the LLama model. Note that prompt could not only be the preset words,
/// but also the question you want to ask.
Expand All @@ -162,15 +142,14 @@ public IEnumerable<string> Chat(ChatHistory history, IInferenceParams? inference
/// <param name="inferenceParams"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
public IEnumerable<string> Chat(string prompt, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
public async IAsyncEnumerable<string> ChatAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
foreach(var inputTransform in InputTransformPipeline)
{
prompt = inputTransform.Transform(prompt);
}

History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages);
StringBuilder sb = new();
foreach (var result in ChatInternal(prompt, inferenceParams, cancellationToken))
await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))
{
yield return result;
sb.Append(result);
Expand Down Expand Up @@ -198,35 +177,6 @@ public async IAsyncEnumerable<string> ChatAsync(ChatHistory history, IInferenceP
History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages);
}

/// <summary>
/// Get the response from the LLama model with chat histories asynchronously.
/// </summary>
/// <param name="prompt"></param>
/// <param name="inferenceParams"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
public async IAsyncEnumerable<string> ChatAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
foreach (var inputTransform in InputTransformPipeline)
{
prompt = inputTransform.Transform(prompt);
}
History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages);
StringBuilder sb = new();
await foreach (var result in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))
{
yield return result;
sb.Append(result);
}
History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages);
}

private IEnumerable<string> ChatInternal(string prompt, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
{
var results = _executor.Infer(prompt, inferenceParams, cancellationToken);
return OutputTransform.Transform(results);
}

private async IAsyncEnumerable<string> ChatAsyncInternal(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var results = _executor.InferAsync(prompt, inferenceParams, cancellationToken);
Expand Down
Loading