Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public override Task<ChatResponse> GetResponseAsync(
{
_ = Throw.IfNull(messages);

return UseCaching(options) ?
return EnableCaching(messages, options) ?
GetCachedResponseAsync(messages, options, cancellationToken) :
base.GetResponseAsync(messages, options, cancellationToken);
}
Expand Down Expand Up @@ -79,7 +79,7 @@ public override IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
{
_ = Throw.IfNull(messages);

return UseCaching(options) ?
return EnableCaching(messages, options) ?
GetCachedStreamingResponseAsync(messages, options, cancellationToken) :
base.GetStreamingResponseAsync(messages, options, cancellationToken);
}
Expand Down Expand Up @@ -196,12 +196,25 @@ private async IAsyncEnumerable<ChatResponseUpdate> GetCachedStreamingResponseAsy
/// <exception cref="ArgumentNullException"><paramref name="value"/> is <see langword="null"/>.</exception>
protected abstract Task WriteCacheStreamingAsync(string key, IReadOnlyList<ChatResponseUpdate> value, CancellationToken cancellationToken);

/// <summary>Determine whether to use caching with the request.</summary>
private static bool UseCaching(ChatOptions? options)
/// <summary>Determines whether caching should be used with the specified request.</summary>
/// <param name="messages">The sequence of chat messages included in the request.</param>
/// <param name="options">The chat options included in the request.</param>
/// <returns>
/// <see langword="true"/> if caching should be used for the request, such that the <see cref="CachingChatClient"/>
/// will try to satisfy the request from the cache, or if it can't, will try to cache the fetched response.
/// <see langword="false"/> if caching should not be used for the request, such that the request will
/// be passed through to the inner <see cref="IChatClient"/> without attempting to read from or write to the cache.
/// </returns>
/// <remarks>
/// The default implementation returns <see langword="true"/> as long as the <paramref name="options"/>
/// does not have a <see cref="ChatOptions.ConversationId"/> set.
/// </remarks>
protected virtual bool EnableCaching(IEnumerable<ChatMessage> messages, ChatOptions? options)
{
// We want to skip caching if options.ConversationId is set. If it's set, that implies there's
// some state that will impact the response and that's not represented in the messages. Since
// that state could change even with the same ID, we have to assume caching isn't valid.
// that state could change even with the same ID (e.g. if it's a thread ID representing the
// mutable state of a conversation), we have to assume caching isn't valid.
return options?.ConversationId is null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
"Member": "abstract string Microsoft.Extensions.AI.CachingChatClient.GetCacheKey(System.Collections.Generic.IEnumerable<Microsoft.Extensions.AI.ChatMessage> messages, Microsoft.Extensions.AI.ChatOptions? options, params System.ReadOnlySpan<object?> additionalValues);",
"Stage": "Stable"
},
{
"Member": "virtual bool Microsoft.Extensions.AI.CachingChatClient.EnableCaching(System.Collections.Generic.IEnumerable<Microsoft.Extensions.AI.ChatMessage> messages, Microsoft.Extensions.AI.ChatOptions? options);",
"Stage": "Stable"
},
{
"Member": "override System.Threading.Tasks.Task<Microsoft.Extensions.AI.ChatResponse> Microsoft.Extensions.AI.CachingChatClient.GetResponseAsync(System.Collections.Generic.IEnumerable<Microsoft.Extensions.AI.ChatMessage> messages, Microsoft.Extensions.AI.ChatOptions? options = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken));",
"Stage": "Stable"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,11 @@ public void Ctor_ExpectedDefaults()
}

[Theory]
[InlineData(false)]
[InlineData(true)]
public async Task CachesSuccessResultsAsync(bool conversationIdSet)
[InlineData(false, false)]
[InlineData(false, true)]
[InlineData(true, false)]
[InlineData(true, true)]
public async Task CachesSuccessResultsAsync(bool conversationIdSet, bool customCaching)
{
// Arrange
ChatOptions options = new() { ConversationId = conversationIdSet ? "123" : null };
Expand Down Expand Up @@ -79,10 +81,16 @@ public async Task CachesSuccessResultsAsync(bool conversationIdSet)
return Task.FromResult(expectedResponse);
}
};
using var outer = new DistributedCachingChatClient(testClient, _storage)
{
JsonSerializerOptions = TestJsonSerializerContext.Default.Options
};

int enableCachingInvocations = 0;
using var outer = customCaching ?
new CustomCachingChatClient(testClient, _storage, (m, o) =>
{
return ++enableCachingInvocations % 2 == 0;
}) :
new DistributedCachingChatClient(testClient, _storage);

outer.JsonSerializerOptions = TestJsonSerializerContext.Default.Options;

// Make the initial request and do a quick sanity check
var result1 = await outer.GetResponseAsync("some input", options);
Expand All @@ -93,12 +101,28 @@ public async Task CachesSuccessResultsAsync(bool conversationIdSet)
var result2 = await outer.GetResponseAsync("some input", options);

// Assert
Assert.Equal(conversationIdSet ? 2 : 1, innerCallCount);
if (customCaching)
{
Assert.Equal(enableCachingInvocations % 2 == 0 ? 2 : 1, innerCallCount);
}
else
{
Assert.Equal(conversationIdSet ? 2 : 1, innerCallCount);
}

AssertResponsesEqual(expectedResponse, result2);

// Act/Assert 2: Cache misses do not return cached results
await outer.GetResponseAsync("some modified input", options);
Assert.Equal(conversationIdSet ? 3 : 2, innerCallCount);
Assert.Equal(conversationIdSet || customCaching ? 3 : 2, innerCallCount);

Assert.Equal(customCaching ? 3 : 0, enableCachingInvocations);
}

private sealed class CustomCachingChatClient(IChatClient innerClient, IDistributedCache storage, Func<IEnumerable<ChatMessage>, ChatOptions?, bool> enableCaching) :
DistributedCachingChatClient(innerClient, storage)
{
protected override bool EnableCaching(IEnumerable<ChatMessage> messages, ChatOptions? options) => enableCaching(messages, options);
}

[Fact]
Expand Down
Loading