Skip to content

Commit c86b7ea

Browse files
authored
Add ToChatCompletion{Async} methods for combining StreamingChatCompleteUpdates (#5605)
1 parent d8f84d7 commit c86b7ea

File tree

2 files changed

+412
-0
lines changed

2 files changed

+412
-0
lines changed
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using System.Collections.Generic;
5+
#if NET
6+
using System.Runtime.InteropServices;
7+
#endif
8+
using System.Text;
9+
using System.Threading;
10+
using System.Threading.Tasks;
11+
using Microsoft.Shared.Diagnostics;
12+
13+
#pragma warning disable S109 // Magic numbers should not be used
14+
#pragma warning disable S127 // "for" loop stop conditions should be invariant
15+
16+
namespace Microsoft.Extensions.AI;
17+
18+
/// <summary>
19+
/// Provides extension methods for working with <see cref="StreamingChatCompletionUpdate"/> instances.
20+
/// </summary>
21+
public static class StreamingChatCompletionUpdateExtensions
22+
{
23+
/// <summary>Combines <see cref="StreamingChatCompletionUpdate"/> instances into a single <see cref="ChatCompletion"/>.</summary>
24+
/// <param name="updates">The updates to be combined.</param>
25+
/// <param name="coalesceContent">
26+
/// <see langword="true"/> to attempt to coalesce contiguous <see cref="AIContent"/> items, where applicable,
27+
/// into a single <see cref="AIContent"/>, in order to reduce the number of individual content items that are included in
28+
/// the manufactured <see cref="ChatMessage"/> instances. When <see langword="false"/>, the original content items are used.
29+
/// The default is <see langword="true"/>.
30+
/// </param>
31+
/// <returns>The combined <see cref="ChatCompletion"/>.</returns>
32+
public static ChatCompletion ToChatCompletion(
33+
this IEnumerable<StreamingChatCompletionUpdate> updates, bool coalesceContent = true)
34+
{
35+
_ = Throw.IfNull(updates);
36+
37+
ChatCompletion completion = new([]);
38+
Dictionary<int, ChatMessage> messages = [];
39+
40+
foreach (var update in updates)
41+
{
42+
ProcessUpdate(update, messages, completion);
43+
}
44+
45+
AddMessagesToCompletion(messages, completion, coalesceContent);
46+
47+
return completion;
48+
}
49+
50+
/// <summary>Combines <see cref="StreamingChatCompletionUpdate"/> instances into a single <see cref="ChatCompletion"/>.</summary>
51+
/// <param name="updates">The updates to be combined.</param>
52+
/// <param name="coalesceContent">
53+
/// <see langword="true"/> to attempt to coalesce contiguous <see cref="AIContent"/> items, where applicable,
54+
/// into a single <see cref="AIContent"/>, in order to reduce the number of individual content items that are included in
55+
/// the manufactured <see cref="ChatMessage"/> instances. When <see langword="false"/>, the original content items are used.
56+
/// The default is <see langword="true"/>.
57+
/// </param>
58+
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
59+
/// <returns>The combined <see cref="ChatCompletion"/>.</returns>
60+
public static Task<ChatCompletion> ToChatCompletionAsync(
61+
this IAsyncEnumerable<StreamingChatCompletionUpdate> updates, bool coalesceContent = true, CancellationToken cancellationToken = default)
62+
{
63+
_ = Throw.IfNull(updates);
64+
65+
return ToChatCompletionAsync(updates, coalesceContent, cancellationToken);
66+
67+
static async Task<ChatCompletion> ToChatCompletionAsync(
68+
IAsyncEnumerable<StreamingChatCompletionUpdate> updates, bool coalesceContent, CancellationToken cancellationToken)
69+
{
70+
ChatCompletion completion = new([]);
71+
Dictionary<int, ChatMessage> messages = [];
72+
73+
await foreach (var update in updates.WithCancellation(cancellationToken).ConfigureAwait(false))
74+
{
75+
ProcessUpdate(update, messages, completion);
76+
}
77+
78+
AddMessagesToCompletion(messages, completion, coalesceContent);
79+
80+
return completion;
81+
}
82+
}
83+
84+
/// <summary>Processes the <see cref="StreamingChatCompletionUpdate"/>, incorporating its contents into <paramref name="messages"/> and <paramref name="completion"/>.</summary>
85+
/// <param name="update">The update to process.</param>
86+
/// <param name="messages">The dictionary mapping <see cref="StreamingChatCompletionUpdate.ChoiceIndex"/> to the <see cref="ChatMessage"/> being built for that choice.</param>
87+
/// <param name="completion">The <see cref="ChatCompletion"/> object whose properties should be updated based on <paramref name="update"/>.</param>
88+
private static void ProcessUpdate(StreamingChatCompletionUpdate update, Dictionary<int, ChatMessage> messages, ChatCompletion completion)
89+
{
90+
completion.CompletionId ??= update.CompletionId;
91+
completion.CreatedAt ??= update.CreatedAt;
92+
completion.FinishReason ??= update.FinishReason;
93+
completion.ModelId ??= update.ModelId;
94+
95+
#if NET
96+
ChatMessage message = CollectionsMarshal.GetValueRefOrAddDefault(messages, update.ChoiceIndex, out _) ??=
97+
new(default, new List<AIContent>());
98+
#else
99+
if (!messages.TryGetValue(update.ChoiceIndex, out ChatMessage? message))
100+
{
101+
messages[update.ChoiceIndex] = message = new(default, new List<AIContent>());
102+
}
103+
#endif
104+
105+
((List<AIContent>)message.Contents).AddRange(update.Contents);
106+
107+
message.AuthorName ??= update.AuthorName;
108+
if (update.Role is ChatRole role && message.Role == default)
109+
{
110+
message.Role = role;
111+
}
112+
113+
if (update.AdditionalProperties is not null)
114+
{
115+
if (message.AdditionalProperties is null)
116+
{
117+
message.AdditionalProperties = new(update.AdditionalProperties);
118+
}
119+
else
120+
{
121+
foreach (var entry in update.AdditionalProperties)
122+
{
123+
// Use first-wins behavior to match the behavior of the other properties.
124+
_ = message.AdditionalProperties.TryAdd(entry.Key, entry.Value);
125+
}
126+
}
127+
}
128+
}
129+
130+
/// <summary>Finalizes the <paramref name="completion"/> object by transferring the <paramref name="messages"/> into it.</summary>
131+
/// <param name="messages">The messages to process further and transfer into <paramref name="completion"/>.</param>
132+
/// <param name="completion">The result <see cref="ChatCompletion"/> being built.</param>
133+
/// <param name="coalesceContent">The corresponding option value provided to <see cref="ToChatCompletion"/> or <see cref="ToChatCompletionAsync"/>.</param>
134+
private static void AddMessagesToCompletion(Dictionary<int, ChatMessage> messages, ChatCompletion completion, bool coalesceContent)
135+
{
136+
foreach (var entry in messages)
137+
{
138+
if (entry.Value.Role == default)
139+
{
140+
entry.Value.Role = ChatRole.Assistant;
141+
}
142+
143+
if (coalesceContent)
144+
{
145+
CoalesceTextContent((List<AIContent>)entry.Value.Contents);
146+
}
147+
148+
completion.Choices.Add(entry.Value);
149+
150+
if (completion.Usage is null)
151+
{
152+
foreach (var content in entry.Value.Contents)
153+
{
154+
if (content is UsageContent c)
155+
{
156+
completion.Usage = c.Details;
157+
break;
158+
}
159+
}
160+
}
161+
}
162+
}
163+
164+
/// <summary>Coalesces sequential <see cref="TextContent"/> content elements.</summary>
165+
private static void CoalesceTextContent(List<AIContent> contents)
166+
{
167+
StringBuilder? coalescedText = null;
168+
169+
// Iterate through all of the items in the list looking for contiguous items that can be coalesced.
170+
int start = 0;
171+
while (start < contents.Count - 1)
172+
{
173+
// We need at least two TextContents in a row to be able to coalesce.
174+
if (contents[start] is not TextContent firstText)
175+
{
176+
start++;
177+
continue;
178+
}
179+
180+
if (contents[start + 1] is not TextContent secondText)
181+
{
182+
start += 2;
183+
continue;
184+
}
185+
186+
// Append the text from those nodes and continue appending subsequent TextContents until we run out.
187+
// We null out nodes as their text is appended so that we can later remove them all in one O(N) operation.
188+
coalescedText ??= new();
189+
_ = coalescedText.Clear().Append(firstText.Text).Append(secondText.Text);
190+
contents[start + 1] = null!;
191+
int i = start + 2;
192+
for (; i < contents.Count && contents[i] is TextContent next; i++)
193+
{
194+
_ = coalescedText.Append(next.Text);
195+
contents[i] = null!;
196+
}
197+
198+
// Store the replacement node.
199+
contents[start] = new TextContent(coalescedText.ToString())
200+
{
201+
// We inherit the properties of the first text node. We don't currently propagate additional
202+
// properties from the subsequent nodes. If we ever need to, we can add that here.
203+
AdditionalProperties = firstText.AdditionalProperties?.Clone(),
204+
};
205+
206+
start = i;
207+
}
208+
209+
// Remove all of the null slots left over from the coalescing process.
210+
_ = contents.RemoveAll(u => u is null);
211+
}
212+
}

0 commit comments

Comments
 (0)