diff --git a/LLama.KernelMemory/BuilderExtensions.cs b/LLama.KernelMemory/BuilderExtensions.cs index 84e29862e..474cf8be4 100644 --- a/LLama.KernelMemory/BuilderExtensions.cs +++ b/LLama.KernelMemory/BuilderExtensions.cs @@ -74,8 +74,10 @@ public static IKernelMemoryBuilder WithLLamaSharpTextGeneration(this IKernelMemo /// /// The KernelMemoryBuilder instance. /// The LLamaSharpConfig instance. - /// The KernelMemoryBuilder instance with LLamaSharpTextEmbeddingGeneration and LLamaSharpTextGeneration added. - public static IKernelMemoryBuilder WithLLamaSharpDefaults(this IKernelMemoryBuilder builder, LLamaSharpConfig config) + /// + /// + /// The KernelMemoryBuilder instance with LLamaSharpTextEmbeddingGeneration and LLamaSharpTextGeneration added. + public static IKernelMemoryBuilder WithLLamaSharpDefaults(this IKernelMemoryBuilder builder, LLamaSharpConfig config, LLamaWeights? weights=null, LLamaContext? context=null) { var parameters = new ModelParams(config.ModelPath) { @@ -84,15 +86,20 @@ public static IKernelMemoryBuilder WithLLamaSharpDefaults(this IKernelMemoryBuil GpuLayerCount = config?.GpuLayerCount ?? 20, EmbeddingMode = true, MainGpu = config?.MainGpu ?? 0, - SplitMode = config?.SplitMode ?? GPUSplitMode.None + SplitMode = config?.SplitMode ?? GPUSplitMode.None, }; - var weights = LLamaWeights.LoadFromFile(parameters); - var context = weights.CreateContext(parameters); + + if (weights == null) + { + weights = LLamaWeights.LoadFromFile(parameters); + context = weights.CreateContext(parameters); + } + var executor = new StatelessExecutor(weights, parameters); var embedder = new LLamaEmbedder(weights, parameters); builder.WithLLamaSharpTextEmbeddingGeneration(new LLamaSharpTextEmbeddingGenerator(embedder)); builder.WithLLamaSharpTextGeneration(new LlamaSharpTextGenerator(weights, context, executor, config?.DefaultInferenceParams)); return builder; - } + } } }