diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/DelegatingAIFunction.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/DelegatingAIFunction.cs new file mode 100644 index 00000000000..a52c5acd959 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/DelegatingAIFunction.cs @@ -0,0 +1,61 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Reflection; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Shared.Diagnostics; + +#pragma warning disable SA1202 // Elements should be ordered by access + +namespace Microsoft.Extensions.AI; + +/// +/// Provides an optional base class for an that passes through calls to another instance. +/// +public class DelegatingAIFunction : AIFunction +{ + /// + /// Initializes a new instance of the class as a wrapper around . + /// + /// The inner AI function to which all calls are delegated by default. + /// is . + protected DelegatingAIFunction(AIFunction innerFunction) + { + InnerFunction = Throw.IfNull(innerFunction); + } + + /// Gets the inner . + protected AIFunction InnerFunction { get; } + + /// + public override string Name => InnerFunction.Name; + + /// + public override string Description => InnerFunction.Description; + + /// + public override JsonElement JsonSchema => InnerFunction.JsonSchema; + + /// + public override JsonElement? ReturnJsonSchema => InnerFunction.ReturnJsonSchema; + + /// + public override JsonSerializerOptions JsonSerializerOptions => InnerFunction.JsonSerializerOptions; + + /// + public override MethodInfo? UnderlyingMethod => InnerFunction.UnderlyingMethod; + + /// + public override IReadOnlyDictionary AdditionalProperties => InnerFunction.AdditionalProperties; + + /// + public override string ToString() => InnerFunction.ToString(); + + /// + protected override ValueTask InvokeCoreAsync(AIFunctionArguments arguments, CancellationToken cancellationToken) => + InnerFunction.InvokeAsync(arguments, cancellationToken); +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.json b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.json index 79776b0ecb4..5e87edc01f9 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.json +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.json @@ -1355,6 +1355,58 @@ } ] }, + { + "Type": "class Microsoft.Extensions.AI.DelegatingAIFunction : Microsoft.Extensions.AI.AIFunction", + "Stage": "Stable", + "Methods": [ + { + "Member": "Microsoft.Extensions.AI.DelegatingAIFunction.DelegatingAIFunction(Microsoft.Extensions.AI.AIFunction innerFunction);", + "Stage": "Stable" + }, + { + "Member": "override System.Threading.Tasks.ValueTask Microsoft.Extensions.AI.DelegatingAIFunction.InvokeCoreAsync(Microsoft.Extensions.AI.AIFunctionArguments arguments, System.Threading.CancellationToken cancellationToken);", + "Stage": "Stable" + }, + { + "Member": "override string Microsoft.Extensions.AI.DelegatingAIFunction.ToString();", + "Stage": "Experimental" + } + ], + "Properties": [ + { + "Member": "Microsoft.Extensions.AI.AIFunction Microsoft.Extensions.AI.DelegatingAIFunction.InnerFunction { get; }", + "Stage": "Stable" + }, + { + "Member": "override System.Collections.Generic.IReadOnlyDictionary Microsoft.Extensions.AI.DelegatingAIFunction.AdditionalProperties { get; }", + "Stage": "Stable" + }, + { + "Member": "override string Microsoft.Extensions.AI.DelegatingAIFunction.Description { get; }", + "Stage": "Stable" + }, + { + "Member": "override System.Text.Json.JsonElement Microsoft.Extensions.AI.DelegatingAIFunction.JsonSchema { get; }", + "Stage": "Stable" + }, + { + "Member": "override System.Text.Json.JsonSerializerOptions Microsoft.Extensions.AI.DelegatingAIFunction.JsonSerializerOptions { get; }", + "Stage": "Stable" + }, + { + "Member": "override string Microsoft.Extensions.AI.DelegatingAIFunction.Name { get; }", + "Stage": "Stable" + }, + { + "Member": "override System.Text.Json.JsonElement? Microsoft.Extensions.AI.DelegatingAIFunction.ReturnJsonSchema { get; }", + "Stage": "Stable" + }, + { + "Member": "override System.Reflection.MethodInfo? Microsoft.Extensions.AI.DelegatingAIFunction.UnderlyingMethod { get; }", + "Stage": "Stable" + } + ] + }, { "Type": "class Microsoft.Extensions.AI.DelegatingChatClient : Microsoft.Extensions.AI.IChatClient, System.IDisposable", "Stage": "Stable", diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Functions/DelegatingAIFunctionTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Functions/DelegatingAIFunctionTests.cs new file mode 100644 index 00000000000..cfad15efdc0 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Functions/DelegatingAIFunctionTests.cs @@ -0,0 +1,92 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Reflection; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class DelegatingAIFunctionTests +{ + [Fact] + public void Constructor_NullInnerFunction_ThrowsArgumentNullException() + { + Assert.Throws("innerFunction", () => new DerivedFunction(null!)); + } + + [Fact] + public void DefaultOverrides_DelegateToInnerFunction() + { + AIFunction expected = AIFunctionFactory.Create(() => 42); + DerivedFunction actual = new(expected); + + Assert.Same(expected, actual.InnerFunction); + Assert.Equal(expected.Name, actual.Name); + Assert.Equal(expected.Description, actual.Description); + Assert.Equal(expected.JsonSchema, actual.JsonSchema); + Assert.Equal(expected.ReturnJsonSchema, actual.ReturnJsonSchema); + Assert.Same(expected.JsonSerializerOptions, actual.JsonSerializerOptions); + Assert.Same(expected.UnderlyingMethod, actual.UnderlyingMethod); + Assert.Same(expected.AdditionalProperties, actual.AdditionalProperties); + Assert.Equal(expected.ToString(), actual.ToString()); + } + + private sealed class DerivedFunction(AIFunction innerFunction) : DelegatingAIFunction(innerFunction) + { + public new AIFunction InnerFunction => base.InnerFunction; + } + + [Fact] + public void Virtuals_AllOverridden() + { + Assert.All(typeof(DelegatingAIFunction).GetMembers(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance), m => + { + switch (m) + { + case MethodInfo methodInfo when methodInfo.IsVirtual && methodInfo.Name is not ("Finalize" or "Equals" or "GetHashCode"): + Assert.True(methodInfo.DeclaringType == typeof(DelegatingAIFunction), $"{methodInfo.Name} not overridden"); + break; + + case PropertyInfo propertyInfo when propertyInfo.GetMethod?.IsVirtual is true: + Assert.True(propertyInfo.DeclaringType == typeof(DelegatingAIFunction), $"{propertyInfo.Name} not overridden"); + break; + } + }); + } + + [Fact] + public async Task OverriddenInvocation_SuccessfullyInvoked() + { + bool innerInvoked = false; + AIFunction inner = AIFunctionFactory.Create(int () => + { + innerInvoked = true; + throw new Exception("uh oh"); + }, "TestFunction", "A test function for DelegatingAIFunction"); + + AIFunction actual = new OverridesInvocation(inner, (args, ct) => new ValueTask(84)); + + Assert.Equal(inner.Name, actual.Name); + Assert.Equal(inner.Description, actual.Description); + Assert.Equal(inner.JsonSchema, actual.JsonSchema); + Assert.Equal(inner.ReturnJsonSchema, actual.ReturnJsonSchema); + Assert.Same(inner.JsonSerializerOptions, actual.JsonSerializerOptions); + Assert.Same(inner.UnderlyingMethod, actual.UnderlyingMethod); + Assert.Same(inner.AdditionalProperties, actual.AdditionalProperties); + Assert.Equal(inner.ToString(), actual.ToString()); + + object? result = await actual.InvokeAsync(new(), CancellationToken.None); + Assert.Contains("84", result?.ToString()); + + Assert.False(innerInvoked); + } + + private sealed class OverridesInvocation(AIFunction innerFunction, Func> invokeAsync) : DelegatingAIFunction(innerFunction) + { + protected override ValueTask InvokeCoreAsync(AIFunctionArguments arguments, CancellationToken cancellationToken) => + invokeAsync(arguments, cancellationToken); + } +}