Skip to content

Commit 586fe9b

Browse files
stephentoubjeffhandley
authored andcommitted
Avoid caching in CachingChatClient when ConversationId is set (#6400)
1 parent e88529f commit 586fe9b

File tree

2 files changed

+50
-18
lines changed

2 files changed

+50
-18
lines changed

src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
using Microsoft.Shared.Diagnostics;
1010

1111
#pragma warning disable S127 // "for" loop stop conditions should be invariant
12+
#pragma warning disable SA1202 // Elements should be ordered by access
1213

1314
namespace Microsoft.Extensions.AI;
1415

@@ -45,11 +46,19 @@ protected CachingChatClient(IChatClient innerClient)
4546
public bool CoalesceStreamingUpdates { get; set; } = true;
4647

4748
/// <inheritdoc />
48-
public override async Task<ChatResponse> GetResponseAsync(
49+
public override Task<ChatResponse> GetResponseAsync(
4950
IEnumerable<ChatMessage> messages, ChatOptions? options = null, CancellationToken cancellationToken = default)
5051
{
5152
_ = Throw.IfNull(messages);
5253

54+
return UseCaching(options) ?
55+
GetCachedResponseAsync(messages, options, cancellationToken) :
56+
base.GetResponseAsync(messages, options, cancellationToken);
57+
}
58+
59+
private async Task<ChatResponse> GetCachedResponseAsync(
60+
IEnumerable<ChatMessage> messages, ChatOptions? options = null, CancellationToken cancellationToken = default)
61+
{
5362
// We're only storing the final result, not the in-flight task, so that we can avoid caching failures
5463
// or having problems when one of the callers cancels but others don't. This has the drawback that
5564
// concurrent callers might trigger duplicate requests, but that's acceptable.
@@ -65,11 +74,19 @@ public override async Task<ChatResponse> GetResponseAsync(
6574
}
6675

6776
/// <inheritdoc />
68-
public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
69-
IEnumerable<ChatMessage> messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
77+
public override IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
78+
IEnumerable<ChatMessage> messages, ChatOptions? options = null, CancellationToken cancellationToken = default)
7079
{
7180
_ = Throw.IfNull(messages);
7281

82+
return UseCaching(options) ?
83+
GetCachedStreamingResponseAsync(messages, options, cancellationToken) :
84+
base.GetStreamingResponseAsync(messages, options, cancellationToken);
85+
}
86+
87+
private async IAsyncEnumerable<ChatResponseUpdate> GetCachedStreamingResponseAsync(
88+
IEnumerable<ChatMessage> messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
89+
{
7390
if (CoalesceStreamingUpdates)
7491
{
7592
// When coalescing updates, we cache non-streaming results coalesced from streaming ones. That means
@@ -178,4 +195,13 @@ public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseA
178195
/// <exception cref="ArgumentNullException"><paramref name="key"/> is <see langword="null"/>.</exception>
179196
/// <exception cref="ArgumentNullException"><paramref name="value"/> is <see langword="null"/>.</exception>
180197
protected abstract Task WriteCacheStreamingAsync(string key, IReadOnlyList<ChatResponseUpdate> value, CancellationToken cancellationToken);
198+
199+
/// <summary>Determine whether to use caching with the request.</summary>
200+
private static bool UseCaching(ChatOptions? options)
201+
{
202+
// We want to skip caching if options.ConversationId is set. If it's set, that implies there's
203+
// some state that will impact the response and that's not represented in the messages. Since
204+
// that state could change even with the same ID, we have to assume caching isn't valid.
205+
return options?.ConversationId is null;
206+
}
181207
}

test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,13 @@ public void Ctor_ExpectedDefaults()
3232
Assert.True(cachingClient.CoalesceStreamingUpdates);
3333
}
3434

35-
[Fact]
36-
public async Task CachesSuccessResultsAsync()
35+
[Theory]
36+
[InlineData(false)]
37+
[InlineData(true)]
38+
public async Task CachesSuccessResultsAsync(bool conversationIdSet)
3739
{
3840
// Arrange
41+
ChatOptions options = new() { ConversationId = conversationIdSet ? "123" : null };
3942

4043
// Verify that all the expected properties will round-trip through the cache,
4144
// even if this involves serialization
@@ -82,20 +85,20 @@ public async Task CachesSuccessResultsAsync()
8285
};
8386

8487
// Make the initial request and do a quick sanity check
85-
var result1 = await outer.GetResponseAsync("some input");
88+
var result1 = await outer.GetResponseAsync("some input", options);
8689
Assert.Same(expectedResponse, result1);
8790
Assert.Equal(1, innerCallCount);
8891

8992
// Act
90-
var result2 = await outer.GetResponseAsync("some input");
93+
var result2 = await outer.GetResponseAsync("some input", options);
9194

9295
// Assert
93-
Assert.Equal(1, innerCallCount);
96+
Assert.Equal(conversationIdSet ? 2 : 1, innerCallCount);
9497
AssertResponsesEqual(expectedResponse, result2);
9598

9699
// Act/Assert 2: Cache misses do not return cached results
97-
await outer.GetResponseAsync("some modified input");
98-
Assert.Equal(2, innerCallCount);
100+
await outer.GetResponseAsync("some modified input", options);
101+
Assert.Equal(conversationIdSet ? 3 : 2, innerCallCount);
99102
}
100103

101104
[Fact]
@@ -207,10 +210,13 @@ public async Task DoesNotCacheCanceledResultsAsync()
207210
Assert.Equal("A good result", result2.Text);
208211
}
209212

210-
[Fact]
211-
public async Task StreamingCachesSuccessResultsAsync()
213+
[Theory]
214+
[InlineData(false)]
215+
[InlineData(true)]
216+
public async Task StreamingCachesSuccessResultsAsync(bool conversationIdSet)
212217
{
213218
// Arrange
219+
ChatOptions options = new() { ConversationId = conversationIdSet ? "123" : null };
214220

215221
// Verify that all the expected properties will round-trip through the cache,
216222
// even if this involves serialization
@@ -255,20 +261,20 @@ public async Task StreamingCachesSuccessResultsAsync()
255261
};
256262

257263
// Make the initial request and do a quick sanity check
258-
var result1 = outer.GetStreamingResponseAsync("some input");
264+
var result1 = outer.GetStreamingResponseAsync("some input", options);
259265
await AssertResponsesEqualAsync(actualUpdate, result1);
260266
Assert.Equal(1, innerCallCount);
261267

262268
// Act
263-
var result2 = outer.GetStreamingResponseAsync("some input");
269+
var result2 = outer.GetStreamingResponseAsync("some input", options);
264270

265271
// Assert
266-
Assert.Equal(1, innerCallCount);
267-
await AssertResponsesEqualAsync(expectedCachedResponse, result2);
272+
Assert.Equal(conversationIdSet ? 2 : 1, innerCallCount);
273+
await AssertResponsesEqualAsync(conversationIdSet ? actualUpdate : expectedCachedResponse, result2);
268274

269275
// Act/Assert 2: Cache misses do not return cached results
270-
await ToListAsync(outer.GetStreamingResponseAsync("some modified input"));
271-
Assert.Equal(2, innerCallCount);
276+
await ToListAsync(outer.GetStreamingResponseAsync("some modified input", options));
277+
Assert.Equal(conversationIdSet ? 3 : 2, innerCallCount);
272278
}
273279

274280
[Theory]

0 commit comments

Comments
 (0)