diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator.cs index 74913b87e8a..02bf1880427 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator.cs @@ -24,7 +24,7 @@ namespace Microsoft.Extensions.AI; /// are used which might employ such mutation. /// /// -public interface IEmbeddingGenerator : IDisposable +public interface IEmbeddingGenerator : IDisposable where TEmbedding : Embedding { /// Generates embeddings for each of the supplied . diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/GeneratedEmbeddingsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/GeneratedEmbeddingsTests.cs index b7dffb1c46c..c13730fe604 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/GeneratedEmbeddingsTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/GeneratedEmbeddingsTests.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Threading.Tasks; using Xunit; #pragma warning disable xUnit2013 // Do not use equality check to check for collection size. @@ -243,4 +244,21 @@ public void Indexer_InvalidIndex_Throws() Assert.Throws("index", () => embeddings[-1]); Assert.Throws("index", () => embeddings[2]); } + + [Fact] + public async Task Generator_SupportsCovariantInput() + { + var expectedGeneratedEmbeddings = new GeneratedEmbeddings>([new Embedding(new float[] { 1, 2, 3 })]); + + using IEmbeddingGenerator> acceptsObject = new TestEmbeddingGenerator> + { + GenerateAsyncCallback = (values, options, cancellationToken) => Task.FromResult(expectedGeneratedEmbeddings), + }; + + IEmbeddingGenerator> acceptsString = acceptsObject; + + var actual = await acceptsString.GenerateAsync(["hello"]); + + Assert.Same(expectedGeneratedEmbeddings, actual); + } } diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestEmbeddingGenerator.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestEmbeddingGenerator.cs index e0d747cfc9d..dbb69304007 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestEmbeddingGenerator.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestEmbeddingGenerator.cs @@ -6,23 +6,27 @@ using System.Threading; using System.Threading.Tasks; +#pragma warning disable SA1402 // File may only contain a single type +#pragma warning disable CA1816 // Dispose methods should call SuppressFinalize + namespace Microsoft.Extensions.AI; -public sealed class TestEmbeddingGenerator : IEmbeddingGenerator> +public class TestEmbeddingGenerator : IEmbeddingGenerator + where TEmbedding : Embedding { public TestEmbeddingGenerator() { GetServiceCallback = DefaultGetServiceCallback; } - public Func, EmbeddingGenerationOptions?, CancellationToken, Task>>>? GenerateAsyncCallback { get; set; } + public Func, EmbeddingGenerationOptions?, CancellationToken, Task>>? GenerateAsyncCallback { get; set; } public Func GetServiceCallback { get; set; } private object? DefaultGetServiceCallback(Type serviceType, object? serviceKey) => serviceType is not null && serviceKey is null && serviceType.IsInstanceOfType(this) ? this : null; - public Task>> GenerateAsync(IEnumerable values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default) + public Task> GenerateAsync(IEnumerable values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default) => GenerateAsyncCallback!.Invoke(values, options, cancellationToken); public object? GetService(Type serviceType, object? serviceKey = null) @@ -33,3 +37,5 @@ void IDisposable.Dispose() // No resources to dispose } } + +public sealed class TestEmbeddingGenerator : TestEmbeddingGenerator>;