From 46f01bbc944ea8ecca52e3216a68a3ed84679138 Mon Sep 17 00:00:00 2001 From: Yaohui Liu Date: Sun, 5 Nov 2023 17:16:50 +0800 Subject: [PATCH] feat(kernel-memory): avoid loading model twice. --- LLama.KernelMemory/BuilderExtensions.cs | 41 ++++++++++++++++- .../LLamaSharpTextEmbeddingGeneration.cs | 44 +++++++++++++++++-- LLama.KernelMemory/LlamaSharpConfig.cs | 2 +- .../LlamaSharpTextGeneration.cs | 31 +++++++++++-- 4 files changed, 108 insertions(+), 10 deletions(-) diff --git a/LLama.KernelMemory/BuilderExtensions.cs b/LLama.KernelMemory/BuilderExtensions.cs index 7d280b494..0b92ca6ed 100644 --- a/LLama.KernelMemory/BuilderExtensions.cs +++ b/LLama.KernelMemory/BuilderExtensions.cs @@ -4,6 +4,9 @@ using System.Linq; using System.Text; using System.Threading.Tasks; +using LLama; +using LLama.Common; +using Microsoft.KernelMemory.AI; namespace LLamaSharp.KernelMemory { @@ -24,6 +27,18 @@ public static KernelMemoryBuilder WithLLamaSharpTextEmbeddingGeneration(this Ker return builder; } + /// + /// Adds LLamaSharpTextEmbeddingGeneration to the KernelMemoryBuilder. + /// + /// The KernelMemoryBuilder instance. + /// The LLamaSharpTextEmbeddingGeneration instance. + /// The KernelMemoryBuilder instance with LLamaSharpTextEmbeddingGeneration added. + public static KernelMemoryBuilder WithLLamaSharpTextEmbeddingGeneration(this KernelMemoryBuilder builder, LLamaSharpTextEmbeddingGeneration textEmbeddingGeneration) + { + builder.WithCustomEmbeddingGeneration(textEmbeddingGeneration); + return builder; + } + /// /// Adds LLamaSharpTextGeneration to the KernelMemoryBuilder. /// @@ -36,6 +51,18 @@ public static KernelMemoryBuilder WithLLamaSharpTextGeneration(this KernelMemory return builder; } + /// + /// Adds LLamaSharpTextGeneration to the KernelMemoryBuilder. + /// + /// The KernelMemoryBuilder instance. + /// The LlamaSharpTextGeneration instance. + /// The KernelMemoryBuilder instance with LLamaSharpTextGeneration added. + public static KernelMemoryBuilder WithLLamaSharpTextGeneration(this KernelMemoryBuilder builder, LlamaSharpTextGeneration textGeneration) + { + builder.WithCustomTextGeneration(textGeneration); + return builder; + } + /// /// Adds LLamaSharpTextEmbeddingGeneration and LLamaSharpTextGeneration to the KernelMemoryBuilder. /// @@ -44,8 +71,18 @@ public static KernelMemoryBuilder WithLLamaSharpTextGeneration(this KernelMemory /// The KernelMemoryBuilder instance with LLamaSharpTextEmbeddingGeneration and LLamaSharpTextGeneration added. public static KernelMemoryBuilder WithLLamaSharpDefaults(this KernelMemoryBuilder builder, LLamaSharpConfig config) { - builder.WithLLamaSharpTextEmbeddingGeneration(config); - builder.WithLLamaSharpTextGeneration(config); + var parameters = new ModelParams(config.ModelPath) + { + ContextSize = config?.ContextSize ?? 2048, + Seed = config?.Seed ?? 0, + GpuLayerCount = config?.GpuLayerCount ?? 20 + }; + var weights = LLamaWeights.LoadFromFile(parameters); + var context = weights.CreateContext(parameters); + var executor = new StatelessExecutor(weights, parameters); + var embedder = new LLamaEmbedder(weights, parameters); + builder.WithLLamaSharpTextEmbeddingGeneration(new LLamaSharpTextEmbeddingGeneration(embedder)); + builder.WithLLamaSharpTextGeneration(new LlamaSharpTextGeneration(weights, context, executor)); return builder; } } diff --git a/LLama.KernelMemory/LLamaSharpTextEmbeddingGeneration.cs b/LLama.KernelMemory/LLamaSharpTextEmbeddingGeneration.cs index a1681e153..cebbbe64b 100644 --- a/LLama.KernelMemory/LLamaSharpTextEmbeddingGeneration.cs +++ b/LLama.KernelMemory/LLamaSharpTextEmbeddingGeneration.cs @@ -1,4 +1,5 @@ using LLama; +using LLama.Abstractions; using LLama.Common; using Microsoft.SemanticKernel.AI.Embeddings; using System; @@ -14,9 +15,11 @@ namespace LLamaSharp.KernelMemory /// public class LLamaSharpTextEmbeddingGeneration : ITextEmbeddingGeneration, IDisposable { - private readonly LLamaSharpConfig _config; + private readonly LLamaSharpConfig? _config; + private readonly LLamaWeights? _weights; private readonly LLamaEmbedder _embedder; - private readonly LLamaWeights _weights; + private bool _ownsEmbedder = false; + private bool _ownsWeights = false; /// /// Initializes a new instance of the class. @@ -28,13 +31,46 @@ public LLamaSharpTextEmbeddingGeneration(LLamaSharpConfig config) var @params = new ModelParams(_config.ModelPath); _weights = LLamaWeights.LoadFromFile(@params); _embedder = new LLamaEmbedder(_weights, @params); + _ownsWeights = true; + _ownsEmbedder = true; + } + + /// + /// Initializes a new instance of the class from reused weights. + /// + /// The configuration for LLamaSharp. + /// A LLamaWeights object. + public LLamaSharpTextEmbeddingGeneration(LLamaSharpConfig config, LLamaWeights weights) + { + this._config = config; + var @params = new ModelParams(_config.ModelPath); + _weights = weights; + _embedder = new LLamaEmbedder(_weights, @params); + _ownsEmbedder = true; + } + + /// + /// Initializes a new instance of the class from reused embedder. + /// + /// A LLamaEmbedder object. + public LLamaSharpTextEmbeddingGeneration(LLamaEmbedder embedder) + { + this._config = null; + this._weights = null; + _embedder = embedder; } /// public void Dispose() { - _embedder.Dispose(); - _weights.Dispose(); + if (_ownsWeights) + { + _weights?.Dispose(); + } + if(_ownsEmbedder) + { + _embedder.Dispose(); + } } /// diff --git a/LLama.KernelMemory/LlamaSharpConfig.cs b/LLama.KernelMemory/LlamaSharpConfig.cs index 2220bf719..7d3aefbef 100644 --- a/LLama.KernelMemory/LlamaSharpConfig.cs +++ b/LLama.KernelMemory/LlamaSharpConfig.cs @@ -7,7 +7,7 @@ namespace LLamaSharp.KernelMemory { /// - /// Represents the configuration for LLamaSharp. + /// Represents the configuration for LLamaSharp. Available properties are `ModelPath`, `ContextSize`, `Seed`, `GpuLayerCount`. /// public class LLamaSharpConfig { diff --git a/LLama.KernelMemory/LlamaSharpTextGeneration.cs b/LLama.KernelMemory/LlamaSharpTextGeneration.cs index abc534b3c..c3734ea4d 100644 --- a/LLama.KernelMemory/LlamaSharpTextGeneration.cs +++ b/LLama.KernelMemory/LlamaSharpTextGeneration.cs @@ -1,4 +1,5 @@ using LLama; +using LLama.Abstractions; using LLama.Common; using Microsoft.KernelMemory.AI; using System; @@ -14,10 +15,12 @@ namespace LLamaSharp.KernelMemory /// public class LlamaSharpTextGeneration : ITextGeneration, IDisposable { - private readonly LLamaSharpConfig _config; + private readonly LLamaSharpConfig? _config; private readonly LLamaWeights _weights; private readonly StatelessExecutor _executor; private readonly LLamaContext _context; + private bool _ownsContext = false; + private bool _ownsWeights = false; /// /// Initializes a new instance of the class. @@ -35,13 +38,35 @@ public LlamaSharpTextGeneration(LLamaSharpConfig config) _weights = LLamaWeights.LoadFromFile(parameters); _context = _weights.CreateContext(parameters); _executor = new StatelessExecutor(_weights, parameters); + _ownsWeights = _ownsContext = true; + } + + /// + /// Initializes a new instance of the class from reused weights, context and executor. + /// If executor is not specified, then a StatelessExecutor will be created with `context.Params`. So far only `StatelessExecutor` is expected. + /// + /// A LLamaWeights object. + /// A LLamaContext object. + /// An executor. Currently only StatelessExecutor is expected. + public LlamaSharpTextGeneration(LLamaWeights weights, LLamaContext context, StatelessExecutor? executor = null) + { + _config = null; + _weights = weights; + _context = context; + _executor = executor ?? new StatelessExecutor(_weights, _context.Params); } /// public void Dispose() { - _context.Dispose(); - _weights.Dispose(); + if (_ownsWeights) + { + _weights?.Dispose(); + } + if (_ownsContext) + { + _context.Dispose(); + } } ///