Skip to content

Commit 54c6a59

Browse files
committed
Add AITool.GetService
Following the same pattern as elsewhere in M.E.AI, this enables a consumer to reach through layers of delegating tools to grab information from inner ones, such as whether they were marked as requiring approval.
1 parent 53ef115 commit 54c6a59

File tree

9 files changed

+107
-7
lines changed

9 files changed

+107
-7
lines changed

src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/DelegatingAIFunction.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,4 +58,14 @@ protected DelegatingAIFunction(AIFunction innerFunction)
5858
/// <inheritdoc />
5959
protected override ValueTask<object?> InvokeCoreAsync(AIFunctionArguments arguments, CancellationToken cancellationToken) =>
6060
InnerFunction.InvokeAsync(arguments, cancellationToken);
61+
62+
/// <inheritdoc />
63+
public override object? GetService(Type serviceType, object? serviceKey = null)
64+
{
65+
_ = Throw.IfNull(serviceType);
66+
67+
return
68+
serviceKey is null && serviceType.IsInstanceOfType(this) ? this :
69+
InnerFunction.GetService(serviceType, serviceKey);
70+
}
6171
}

src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/DelegatingAIFunctionDeclaration.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,14 @@ protected DelegatingAIFunctionDeclaration(AIFunctionDeclaration innerFunction)
4545

4646
/// <inheritdoc />
4747
public override string ToString() => InnerFunction.ToString();
48+
49+
/// <inheritdoc />
50+
public override object? GetService(Type serviceType, object? serviceKey = null)
51+
{
52+
_ = Throw.IfNull(serviceType);
53+
54+
return
55+
serviceKey is null && serviceType.IsInstanceOfType(this) ? this :
56+
InnerFunction.GetService(serviceType, serviceKey);
57+
}
4858
}

src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.json

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -685,6 +685,14 @@
685685
"Member": "Microsoft.Extensions.AI.AITool.AITool();",
686686
"Stage": "Stable"
687687
},
688+
{
689+
"Member": "virtual object? Microsoft.Extensions.AI.AITool.GetService(System.Type serviceType, object? serviceKey = null);",
690+
"Stage": "Stable"
691+
},
692+
{
693+
"Member": "TService? Microsoft.Extensions.AI.AITool.GetService<TService>(object? serviceKey = null);",
694+
"Stage": "Stable"
695+
},
688696
{
689697
"Member": "override string Microsoft.Extensions.AI.AITool.ToString();",
690698
"Stage": "Stable"
@@ -1477,6 +1485,10 @@
14771485
"Member": "Microsoft.Extensions.AI.DelegatingAIFunction.DelegatingAIFunction(Microsoft.Extensions.AI.AIFunction innerFunction);",
14781486
"Stage": "Stable"
14791487
},
1488+
{
1489+
"Member": "override object? Microsoft.Extensions.AI.DelegatingAIFunction.GetService(System.Type serviceType, object? serviceKey = null);",
1490+
"Stage": "Stable"
1491+
},
14801492
{
14811493
"Member": "override System.Threading.Tasks.ValueTask<object?> Microsoft.Extensions.AI.DelegatingAIFunction.InvokeCoreAsync(Microsoft.Extensions.AI.AIFunctionArguments arguments, System.Threading.CancellationToken cancellationToken);",
14821494
"Stage": "Stable"

src/Libraries/Microsoft.Extensions.AI.Abstractions/Tools/AITool.cs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33

4+
using System;
45
using System.Collections.Generic;
56
using System.Diagnostics;
67
using System.Text;
78
using Microsoft.Shared.Collections;
9+
using Microsoft.Shared.Diagnostics;
810

911
namespace Microsoft.Extensions.AI;
1012

@@ -31,6 +33,35 @@ protected AITool()
3133
/// <inheritdoc/>
3234
public override string ToString() => Name;
3335

36+
/// <summary>Asks the <see cref="AITool"/> for an object of the specified type <paramref name="serviceType"/>.</summary>
37+
/// <param name="serviceType">The type of object being requested.</param>
38+
/// <param name="serviceKey">An optional key that can be used to help identify the target service.</param>
39+
/// <returns>The found object, otherwise <see langword="null"/>.</returns>
40+
/// <exception cref="ArgumentNullException"><paramref name="serviceType"/> is <see langword="null"/>.</exception>
41+
/// <remarks>
42+
/// The purpose of this method is to allow for the retrieval of strongly-typed services that might be provided by the <see cref="AITool"/>,
43+
/// including itself or any services it might be wrapping.
44+
/// </remarks>
45+
public virtual object? GetService(Type serviceType, object? serviceKey = null)
46+
{
47+
_ = Throw.IfNull(serviceType);
48+
49+
return
50+
serviceKey is null && serviceType.IsInstanceOfType(this) ? this :
51+
null;
52+
}
53+
54+
/// <summary>Asks the <see cref="AITool"/> for an object of type <typeparamref name="TService"/>.</summary>
55+
/// <typeparam name="TService">The type of the object to be retrieved.</typeparam>
56+
/// <param name="serviceKey">An optional key that can be used to help identify the target service.</param>
57+
/// <returns>The found object, otherwise <see langword="null"/>.</returns>
58+
/// <remarks>
59+
/// The purpose of this method is to allow for the retrieval of strongly typed services that may be provided by the <see cref="AITool"/>,
60+
/// including itself or any services it might be wrapping.
61+
/// </remarks>
62+
public TService? GetService<TService>(object? serviceKey = null) =>
63+
GetService(typeof(TService), serviceKey) is TService service ? service : default;
64+
3465
/// <summary>Gets the string to display in the debugger for this instance.</summary>
3566
[DebuggerBrowsable(DebuggerBrowsableState.Never)]
3667
private string DebuggerDisplay

src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIResponsesChatClient.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -910,5 +910,15 @@ internal sealed class ResponseToolAITool(ResponseTool tool) : AITool
910910
{
911911
public ResponseTool Tool => tool;
912912
public override string Name => Tool.GetType().Name;
913+
914+
/// <inheritdoc />
915+
public override object? GetService(Type serviceType, object? serviceKey = null)
916+
{
917+
_ = Throw.IfNull(serviceType);
918+
919+
return
920+
serviceKey is null && serviceType.IsInstanceOfType(Tool) ? Tool :
921+
base.GetService(serviceType, serviceKey);
922+
}
913923
}
914924
}

