From 007f501cd50364884f152d1b14f2724c3cc90592 Mon Sep 17 00:00:00 2001 From: xbotter Date: Fri, 17 Nov 2023 22:50:01 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=F0=9F=94=A7=20Update=20LlamaSharpConfig=20?= =?UTF-8?q?and=20LlamaSharpTextGeneration?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add DefaultInferenceParams property to LlamaSharpConfig - Update GenerateTextAsync method in LlamaSharpTextGeneration to use DefaultInferenceParams if available - Update OptionsToParams method in LlamaSharpTextGeneration to handle defaultParams --- LLama.KernelMemory/LlamaSharpConfig.cs | 9 ++++- .../LlamaSharpTextGeneration.cs | 35 +++++++++++++------ 2 files changed, 33 insertions(+), 11 deletions(-) diff --git a/LLama.KernelMemory/LlamaSharpConfig.cs b/LLama.KernelMemory/LlamaSharpConfig.cs index 7d3aefbef..9299759e3 100644 --- a/LLama.KernelMemory/LlamaSharpConfig.cs +++ b/LLama.KernelMemory/LlamaSharpConfig.cs @@ -1,4 +1,5 @@ -using System; +using LLama.Common; +using System; using System.Collections.Generic; using System.Linq; using System.Text; @@ -39,5 +40,11 @@ public LLamaSharpConfig(string modelPath) /// Gets or sets the number of GPU layers. /// public int? GpuLayerCount { get; set; } + + + /// + /// Set the default inference parameters. + /// + public InferenceParams? DefaultInferenceParams { get; set; } } } diff --git a/LLama.KernelMemory/LlamaSharpTextGeneration.cs b/LLama.KernelMemory/LlamaSharpTextGeneration.cs index c3734ea4d..dab918f66 100644 --- a/LLama.KernelMemory/LlamaSharpTextGeneration.cs +++ b/LLama.KernelMemory/LlamaSharpTextGeneration.cs @@ -72,20 +72,35 @@ public void Dispose() /// public IAsyncEnumerable GenerateTextAsync(string prompt, TextGenerationOptions options, CancellationToken cancellationToken = default) { - return _executor.InferAsync(prompt, OptionsToParams(options), cancellationToken: cancellationToken); + return _executor.InferAsync(prompt, OptionsToParams(options, this._config?.DefaultInferenceParams), cancellationToken: cancellationToken); } - private static InferenceParams OptionsToParams(TextGenerationOptions options) + private static InferenceParams OptionsToParams(TextGenerationOptions options, InferenceParams? defaultParams) { - return new InferenceParams() + if (defaultParams != null) { - AntiPrompts = options.StopSequences.ToList().AsReadOnly(), - Temperature = (float)options.Temperature, - MaxTokens = options.MaxTokens ?? 1024, - FrequencyPenalty = (float)options.FrequencyPenalty, - PresencePenalty = (float)options.PresencePenalty, - TopP = (float)options.TopP, - }; + return defaultParams with + { + AntiPrompts = defaultParams.AntiPrompts.Concat(options.StopSequences).ToList().AsReadOnly(), + Temperature = options.Temperature == default ? defaultParams.Temperature : default, + MaxTokens = options.MaxTokens ?? defaultParams.MaxTokens, + FrequencyPenalty = options.FrequencyPenalty == default ? defaultParams.FrequencyPenalty : default, + PresencePenalty = options.PresencePenalty == default ? defaultParams.PresencePenalty : default, + TopP = options.TopP == default ? defaultParams.TopP : default + }; + } + else + { + return new InferenceParams() + { + AntiPrompts = options.StopSequences.ToList().AsReadOnly(), + Temperature = (float)options.Temperature, + MaxTokens = options.MaxTokens ?? 1024, + FrequencyPenalty = (float)options.FrequencyPenalty, + PresencePenalty = (float)options.PresencePenalty, + TopP = (float)options.TopP, + }; + } } } } From 286904920b499b50ab343f85990c509656ef7ef6 Mon Sep 17 00:00:00 2001 From: xbotter Date: Sat, 18 Nov 2023 21:40:04 +0800 Subject: [PATCH 2/2] update DefaultInferenceParams in WithLLamaSharpDefaults --- LLama.Examples/Examples/KernelMemory.cs | 8 +++++++- LLama.Examples/Examples/Runner.cs | 3 ++- LLama.KernelMemory/BuilderExtensions.cs | 2 +- LLama.KernelMemory/LlamaSharpTextGeneration.cs | 18 +++++++++--------- 4 files changed, 19 insertions(+), 12 deletions(-) diff --git a/LLama.Examples/Examples/KernelMemory.cs b/LLama.Examples/Examples/KernelMemory.cs index 0f63447f4..0c2b908fd 100644 --- a/LLama.Examples/Examples/KernelMemory.cs +++ b/LLama.Examples/Examples/KernelMemory.cs @@ -17,7 +17,13 @@ public static async Task Run() Console.Write("Please input your model path: "); var modelPath = Console.ReadLine(); var memory = new KernelMemoryBuilder() - .WithLLamaSharpDefaults(new LLamaSharpConfig(modelPath)) + .WithLLamaSharpDefaults(new LLamaSharpConfig(modelPath) + { + DefaultInferenceParams = new Common.InferenceParams + { + AntiPrompts = new List { "\n\n" } + } + }) .With(new TextPartitioningOptions { MaxTokensPerParagraph = 300, diff --git a/LLama.Examples/Examples/Runner.cs b/LLama.Examples/Examples/Runner.cs index f2f1351f6..aca0a7daa 100644 --- a/LLama.Examples/Examples/Runner.cs +++ b/LLama.Examples/Examples/Runner.cs @@ -42,7 +42,8 @@ public static async Task Run() AnsiConsole.Write(new Rule(choice)); await example(); } - + Console.WriteLine("Press any key to continue..."); + Console.ReadKey(); AnsiConsole.Clear(); } } diff --git a/LLama.KernelMemory/BuilderExtensions.cs b/LLama.KernelMemory/BuilderExtensions.cs index 0b92ca6ed..dc746dc60 100644 --- a/LLama.KernelMemory/BuilderExtensions.cs +++ b/LLama.KernelMemory/BuilderExtensions.cs @@ -82,7 +82,7 @@ public static KernelMemoryBuilder WithLLamaSharpDefaults(this KernelMemoryBuilde var executor = new StatelessExecutor(weights, parameters); var embedder = new LLamaEmbedder(weights, parameters); builder.WithLLamaSharpTextEmbeddingGeneration(new LLamaSharpTextEmbeddingGeneration(embedder)); - builder.WithLLamaSharpTextGeneration(new LlamaSharpTextGeneration(weights, context, executor)); + builder.WithLLamaSharpTextGeneration(new LlamaSharpTextGeneration(weights, context, executor, config?.DefaultInferenceParams)); return builder; } } diff --git a/LLama.KernelMemory/LlamaSharpTextGeneration.cs b/LLama.KernelMemory/LlamaSharpTextGeneration.cs index dab918f66..663a77cf6 100644 --- a/LLama.KernelMemory/LlamaSharpTextGeneration.cs +++ b/LLama.KernelMemory/LlamaSharpTextGeneration.cs @@ -15,10 +15,10 @@ namespace LLamaSharp.KernelMemory /// public class LlamaSharpTextGeneration : ITextGeneration, IDisposable { - private readonly LLamaSharpConfig? _config; private readonly LLamaWeights _weights; private readonly StatelessExecutor _executor; private readonly LLamaContext _context; + private readonly InferenceParams? _defaultInferenceParams; private bool _ownsContext = false; private bool _ownsWeights = false; @@ -28,7 +28,6 @@ public class LlamaSharpTextGeneration : ITextGeneration, IDisposable /// The configuration for LLamaSharp. public LlamaSharpTextGeneration(LLamaSharpConfig config) { - this._config = config; var parameters = new ModelParams(config.ModelPath) { ContextSize = config?.ContextSize ?? 2048, @@ -38,6 +37,7 @@ public LlamaSharpTextGeneration(LLamaSharpConfig config) _weights = LLamaWeights.LoadFromFile(parameters); _context = _weights.CreateContext(parameters); _executor = new StatelessExecutor(_weights, parameters); + _defaultInferenceParams = config?.DefaultInferenceParams; _ownsWeights = _ownsContext = true; } @@ -48,12 +48,12 @@ public LlamaSharpTextGeneration(LLamaSharpConfig config) /// A LLamaWeights object. /// A LLamaContext object. /// An executor. Currently only StatelessExecutor is expected. - public LlamaSharpTextGeneration(LLamaWeights weights, LLamaContext context, StatelessExecutor? executor = null) + public LlamaSharpTextGeneration(LLamaWeights weights, LLamaContext context, StatelessExecutor? executor = null, InferenceParams? inferenceParams = null) { - _config = null; _weights = weights; _context = context; _executor = executor ?? new StatelessExecutor(_weights, _context.Params); + _defaultInferenceParams = inferenceParams; } /// @@ -72,7 +72,7 @@ public void Dispose() /// public IAsyncEnumerable GenerateTextAsync(string prompt, TextGenerationOptions options, CancellationToken cancellationToken = default) { - return _executor.InferAsync(prompt, OptionsToParams(options, this._config?.DefaultInferenceParams), cancellationToken: cancellationToken); + return _executor.InferAsync(prompt, OptionsToParams(options, this._defaultInferenceParams), cancellationToken: cancellationToken); } private static InferenceParams OptionsToParams(TextGenerationOptions options, InferenceParams? defaultParams) @@ -82,11 +82,11 @@ private static InferenceParams OptionsToParams(TextGenerationOptions options, In return defaultParams with { AntiPrompts = defaultParams.AntiPrompts.Concat(options.StopSequences).ToList().AsReadOnly(), - Temperature = options.Temperature == default ? defaultParams.Temperature : default, + Temperature = options.Temperature == defaultParams.Temperature ? defaultParams.Temperature : (float)options.Temperature, MaxTokens = options.MaxTokens ?? defaultParams.MaxTokens, - FrequencyPenalty = options.FrequencyPenalty == default ? defaultParams.FrequencyPenalty : default, - PresencePenalty = options.PresencePenalty == default ? defaultParams.PresencePenalty : default, - TopP = options.TopP == default ? defaultParams.TopP : default + FrequencyPenalty = options.FrequencyPenalty == defaultParams.FrequencyPenalty ? defaultParams.FrequencyPenalty : (float)options.FrequencyPenalty, + PresencePenalty = options.PresencePenalty == defaultParams.PresencePenalty ? defaultParams.PresencePenalty : (float)options.PresencePenalty, + TopP = options.TopP == defaultParams.TopP ? defaultParams.TopP : (float)options.TopP }; } else