Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions LLama.Examples/Examples/SemanticKernelChat.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@ public static async Task Run()
// Load weights into memory
var parameters = new ModelParams(modelPath);
using var model = LLamaWeights.LoadFromFile(parameters);
using var context = model.CreateContext(parameters);
var ex = new InteractiveExecutor(context);
var ex = new StatelessExecutor(model, parameters);

var chatGPT = new LLamaSharpChatCompletion(ex);

var chatHistory = chatGPT.CreateNewChat("You are a librarian, expert about books");
var chatHistory = chatGPT.CreateNewChat("This is a conversation between the assistant and the user. \n\n You are a librarian, expert about books. ");

Console.WriteLine("Chat content:");
Console.WriteLine("------------------------");
Expand Down
8 changes: 4 additions & 4 deletions LLama.SemanticKernel/ChatCompletion/HistoryTransform.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using static LLama.LLamaTransforms;
using LLama.Common;
using System.Text;
using static LLama.LLamaTransforms;

namespace LLamaSharp.SemanticKernel.ChatCompletion;

Expand All @@ -10,8 +12,6 @@ public class HistoryTransform : DefaultHistoryTransform
/// <inheritdoc/>
public override string HistoryToText(global::LLama.Common.ChatHistory history)
{
var prompt = base.HistoryToText(history);
return prompt + "\nAssistant:";

return base.HistoryToText(history) + $"{AuthorRole.Assistant}: ";
}
}
43 changes: 22 additions & 21 deletions LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
using LLama;
using LLama.Abstractions;
using Microsoft.SemanticKernel.AI;
using Microsoft.SemanticKernel.AI.ChatCompletion;
using System.Runtime.CompilerServices;
using static LLama.LLamaTransforms;

namespace LLamaSharp.SemanticKernel.ChatCompletion;

Expand All @@ -10,10 +12,10 @@ namespace LLamaSharp.SemanticKernel.ChatCompletion;
/// </summary>
public sealed class LLamaSharpChatCompletion : IChatCompletion
{
private const string UserRole = "user:";
private const string AssistantRole = "assistant:";
private ChatSession session;
private readonly StatelessExecutor _model;
private ChatRequestSettings defaultRequestSettings;
private readonly IHistoryTransform historyTransform;
private readonly ITextStreamTransform outputTransform;

private readonly Dictionary<string, string> _attributes = new();

Expand All @@ -30,18 +32,17 @@ static ChatRequestSettings GetDefaultSettings()
};
}

public LLamaSharpChatCompletion(InteractiveExecutor model, ChatRequestSettings? defaultRequestSettings = default)
public LLamaSharpChatCompletion(StatelessExecutor model,
ChatRequestSettings? defaultRequestSettings = default,
IHistoryTransform? historyTransform = null,
ITextStreamTransform? outputTransform = null)
{
this.session = new ChatSession(model)
.WithHistoryTransform(new HistoryTransform())
.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { UserRole, AssistantRole }));
this.defaultRequestSettings = defaultRequestSettings ??= GetDefaultSettings();
}

public LLamaSharpChatCompletion(ChatSession session, ChatRequestSettings? defaultRequestSettings = default)
{
this.session = session;
this.defaultRequestSettings = defaultRequestSettings ??= GetDefaultSettings();
this._model = model;
this.defaultRequestSettings = defaultRequestSettings ?? GetDefaultSettings();
this.historyTransform = historyTransform ?? new HistoryTransform();
this.outputTransform = outputTransform ?? new KeywordTextOutputStreamTransform(new[] { $"{LLama.Common.AuthorRole.User}:",
$"{LLama.Common.AuthorRole.Assistant}:",
$"{LLama.Common.AuthorRole.System}:"});
}

/// <inheritdoc/>
Expand All @@ -60,14 +61,14 @@ public ChatHistory CreateNewChat(string? instructions = "")
/// <inheritdoc/>
public Task<IReadOnlyList<IChatResult>> GetChatCompletionsAsync(ChatHistory chat, AIRequestSettings? requestSettings = null, CancellationToken cancellationToken = default)
{
var settings = requestSettings != null
var settings = requestSettings != null
? ChatRequestSettings.FromRequestSettings(requestSettings)
: defaultRequestSettings;
var prompt = historyTransform.HistoryToText(chat.ToLLamaSharpChatHistory());

// This call is not awaited because LLamaSharpChatResult accepts an IAsyncEnumerable.
var result = this.session.ChatAsync(chat.ToLLamaSharpChatHistory(), settings.ToLLamaSharpInferenceParams(), cancellationToken);
var result = _model.InferAsync(prompt, settings.ToLLamaSharpInferenceParams(), cancellationToken);

return Task.FromResult<IReadOnlyList<IChatResult>>(new List<IChatResult> { new LLamaSharpChatResult(result) }.AsReadOnly());
return Task.FromResult<IReadOnlyList<IChatResult>>(new List<IChatResult> { new LLamaSharpChatResult(outputTransform.TransformAsync(result)) }.AsReadOnly());
}

/// <inheritdoc/>
Expand All @@ -78,10 +79,10 @@ public async IAsyncEnumerable<IChatStreamingResult> GetStreamingChatCompletionsA
var settings = requestSettings != null
? ChatRequestSettings.FromRequestSettings(requestSettings)
: defaultRequestSettings;

var prompt = historyTransform.HistoryToText(chat.ToLLamaSharpChatHistory());
// This call is not awaited because LLamaSharpChatResult accepts an IAsyncEnumerable.
var result = this.session.ChatAsync(chat.ToLLamaSharpChatHistory(), settings.ToLLamaSharpInferenceParams(), cancellationToken);
var result = _model.InferAsync(prompt, settings.ToLLamaSharpInferenceParams(), cancellationToken);

yield return new LLamaSharpChatResult(result);
yield return new LLamaSharpChatResult(outputTransform.TransformAsync(result));
}
}
5 changes: 4 additions & 1 deletion LLama.SemanticKernel/ExtensionMethods.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ public static class ExtensionMethods
throw new ArgumentNullException(nameof(requestSettings));
}

var antiPrompts = new List<string>(requestSettings.StopSequences) { AuthorRole.User.ToString() + ":" };
var antiPrompts = new List<string>(requestSettings.StopSequences)
{ LLama.Common.AuthorRole.User.ToString() + ":" ,
LLama.Common.AuthorRole.Assistant.ToString() + ":",
LLama.Common.AuthorRole.System.ToString() + ":"};
return new global::LLama.Common.InferenceParams
{
Temperature = (float)requestSettings.Temperature,
Expand Down