src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseA
430430
List<ChatMessage> originalMessages = [.. messages];
431431
messages = originalMessages;
432432

433-
ApprovalRequiredAIFunction[]? approvalRequiredFunctions = null; // available tools that require approval
433+
AITool[]? approvalRequiredFunctions = null; // available tools that require approval
434434
List<ChatMessage>? augmentedHistory = null; // the actual history of messages sent on turns other than the first
435435
List<FunctionCallContent>? functionCallContents = null; // function call contents that need responding to in the current turn
436436
List<ChatMessage>? responseMessages = null; // tracked list of messages, across multiple turns, to be used in fallback cases to reconstitute history
@@ -539,7 +539,7 @@ public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseA
539539
approvalRequiredFunctions =
540540
(options?.Tools ?? Enumerable.Empty<AITool>())
541541
.Concat(AdditionalTools ?? Enumerable.Empty<AITool>())
542-
.OfType<ApprovalRequiredAIFunction>()
542+
.Where(t => t.GetService<ApprovalRequiredAIFunction>() is not null)
543543
.ToArray();
544544
}
545545

@@ -741,7 +741,7 @@ private static (Dictionary<string, AITool>? ToolMap, bool AnyRequireApproval) Cr
741741
for (int i = 0; i < count; i++)
742742
{
743743
AITool tool = toolList[i];
744-
anyRequireApproval |= tool is ApprovalRequiredAIFunction;
744+
anyRequireApproval |= tool.GetService<ApprovalRequiredAIFunction>() is not null;
745745
map[tool.Name] = tool;
746746
}
747747
}
@@ -1455,7 +1455,7 @@ private static ChatMessage ConvertToFunctionCallContentMessage(ApprovalResultWit
14551455
/// </summary>
14561456
private static (bool hasApprovalRequiringFcc, int lastApprovalCheckedFCCIndex) CheckForApprovalRequiringFCC(
14571457
List<FunctionCallContent>? functionCallContents,
1458-
ApprovalRequiredAIFunction[] approvalRequiredFunctions,
1458+
AITool[] approvalRequiredFunctions,
14591459
bool hasApprovalRequiringFcc,
14601460
int lastApprovalCheckedFCCIndex)
14611461
{
@@ -1536,7 +1536,7 @@ private static IList<ChatMessage> ReplaceFunctionCallsWithApprovalRequests(
15361536
{
15371537
foreach (var t in toolMap)
15381538
{
1539-
if (t.Value is ApprovalRequiredAIFunction araf && araf.Name == functionCall.Name)
1539+
if (t.Value.GetService<ApprovalRequiredAIFunction>() is { } araf && araf.Name == functionCall.Name)
15401540
{
15411541
anyApprovalRequired = true;
15421542
break;

test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Functions/DelegatingAIFunctionTests.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ public void Constructor_NullInnerFunction_ThrowsArgumentNullException()
2020
[Fact]
2121
public void DefaultOverrides_DelegateToInnerFunction()
2222
{
23-
AIFunction expected = AIFunctionFactory.Create(() => 42);
23+
AIFunction expected = new ApprovalRequiredAIFunction(AIFunctionFactory.Create(() => 42));
2424
DerivedFunction actual = new(expected);
2525

2626
Assert.Same(expected, actual.InnerFunction);
@@ -32,6 +32,7 @@ public void DefaultOverrides_DelegateToInnerFunction()
3232
Assert.Same(expected.UnderlyingMethod, actual.UnderlyingMethod);
3333
Assert.Same(expected.AdditionalProperties, actual.AdditionalProperties);
3434
Assert.Equal(expected.ToString(), actual.ToString());
35+
Assert.Same(expected, actual.GetService<ApprovalRequiredAIFunction>());
3536
}
3637

3738
private sealed class DerivedFunction(AIFunction innerFunction) : DelegatingAIFunction(innerFunction)

test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Tools/AIToolTests.cs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33

4+
using System;
45
using Xunit;
56

67
namespace Microsoft.Extensions.AI;
@@ -17,5 +18,25 @@ public void Constructor_Roundtrips()
1718
Assert.Empty(tool.AdditionalProperties);
1819
}
1920

21+
[Fact]
22+
public void GetService_ReturnsExpectedObject()
23+
{
24+
DerivedAITool tool = new();
25+
26+
Assert.Throws<ArgumentNullException>("serviceType", () => tool.GetService(null!));
27+
28+
Assert.Same(tool, tool.GetService(typeof(object)));
29+
Assert.Same(tool, tool.GetService(typeof(AITool)));
30+
Assert.Same(tool, tool.GetService(typeof(DerivedAITool)));
31+
32+
Assert.Same(tool, tool.GetService<object>());
33+
Assert.Same(tool, tool.GetService<AITool>());
34+
Assert.Same(tool, tool.GetService<DerivedAITool>());
35+
36+
Assert.Null(tool.GetService<object>("key"));
37+
Assert.Null(tool.GetService<AITool>("key"));
38+
Assert.Null(tool.GetService<DerivedAITool>("key"));
39+
}
40+
2041
private sealed class DerivedAITool : AITool;
2142
}

test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIConversionTests.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1193,12 +1193,17 @@ public void ListAddResponseTool_AddsToolCorrectly()
11931193
Assert.Single(options.Tools);
11941194
Assert.NotNull(options.Tools[0]);
11951195

1196+
var rawSearchTool = ResponseTool.CreateWebSearchTool();
11961197
options = new()
11971198
{
1198-
Tools = [ResponseTool.CreateWebSearchTool().AsAITool()],
1199+
Tools = [rawSearchTool.AsAITool()],
11991200
};
12001201
Assert.Single(options.Tools);
12011202
Assert.NotNull(options.Tools[0]);
1203+
1204+
Assert.Same(rawSearchTool, options.Tools[0].GetService<ResponseTool>());
1205+
Assert.Same(rawSearchTool, options.Tools[0].GetService<WebSearchTool>());
1206+
Assert.Null(options.Tools[0].GetService<ResponseTool>("key"));
12021207
}
12031208

12041209
private static async IAsyncEnumerable<T> CreateAsyncEnumerable<T>(IEnumerable<T> source)

0 commit comments

Comments
 (0)