diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs index 211fc39ec85..2923b0ad62d 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs @@ -51,7 +51,7 @@ public override Task GetResponseAsync( { _ = Throw.IfNull(messages); - return UseCaching(options) ? + return EnableCaching(messages, options) ? GetCachedResponseAsync(messages, options, cancellationToken) : base.GetResponseAsync(messages, options, cancellationToken); } @@ -79,7 +79,7 @@ public override IAsyncEnumerable GetStreamingResponseAsync( { _ = Throw.IfNull(messages); - return UseCaching(options) ? + return EnableCaching(messages, options) ? GetCachedStreamingResponseAsync(messages, options, cancellationToken) : base.GetStreamingResponseAsync(messages, options, cancellationToken); } @@ -196,12 +196,25 @@ private async IAsyncEnumerable GetCachedStreamingResponseAsy /// is . protected abstract Task WriteCacheStreamingAsync(string key, IReadOnlyList value, CancellationToken cancellationToken); - /// Determine whether to use caching with the request. - private static bool UseCaching(ChatOptions? options) + /// Determines whether caching should be used with the specified request. + /// The sequence of chat messages included in the request. + /// The chat options included in the request. + /// + /// if caching should be used for the request, such that the + /// will try to satisfy the request from the cache, or if it can't, will try to cache the fetched response. + /// if caching should not be used for the request, such that the request will + /// be passed through to the inner without attempting to read from or write to the cache. + /// + /// + /// The default implementation returns as long as the + /// does not have a set. + /// + protected virtual bool EnableCaching(IEnumerable messages, ChatOptions? options) { // We want to skip caching if options.ConversationId is set. If it's set, that implies there's // some state that will impact the response and that's not represented in the messages. Since - // that state could change even with the same ID, we have to assume caching isn't valid. + // that state could change even with the same ID (e.g. if it's a thread ID representing the + // mutable state of a conversation), we have to assume caching isn't valid. return options?.ConversationId is null; } } diff --git a/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.json b/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.json index 5dcad329637..f7f246eb35c 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.json +++ b/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.json @@ -16,6 +16,10 @@ "Member": "abstract string Microsoft.Extensions.AI.CachingChatClient.GetCacheKey(System.Collections.Generic.IEnumerable messages, Microsoft.Extensions.AI.ChatOptions? options, params System.ReadOnlySpan additionalValues);", "Stage": "Stable" }, + { + "Member": "virtual bool Microsoft.Extensions.AI.CachingChatClient.EnableCaching(System.Collections.Generic.IEnumerable messages, Microsoft.Extensions.AI.ChatOptions? options);", + "Stage": "Stable" + }, { "Member": "override System.Threading.Tasks.Task Microsoft.Extensions.AI.CachingChatClient.GetResponseAsync(System.Collections.Generic.IEnumerable messages, Microsoft.Extensions.AI.ChatOptions? options = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken));", "Stage": "Stable" diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs index 4f2427d133c..2c755da7be9 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs @@ -33,9 +33,11 @@ public void Ctor_ExpectedDefaults() } [Theory] - [InlineData(false)] - [InlineData(true)] - public async Task CachesSuccessResultsAsync(bool conversationIdSet) + [InlineData(false, false)] + [InlineData(false, true)] + [InlineData(true, false)] + [InlineData(true, true)] + public async Task CachesSuccessResultsAsync(bool conversationIdSet, bool customCaching) { // Arrange ChatOptions options = new() { ConversationId = conversationIdSet ? "123" : null }; @@ -79,10 +81,16 @@ public async Task CachesSuccessResultsAsync(bool conversationIdSet) return Task.FromResult(expectedResponse); } }; - using var outer = new DistributedCachingChatClient(testClient, _storage) - { - JsonSerializerOptions = TestJsonSerializerContext.Default.Options - }; + + int enableCachingInvocations = 0; + using var outer = customCaching ? + new CustomCachingChatClient(testClient, _storage, (m, o) => + { + return ++enableCachingInvocations % 2 == 0; + }) : + new DistributedCachingChatClient(testClient, _storage); + + outer.JsonSerializerOptions = TestJsonSerializerContext.Default.Options; // Make the initial request and do a quick sanity check var result1 = await outer.GetResponseAsync("some input", options); @@ -93,12 +101,28 @@ public async Task CachesSuccessResultsAsync(bool conversationIdSet) var result2 = await outer.GetResponseAsync("some input", options); // Assert - Assert.Equal(conversationIdSet ? 2 : 1, innerCallCount); + if (customCaching) + { + Assert.Equal(enableCachingInvocations % 2 == 0 ? 2 : 1, innerCallCount); + } + else + { + Assert.Equal(conversationIdSet ? 2 : 1, innerCallCount); + } + AssertResponsesEqual(expectedResponse, result2); // Act/Assert 2: Cache misses do not return cached results await outer.GetResponseAsync("some modified input", options); - Assert.Equal(conversationIdSet ? 3 : 2, innerCallCount); + Assert.Equal(conversationIdSet || customCaching ? 3 : 2, innerCallCount); + + Assert.Equal(customCaching ? 3 : 0, enableCachingInvocations); + } + + private sealed class CustomCachingChatClient(IChatClient innerClient, IDistributedCache storage, Func, ChatOptions?, bool> enableCaching) : + DistributedCachingChatClient(innerClient, storage) + { + protected override bool EnableCaching(IEnumerable messages, ChatOptions? options) => enableCaching(messages, options); } [Fact]