From cefb091d6cac4d50af1e232e382a917442bc343a Mon Sep 17 00:00:00 2001 From: bmazzarol-bunnings Date: Mon, 29 Sep 2025 22:34:26 +0800 Subject: [PATCH] fix: Allow externally managed contexts with LLamaEmbedder Fixes #1259 and potentially #1247 with changes to how the caller manages the LLamaEmbedder. --- LLama.Unittest/LLamaEmbedderTests.cs | 71 +++++++++++------------ LLama/LLamaEmbedder.EmbeddingGenerator.cs | 42 ++++++-------- LLama/LLamaEmbedder.cs | 70 ++++++++++++++++------ 3 files changed, 106 insertions(+), 77 deletions(-) diff --git a/LLama.Unittest/LLamaEmbedderTests.cs b/LLama.Unittest/LLamaEmbedderTests.cs index 7d7654126..5c01984b0 100644 --- a/LLama.Unittest/LLamaEmbedderTests.cs +++ b/LLama.Unittest/LLamaEmbedderTests.cs @@ -41,43 +41,40 @@ private async Task CompareEmbeddings(string modelPath) var spoon = (await embedder.GetEmbeddings("The spoon is not real")).Single().EuclideanNormalization(); Assert.DoesNotContain(float.NaN, spoon); - - if (false) - { - //TODO: the below does not work with the new memory efficient context handling - we probably need to define Microsoft.Extensions.AI.IEmbeddingGenerator GetService interface that creates the context on the fly - - var generator = (IEmbeddingGenerator>)embedder; - Assert.NotNull(generator.GetService()); - Assert.Equal(nameof(LLamaEmbedder), generator.GetService()?.ProviderName); - Assert.NotNull(generator.GetService()?.DefaultModelId); - Assert.NotEmpty(generator.GetService()?.DefaultModelId!); - Assert.Same(embedder, generator.GetService()); - Assert.Same(generator, generator.GetService>>()); - Assert.Null(generator.GetService()); - - var embeddings = await generator.GenerateAsync( - [ - "The cat is cute", - "The kitten is cute", - "The spoon is not real" - ]); - Assert.All(cat.Zip(embeddings[0].Vector.Span.EuclideanNormalization()), e => Assert.Equal(e.First, e.Second, 0.001)); - Assert.All(kitten.Zip(embeddings[1].Vector.Span.EuclideanNormalization()), e => Assert.Equal(e.First, e.Second, 0.001)); - Assert.All(spoon.Zip(embeddings[2].Vector.Span.EuclideanNormalization()), e => Assert.Equal(e.First, e.Second, 0.001)); - - _testOutputHelper.WriteLine($"Cat = [{string.Join(",", cat.AsMemory().Slice(0, 7).ToArray())}...]"); - _testOutputHelper.WriteLine($"Kitten = [{string.Join(",", kitten.AsMemory().Slice(0, 7).ToArray())}...]"); - _testOutputHelper.WriteLine($"Spoon = [{string.Join(",", spoon.AsMemory().Slice(0, 7).ToArray())}...]"); - - var close = 1 - Dot(cat, kitten); - var far = 1 - Dot(cat, spoon); - - _testOutputHelper.WriteLine(""); - _testOutputHelper.WriteLine($"Cat.Kitten (Close): {close:F4}"); - _testOutputHelper.WriteLine($"Cat.Spoon (Far): {far:F4}"); - - Assert.True(close < far); - } + + using var context = new LLamaContext(weights, @params); + var managedEmbedder = new LLamaEmbedder(context); + IEmbeddingGenerator> generator = managedEmbedder; + Assert.NotNull(generator.GetService()); + Assert.Equal(nameof(LLamaEmbedder), generator.GetService()?.ProviderName); + Assert.NotNull(generator.GetService()?.DefaultModelId); + Assert.NotEmpty(generator.GetService()?.DefaultModelId!); + Assert.Same(managedEmbedder, generator.GetService()); + Assert.Same(generator, generator.GetService>>()); + Assert.Null(generator.GetService()); + + var embeddings = await generator.GenerateAsync( + [ + "The cat is cute", + "The kitten is cute", + "The spoon is not real" + ]); + Assert.All(cat.Zip(embeddings[0].Vector.Span.EuclideanNormalization()), e => Assert.Equal(e.First, e.Second, 0.001)); + Assert.All(kitten.Zip(embeddings[1].Vector.Span.EuclideanNormalization()), e => Assert.Equal(e.First, e.Second, 0.001)); + Assert.All(spoon.Zip(embeddings[2].Vector.Span.EuclideanNormalization()), e => Assert.Equal(e.First, e.Second, 0.001)); + + _testOutputHelper.WriteLine($"Cat = [{string.Join(",", cat.AsMemory().Slice(0, 7).ToArray())}...]"); + _testOutputHelper.WriteLine($"Kitten = [{string.Join(",", kitten.AsMemory().Slice(0, 7).ToArray())}...]"); + _testOutputHelper.WriteLine($"Spoon = [{string.Join(",", spoon.AsMemory().Slice(0, 7).ToArray())}...]"); + + var close = 1 - Dot(cat, kitten); + var far = 1 - Dot(cat, spoon); + + _testOutputHelper.WriteLine(""); + _testOutputHelper.WriteLine($"Cat.Kitten (Close): {close:F4}"); + _testOutputHelper.WriteLine($"Cat.Spoon (Far): {far:F4}"); + + Assert.True(close < far); } [Fact] diff --git a/LLama/LLamaEmbedder.EmbeddingGenerator.cs b/LLama/LLamaEmbedder.EmbeddingGenerator.cs index bce9f8d8b..3960dd227 100644 --- a/LLama/LLamaEmbedder.EmbeddingGenerator.cs +++ b/LLama/LLamaEmbedder.EmbeddingGenerator.cs @@ -3,7 +3,6 @@ using System.Diagnostics; using System.Threading; using System.Threading.Tasks; -using LLama.Native; using Microsoft.Extensions.AI; namespace LLama; @@ -16,25 +15,27 @@ public partial class LLamaEmbedder /// object? IEmbeddingGenerator.GetService(Type serviceType, object? serviceKey) { - if (serviceKey is null) + if (serviceKey is not null) { - if (serviceType == typeof(EmbeddingGeneratorMetadata)) - { - return _metadata ??= new( - nameof(LLamaEmbedder), - defaultModelId: Context.NativeHandle.ModelHandle.ReadMetadata().TryGetValue("general.name", out var name) ? name : null, - defaultModelDimensions: EmbeddingSize); - } + return null; + } + + if (_hasExternalContext && serviceType == typeof(EmbeddingGeneratorMetadata)) + { + return _metadata ??= new( + nameof(LLamaEmbedder), + defaultModelId: Context.NativeHandle.ModelHandle.ReadMetadata().TryGetValue("general.name", out var name) ? name : null, + defaultModelDimensions: EmbeddingSize); + } - if (serviceType?.IsInstanceOfType(Context) is true) - { - return Context; - } + if (_hasExternalContext && serviceType?.IsInstanceOfType(Context) is true) + { + return Context; + } - if (serviceType?.IsInstanceOfType(this) is true) - { - return this; - } + if (serviceType?.IsInstanceOfType(this) is true) + { + return this; } return null; @@ -43,11 +44,6 @@ public partial class LLamaEmbedder /// async Task>> IEmbeddingGenerator>.GenerateAsync(IEnumerable values, EmbeddingGenerationOptions? options, CancellationToken cancellationToken) { - if (Context.NativeHandle.PoolingType == LLamaPoolingType.None) - { - throw new NotSupportedException($"Embedding generation is not supported with {nameof(LLamaPoolingType)}.{nameof(LLamaPoolingType.None)}."); - } - GeneratedEmbeddings> results = new() { Usage = new() { InputTokenCount = 0 }, @@ -56,7 +52,7 @@ async Task>> IEmbeddingGenerator(embeddings[0]) { CreatedAt = DateTime.UtcNow }); diff --git a/LLama/LLamaEmbedder.cs b/LLama/LLamaEmbedder.cs index eee9a01e9..e831a1724 100644 --- a/LLama/LLamaEmbedder.cs +++ b/LLama/LLamaEmbedder.cs @@ -1,14 +1,11 @@ using System; using System.Collections.Generic; -using System.Linq; using System.Threading; using System.Threading.Tasks; using LLama.Abstractions; using LLama.Exceptions; using LLama.Native; -using Microsoft.Extensions.AI; using Microsoft.Extensions.Logging; -using static System.Net.Mime.MediaTypeNames; namespace LLama; @@ -26,18 +23,26 @@ public sealed partial class LLamaEmbedder /// /// LLama Context /// + /// + /// If the context was not provided externally, the returned context will be in a disposed state. + /// public LLamaContext Context { get; private set; } - private LLamaWeights _weights; - private IContextParams _params; - private ILogger? _logger; + private readonly LLamaWeights? _weights; + private readonly IContextParams _params; + private readonly ILogger? _logger; + private readonly bool _hasExternalContext; /// - /// Create a new embedder, using the given LLamaWeights + /// Create a new embedder, using the given . + /// This will create and dispose a new for each embedding request. + /// If you want to manage the context lifetime yourself, consider using the other constructor that takes a . /// - /// - /// - /// + /// weights to use for generating embeddings. The weights must be for a model that supports embeddings (i.e. it must have an encoder or a decoder, but not both). + /// context parameters to use when creating the context + /// optional logger + /// raised if the provided context has batch size different from ubatch size + /// raised if the provided context is for an encoder-decoder model public LLamaEmbedder(LLamaWeights weights, IContextParams @params, ILogger? logger = null) { if (@params.UBatchSize != @params.BatchSize) @@ -51,12 +56,39 @@ public LLamaEmbedder(LLamaWeights weights, IContextParams @params, ILogger? logg _weights = weights; _params = @params; _logger = logger; + _hasExternalContext = false; + } + + /// + /// Creates a new embedder using the given . + /// The caller is responsible for managing the lifetime of the context, and must ensure that the context remains valid + /// for the entire lifetime of this . The context will not be disposed when this embedder is disposed. + /// + /// context to use for generating embeddings. The context must be configured with a model that supports embeddings (i.e. it must have an encoder or a decoder, but not both). + /// optional logger + /// raised if the provided context has batch size different from ubatch size + /// raised if the provided context is for an encoder-decoder model + public LLamaEmbedder(LLamaContext context, ILogger? logger = null) + { + if(context.Params.UBatchSize != context.Params.BatchSize) + throw new ArgumentException("For non-causal models, batch size must be equal to ubatch size", nameof(context)); + + if (context.NativeHandle.ModelHandle is { HasEncoder: true, HasDecoder: true }) + throw new NotSupportedException("Computing embeddings in encoder-decoder models is not supported"); + + Context = context; + EmbeddingSize = Context.EmbeddingSize; + NativeApi.llama_set_embeddings(Context.NativeHandle, true); + _params = context.Params; + _logger = logger; + _hasExternalContext = true; } /// public void Dispose() { - Context.Dispose(); + if(!_hasExternalContext && !Context.NativeHandle.IsClosed) + Context.Dispose(); } /// @@ -72,14 +104,17 @@ public void Dispose() public async Task> GetEmbeddings(string input, CancellationToken cancellationToken = default) => (await GetEmbeddingsWithTokenCount(input, cancellationToken).ConfigureAwait(false)).Embeddings; + private async Task<(IReadOnlyList Embeddings, int Tokens)> GetEmbeddingsWithTokenCount(string input, CancellationToken cancellationToken = default) { - // Ensure the context from last time is disposed (it always should be) - if (!Context.NativeHandle.IsClosed) - Context.Dispose(); + if (!_hasExternalContext) + { + if (!Context.NativeHandle.IsClosed) + Context.Dispose(); - Context = _weights.CreateContext(_params, _logger); - NativeApi.llama_set_embeddings(Context.NativeHandle, true); + Context = _weights!.CreateContext(_params, _logger); + NativeApi.llama_set_embeddings(Context.NativeHandle, true); + } // Add all of the tokens to the batch var tokens = Context.Tokenize(input, special: true); @@ -150,7 +185,8 @@ public async Task> GetEmbeddings(string input, Cancellati embedding.EuclideanNormalization(); } - Context.Dispose(); + if (!_hasExternalContext) + Context.Dispose(); return (results, tokens.Length); }