diff --git a/LLama.SemanticKernel/ChatCompletion/ChatRequestSettings.cs b/LLama.SemanticKernel/ChatCompletion/ChatRequestSettings.cs
index a12fe439b..e04ee9e4b 100644
--- a/LLama.SemanticKernel/ChatCompletion/ChatRequestSettings.cs
+++ b/LLama.SemanticKernel/ChatCompletion/ChatRequestSettings.cs
@@ -1,4 +1,6 @@
using Microsoft.SemanticKernel.AI;
+using System.Text.Json;
+using System.Text.Json.Serialization;
namespace LLamaSharp.SemanticKernel.ChatCompletion;
@@ -8,12 +10,14 @@ public class ChatRequestSettings : AIRequestSettings
/// Temperature controls the randomness of the completion.
/// The higher the temperature, the more random the completion.
///
+ [JsonPropertyName("temperature")]
public double Temperature { get; set; } = 0;
///
/// TopP controls the diversity of the completion.
/// The higher the TopP, the more diverse the completion.
///
+ [JsonPropertyName("top_p")]
public double TopP { get; set; } = 0;
///
@@ -21,6 +25,7 @@ public class ChatRequestSettings : AIRequestSettings
/// based on whether they appear in the text so far, increasing the
/// model's likelihood to talk about new topics.
///
+ [JsonPropertyName("presence_penalty")]
public double PresencePenalty { get; set; } = 0;
///
@@ -28,11 +33,13 @@ public class ChatRequestSettings : AIRequestSettings
/// based on their existing frequency in the text so far, decreasing
/// the model's likelihood to repeat the same line verbatim.
///
+ [JsonPropertyName("frequency_penalty")]
public double FrequencyPenalty { get; set; } = 0;
///
/// Sequences where the completion will stop generating further tokens.
///
+ [JsonPropertyName("stop_sequences")]
public IList StopSequences { get; set; } = Array.Empty();
///
@@ -40,15 +47,67 @@ public class ChatRequestSettings : AIRequestSettings
/// Note: Because this parameter generates many completions, it can quickly consume your token quota.
/// Use carefully and ensure that you have reasonable settings for max_tokens and stop.
///
+ [JsonPropertyName("results_per_prompt")]
public int ResultsPerPrompt { get; set; } = 1;
///
/// The maximum number of tokens to generate in the completion.
///
+ [JsonPropertyName("max_tokens")]
public int? MaxTokens { get; set; }
///
/// Modify the likelihood of specified tokens appearing in the completion.
///
+ [JsonPropertyName("token_selection_biases")]
public IDictionary TokenSelectionBiases { get; set; } = new Dictionary();
+
+ ///
+ /// Create a new settings object with the values from another settings object.
+ ///
+ /// Template configuration
+ /// Default max tokens
+ /// An instance of OpenAIRequestSettings
+ public static ChatRequestSettings FromRequestSettings(AIRequestSettings? requestSettings, int? defaultMaxTokens = null)
+ {
+ if (requestSettings is null)
+ {
+ return new ChatRequestSettings()
+ {
+ MaxTokens = defaultMaxTokens
+ };
+ }
+
+ if (requestSettings is ChatRequestSettings requestSettingsChatRequestSettings)
+ {
+ return requestSettingsChatRequestSettings;
+ }
+
+ var json = JsonSerializer.Serialize(requestSettings);
+ var chatRequestSettings = JsonSerializer.Deserialize(json, s_options);
+
+ if (chatRequestSettings is not null)
+ {
+ return chatRequestSettings;
+ }
+
+ throw new ArgumentException($"Invalid request settings, cannot convert to {nameof(ChatRequestSettings)}", nameof(requestSettings));
+ }
+
+ private static readonly JsonSerializerOptions s_options = CreateOptions();
+
+ private static JsonSerializerOptions CreateOptions()
+ {
+ JsonSerializerOptions options = new()
+ {
+ WriteIndented = true,
+ MaxDepth = 20,
+ AllowTrailingCommas = true,
+ PropertyNameCaseInsensitive = true,
+ ReadCommentHandling = JsonCommentHandling.Skip,
+ Converters = { new ChatRequestSettingsConverter() }
+ };
+
+ return options;
+ }
}
diff --git a/LLama.SemanticKernel/ChatCompletion/ChatRequestSettingsConverter.cs b/LLama.SemanticKernel/ChatCompletion/ChatRequestSettingsConverter.cs
new file mode 100644
index 000000000..f0d3a4307
--- /dev/null
+++ b/LLama.SemanticKernel/ChatCompletion/ChatRequestSettingsConverter.cs
@@ -0,0 +1,105 @@
+using System;
+using System.Collections.Generic;
+using System.Text.Json;
+using System.Text.Json.Serialization;
+
+namespace LLamaSharp.SemanticKernel.ChatCompletion;
+
+///
+/// JSON converter for
+///
+public class ChatRequestSettingsConverter : JsonConverter
+{
+ ///
+ public override ChatRequestSettings? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
+ {
+ var requestSettings = new ChatRequestSettings();
+
+ while (reader.Read() && reader.TokenType != JsonTokenType.EndObject)
+ {
+ if (reader.TokenType == JsonTokenType.PropertyName)
+ {
+ string? propertyName = reader.GetString();
+
+ if (propertyName is not null)
+ {
+ // normalise property name to uppercase
+ propertyName = propertyName.ToUpperInvariant();
+ }
+
+ reader.Read();
+
+ switch (propertyName)
+ {
+ case "TEMPERATURE":
+ requestSettings.Temperature = reader.GetDouble();
+ break;
+ case "TOPP":
+ case "TOP_P":
+ requestSettings.TopP = reader.GetDouble();
+ break;
+ case "FREQUENCYPENALTY":
+ case "FREQUENCY_PENALTY":
+ requestSettings.FrequencyPenalty = reader.GetDouble();
+ break;
+ case "PRESENCEPENALTY":
+ case "PRESENCE_PENALTY":
+ requestSettings.PresencePenalty = reader.GetDouble();
+ break;
+ case "MAXTOKENS":
+ case "MAX_TOKENS":
+ requestSettings.MaxTokens = reader.GetInt32();
+ break;
+ case "STOPSEQUENCES":
+ case "STOP_SEQUENCES":
+ requestSettings.StopSequences = JsonSerializer.Deserialize>(ref reader, options) ?? Array.Empty();
+ break;
+ case "RESULTSPERPROMPT":
+ case "RESULTS_PER_PROMPT":
+ requestSettings.ResultsPerPrompt = reader.GetInt32();
+ break;
+ case "TOKENSELECTIONBIASES":
+ case "TOKEN_SELECTION_BIASES":
+ requestSettings.TokenSelectionBiases = JsonSerializer.Deserialize>(ref reader, options) ?? new Dictionary();
+ break;
+ case "SERVICEID":
+ case "SERVICE_ID":
+ requestSettings.ServiceId = reader.GetString();
+ break;
+ default:
+ reader.Skip();
+ break;
+ }
+ }
+ }
+
+ return requestSettings;
+ }
+
+ ///
+ public override void Write(Utf8JsonWriter writer, ChatRequestSettings value, JsonSerializerOptions options)
+ {
+ writer.WriteStartObject();
+
+ writer.WriteNumber("temperature", value.Temperature);
+ writer.WriteNumber("top_p", value.TopP);
+ writer.WriteNumber("frequency_penalty", value.FrequencyPenalty);
+ writer.WriteNumber("presence_penalty", value.PresencePenalty);
+ if (value.MaxTokens is null)
+ {
+ writer.WriteNull("max_tokens");
+ }
+ else
+ {
+ writer.WriteNumber("max_tokens", (decimal)value.MaxTokens);
+ }
+ writer.WritePropertyName("stop_sequences");
+ JsonSerializer.Serialize(writer, value.StopSequences, options);
+ writer.WriteNumber("results_per_prompt", value.ResultsPerPrompt);
+ writer.WritePropertyName("token_selection_biases");
+ JsonSerializer.Serialize(writer, value.TokenSelectionBiases, options);
+ writer.WriteString("service_id", value.ServiceId);
+
+ writer.WriteEndObject();
+ }
+}
\ No newline at end of file
diff --git a/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs b/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs
index fd8693c59..4fcb5baab 100644
--- a/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs
+++ b/LLama.SemanticKernel/ChatCompletion/LLamaSharpChatCompletion.cs
@@ -61,7 +61,7 @@ public ChatHistory CreateNewChat(string? instructions = "")
public Task> GetChatCompletionsAsync(ChatHistory chat, AIRequestSettings? requestSettings = null, CancellationToken cancellationToken = default)
{
var settings = requestSettings != null
- ? (ChatRequestSettings)requestSettings
+ ? ChatRequestSettings.FromRequestSettings(requestSettings)
: defaultRequestSettings;
// This call is not awaited because LLamaSharpChatResult accepts an IAsyncEnumerable.
@@ -76,7 +76,7 @@ public async IAsyncEnumerable GetStreamingChatCompletionsA
#pragma warning restore CS1998
{
var settings = requestSettings != null
- ? (ChatRequestSettings)requestSettings
+ ? ChatRequestSettings.FromRequestSettings(requestSettings)
: defaultRequestSettings;
// This call is not awaited because LLamaSharpChatResult accepts an IAsyncEnumerable.
diff --git a/LLama.SemanticKernel/TextCompletion/LLamaSharpTextCompletion.cs b/LLama.SemanticKernel/TextCompletion/LLamaSharpTextCompletion.cs
index cc41e5d83..059a9ff33 100644
--- a/LLama.SemanticKernel/TextCompletion/LLamaSharpTextCompletion.cs
+++ b/LLama.SemanticKernel/TextCompletion/LLamaSharpTextCompletion.cs
@@ -21,7 +21,7 @@ public LLamaSharpTextCompletion(ILLamaExecutor executor)
public async Task> GetCompletionsAsync(string text, AIRequestSettings? requestSettings, CancellationToken cancellationToken = default)
{
- var settings = (ChatRequestSettings?)requestSettings;
+ var settings = ChatRequestSettings.FromRequestSettings(requestSettings);
var result = executor.InferAsync(text, settings?.ToLLamaSharpInferenceParams(), cancellationToken);
return await Task.FromResult(new List { new LLamaTextResult(result) }.AsReadOnly()).ConfigureAwait(false);
}
@@ -30,7 +30,7 @@ public async Task> GetCompletionsAsync(string text, A
public async IAsyncEnumerable GetStreamingCompletionsAsync(string text, AIRequestSettings? requestSettings,[EnumeratorCancellation] CancellationToken cancellationToken = default)
#pragma warning restore CS1998
{
- var settings = (ChatRequestSettings?)requestSettings;
+ var settings = ChatRequestSettings.FromRequestSettings(requestSettings);
var result = executor.InferAsync(text, settings?.ToLLamaSharpInferenceParams(), cancellationToken);
yield return new LLamaTextResult(result);
}
diff --git a/LLama.Unittest/LLama.Unittest.csproj b/LLama.Unittest/LLama.Unittest.csproj
index 0532244df..bcd5feeff 100644
--- a/LLama.Unittest/LLama.Unittest.csproj
+++ b/LLama.Unittest/LLama.Unittest.csproj
@@ -30,6 +30,7 @@
+
diff --git a/LLama.Unittest/SemanticKernel/ChatRequestSettingsConverterTests.cs b/LLama.Unittest/SemanticKernel/ChatRequestSettingsConverterTests.cs
new file mode 100644
index 000000000..4190e852c
--- /dev/null
+++ b/LLama.Unittest/SemanticKernel/ChatRequestSettingsConverterTests.cs
@@ -0,0 +1,107 @@
+using LLamaSharp.SemanticKernel.ChatCompletion;
+using System.Text.Json;
+
+namespace LLama.Unittest.SemanticKernel
+{
+ public class ChatRequestSettingsConverterTests
+ {
+ [Fact]
+ public void ChatRequestSettingsConverter_DeserializeWithDefaults()
+ {
+ // Arrange
+ var options = new JsonSerializerOptions();
+ options.Converters.Add(new ChatRequestSettingsConverter());
+ var json = "{}";
+
+ // Act
+ var requestSettings = JsonSerializer.Deserialize(json, options);
+
+ // Assert
+ Assert.NotNull(requestSettings);
+ Assert.Equal(0, requestSettings.FrequencyPenalty);
+ Assert.Null(requestSettings.MaxTokens);
+ Assert.Equal(0, requestSettings.PresencePenalty);
+ Assert.Equal(1, requestSettings.ResultsPerPrompt);
+ Assert.NotNull(requestSettings.StopSequences);
+ Assert.Empty(requestSettings.StopSequences);
+ Assert.Equal(0, requestSettings.Temperature);
+ Assert.NotNull(requestSettings.TokenSelectionBiases);
+ Assert.Empty(requestSettings.TokenSelectionBiases);
+ Assert.Equal(0, requestSettings.TopP);
+ }
+
+ [Fact]
+ public void ChatRequestSettingsConverter_DeserializeWithSnakeCase()
+ {
+ // Arrange
+ var options = new JsonSerializerOptions();
+ options.AllowTrailingCommas = true;
+ options.Converters.Add(new ChatRequestSettingsConverter());
+ var json = @"{
+ ""frequency_penalty"": 0.5,
+ ""max_tokens"": 250,
+ ""presence_penalty"": 0.5,
+ ""results_per_prompt"": -1,
+ ""stop_sequences"": [ ""foo"", ""bar"" ],
+ ""temperature"": 0.5,
+ ""token_selection_biases"": { ""1"": 2, ""3"": 4 },
+ ""top_p"": 0.5,
+}";
+
+ // Act
+ var requestSettings = JsonSerializer.Deserialize(json, options);
+
+ // Assert
+ Assert.NotNull(requestSettings);
+ Assert.Equal(0.5, requestSettings.FrequencyPenalty);
+ Assert.Equal(250, requestSettings.MaxTokens);
+ Assert.Equal(0.5, requestSettings.PresencePenalty);
+ Assert.Equal(-1, requestSettings.ResultsPerPrompt);
+ Assert.NotNull(requestSettings.StopSequences);
+ Assert.Contains("foo", requestSettings.StopSequences);
+ Assert.Contains("bar", requestSettings.StopSequences);
+ Assert.Equal(0.5, requestSettings.Temperature);
+ Assert.NotNull(requestSettings.TokenSelectionBiases);
+ Assert.Equal(2, requestSettings.TokenSelectionBiases[1]);
+ Assert.Equal(4, requestSettings.TokenSelectionBiases[3]);
+ Assert.Equal(0.5, requestSettings.TopP);
+ }
+
+ [Fact]
+ public void ChatRequestSettingsConverter_DeserializeWithPascalCase()
+ {
+ // Arrange
+ var options = new JsonSerializerOptions();
+ options.AllowTrailingCommas = true;
+ options.Converters.Add(new ChatRequestSettingsConverter());
+ var json = @"{
+ ""FrequencyPenalty"": 0.5,
+ ""MaxTokens"": 250,
+ ""PresencePenalty"": 0.5,
+ ""ResultsPerPrompt"": -1,
+ ""StopSequences"": [ ""foo"", ""bar"" ],
+ ""Temperature"": 0.5,
+ ""TokenSelectionBiases"": { ""1"": 2, ""3"": 4 },
+ ""TopP"": 0.5,
+}";
+
+ // Act
+ var requestSettings = JsonSerializer.Deserialize(json, options);
+
+ // Assert
+ Assert.NotNull(requestSettings);
+ Assert.Equal(0.5, requestSettings.FrequencyPenalty);
+ Assert.Equal(250, requestSettings.MaxTokens);
+ Assert.Equal(0.5, requestSettings.PresencePenalty);
+ Assert.Equal(-1, requestSettings.ResultsPerPrompt);
+ Assert.NotNull(requestSettings.StopSequences);
+ Assert.Contains("foo", requestSettings.StopSequences);
+ Assert.Contains("bar", requestSettings.StopSequences);
+ Assert.Equal(0.5, requestSettings.Temperature);
+ Assert.NotNull(requestSettings.TokenSelectionBiases);
+ Assert.Equal(2, requestSettings.TokenSelectionBiases[1]);
+ Assert.Equal(4, requestSettings.TokenSelectionBiases[3]);
+ Assert.Equal(0.5, requestSettings.TopP);
+ }
+ }
+}
diff --git a/LLama.Unittest/SemanticKernel/ChatRequestSettingsTests.cs b/LLama.Unittest/SemanticKernel/ChatRequestSettingsTests.cs
new file mode 100644
index 000000000..99881b575
--- /dev/null
+++ b/LLama.Unittest/SemanticKernel/ChatRequestSettingsTests.cs
@@ -0,0 +1,169 @@
+using LLamaSharp.SemanticKernel.ChatCompletion;
+using Microsoft.SemanticKernel.AI;
+
+namespace LLama.Unittest.SemanticKernel
+{
+ public class ChatRequestSettingsTests
+ {
+ [Fact]
+ public void ChatRequestSettings_FromRequestSettingsNull()
+ {
+ // Arrange
+ // Act
+ var requestSettings = ChatRequestSettings.FromRequestSettings(null, null);
+
+ // Assert
+ Assert.NotNull(requestSettings);
+ Assert.Equal(0, requestSettings.FrequencyPenalty);
+ Assert.Null(requestSettings.MaxTokens);
+ Assert.Equal(0, requestSettings.PresencePenalty);
+ Assert.Equal(1, requestSettings.ResultsPerPrompt);
+ Assert.NotNull(requestSettings.StopSequences);
+ Assert.Empty(requestSettings.StopSequences);
+ Assert.Equal(0, requestSettings.Temperature);
+ Assert.NotNull(requestSettings.TokenSelectionBiases);
+ Assert.Empty(requestSettings.TokenSelectionBiases);
+ Assert.Equal(0, requestSettings.TopP);
+ }
+
+ [Fact]
+ public void ChatRequestSettings_FromRequestSettingsNullWithMaxTokens()
+ {
+ // Arrange
+ // Act
+ var requestSettings = ChatRequestSettings.FromRequestSettings(null, 200);
+
+ // Assert
+ Assert.NotNull(requestSettings);
+ Assert.Equal(0, requestSettings.FrequencyPenalty);
+ Assert.Equal(200, requestSettings.MaxTokens);
+ Assert.Equal(0, requestSettings.PresencePenalty);
+ Assert.Equal(1, requestSettings.ResultsPerPrompt);
+ Assert.NotNull(requestSettings.StopSequences);
+ Assert.Empty(requestSettings.StopSequences);
+ Assert.Equal(0, requestSettings.Temperature);
+ Assert.NotNull(requestSettings.TokenSelectionBiases);
+ Assert.Empty(requestSettings.TokenSelectionBiases);
+ Assert.Equal(0, requestSettings.TopP);
+ }
+
+ [Fact]
+ public void ChatRequestSettings_FromExistingRequestSettings()
+ {
+ // Arrange
+ var originalRequestSettings = new ChatRequestSettings()
+ {
+ FrequencyPenalty = 0.5,
+ MaxTokens = 100,
+ PresencePenalty = 0.5,
+ ResultsPerPrompt = -1,
+ StopSequences = new[] { "foo", "bar" },
+ Temperature = 0.5,
+ TokenSelectionBiases = new Dictionary() { { 1, 2 }, { 3, 4 } },
+ TopP = 0.5,
+ };
+
+ // Act
+ var requestSettings = ChatRequestSettings.FromRequestSettings(originalRequestSettings);
+
+ // Assert
+ Assert.NotNull(requestSettings);
+ Assert.Equal(originalRequestSettings, requestSettings);
+ }
+
+ [Fact]
+ public void ChatRequestSettings_FromAIRequestSettings()
+ {
+ // Arrange
+ var originalRequestSettings = new AIRequestSettings()
+ {
+ ServiceId = "test",
+ };
+
+ // Act
+ var requestSettings = ChatRequestSettings.FromRequestSettings(originalRequestSettings);
+
+ // Assert
+ Assert.NotNull(requestSettings);
+ Assert.Equal(originalRequestSettings.ServiceId, requestSettings.ServiceId);
+ }
+
+ [Fact]
+ public void ChatRequestSettings_FromAIRequestSettingsWithExtraPropertiesInSnakeCase()
+ {
+ // Arrange
+ var originalRequestSettings = new AIRequestSettings()
+ {
+ ServiceId = "test",
+ ExtensionData = new Dictionary
+ {
+ { "frequency_penalty", 0.5 },
+ { "max_tokens", 250 },
+ { "presence_penalty", 0.5 },
+ { "results_per_prompt", -1 },
+ { "stop_sequences", new [] { "foo", "bar" } },
+ { "temperature", 0.5 },
+ { "token_selection_biases", new Dictionary() { { 1, 2 }, { 3, 4 } } },
+ { "top_p", 0.5 },
+ }
+ };
+
+ // Act
+ var requestSettings = ChatRequestSettings.FromRequestSettings(originalRequestSettings);
+
+ // Assert
+ Assert.NotNull(requestSettings);
+ Assert.Equal(0.5, requestSettings.FrequencyPenalty);
+ Assert.Equal(250, requestSettings.MaxTokens);
+ Assert.Equal(0.5, requestSettings.PresencePenalty);
+ Assert.Equal(-1, requestSettings.ResultsPerPrompt);
+ Assert.NotNull(requestSettings.StopSequences);
+ Assert.Contains("foo", requestSettings.StopSequences);
+ Assert.Contains("bar", requestSettings.StopSequences);
+ Assert.Equal(0.5, requestSettings.Temperature);
+ Assert.NotNull(requestSettings.TokenSelectionBiases);
+ Assert.Equal(2, requestSettings.TokenSelectionBiases[1]);
+ Assert.Equal(4, requestSettings.TokenSelectionBiases[3]);
+ Assert.Equal(0.5, requestSettings.TopP);
+ }
+
+ [Fact]
+ public void ChatRequestSettings_FromAIRequestSettingsWithExtraPropertiesInPascalCase()
+ {
+ // Arrange
+ var originalRequestSettings = new AIRequestSettings()
+ {
+ ServiceId = "test",
+ ExtensionData = new Dictionary
+ {
+ { "FrequencyPenalty", 0.5 },
+ { "MaxTokens", 250 },
+ { "PresencePenalty", 0.5 },
+ { "ResultsPerPrompt", -1 },
+ { "StopSequences", new [] { "foo", "bar" } },
+ { "Temperature", 0.5 },
+ { "TokenSelectionBiases", new Dictionary() { { 1, 2 }, { 3, 4 } } },
+ { "TopP", 0.5 },
+ }
+ };
+
+ // Act
+ var requestSettings = ChatRequestSettings.FromRequestSettings(originalRequestSettings);
+
+ // Assert
+ Assert.NotNull(requestSettings);
+ Assert.Equal(0.5, requestSettings.FrequencyPenalty);
+ Assert.Equal(250, requestSettings.MaxTokens);
+ Assert.Equal(0.5, requestSettings.PresencePenalty);
+ Assert.Equal(-1, requestSettings.ResultsPerPrompt);
+ Assert.NotNull(requestSettings.StopSequences);
+ Assert.Contains("foo", requestSettings.StopSequences);
+ Assert.Contains("bar", requestSettings.StopSequences);
+ Assert.Equal(0.5, requestSettings.Temperature);
+ Assert.NotNull(requestSettings.TokenSelectionBiases);
+ Assert.Equal(2, requestSettings.TokenSelectionBiases[1]);
+ Assert.Equal(4, requestSettings.TokenSelectionBiases[3]);
+ Assert.Equal(0.5, requestSettings.TopP);
+ }
+ }
+}