diff --git a/LLama.SemanticKernel/ChatCompletion/ChatRequestSettings.cs b/LLama.SemanticKernel/ChatCompletion/ChatRequestSettings.cs index a12fe439b..e04ee9e4b 100644 --- a/LLama.SemanticKernel/ChatCompletion/ChatRequestSettings.cs +++ b/LLama.SemanticKernel/ChatCompletion/ChatRequestSettings.cs @@ -1,4 +1,6 @@ using Microsoft.SemanticKernel.AI; +using System.Text.Json; +using System.Text.Json.Serialization; namespace LLamaSharp.SemanticKernel.ChatCompletion; @@ -8,12 +10,14 @@ public class ChatRequestSettings : AIRequestSettings /// Temperature controls the randomness of the completion. /// The higher the temperature, the more random the completion. /// + [JsonPropertyName("temperature")] public double Temperature { get; set; } = 0; /// /// TopP controls the diversity of the completion. /// The higher the TopP, the more diverse the completion. /// + [JsonPropertyName("top_p")] public double TopP { get; set; } = 0; /// @@ -21,6 +25,7 @@ public class ChatRequestSettings : AIRequestSettings /// based on whether they appear in the text so far, increasing the /// model's likelihood to talk about new topics. /// + [JsonPropertyName("presence_penalty")] public double PresencePenalty { get; set; } = 0; /// @@ -28,11 +33,13 @@ public class ChatRequestSettings : AIRequestSettings /// based on their existing frequency in the text so far, decreasing /// the model's likelihood to repeat the same line verbatim. /// + [JsonPropertyName("frequency_penalty")] public double FrequencyPenalty { get; set; } = 0; /// /// Sequences where the completion will stop generating further tokens. /// + [JsonPropertyName("stop_sequences")] public IList StopSequences { get; set; } = Array.Empty(); /// @@ -40,15 +47,67 @@ public class ChatRequestSettings : AIRequestSettings /// Note: Because this parameter generates many completions, it can quickly consume your token quota. /// Use carefully and ensure that you have reasonable settings for max_tokens and stop. /// + [JsonPropertyName("results_per_prompt")] public int ResultsPerPrompt { get; set; } = 1; /// /// The maximum number of tokens to generate in the completion. /// + [JsonPropertyName("max_tokens")] public int? MaxTokens { get; set; } /// /// Modify the likelihood of specified tokens appearing in the completion. /// + [JsonPropertyName("token_selection_biases")] public IDictionary TokenSelectionBiases { get; set; } = new Dictionary(); + + /// + /// Create a new settings object with the values from another settings object. + /// + /// Template configuration + /// Default max tokens + /// An instance of OpenAIRequestSettings + public static ChatRequestSettings FromRequestSettings(AIRequestSettings? requestSettings, int? defaultMaxTokens = null) + { + if (requestSettings is null) + { + return new ChatRequestSettings() + { + MaxTokens = defaultMaxTokens + }; + } + + if (requestSettings is ChatRequestSettings requestSettingsChatRequestSettings) + { + return requestSettingsChatRequestSettings; + } + + var json = JsonSerializer.Serialize(requestSettings); + var chatRequestSettings = JsonSerializer.Deserialize(json, s_options); + + if (chatRequestSettings is not null) + { + return chatRequestSettings; + } + + throw new ArgumentException($"Invalid request settings, cannot convert to {nameof(ChatRequestSettings)}", nameof(requestSettings)); + } + + private static readonly JsonSerializerOptions s_options = CreateOptions(); + + private static JsonSerializerOptions CreateOptions() + { + JsonSerializerOptions options = new() + { + WriteIndented = true, + MaxDepth = 20, + AllowTrailingCommas = true, + PropertyNameCaseInsensitive = true, + ReadCommentHandling = JsonCommentHandling.Skip, + Converters = { new ChatRequestSettingsConverter() } + }; + + return options; + } } diff --git a/LLama.SemanticKernel/ChatCompletion/ChatRequestSettingsConverter.cs b/LLama.SemanticKernel/ChatCompletion/ChatRequestSettingsConverter.cs new file mode 100644 index 000000000..f0d3a4307 --- /dev/null +++ b/LLama.SemanticKernel/ChatCompletion/ChatRequestSettingsConverter.cs @@ -0,0 +1,105 @@ +using System; +using System.Collections.Generic; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace LLamaSharp.SemanticKernel.ChatCompletion; + +/// +/// JSON converter for +/// +public class ChatRequestSettingsConverter : JsonConverter +{ + /// + public override ChatRequestSettings? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + var requestSettings = new ChatRequestSettings(); + + while (reader.Read() && reader.TokenType != JsonTokenType.EndObject) + { + if (reader.TokenType == JsonTokenType.PropertyName) + { + string? propertyName = reader.GetString(); + + if (propertyName is not null) + { + // normalise property name to uppercase + propertyName = propertyName.ToUpperInvariant(); + } + + reader.Read(); + + switch (propertyName) + { + case "TEMPERATURE": + requestSettings.Temperature = reader.GetDouble(); + break; + case "TOPP": + case "TOP_P": + requestSettings.TopP = reader.GetDouble(); + break; + case "FREQUENCYPENALTY": + case "FREQUENCY_PENALTY": + requestSettings.FrequencyPenalty = reader.GetDouble(); + break; + case "PRESENCEPENALTY": + case "PRESENCE_PENALTY": + requestSettings.PresencePenalty = reader.GetDouble(); + break; + case "MAXTOKENS": + case "MAX_TOKENS": + requestSettings.MaxTokens = reader.GetInt32(); + break; + case "STOPSEQUENCES": + case "STOP_SEQUENCES": + requestSettings.StopSequences = JsonSerializer.Deserialize>(ref reader, options) ?? Array.Empty(); + break; + case "RESULTSPERPROMPT": + case "RESULTS_PER_PROMPT": + requestSettings.ResultsPerPrompt = reader.GetInt32(); + break; + case "TOKENSELECTIONBIASES": + case "TOKEN_SELECTION_BIASES": + requestSettings.TokenSelectionBiases = JsonSerializer.Deserialize>(ref reader, options) ?? new Dictionary(); + break; + case "SERVICEID": + case "SERVICE_ID": + requestSettings.ServiceId = reader.GetString(); + break; + default: + reader.Skip(); + break; + } + } + } + + return requestSettings; + } + + /// + public override void Write(Utf8JsonWriter writer, ChatRequestSettings value, JsonSerializerOptions options) + { + writer.WriteStartObject(); + + writer.WriteNumber("temperature", value.Temperature); + writer.WriteNumber("top_p", value.TopP); + writer.WriteNumber("frequency_penalty", value.FrequencyPenalty); + writer.WriteNumber("presence_penalty", value.PresencePenalty); + if (value.MaxTokens is null) + { + writer.WriteNull("max_tokens"); + } + else + { + writer.WriteNumber("max_tokens", (decimal)value.MaxTokens); + } + writer.WritePropertyName("stop_sequences"); + JsonSerializer.Serialize(writer, value.StopSequences, options); + writer.WriteNumber("results_per_prompt", value.ResultsPerPrompt); + writer.WritePropertyName("token_selection_biases"); + JsonSerializer.Serialize(writer, value.TokenSelectionBiases, options); + writer.WriteString("service_id", value.ServiceId); + + writer.WriteEndObject(); + } +} \ No newline at end of file diff --git a/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs b/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs index fd8693c59..4fcb5baab 100644 --- a/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs +++ b/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs @@ -61,7 +61,7 @@ public ChatHistory CreateNewChat(string? instructions = "") public Task> GetChatCompletionsAsync(ChatHistory chat, AIRequestSettings? requestSettings = null, CancellationToken cancellationToken = default) { var settings = requestSettings != null - ? (ChatRequestSettings)requestSettings + ? ChatRequestSettings.FromRequestSettings(requestSettings) : defaultRequestSettings; // This call is not awaited because LLamaSharpChatResult accepts an IAsyncEnumerable. @@ -76,7 +76,7 @@ public async IAsyncEnumerable GetStreamingChatCompletionsA #pragma warning restore CS1998 { var settings = requestSettings != null - ? (ChatRequestSettings)requestSettings + ? ChatRequestSettings.FromRequestSettings(requestSettings) : defaultRequestSettings; // This call is not awaited because LLamaSharpChatResult accepts an IAsyncEnumerable. diff --git a/LLama.SemanticKernel/TextCompletion/LLamaSharpTextCompletion.cs b/LLama.SemanticKernel/TextCompletion/LLamaSharpTextCompletion.cs index cc41e5d83..059a9ff33 100644 --- a/LLama.SemanticKernel/TextCompletion/LLamaSharpTextCompletion.cs +++ b/LLama.SemanticKernel/TextCompletion/LLamaSharpTextCompletion.cs @@ -21,7 +21,7 @@ public LLamaSharpTextCompletion(ILLamaExecutor executor) public async Task> GetCompletionsAsync(string text, AIRequestSettings? requestSettings, CancellationToken cancellationToken = default) { - var settings = (ChatRequestSettings?)requestSettings; + var settings = ChatRequestSettings.FromRequestSettings(requestSettings); var result = executor.InferAsync(text, settings?.ToLLamaSharpInferenceParams(), cancellationToken); return await Task.FromResult(new List { new LLamaTextResult(result) }.AsReadOnly()).ConfigureAwait(false); } @@ -30,7 +30,7 @@ public async Task> GetCompletionsAsync(string text, A public async IAsyncEnumerable GetStreamingCompletionsAsync(string text, AIRequestSettings? requestSettings,[EnumeratorCancellation] CancellationToken cancellationToken = default) #pragma warning restore CS1998 { - var settings = (ChatRequestSettings?)requestSettings; + var settings = ChatRequestSettings.FromRequestSettings(requestSettings); var result = executor.InferAsync(text, settings?.ToLLamaSharpInferenceParams(), cancellationToken); yield return new LLamaTextResult(result); } diff --git a/LLama.Unittest/LLama.Unittest.csproj b/LLama.Unittest/LLama.Unittest.csproj index 0532244df..bcd5feeff 100644 --- a/LLama.Unittest/LLama.Unittest.csproj +++ b/LLama.Unittest/LLama.Unittest.csproj @@ -30,6 +30,7 @@ + diff --git a/LLama.Unittest/SemanticKernel/ChatRequestSettingsConverterTests.cs b/LLama.Unittest/SemanticKernel/ChatRequestSettingsConverterTests.cs new file mode 100644 index 000000000..4190e852c --- /dev/null +++ b/LLama.Unittest/SemanticKernel/ChatRequestSettingsConverterTests.cs @@ -0,0 +1,107 @@ +using LLamaSharp.SemanticKernel.ChatCompletion; +using System.Text.Json; + +namespace LLama.Unittest.SemanticKernel +{ + public class ChatRequestSettingsConverterTests + { + [Fact] + public void ChatRequestSettingsConverter_DeserializeWithDefaults() + { + // Arrange + var options = new JsonSerializerOptions(); + options.Converters.Add(new ChatRequestSettingsConverter()); + var json = "{}"; + + // Act + var requestSettings = JsonSerializer.Deserialize(json, options); + + // Assert + Assert.NotNull(requestSettings); + Assert.Equal(0, requestSettings.FrequencyPenalty); + Assert.Null(requestSettings.MaxTokens); + Assert.Equal(0, requestSettings.PresencePenalty); + Assert.Equal(1, requestSettings.ResultsPerPrompt); + Assert.NotNull(requestSettings.StopSequences); + Assert.Empty(requestSettings.StopSequences); + Assert.Equal(0, requestSettings.Temperature); + Assert.NotNull(requestSettings.TokenSelectionBiases); + Assert.Empty(requestSettings.TokenSelectionBiases); + Assert.Equal(0, requestSettings.TopP); + } + + [Fact] + public void ChatRequestSettingsConverter_DeserializeWithSnakeCase() + { + // Arrange + var options = new JsonSerializerOptions(); + options.AllowTrailingCommas = true; + options.Converters.Add(new ChatRequestSettingsConverter()); + var json = @"{ + ""frequency_penalty"": 0.5, + ""max_tokens"": 250, + ""presence_penalty"": 0.5, + ""results_per_prompt"": -1, + ""stop_sequences"": [ ""foo"", ""bar"" ], + ""temperature"": 0.5, + ""token_selection_biases"": { ""1"": 2, ""3"": 4 }, + ""top_p"": 0.5, +}"; + + // Act + var requestSettings = JsonSerializer.Deserialize(json, options); + + // Assert + Assert.NotNull(requestSettings); + Assert.Equal(0.5, requestSettings.FrequencyPenalty); + Assert.Equal(250, requestSettings.MaxTokens); + Assert.Equal(0.5, requestSettings.PresencePenalty); + Assert.Equal(-1, requestSettings.ResultsPerPrompt); + Assert.NotNull(requestSettings.StopSequences); + Assert.Contains("foo", requestSettings.StopSequences); + Assert.Contains("bar", requestSettings.StopSequences); + Assert.Equal(0.5, requestSettings.Temperature); + Assert.NotNull(requestSettings.TokenSelectionBiases); + Assert.Equal(2, requestSettings.TokenSelectionBiases[1]); + Assert.Equal(4, requestSettings.TokenSelectionBiases[3]); + Assert.Equal(0.5, requestSettings.TopP); + } + + [Fact] + public void ChatRequestSettingsConverter_DeserializeWithPascalCase() + { + // Arrange + var options = new JsonSerializerOptions(); + options.AllowTrailingCommas = true; + options.Converters.Add(new ChatRequestSettingsConverter()); + var json = @"{ + ""FrequencyPenalty"": 0.5, + ""MaxTokens"": 250, + ""PresencePenalty"": 0.5, + ""ResultsPerPrompt"": -1, + ""StopSequences"": [ ""foo"", ""bar"" ], + ""Temperature"": 0.5, + ""TokenSelectionBiases"": { ""1"": 2, ""3"": 4 }, + ""TopP"": 0.5, +}"; + + // Act + var requestSettings = JsonSerializer.Deserialize(json, options); + + // Assert + Assert.NotNull(requestSettings); + Assert.Equal(0.5, requestSettings.FrequencyPenalty); + Assert.Equal(250, requestSettings.MaxTokens); + Assert.Equal(0.5, requestSettings.PresencePenalty); + Assert.Equal(-1, requestSettings.ResultsPerPrompt); + Assert.NotNull(requestSettings.StopSequences); + Assert.Contains("foo", requestSettings.StopSequences); + Assert.Contains("bar", requestSettings.StopSequences); + Assert.Equal(0.5, requestSettings.Temperature); + Assert.NotNull(requestSettings.TokenSelectionBiases); + Assert.Equal(2, requestSettings.TokenSelectionBiases[1]); + Assert.Equal(4, requestSettings.TokenSelectionBiases[3]); + Assert.Equal(0.5, requestSettings.TopP); + } + } +} diff --git a/LLama.Unittest/SemanticKernel/ChatRequestSettingsTests.cs b/LLama.Unittest/SemanticKernel/ChatRequestSettingsTests.cs new file mode 100644 index 000000000..99881b575 --- /dev/null +++ b/LLama.Unittest/SemanticKernel/ChatRequestSettingsTests.cs @@ -0,0 +1,169 @@ +using LLamaSharp.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.AI; + +namespace LLama.Unittest.SemanticKernel +{ + public class ChatRequestSettingsTests + { + [Fact] + public void ChatRequestSettings_FromRequestSettingsNull() + { + // Arrange + // Act + var requestSettings = ChatRequestSettings.FromRequestSettings(null, null); + + // Assert + Assert.NotNull(requestSettings); + Assert.Equal(0, requestSettings.FrequencyPenalty); + Assert.Null(requestSettings.MaxTokens); + Assert.Equal(0, requestSettings.PresencePenalty); + Assert.Equal(1, requestSettings.ResultsPerPrompt); + Assert.NotNull(requestSettings.StopSequences); + Assert.Empty(requestSettings.StopSequences); + Assert.Equal(0, requestSettings.Temperature); + Assert.NotNull(requestSettings.TokenSelectionBiases); + Assert.Empty(requestSettings.TokenSelectionBiases); + Assert.Equal(0, requestSettings.TopP); + } + + [Fact] + public void ChatRequestSettings_FromRequestSettingsNullWithMaxTokens() + { + // Arrange + // Act + var requestSettings = ChatRequestSettings.FromRequestSettings(null, 200); + + // Assert + Assert.NotNull(requestSettings); + Assert.Equal(0, requestSettings.FrequencyPenalty); + Assert.Equal(200, requestSettings.MaxTokens); + Assert.Equal(0, requestSettings.PresencePenalty); + Assert.Equal(1, requestSettings.ResultsPerPrompt); + Assert.NotNull(requestSettings.StopSequences); + Assert.Empty(requestSettings.StopSequences); + Assert.Equal(0, requestSettings.Temperature); + Assert.NotNull(requestSettings.TokenSelectionBiases); + Assert.Empty(requestSettings.TokenSelectionBiases); + Assert.Equal(0, requestSettings.TopP); + } + + [Fact] + public void ChatRequestSettings_FromExistingRequestSettings() + { + // Arrange + var originalRequestSettings = new ChatRequestSettings() + { + FrequencyPenalty = 0.5, + MaxTokens = 100, + PresencePenalty = 0.5, + ResultsPerPrompt = -1, + StopSequences = new[] { "foo", "bar" }, + Temperature = 0.5, + TokenSelectionBiases = new Dictionary() { { 1, 2 }, { 3, 4 } }, + TopP = 0.5, + }; + + // Act + var requestSettings = ChatRequestSettings.FromRequestSettings(originalRequestSettings); + + // Assert + Assert.NotNull(requestSettings); + Assert.Equal(originalRequestSettings, requestSettings); + } + + [Fact] + public void ChatRequestSettings_FromAIRequestSettings() + { + // Arrange + var originalRequestSettings = new AIRequestSettings() + { + ServiceId = "test", + }; + + // Act + var requestSettings = ChatRequestSettings.FromRequestSettings(originalRequestSettings); + + // Assert + Assert.NotNull(requestSettings); + Assert.Equal(originalRequestSettings.ServiceId, requestSettings.ServiceId); + } + + [Fact] + public void ChatRequestSettings_FromAIRequestSettingsWithExtraPropertiesInSnakeCase() + { + // Arrange + var originalRequestSettings = new AIRequestSettings() + { + ServiceId = "test", + ExtensionData = new Dictionary + { + { "frequency_penalty", 0.5 }, + { "max_tokens", 250 }, + { "presence_penalty", 0.5 }, + { "results_per_prompt", -1 }, + { "stop_sequences", new [] { "foo", "bar" } }, + { "temperature", 0.5 }, + { "token_selection_biases", new Dictionary() { { 1, 2 }, { 3, 4 } } }, + { "top_p", 0.5 }, + } + }; + + // Act + var requestSettings = ChatRequestSettings.FromRequestSettings(originalRequestSettings); + + // Assert + Assert.NotNull(requestSettings); + Assert.Equal(0.5, requestSettings.FrequencyPenalty); + Assert.Equal(250, requestSettings.MaxTokens); + Assert.Equal(0.5, requestSettings.PresencePenalty); + Assert.Equal(-1, requestSettings.ResultsPerPrompt); + Assert.NotNull(requestSettings.StopSequences); + Assert.Contains("foo", requestSettings.StopSequences); + Assert.Contains("bar", requestSettings.StopSequences); + Assert.Equal(0.5, requestSettings.Temperature); + Assert.NotNull(requestSettings.TokenSelectionBiases); + Assert.Equal(2, requestSettings.TokenSelectionBiases[1]); + Assert.Equal(4, requestSettings.TokenSelectionBiases[3]); + Assert.Equal(0.5, requestSettings.TopP); + } + + [Fact] + public void ChatRequestSettings_FromAIRequestSettingsWithExtraPropertiesInPascalCase() + { + // Arrange + var originalRequestSettings = new AIRequestSettings() + { + ServiceId = "test", + ExtensionData = new Dictionary + { + { "FrequencyPenalty", 0.5 }, + { "MaxTokens", 250 }, + { "PresencePenalty", 0.5 }, + { "ResultsPerPrompt", -1 }, + { "StopSequences", new [] { "foo", "bar" } }, + { "Temperature", 0.5 }, + { "TokenSelectionBiases", new Dictionary() { { 1, 2 }, { 3, 4 } } }, + { "TopP", 0.5 }, + } + }; + + // Act + var requestSettings = ChatRequestSettings.FromRequestSettings(originalRequestSettings); + + // Assert + Assert.NotNull(requestSettings); + Assert.Equal(0.5, requestSettings.FrequencyPenalty); + Assert.Equal(250, requestSettings.MaxTokens); + Assert.Equal(0.5, requestSettings.PresencePenalty); + Assert.Equal(-1, requestSettings.ResultsPerPrompt); + Assert.NotNull(requestSettings.StopSequences); + Assert.Contains("foo", requestSettings.StopSequences); + Assert.Contains("bar", requestSettings.StopSequences); + Assert.Equal(0.5, requestSettings.Temperature); + Assert.NotNull(requestSettings.TokenSelectionBiases); + Assert.Equal(2, requestSettings.TokenSelectionBiases[1]); + Assert.Equal(4, requestSettings.TokenSelectionBiases[3]); + Assert.Equal(0.5, requestSettings.TopP); + } + } +}