From 092d9417577c7b0163b325aa3c47675c91b4d71a Mon Sep 17 00:00:00 2001 From: Jose Perez Rodriguez Date: Wed, 5 Mar 2025 03:35:31 +0000 Subject: [PATCH 1/7] Merged PR 48132: Flowing dependencies and getting ready for 9.3 release #### AI description (iteration 1) #### PR Classification Release preparation and dependency updates. #### PR Summary This pull request updates dependencies and prepares the project for the 9.3 release. - Updated multiple dependencies to version 9.0.3 in `eng/Version.Details.xml` and `eng/Versions.props`. - Removed code coverage stage and related jobs from `azure-pipelines.yml`. - Added setup for private feeds credentials in `eng/pipelines/templates/BuildAndTest.yml`. - Disabled NU1507 warning in `Directory.Build.props`. --- Directory.Build.props | 5 + NuGet.config | 104 ++++------ azure-pipelines.yml | 48 +---- eng/MSBuild/ProjectStaging.targets | 24 +-- eng/Version.Details.xml | 190 +++++++++--------- eng/Versions.props | 121 +++++------ eng/pipelines/templates/BuildAndTest.yml | 17 ++ .../Directory.Build.props | 10 + .../Microsoft.Gen.MetricsReports.csproj | 1 - .../Directory.Build.props | 11 + .../Microsoft.AspNetCore.Testing.csproj | 2 - ...icrosoft.Extensions.AI.Abstractions.csproj | 3 +- ...soft.Extensions.AI.AzureAIInference.csproj | 3 +- .../Microsoft.Extensions.AI.Ollama.csproj | 3 +- .../Microsoft.Extensions.AI.OpenAI.csproj | 3 +- .../Microsoft.Extensions.AI.csproj | 3 +- .../Directory.Build.props | 12 ++ ...osoft.Extensions.Diagnostics.Probes.csproj | 2 - .../Directory.Build.props | 12 ++ ...icrosoft.Extensions.Hosting.Testing.csproj | 2 - .../Directory.Build.props | 12 ++ ...osoft.Extensions.Options.Contextual.csproj | 2 - .../Microsoft.Extensions.AI.Templates.csproj | 1 + 23 files changed, 290 insertions(+), 301 deletions(-) create mode 100644 src/Generators/Microsoft.Gen.MetricsReports/Directory.Build.props create mode 100644 src/Libraries/Microsoft.AspNetCore.Testing/Directory.Build.props create mode 100644 src/Libraries/Microsoft.Extensions.Diagnostics.Probes/Directory.Build.props create mode 100644 src/Libraries/Microsoft.Extensions.Hosting.Testing/Directory.Build.props create mode 100644 src/Libraries/Microsoft.Extensions.Options.Contextual/Directory.Build.props diff --git a/Directory.Build.props b/Directory.Build.props index 2cdce43e08e..67440d3bbc8 100644 --- a/Directory.Build.props +++ b/Directory.Build.props @@ -34,6 +34,11 @@ $(NetCoreTargetFrameworks) + + + $(NoWarn);NU1507 + + false latest diff --git a/NuGet.config b/NuGet.config index 7d0951ff63a..d857fb2de44 100644 --- a/NuGet.config +++ b/NuGet.config @@ -1,33 +1,29 @@ - + - - - - - - - - - - - - - - - - + + + + + + + + + + + + - - - - - - + + + + + + @@ -38,55 +34,29 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + - - - - - - + + + + + + diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 66022fdc249..2737b095fd7 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -145,7 +145,7 @@ extends: parameters: enableMicrobuild: true enableTelemetry: true - enableSourceIndex: true + enableSourceIndex: false runAsPublic: ${{ variables['runAsPublic'] }} # Publish build logs enablePublishBuildArtifacts: true @@ -230,51 +230,6 @@ extends: isWindows: false warnAsError: 0 - # ---------------------------------------------------------------- - # This stage performs quality gates enforcements - # ---------------------------------------------------------------- - - stage: codecoverage - displayName: CodeCoverage - dependsOn: - - build - condition: and(succeeded('build'), ne(variables['SkipQualityGates'], 'true')) - variables: - - template: /eng/common/templates-official/variables/pool-providers.yml@self - jobs: - - template: /eng/common/templates-official/jobs/jobs.yml@self - parameters: - enableMicrobuild: true - enableTelemetry: true - runAsPublic: ${{ variables['runAsPublic'] }} - workspace: - clean: all - - # ---------------------------------------------------------------- - # This stage downloads the code coverage reports from the build jobs, - # merges those and validates the combined test coverage. - # ---------------------------------------------------------------- - jobs: - - job: CodeCoverageReport - timeoutInMinutes: 180 - - pool: - name: NetCore1ESPool-Internal - image: 1es-mariner-2 - os: linux - - preSteps: - - checkout: self - clean: true - persistCredentials: true - fetchDepth: 1 - - steps: - - script: $(Build.SourcesDirectory)/build.sh --ci --restore - displayName: Init toolset - - - template: /eng/pipelines/templates/VerifyCoverageReport.yml - - # ---------------------------------------------------------------- # This stage only performs a build treating warnings as errors # to detect any kind of code style violations @@ -330,7 +285,6 @@ extends: parameters: validateDependsOn: - build - - codecoverage - correctness publishingInfraVersion: 3 enableSymbolValidation: false diff --git a/eng/MSBuild/ProjectStaging.targets b/eng/MSBuild/ProjectStaging.targets index dcb8ab80ccc..e3a89d03542 100644 --- a/eng/MSBuild/ProjectStaging.targets +++ b/eng/MSBuild/ProjectStaging.targets @@ -4,19 +4,6 @@ true $(NoWarn);LA0003 - - - Experimental package. $(Description) - Obsolete Package. $(Description) - - - - - - true @@ -35,11 +22,10 @@ - - - <_ExpectedVersionSuffix>$(_PreReleaseLabel)$(_BuildNumberLabels) - + + + Experimental package. $(Description) + Obsolete Package. $(Description) + - - diff --git a/eng/Version.Details.xml b/eng/Version.Details.xml index 0cd7fa3add9..bc3fe9e81c1 100644 --- a/eng/Version.Details.xml +++ b/eng/Version.Details.xml @@ -1,196 +1,196 @@ - + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime - 80aa709f5d919c6814726788dc6dabe23e79e672 + 831d23e56149cd59c40fc00c7feb7c5334bd19c4 - + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime - 80aa709f5d919c6814726788dc6dabe23e79e672 + 831d23e56149cd59c40fc00c7feb7c5334bd19c4 - + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime - 80aa709f5d919c6814726788dc6dabe23e79e672 + 831d23e56149cd59c40fc00c7feb7c5334bd19c4 - + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime - 80aa709f5d919c6814726788dc6dabe23e79e672 + 831d23e56149cd59c40fc00c7feb7c5334bd19c4 - + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime - 80aa709f5d919c6814726788dc6dabe23e79e672 + 831d23e56149cd59c40fc00c7feb7c5334bd19c4 - + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime - 80aa709f5d919c6814726788dc6dabe23e79e672 + 831d23e56149cd59c40fc00c7feb7c5334bd19c4 - + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime - 80aa709f5d919c6814726788dc6dabe23e79e672 + 831d23e56149cd59c40fc00c7feb7c5334bd19c4 - + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime - 80aa709f5d919c6814726788dc6dabe23e79e672 + 831d23e56149cd59c40fc00c7feb7c5334bd19c4 - + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime - 80aa709f5d919c6814726788dc6dabe23e79e672 + 831d23e56149cd59c40fc00c7feb7c5334bd19c4 - + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime - 80aa709f5d919c6814726788dc6dabe23e79e672 + 831d23e56149cd59c40fc00c7feb7c5334bd19c4 - + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime - 80aa709f5d919c6814726788dc6dabe23e79e672 + 831d23e56149cd59c40fc00c7feb7c5334bd19c4 - + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime - 80aa709f5d919c6814726788dc6dabe23e79e672 + 831d23e56149cd59c40fc00c7feb7c5334bd19c4 - + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime - 80aa709f5d919c6814726788dc6dabe23e79e672 + 831d23e56149cd59c40fc00c7feb7c5334bd19c4 - + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime - 80aa709f5d919c6814726788dc6dabe23e79e672 + 831d23e56149cd59c40fc00c7feb7c5334bd19c4 - + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime - 80aa709f5d919c6814726788dc6dabe23e79e672 + 831d23e56149cd59c40fc00c7feb7c5334bd19c4 - + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime - 80aa709f5d919c6814726788dc6dabe23e79e672 + 831d23e56149cd59c40fc00c7feb7c5334bd19c4 - + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime - 80aa709f5d919c6814726788dc6dabe23e79e672 + 831d23e56149cd59c40fc00c7feb7c5334bd19c4 - + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime - 80aa709f5d919c6814726788dc6dabe23e79e672 + 831d23e56149cd59c40fc00c7feb7c5334bd19c4 - + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime - 80aa709f5d919c6814726788dc6dabe23e79e672 + 831d23e56149cd59c40fc00c7feb7c5334bd19c4 - + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime - 80aa709f5d919c6814726788dc6dabe23e79e672 + 831d23e56149cd59c40fc00c7feb7c5334bd19c4 - + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime - 80aa709f5d919c6814726788dc6dabe23e79e672 + 831d23e56149cd59c40fc00c7feb7c5334bd19c4 - + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime - 80aa709f5d919c6814726788dc6dabe23e79e672 + 831d23e56149cd59c40fc00c7feb7c5334bd19c4 - + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime - 80aa709f5d919c6814726788dc6dabe23e79e672 + 831d23e56149cd59c40fc00c7feb7c5334bd19c4 - + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime - 80aa709f5d919c6814726788dc6dabe23e79e672 + 831d23e56149cd59c40fc00c7feb7c5334bd19c4 - + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime - 80aa709f5d919c6814726788dc6dabe23e79e672 + 831d23e56149cd59c40fc00c7feb7c5334bd19c4 - + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime - 80aa709f5d919c6814726788dc6dabe23e79e672 + 831d23e56149cd59c40fc00c7feb7c5334bd19c4 - + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime - 80aa709f5d919c6814726788dc6dabe23e79e672 + 831d23e56149cd59c40fc00c7feb7c5334bd19c4 - + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime - 80aa709f5d919c6814726788dc6dabe23e79e672 + 831d23e56149cd59c40fc00c7feb7c5334bd19c4 - + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime - 80aa709f5d919c6814726788dc6dabe23e79e672 + 831d23e56149cd59c40fc00c7feb7c5334bd19c4 - + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime - 80aa709f5d919c6814726788dc6dabe23e79e672 + 831d23e56149cd59c40fc00c7feb7c5334bd19c4 - + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime - 80aa709f5d919c6814726788dc6dabe23e79e672 + 831d23e56149cd59c40fc00c7feb7c5334bd19c4 - + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime - 80aa709f5d919c6814726788dc6dabe23e79e672 + 831d23e56149cd59c40fc00c7feb7c5334bd19c4 - + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime - 80aa709f5d919c6814726788dc6dabe23e79e672 + 831d23e56149cd59c40fc00c7feb7c5334bd19c4 - + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime - 80aa709f5d919c6814726788dc6dabe23e79e672 + 831d23e56149cd59c40fc00c7feb7c5334bd19c4 - + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime - 80aa709f5d919c6814726788dc6dabe23e79e672 + 831d23e56149cd59c40fc00c7feb7c5334bd19c4 - + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime - 80aa709f5d919c6814726788dc6dabe23e79e672 + 831d23e56149cd59c40fc00c7feb7c5334bd19c4 - + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime - 80aa709f5d919c6814726788dc6dabe23e79e672 + 831d23e56149cd59c40fc00c7feb7c5334bd19c4 - + https://dev.azure.com/dnceng/internal/_git/dotnet-aspnetcore - 704f7cb1d2cea33afb00c2097731216f121c2c73 + b96167fbfe8bd45d94e4dcda42c7d09eb5745459 - + https://dev.azure.com/dnceng/internal/_git/dotnet-aspnetcore - 704f7cb1d2cea33afb00c2097731216f121c2c73 + b96167fbfe8bd45d94e4dcda42c7d09eb5745459 - + https://dev.azure.com/dnceng/internal/_git/dotnet-aspnetcore - 704f7cb1d2cea33afb00c2097731216f121c2c73 + b96167fbfe8bd45d94e4dcda42c7d09eb5745459 - + https://dev.azure.com/dnceng/internal/_git/dotnet-aspnetcore - 704f7cb1d2cea33afb00c2097731216f121c2c73 + b96167fbfe8bd45d94e4dcda42c7d09eb5745459 - + https://dev.azure.com/dnceng/internal/_git/dotnet-aspnetcore - 704f7cb1d2cea33afb00c2097731216f121c2c73 + b96167fbfe8bd45d94e4dcda42c7d09eb5745459 - + https://dev.azure.com/dnceng/internal/_git/dotnet-aspnetcore - 704f7cb1d2cea33afb00c2097731216f121c2c73 + b96167fbfe8bd45d94e4dcda42c7d09eb5745459 - + https://dev.azure.com/dnceng/internal/_git/dotnet-aspnetcore - 704f7cb1d2cea33afb00c2097731216f121c2c73 + b96167fbfe8bd45d94e4dcda42c7d09eb5745459 - + https://dev.azure.com/dnceng/internal/_git/dotnet-aspnetcore - 704f7cb1d2cea33afb00c2097731216f121c2c73 + b96167fbfe8bd45d94e4dcda42c7d09eb5745459 - + https://dev.azure.com/dnceng/internal/_git/dotnet-aspnetcore - 704f7cb1d2cea33afb00c2097731216f121c2c73 + b96167fbfe8bd45d94e4dcda42c7d09eb5745459 - - https://github.com/dotnet/efcore - + + https://dev.azure.com/dnceng/internal/_git/dotnet-efcore + 68c7e19496df80819410fc6de1682a194aad33d3 diff --git a/eng/Versions.props b/eng/Versions.props index baf6e933cfb..9c74d5ae9d4 100644 --- a/eng/Versions.props +++ b/eng/Versions.props @@ -11,8 +11,11 @@ $(MajorVersion).$(MinorVersion).0.0 - + release true @@ -28,55 +31,55 @@ --> - 9.0.2 - 9.0.2 - 9.0.2 - 9.0.2 - 9.0.2 - 9.0.2 - 9.0.2 - 9.0.2 - 9.0.2 - 9.0.2 - 9.0.2 - 9.0.2 - 9.0.2 - 9.0.2 - 9.0.2 - 9.0.2 - 9.0.2 - 9.0.2 - 9.0.2 - 9.0.2 - 9.0.2 - 9.0.2 - 9.0.2 - 9.0.2 - 9.0.2 - 9.0.2 - 9.0.2 - 9.0.2 - 9.0.2 - 9.0.2 - 9.0.2 - 9.0.2 - 9.0.2 - 9.0.2 - 9.0.2 - 9.0.2 - 9.0.2 + 9.0.3 + 9.0.3 + 9.0.3 + 9.0.3 + 9.0.3 + 9.0.3 + 9.0.3 + 9.0.3 + 9.0.3 + 9.0.3 + 9.0.3 + 9.0.3 + 9.0.3 + 9.0.3 + 9.0.3 + 9.0.3 + 9.0.3 + 9.0.3 + 9.0.3 + 9.0.3 + 9.0.3 + 9.0.3 + 9.0.3 + 9.0.3 + 9.0.3 + 9.0.3 + 9.0.3 + 9.0.3 + 9.0.3 + 9.0.3 + 9.0.3 + 9.0.3 + 9.0.3 + 9.0.3 + 9.0.3 + 9.0.3 + 9.0.3 - 9.0.2 - 9.0.2 - 9.0.2 - 9.0.2 - 9.0.2 - 9.0.2 - 9.0.2 - 9.0.2 - 9.0.2 + 9.0.3 + 9.0.3 + 9.0.3 + 9.0.3 + 9.0.3 + 9.0.3 + 9.0.3 + 9.0.3 + 9.0.3 - 9.0.2 + 9.0.3 9.0.0-beta.25111.5 @@ -102,8 +105,8 @@ 8.0.1 8.0.0 8.0.2 - 8.0.13 - 8.0.13 + 8.0.14 + 8.0.14 8.0.0 8.0.1 8.0.1 @@ -120,15 +123,15 @@ 8.0.5 8.0.0 - 8.0.13 - 8.0.13 - 8.0.13 - 8.0.13 - 8.0.13 - 8.0.13 - 8.0.13 - 8.0.13 - 8.0.13 + 8.0.14 + 8.0.14 + 8.0.14 + 8.0.14 + 8.0.14 + 8.0.14 + 8.0.14 + 8.0.14 + 8.0.14 1.0.0-beta.3 2.2.0-beta.1 diff --git a/eng/pipelines/templates/BuildAndTest.yml b/eng/pipelines/templates/BuildAndTest.yml index a1541c34c75..35a7104b40a 100644 --- a/eng/pipelines/templates/BuildAndTest.yml +++ b/eng/pipelines/templates/BuildAndTest.yml @@ -28,6 +28,23 @@ steps: inputs: versionSpec: "20.x" checkLatest: true + - task: PowerShell@2 + displayName: Setup Private Feeds Credentials + condition: eq(variables['Agent.OS'], 'Windows_NT') + inputs: + filePath: $(Build.SourcesDirectory)/eng/common/SetupNugetSources.ps1 + arguments: -ConfigFile $(Build.SourcesDirectory)/NuGet.config -Password $Env:Token + env: + Token: $(dn-bot-dnceng-artifact-feeds-rw) + + - task: Bash@3 + displayName: Setup Private Feeds Credentials + condition: ne(variables['Agent.OS'], 'Windows_NT') + inputs: + filePath: $(Build.SourcesDirectory)/eng/common/SetupNugetSources.sh + arguments: $(Build.SourcesDirectory)/NuGet.config $Token + env: + Token: $(dn-bot-dnceng-artifact-feeds-rw) - script: ${{ parameters.buildScript }} -restore diff --git a/src/Generators/Microsoft.Gen.MetricsReports/Directory.Build.props b/src/Generators/Microsoft.Gen.MetricsReports/Directory.Build.props new file mode 100644 index 00000000000..f739a758633 --- /dev/null +++ b/src/Generators/Microsoft.Gen.MetricsReports/Directory.Build.props @@ -0,0 +1,10 @@ + + + + dev + true + + + \ No newline at end of file diff --git a/src/Generators/Microsoft.Gen.MetricsReports/Microsoft.Gen.MetricsReports.csproj b/src/Generators/Microsoft.Gen.MetricsReports/Microsoft.Gen.MetricsReports.csproj index ecfe0d4059a..47ffba75388 100644 --- a/src/Generators/Microsoft.Gen.MetricsReports/Microsoft.Gen.MetricsReports.csproj +++ b/src/Generators/Microsoft.Gen.MetricsReports/Microsoft.Gen.MetricsReports.csproj @@ -12,7 +12,6 @@ - dev 67 85 diff --git a/src/Libraries/Microsoft.AspNetCore.Testing/Directory.Build.props b/src/Libraries/Microsoft.AspNetCore.Testing/Directory.Build.props new file mode 100644 index 00000000000..4503261437d --- /dev/null +++ b/src/Libraries/Microsoft.AspNetCore.Testing/Directory.Build.props @@ -0,0 +1,11 @@ + + + + dev + EXTEXP0014 + + + + \ No newline at end of file diff --git a/src/Libraries/Microsoft.AspNetCore.Testing/Microsoft.AspNetCore.Testing.csproj b/src/Libraries/Microsoft.AspNetCore.Testing/Microsoft.AspNetCore.Testing.csproj index d135a3472db..01eb8bb974e 100644 --- a/src/Libraries/Microsoft.AspNetCore.Testing/Microsoft.AspNetCore.Testing.csproj +++ b/src/Libraries/Microsoft.AspNetCore.Testing/Microsoft.AspNetCore.Testing.csproj @@ -12,8 +12,6 @@ - dev - EXTEXP0014 100 100 diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.csproj b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.csproj index da24217861e..123a9a23334 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.csproj +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.csproj @@ -1,4 +1,4 @@ - + Microsoft.Extensions.AI @@ -8,6 +8,7 @@ preview + true false 82 0 diff --git a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/Microsoft.Extensions.AI.AzureAIInference.csproj b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/Microsoft.Extensions.AI.AzureAIInference.csproj index 5384a7992d7..1f14a18d823 100644 --- a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/Microsoft.Extensions.AI.AzureAIInference.csproj +++ b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/Microsoft.Extensions.AI.AzureAIInference.csproj @@ -1,4 +1,4 @@ - + Microsoft.Extensions.AI @@ -8,6 +8,7 @@ preview + true false 86 0 diff --git a/src/Libraries/Microsoft.Extensions.AI.Ollama/Microsoft.Extensions.AI.Ollama.csproj b/src/Libraries/Microsoft.Extensions.AI.Ollama/Microsoft.Extensions.AI.Ollama.csproj index 4189a7fb466..b8e47a28fad 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Ollama/Microsoft.Extensions.AI.Ollama.csproj +++ b/src/Libraries/Microsoft.Extensions.AI.Ollama/Microsoft.Extensions.AI.Ollama.csproj @@ -1,4 +1,4 @@ - + Microsoft.Extensions.AI @@ -8,6 +8,7 @@ preview + true false 78 0 diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/Microsoft.Extensions.AI.OpenAI.csproj b/src/Libraries/Microsoft.Extensions.AI.OpenAI/Microsoft.Extensions.AI.OpenAI.csproj index 18bfe009184..f9e83e3ce88 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/Microsoft.Extensions.AI.OpenAI.csproj +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/Microsoft.Extensions.AI.OpenAI.csproj @@ -1,4 +1,4 @@ - + Microsoft.Extensions.AI @@ -8,6 +8,7 @@ preview + true false 49 0 diff --git a/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj b/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj index 10f590639ec..72bfb799ae7 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj +++ b/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj @@ -1,4 +1,4 @@ - + Microsoft.Extensions.AI @@ -10,6 +10,7 @@ preview + true false 89 0 diff --git a/src/Libraries/Microsoft.Extensions.Diagnostics.Probes/Directory.Build.props b/src/Libraries/Microsoft.Extensions.Diagnostics.Probes/Directory.Build.props new file mode 100644 index 00000000000..0ea108580da --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.Diagnostics.Probes/Directory.Build.props @@ -0,0 +1,12 @@ + + + + true + dev + EXTEXP0015 + + + + \ No newline at end of file diff --git a/src/Libraries/Microsoft.Extensions.Diagnostics.Probes/Microsoft.Extensions.Diagnostics.Probes.csproj b/src/Libraries/Microsoft.Extensions.Diagnostics.Probes/Microsoft.Extensions.Diagnostics.Probes.csproj index 8eb6ac3b9dc..4336188ced0 100644 --- a/src/Libraries/Microsoft.Extensions.Diagnostics.Probes/Microsoft.Extensions.Diagnostics.Probes.csproj +++ b/src/Libraries/Microsoft.Extensions.Diagnostics.Probes/Microsoft.Extensions.Diagnostics.Probes.csproj @@ -13,8 +13,6 @@ - dev - EXTEXP0015 76 75 diff --git a/src/Libraries/Microsoft.Extensions.Hosting.Testing/Directory.Build.props b/src/Libraries/Microsoft.Extensions.Hosting.Testing/Directory.Build.props new file mode 100644 index 00000000000..77a9a53a9e9 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.Hosting.Testing/Directory.Build.props @@ -0,0 +1,12 @@ + + + + true + dev + EXTEXP0016 + + + + \ No newline at end of file diff --git a/src/Libraries/Microsoft.Extensions.Hosting.Testing/Microsoft.Extensions.Hosting.Testing.csproj b/src/Libraries/Microsoft.Extensions.Hosting.Testing/Microsoft.Extensions.Hosting.Testing.csproj index 63f7651746f..fdc40c84838 100644 --- a/src/Libraries/Microsoft.Extensions.Hosting.Testing/Microsoft.Extensions.Hosting.Testing.csproj +++ b/src/Libraries/Microsoft.Extensions.Hosting.Testing/Microsoft.Extensions.Hosting.Testing.csproj @@ -13,8 +13,6 @@ - dev - EXTEXP0016 100 90 diff --git a/src/Libraries/Microsoft.Extensions.Options.Contextual/Directory.Build.props b/src/Libraries/Microsoft.Extensions.Options.Contextual/Directory.Build.props new file mode 100644 index 00000000000..59864c9c658 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.Options.Contextual/Directory.Build.props @@ -0,0 +1,12 @@ + + + + true + dev + EXTEXP0017 + + + + \ No newline at end of file diff --git a/src/Libraries/Microsoft.Extensions.Options.Contextual/Microsoft.Extensions.Options.Contextual.csproj b/src/Libraries/Microsoft.Extensions.Options.Contextual/Microsoft.Extensions.Options.Contextual.csproj index e15687a9ee7..d80898ce3ae 100644 --- a/src/Libraries/Microsoft.Extensions.Options.Contextual/Microsoft.Extensions.Options.Contextual.csproj +++ b/src/Libraries/Microsoft.Extensions.Options.Contextual/Microsoft.Extensions.Options.Contextual.csproj @@ -11,8 +11,6 @@ - dev - EXTEXP0017 100 80 diff --git a/src/ProjectTemplates/Microsoft.Extensions.AI.Templates/Microsoft.Extensions.AI.Templates.csproj b/src/ProjectTemplates/Microsoft.Extensions.AI.Templates/Microsoft.Extensions.AI.Templates.csproj index 60078685015..2762f3148df 100644 --- a/src/ProjectTemplates/Microsoft.Extensions.AI.Templates/Microsoft.Extensions.AI.Templates.csproj +++ b/src/ProjectTemplates/Microsoft.Extensions.AI.Templates/Microsoft.Extensions.AI.Templates.csproj @@ -7,6 +7,7 @@ dotnet-new;templates;ai preview + true AI 0 0 From be25c10b63b183526f007c05726cd7b221d7950e Mon Sep 17 00:00:00 2001 From: Jeff Handley Date: Wed, 5 Mar 2025 05:39:18 +0000 Subject: [PATCH 2/7] Merged PR 48137: Force EF Sqlite down to version 9.0.2 in the MEAI.Templates project We are publishing MEAI.Templates before we publish EF.Sqlite 9.0.2, so we need to force a downgrade of the package version in the project template to prevent unresolved packages. This must be undone or made more durable immediately after initial publication. --- .../ChatWithCustomData.Web-CSharp.csproj.in | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ProjectTemplates/Microsoft.Extensions.AI.Templates/src/ChatWithCustomData/ChatWithCustomData.Web-CSharp/ChatWithCustomData.Web-CSharp.csproj.in b/src/ProjectTemplates/Microsoft.Extensions.AI.Templates/src/ChatWithCustomData/ChatWithCustomData.Web-CSharp/ChatWithCustomData.Web-CSharp.csproj.in index 92cb05131b3..707e0a73855 100644 --- a/src/ProjectTemplates/Microsoft.Extensions.AI.Templates/src/ChatWithCustomData/ChatWithCustomData.Web-CSharp/ChatWithCustomData.Web-CSharp.csproj.in +++ b/src/ProjectTemplates/Microsoft.Extensions.AI.Templates/src/ChatWithCustomData/ChatWithCustomData.Web-CSharp/ChatWithCustomData.Web-CSharp.csproj.in @@ -24,7 +24,7 @@ - + From 9fea0aa5674a8914f6e80130579f761324c0f1f1 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Thu, 6 Mar 2025 10:50:41 -0500 Subject: [PATCH 3/7] Update IChatClient to support multiple return messages - IChatClient no longer bakes mutation of the messages into the contract. The messages are now an `IEnumerable` rather than an `IList`. - The purpose for mutation was to allow for multiple messages to be generated as part of an operation. All messages generated are now returned as part of the ChatResponse, which has a Messages rather than Message property. - Choices have been removed from the surface area, e.g. no ChatResponse.Choices and no ChatResponseUpdate.ChoiceIndex. --- .../AdditionalPropertiesDictionary{TValue}.cs | 12 + .../ChatCompletion/ChatMessage.cs | 36 +- .../ChatCompletion/ChatResponse.cs | 107 ++-- .../ChatCompletion/ChatResponseExtensions.cs | 304 +++++++++++ .../ChatCompletion/ChatResponseUpdate.cs | 71 +-- .../ChatResponseUpdateExtensions.cs | 248 --------- .../ChatCompletion/DelegatingChatClient.cs | 16 +- .../ChatCompletion/IChatClient.cs | 37 +- .../ChatCompletion/RequiredChatToolMode.cs | 1 + .../Contents/AIContentExtensions.cs | 94 +++- .../Contents/FunctionCallContent.cs | 6 +- .../DelegatingEmbeddingGenerator.cs | 2 +- .../EmbeddingGeneratorExtensions.cs | 14 +- .../Embeddings/IEmbeddingGenerator.cs | 15 +- .../README.md | 191 ++----- .../UsageDetails.cs | 3 + .../Utilities/AIJsonUtilities.Schema.cs | 2 + .../AzureAIInferenceChatClient.cs | 34 +- .../AzureAIInferenceEmbeddingGenerator.cs | 3 + .../ChatConversationEvaluator.cs | 47 +- .../CoherenceEvaluator.cs | 5 +- .../EquivalenceEvaluator.cs | 5 +- .../FluencyEvaluator.cs | 5 +- .../GroundednessEvaluator.cs | 5 +- .../RelevanceTruthAndCompletenessEvaluator.cs | 13 +- .../SingleNumericMetricEvaluator.cs | 8 +- .../Utilities/JsonOutputFixer.cs | 4 +- .../Storage/AzureStorageResultStore.cs | 2 +- .../CSharp/ScenarioRun.cs | 2 +- .../CSharp/ScenarioRunExtensions.cs | 67 ++- .../CSharp/ScenarioRunResult.cs | 10 +- .../CSharp/ScenarioRunResultExtensions.cs | 2 +- .../CSharp/Storage/DiskBasedResultStore.cs | 2 +- .../CompositeEvaluator.cs | 6 +- .../EvaluationMetricExtensions.cs | 2 +- .../EvaluationResult.cs | 2 +- .../EvaluationResultExtensions.cs | 8 +- .../EvaluatorExtensions.cs | 98 +++- .../IEvaluator.cs | 2 +- .../TokenizerExtensions.cs | 2 +- .../OllamaChatClient.cs | 21 +- .../OllamaEmbeddingGenerator.cs | 2 + .../OpenAIAssistantClient.cs | 28 +- .../OpenAIChatClient.cs | 15 +- .../OpenAIEmbeddingGenerator.cs | 5 + .../OpenAIModelMapper.ChatCompletion.cs | 36 +- .../OpenAIRealtimeExtensions.cs | 4 + .../OpenAISerializationHelpers.cs | 5 + .../AnonymousDelegatingChatClient.cs | 45 +- .../ChatCompletion/CachingChatClient.cs | 37 +- .../ChatCompletion/ChatClientBuilder.cs | 16 +- .../ChatClientBuilderChatClientExtensions.cs | 2 + ...lientBuilderServiceCollectionExtensions.cs | 23 +- .../ChatClientStructuredOutputExtensions.cs | 45 +- .../ChatCompletion/ChatResponse{T}.cs | 15 +- .../ConfigureOptionsChatClient.cs | 11 +- ...igureOptionsChatClientBuilderExtensions.cs | 2 + ...butedCachingChatClientBuilderExtensions.cs | 1 + .../FunctionInvocationContext.cs | 11 +- .../FunctionInvokingChatClient.cs | 504 +++++++++--------- ...tionInvokingChatClientBuilderExtensions.cs | 1 + .../ChatCompletion/LoggingChatClient.cs | 16 +- .../LoggingChatClientBuilderExtensions.cs | 1 + .../ChatCompletion/OpenTelemetryChatClient.cs | 45 +- .../Embeddings/CachingEmbeddingGenerator.cs | 2 +- ...ionsEmbeddingGeneratorBuilderExtensions.cs | 2 + .../DistributedCachingEmbeddingGenerator.cs | 2 + ...hingEmbeddingGeneratorBuilderExtensions.cs | 1 + .../Embeddings/EmbeddingGeneratorBuilder.cs | 10 +- ...atorBuilderEmbeddingGeneratorExtensions.cs | 2 + ...ratorBuilderServiceCollectionExtensions.cs | 23 +- ...gingEmbeddingGeneratorBuilderExtensions.cs | 1 + .../Functions/AIFunctionFactory.cs | 4 + .../Ingestion/IngestionCacheDbContext.cs | 2 +- .../Services/JsonVectorStore.cs | 2 +- .../ChatClientExtensionsTests.cs | 14 +- .../ChatCompletion/ChatMessageTests.cs | 86 ++- .../ChatCompletion/ChatResponseTests.cs | 204 ++----- .../ChatResponseUpdateExtensionsTests.cs | 104 +--- .../ChatCompletion/ChatResponseUpdateTests.cs | 69 +-- .../DelegatingChatClientTests.cs | 6 +- .../EmbeddingGeneratorExtensionsTests.cs | 2 +- .../TestChatClient.cs | 12 +- .../Utilities/AIJsonUtilitiesTests.cs | 16 +- .../AzureAIInferenceChatClientTests.cs | 35 +- .../AdditionalContextTests.cs | 5 +- .../EndToEndTests.cs | 4 +- ...vanceTruthAndCompletenessEvaluatorTests.cs | 4 +- .../TestEvaluator.cs | 2 +- .../ResultStoreTester.cs | 2 +- .../ScenarioRunResultTests.cs | 23 +- .../CallCountingChatClient.cs | 8 +- .../ChatClientIntegrationTests.cs | 52 +- .../PromptBasedFunctionCallingChatClient.cs | 38 +- .../ReducingChatClientTests.cs | 86 +-- .../OllamaChatClientIntegrationTests.cs | 10 +- .../OllamaChatClientTests.cs | 25 +- .../OpenAIChatClientTests.cs | 39 +- .../OpenAISerializationTests.cs | 24 +- ...atClientStructuredOutputExtensionsTests.cs | 3 +- .../ConfigureOptionsChatClientTests.cs | 2 +- .../DistributedCachingChatClientTest.cs | 125 ++--- .../FunctionInvocationContextTests.cs | 14 +- .../FunctionInvokingChatClientTests.cs | 293 ++++------ .../ChatCompletion/LoggingChatClientTests.cs | 6 +- .../OpenTelemetryChatClientTests.cs | 12 +- .../UseDelegateChatClientTests.cs | 93 ++-- .../UseDelegateEmbeddingGeneratorTests.cs | 2 +- .../Functions/AIFunctionFactoryTest.cs | 2 +- 109 files changed, 1881 insertions(+), 1976 deletions(-) create mode 100644 src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseExtensions.cs delete mode 100644 src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdateExtensions.cs diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/AdditionalPropertiesDictionary{TValue}.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/AdditionalPropertiesDictionary{TValue}.cs index 21d1daf2820..14125e95b76 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/AdditionalPropertiesDictionary{TValue}.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/AdditionalPropertiesDictionary{TValue}.cs @@ -201,6 +201,18 @@ public bool TryGetValue(string key, [NotNullWhen(true)] out T? value) /// bool IReadOnlyDictionary.TryGetValue(string key, out TValue value) => _dictionary.TryGetValue(key, out value!); + /// Copies all of the entries from into the dictionary, overwriting any existing items in the dictionary with the same key. + /// The items to add. + internal void SetAll(IEnumerable> items) + { + _ = Throw.IfNull(items); + + foreach (var item in items) + { + _dictionary[item.Key] = item.Value; + } + } + /// Enumerates the elements of an . public struct Enumerator : IEnumerator> { diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatMessage.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatMessage.cs index 32e2159950c..049536cecd8 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatMessage.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatMessage.cs @@ -1,11 +1,11 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System; using System.Collections.Generic; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Text.Json.Serialization; -using Microsoft.Shared.Diagnostics; namespace Microsoft.Extensions.AI; @@ -17,6 +17,7 @@ public class ChatMessage private string? _authorName; /// Initializes a new instance of the class. + /// The instance defaults to having a role of . [JsonConstructor] public ChatMessage() { @@ -24,7 +25,7 @@ public ChatMessage() /// Initializes a new instance of the class. /// The role of the author of the message. - /// The contents of the message. + /// The text content of the message. public ChatMessage(ChatRole role, string? content) : this(role, content is null ? [] : [new TextContent(content)]) { @@ -33,12 +34,10 @@ public ChatMessage(ChatRole role, string? content) /// Initializes a new instance of the class. /// The role of the author of the message. /// The contents for this message. - public ChatMessage( - ChatRole role, - IList contents) + public ChatMessage(ChatRole role, IList? contents) { Role = role; - _contents = Throw.IfNull(contents); + _contents = contents; } /// Clones the to a new instance. @@ -67,29 +66,12 @@ public string? AuthorName /// Gets or sets the role of the author of the message. public ChatRole Role { get; set; } = ChatRole.User; - /// - /// Gets or sets the text of the first instance in . - /// + /// Gets the text of this message. /// - /// If there is no instance in , then the getter returns , - /// and the setter adds a new instance with the provided value. + /// This property concatenates the text of all objects in . /// [JsonIgnore] - public string? Text - { - get => Contents.FindFirst()?.Text; - set - { - if (Contents.FindFirst() is { } textContent) - { - textContent.Text = value; - } - else if (value is not null) - { - Contents.Add(new TextContent(value)); - } - } - } + public string Text => Contents.ConcatText(); /// Gets or sets the chat message content items. [AllowNull] @@ -112,7 +94,7 @@ public IList Contents public AdditionalPropertiesDictionary? AdditionalProperties { get; set; } /// - public override string ToString() => Contents.ConcatText(); + public override string ToString() => Text; /// Gets a object to display in the debugger display. [DebuggerBrowsable(DebuggerBrowsableState.Never)] diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponse.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponse.cs index f789fc7f974..6babae1258f 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponse.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponse.cs @@ -3,60 +3,62 @@ using System; using System.Collections.Generic; -using System.Text; +using System.Diagnostics.CodeAnalysis; using System.Text.Json.Serialization; using Microsoft.Shared.Diagnostics; namespace Microsoft.Extensions.AI; /// Represents the response to a chat request. +/// +/// provides one or more response messages and metadata about the response. +/// A typical response will contain a single message, however a response may contain multiple messages +/// in a variety of scenarios. For example, if automatic function calling is employed, such that a single +/// request to a may actually generate multiple roundtrips to an inner +/// it uses, all of the involved messages may be surfaced as part of the final . +/// public class ChatResponse { - /// The list of choices in the response. - private IList _choices; + /// The response messages. + private IList? _messages; /// Initializes a new instance of the class. - /// The list of choices in the response, one message per choice. - [JsonConstructor] - public ChatResponse(IList choices) + public ChatResponse() { - _choices = Throw.IfNull(choices); } /// Initializes a new instance of the class. - /// The chat message representing the singular choice in the response. + /// The response message. + /// is . public ChatResponse(ChatMessage message) { _ = Throw.IfNull(message); - _choices = [message]; + + Messages.Add(message); + } + + /// Initializes a new instance of the class. + /// The response messages. + public ChatResponse(IList? messages) + { + _messages = messages; } - /// Gets or sets the list of chat response choices. - public IList Choices + /// Gets or sets the chat response messages. + [AllowNull] + public IList Messages { - get => _choices; - set => _choices = Throw.IfNull(value); + get => _messages ??= new List(1); + set => _messages = value; } - /// Gets the chat response message. + /// Gets the text of the response. /// - /// If there are multiple choices, this property returns the first choice. - /// If is empty, this property will throw. Use to access all choices directly. + /// This property concatenates the of all + /// instances in . /// [JsonIgnore] - public ChatMessage Message - { - get - { - var choices = Choices; - if (choices.Count == 0) - { - throw new InvalidOperationException($"The {nameof(ChatResponse)} instance does not contain any {nameof(ChatMessage)} choices."); - } - - return choices[0]; - } - } + public string Text => _messages?.ConcatText() ?? string.Empty; /// Gets or sets the ID of the chat response. public string? ResponseId { get; set; } @@ -67,7 +69,7 @@ public ChatMessage Message /// the input messages supplied to need only be the additional messages beyond /// what's already stored. If this property is non-, it represents an identifier for that state, /// and it should be used in a subsequent instead of supplying the same messages - /// (and this 's message) as part of the chatMessages parameter. + /// (and this 's message) as part of the messages parameter. /// public string? ChatThreadId { get; set; } @@ -96,26 +98,7 @@ public ChatMessage Message public AdditionalPropertiesDictionary? AdditionalProperties { get; set; } /// - public override string ToString() - { - if (Choices.Count == 1) - { - return Choices[0].ToString(); - } - - StringBuilder sb = new(); - for (int i = 0; i < Choices.Count; i++) - { - if (i > 0) - { - _ = sb.AppendLine().AppendLine(); - } - - _ = sb.Append("Choice ").Append(i).AppendLine(":").Append(Choices[i]); - } - - return sb.ToString(); - } + public override string ToString() => Text; /// Creates an array of instances that represent this . /// An array of instances that may be used to represent this . @@ -135,22 +118,22 @@ public ChatResponseUpdate[] ToChatResponseUpdates() } } - int choicesCount = Choices.Count; - var updates = new ChatResponseUpdate[choicesCount + (extra is null ? 0 : 1)]; + int messageCount = _messages?.Count ?? 0; + var updates = new ChatResponseUpdate[messageCount + (extra is not null ? 1 : 0)]; - for (int choiceIndex = 0; choiceIndex < choicesCount; choiceIndex++) + int i; + for (i = 0; i < messageCount; i++) { - ChatMessage choice = Choices[choiceIndex]; - updates[choiceIndex] = new ChatResponseUpdate + ChatMessage message = _messages![i]; + updates[i] = new ChatResponseUpdate { ChatThreadId = ChatThreadId, - ChoiceIndex = choiceIndex, - AdditionalProperties = choice.AdditionalProperties, - AuthorName = choice.AuthorName, - Contents = choice.Contents, - RawRepresentation = choice.RawRepresentation, - Role = choice.Role, + AdditionalProperties = message.AdditionalProperties, + AuthorName = message.AuthorName, + Contents = message.Contents, + RawRepresentation = message.RawRepresentation, + Role = message.Role, ResponseId = ResponseId, CreatedAt = CreatedAt, @@ -161,7 +144,7 @@ public ChatResponseUpdate[] ToChatResponseUpdates() if (extra is not null) { - updates[choicesCount] = extra; + updates[i] = extra; } return updates; diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseExtensions.cs new file mode 100644 index 00000000000..16eed49db93 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseExtensions.cs @@ -0,0 +1,304 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Shared.Diagnostics; + +#pragma warning disable S109 // Magic numbers should not be used +#pragma warning disable S1121 // Assignments should not be made from within sub-expressions + +namespace Microsoft.Extensions.AI; + +/// +/// Provides extension methods for working with and instances. +/// +public static class ChatResponseExtensions +{ + /// Adds all of the messages from into . + /// The destination list to which the messages from should be added. + /// The response containing the messages to add. + /// is . + /// is . + public static void AddMessages(this IList list, ChatResponse response) + { + _ = Throw.IfNull(list); + _ = Throw.IfNull(response); + + if (list is List listConcrete) + { + listConcrete.AddRange(response.Messages); + } + else + { + foreach (var message in response.Messages) + { + list.Add(message); + } + } + } + + /// Converts the into instances and adds them to . + /// The destination list to which the newly constructed messages should be added. + /// The instances to convert to messages and add to the list. + /// is . + /// is . + /// + /// As part of combining into a series of instances, tne + /// method may use to determine message boundaries, as well as coalesce + /// contiguous items where applicable, e.g. multiple + /// instances in a row may be combined into a single . + /// + public static void AddMessages(this IList list, IEnumerable updates) + { + _ = Throw.IfNull(list); + _ = Throw.IfNull(updates); + + if (updates is ICollection { Count: 0 }) + { + return; + } + + list.AddMessages(updates.ToChatResponse()); + } + + /// Converts the into instances and adds them to . + /// The list to which the newly constructed messages should be added. + /// The instances to convert to messages and add to the list. + /// The to monitor for cancellation requests. The default is . + /// A representing the completion of the operation. + /// is . + /// is . + /// + /// As part of combining into a series of instances, tne + /// method may use to determine message boundaries, as well as coalesce + /// contiguous items where applicable, e.g. multiple + /// instances in a row may be combined into a single . + /// + public static Task AddMessagesAsync( + this IList list, IAsyncEnumerable updates, CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(list); + _ = Throw.IfNull(updates); + + return AddMessagesAsync(list, updates, cancellationToken); + + static async Task AddMessagesAsync( + IList list, IAsyncEnumerable updates, CancellationToken cancellationToken) => + list.AddMessages(await updates.ToChatResponseAsync(cancellationToken).ConfigureAwait(false)); + } + + /// Combines instances into a single . + /// The updates to be combined. + /// The combined . + /// is . + /// + /// As part of combining into a single , the method will attempt to reconstruct + /// instances. This includes using to determine + /// message boundaries, as well as coalescing contiguous items where applicable, e.g. multiple + /// instances in a row may be combined into a single . + /// + public static ChatResponse ToChatResponse( + this IEnumerable updates) + { + _ = Throw.IfNull(updates); + + ChatResponse response = new(); + + foreach (var update in updates) + { + ProcessUpdate(update, response); + } + + FinalizeResponse(response); + + return response; + } + + /// Combines instances into a single . + /// The updates to be combined. + /// The to monitor for cancellation requests. The default is . + /// The combined . + /// is . + /// + /// As part of combining into a single , the method will attempt to reconstruct + /// instances. This includes using to determine + /// message boundaries, as well as coalescing contiguous items where applicable, e.g. multiple + /// instances in a row may be combined into a single . + /// + public static Task ToChatResponseAsync( + this IAsyncEnumerable updates, CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(updates); + + return ToChatResponseAsync(updates, cancellationToken); + + static async Task ToChatResponseAsync( + IAsyncEnumerable updates, CancellationToken cancellationToken) + { + ChatResponse response = new(); + + await foreach (var update in updates.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + ProcessUpdate(update, response); + } + + FinalizeResponse(response); + + return response; + } + } + + /// Finalizes the object. + private static void FinalizeResponse(ChatResponse response) + { + int count = response.Messages.Count; + for (int i = 0; i < count; i++) + { + CoalesceTextContent((List)response.Messages[i].Contents); + } + } + + /// Processes the , incorporating its contents into . + /// The update to process. + /// The object that should be updated based on . + private static void ProcessUpdate(ChatResponseUpdate update, ChatResponse response) + { + // If there is no message created yet, or if the last update we saw had a different + // response ID than the newest update, create a new message. + ChatMessage message; + if (response.Messages.Count == 0 || + (update.ResponseId is string updateId && response.ResponseId is string responseId && updateId != responseId)) + { + message = new ChatMessage(ChatRole.Assistant, []); + response.Messages.Add(message); + } + else + { + message = response.Messages[response.Messages.Count - 1]; + } + + // Some members on ChatResponseUpdate map to members of ChatMessage. + // Incorporate those into the latest message; in cases where the message + // stores a single value, prefer the latest update's value over anything + // stored in the message. + if (update.AuthorName is not null) + { + message.AuthorName = update.AuthorName; + } + + if (update.Role is ChatRole role) + { + message.Role = role; + } + + foreach (var content in update.Contents) + { + switch (content) + { + // Usage content is treated specially and propagated to the response's Usage. + case UsageContent usage: + (response.Usage ??= new()).Add(usage.Details); + break; + + default: + message.Contents.Add(content); + break; + } + } + + // Other members on a ChatResponseUpdate map to members of the ChatResponse. + // Update the response object with those, preferring the values from later updates. + + if (update.ResponseId is not null) + { + // Note that this must come after the message checks earlier, as they depend + // on this value for change detection. + response.ResponseId = update.ResponseId; + } + + if (update.ChatThreadId is not null) + { + response.ChatThreadId = update.ChatThreadId; + } + + if (update.CreatedAt is not null) + { + response.CreatedAt = update.CreatedAt; + } + + if (update.FinishReason is not null) + { + response.FinishReason = update.FinishReason; + } + + if (update.ModelId is not null) + { + response.ModelId = update.ModelId; + } + + if (update.AdditionalProperties is not null) + { + if (response.AdditionalProperties is null) + { + response.AdditionalProperties = new(update.AdditionalProperties); + } + else + { + response.AdditionalProperties.SetAll(update.AdditionalProperties); + } + } + } + + /// Coalesces sequential content elements. + private static void CoalesceTextContent(List contents) + { + StringBuilder? coalescedText = null; + + // Iterate through all of the items in the list looking for contiguous items that can be coalesced. + int start = 0; + while (start < contents.Count - 1) + { + // We need at least two TextContents in a row to be able to coalesce. + if (contents[start] is not TextContent firstText) + { + start++; + continue; + } + + if (contents[start + 1] is not TextContent secondText) + { + start += 2; + continue; + } + + // Append the text from those nodes and continue appending subsequent TextContents until we run out. + // We null out nodes as their text is appended so that we can later remove them all in one O(N) operation. + coalescedText ??= new(); + _ = coalescedText.Clear().Append(firstText.Text).Append(secondText.Text); + contents[start + 1] = null!; + int i = start + 2; + for (; i < contents.Count && contents[i] is TextContent next; i++) + { + _ = coalescedText.Append(next.Text); + contents[i] = null!; + } + + // Store the replacement node. + contents[start] = new TextContent(coalescedText.ToString()) + { + // We inherit the properties of the first text node. We don't currently propagate additional + // properties from the subsequent nodes. If we ever need to, we can add that here. + AdditionalProperties = firstText.AdditionalProperties?.Clone(), + }; + + start = i; + } + + // Remove all of the null slots left over from the coalescing process. + _ = contents.RemoveAll(u => u is null); + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdate.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdate.cs index 8bf9e57ece2..24610ac76fc 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdate.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdate.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Text.Json.Serialization; @@ -15,21 +16,20 @@ namespace Microsoft.Extensions.AI; /// /// is so named because it represents updates /// that layer on each other to form a single chat response. Conceptually, this combines the roles of -/// and in streaming output. For ease of consumption, -/// it also flattens the nested structure you see on streaming chunks in some AI services, so instead of a -/// dictionary of choices, each update is part of a single choice (and hence has its own role, choice ID, etc.). +/// and in streaming output. /// /// /// The relationship between and is -/// codified in the and +/// codified in the and /// , which enable bidirectional conversions /// between the two. Note, however, that the provided conversions may be lossy, for example if multiple /// updates all have different objects whereas there's only one slot for /// such an object available in . Similarly, if different -/// updates that are part of the same choice provide different values for properties like , +/// updates provide different values for properties like , /// only one of the values will be used to populate . /// /// +[DebuggerDisplay("[{Role}] {ContentForDebuggerDisplay}{EllipsesForDebuggerDisplay,nq}")] public class ChatResponseUpdate { /// The response update content items. @@ -38,6 +38,29 @@ public class ChatResponseUpdate /// The name of the author of the update. private string? _authorName; + /// Initializes a new instance of the class. + [JsonConstructor] + public ChatResponseUpdate() + { + } + + /// Initializes a new instance of the class. + /// The role of the author of the update. + /// The text content of the update. + public ChatResponseUpdate(ChatRole? role, string? content) + : this(role, content is null ? null : [new TextContent(content)]) + { + } + + /// Initializes a new instance of the class. + /// The role of the author of the update. + /// The contents of the update. + public ChatResponseUpdate(ChatRole? role, IList? contents) + { + Role = role; + _contents = contents; + } + /// Gets or sets the name of the author of the response update. public string? AuthorName { @@ -48,29 +71,12 @@ public string? AuthorName /// Gets or sets the role of the author of the response update. public ChatRole? Role { get; set; } - /// - /// Gets or sets the text of the first instance in . - /// + /// Gets the text of this update. /// - /// If there is no instance in , then the getter returns , - /// and the setter will add new instance with the provided value. + /// This property concatenates the text of all objects in . /// [JsonIgnore] - public string? Text - { - get => Contents.FindFirst()?.Text; - set - { - if (Contents.FindFirst() is { } textContent) - { - textContent.Text = value; - } - else if (value is not null) - { - Contents.Add(new TextContent(value)); - } - } - } + public string Text => _contents is not null ? _contents.ConcatText() : string.Empty; /// Gets or sets the chat response update content items. [AllowNull] @@ -101,16 +107,13 @@ public IList Contents /// the input messages supplied to need only be the additional messages beyond /// what's already stored. If this property is non-, it represents an identifier for that state, /// and it should be used in a subsequent instead of supplying the same messages - /// (and this streaming message) as part of the chatMessages parameter. + /// (and this streaming message) as part of the messages parameter. /// public string? ChatThreadId { get; set; } /// Gets or sets a timestamp for the response update. public DateTimeOffset? CreatedAt { get; set; } - /// Gets or sets the zero-based index of the choice with which this update is associated in the streaming sequence. - public int ChoiceIndex { get; set; } - /// Gets or sets the finish reason for the operation. public ChatFinishReason? FinishReason { get; set; } @@ -118,5 +121,13 @@ public IList Contents public string? ModelId { get; set; } /// - public override string ToString() => Contents.ConcatText(); + public override string ToString() => Text; + + /// Gets a object to display in the debugger display. + [DebuggerBrowsable(DebuggerBrowsableState.Never)] + private AIContent? ContentForDebuggerDisplay => _contents is { Count: > 0 } ? _contents[0] : null; + + /// Gets an indication for the debugger display of whether there's more content. + [DebuggerBrowsable(DebuggerBrowsableState.Never)] + private string EllipsesForDebuggerDisplay => _contents is { Count: > 1 } ? ", ..." : string.Empty; } diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdateExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdateExtensions.cs deleted file mode 100644 index 25104461cd9..00000000000 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdateExtensions.cs +++ /dev/null @@ -1,248 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Collections.Generic; -using System.Linq; -#if NET -using System.Runtime.InteropServices; -#endif -using System.Text; -using System.Threading; -using System.Threading.Tasks; -using Microsoft.Shared.Diagnostics; - -#pragma warning disable S109 // Magic numbers should not be used -#pragma warning disable S127 // "for" loop stop conditions should be invariant -#pragma warning disable S1121 // Assignments should not be made from within sub-expressions - -namespace Microsoft.Extensions.AI; - -/// -/// Provides extension methods for working with instances. -/// -public static class ChatResponseUpdateExtensions -{ - /// Combines instances into a single . - /// The updates to be combined. - /// - /// to attempt to coalesce contiguous items, where applicable, - /// into a single , in order to reduce the number of individual content items that are included in - /// the manufactured instances. When , the original content items are used. - /// The default is . - /// - /// The combined . - public static ChatResponse ToChatResponse( - this IEnumerable updates, bool coalesceContent = true) - { - _ = Throw.IfNull(updates); - - ChatResponse response = new([]); - Dictionary messages = []; - - foreach (var update in updates) - { - ProcessUpdate(update, messages, response); - } - - AddMessagesToResponse(messages, response, coalesceContent); - - return response; - } - - /// Combines instances into a single . - /// The updates to be combined. - /// - /// to attempt to coalesce contiguous items, where applicable, - /// into a single , in order to reduce the number of individual content items that are included in - /// the manufactured instances. When , the original content items are used. - /// The default is . - /// - /// The to monitor for cancellation requests. The default is . - /// The combined . - public static Task ToChatResponseAsync( - this IAsyncEnumerable updates, bool coalesceContent = true, CancellationToken cancellationToken = default) - { - _ = Throw.IfNull(updates); - - return ToChatResponseAsync(updates, coalesceContent, cancellationToken); - - static async Task ToChatResponseAsync( - IAsyncEnumerable updates, bool coalesceContent, CancellationToken cancellationToken) - { - ChatResponse response = new([]); - Dictionary messages = []; - - await foreach (var update in updates.WithCancellation(cancellationToken).ConfigureAwait(false)) - { - ProcessUpdate(update, messages, response); - } - - AddMessagesToResponse(messages, response, coalesceContent); - - return response; - } - } - - /// Processes the , incorporating its contents into and . - /// The update to process. - /// The dictionary mapping to the being built for that choice. - /// The object whose properties should be updated based on . - private static void ProcessUpdate(ChatResponseUpdate update, Dictionary messages, ChatResponse response) - { - response.ChatThreadId ??= update.ChatThreadId; - response.CreatedAt ??= update.CreatedAt; - response.FinishReason ??= update.FinishReason; - response.ModelId ??= update.ModelId; - response.ResponseId ??= update.ResponseId; - -#if NET - ChatMessage message = CollectionsMarshal.GetValueRefOrAddDefault(messages, update.ChoiceIndex, out _) ??= - new(default, new List()); -#else - if (!messages.TryGetValue(update.ChoiceIndex, out ChatMessage? message)) - { - messages[update.ChoiceIndex] = message = new(default, new List()); - } -#endif - - // Incorporate all content from the update into the response. - foreach (var content in update.Contents) - { - switch (content) - { - // Usage content is treated specially and propagated to the response's Usage. - case UsageContent usage: - (response.Usage ??= new()).Add(usage.Details); - break; - - default: - message.Contents.Add(content); - break; - } - } - - message.AuthorName ??= update.AuthorName; - if (update.Role is ChatRole role && message.Role == default) - { - message.Role = role; - } - - if (update.AdditionalProperties is not null) - { - if (message.AdditionalProperties is null) - { - message.AdditionalProperties = new(update.AdditionalProperties); - } - else - { - foreach (var entry in update.AdditionalProperties) - { - // Use first-wins behavior to match the behavior of the other properties. - _ = message.AdditionalProperties.TryAdd(entry.Key, entry.Value); - } - } - } - } - - /// Finalizes the object by transferring the into it. - /// The messages to process further and transfer into . - /// The result being built. - /// The corresponding option value provided to or . - private static void AddMessagesToResponse(Dictionary messages, ChatResponse response, bool coalesceContent) - { - if (messages.Count <= 1) - { - // Add the single message if there is one. - foreach (var entry in messages) - { - AddMessage(response, coalesceContent, entry); - } - - // In the vast majority case where there's only one choice, promote any additional properties - // from the single message to the chat response, making them more discoverable and more similar - // to how they're typically surfaced from non-streaming services. - if (response.Choices.Count == 1 && - response.Choices[0].AdditionalProperties is { } messageProps) - { - response.Choices[0].AdditionalProperties = null; - response.AdditionalProperties = messageProps; - } - } - else - { - // Add all of the messages, sorted by choice index. - foreach (var entry in messages.OrderBy(entry => entry.Key)) - { - AddMessage(response, coalesceContent, entry); - } - - // If there are multiple choices, we don't promote additional properties from the individual messages. - // At a minimum, we'd want to know which choice the additional properties applied to, and if there were - // conflicting values across the choices, it would be unclear which one should be used. - } - - static void AddMessage(ChatResponse response, bool coalesceContent, KeyValuePair entry) - { - if (entry.Value.Role == default) - { - entry.Value.Role = ChatRole.Assistant; - } - - if (coalesceContent) - { - CoalesceTextContent((List)entry.Value.Contents); - } - - response.Choices.Add(entry.Value); - } - } - - /// Coalesces sequential content elements. - private static void CoalesceTextContent(List contents) - { - StringBuilder? coalescedText = null; - - // Iterate through all of the items in the list looking for contiguous items that can be coalesced. - int start = 0; - while (start < contents.Count - 1) - { - // We need at least two TextContents in a row to be able to coalesce. - if (contents[start] is not TextContent firstText) - { - start++; - continue; - } - - if (contents[start + 1] is not TextContent secondText) - { - start += 2; - continue; - } - - // Append the text from those nodes and continue appending subsequent TextContents until we run out. - // We null out nodes as their text is appended so that we can later remove them all in one O(N) operation. - coalescedText ??= new(); - _ = coalescedText.Clear().Append(firstText.Text).Append(secondText.Text); - contents[start + 1] = null!; - int i = start + 2; - for (; i < contents.Count && contents[i] is TextContent next; i++) - { - _ = coalescedText.Append(next.Text); - contents[i] = null!; - } - - // Store the replacement node. - contents[start] = new TextContent(coalescedText.ToString()) - { - // We inherit the properties of the first text node. We don't currently propagate additional - // properties from the subsequent nodes. If we ever need to, we can add that here. - AdditionalProperties = firstText.AdditionalProperties?.Clone(), - }; - - start = i; - } - - // Remove all of the null slots left over from the coalescing process. - _ = contents.RemoveAll(u => u is null); - } -} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/DelegatingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/DelegatingChatClient.cs index 7882529ac85..23768dd8da7 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/DelegatingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/DelegatingChatClient.cs @@ -13,7 +13,7 @@ namespace Microsoft.Extensions.AI; /// Provides an optional base class for an that passes through calls to another instance. /// /// -/// This is recommended as a base type when building clients that can be chained in any order around an underlying . +/// This is recommended as a base type when building clients that can be chained around an underlying . /// The default implementation simply passes each call to the inner client instance. /// public class DelegatingChatClient : IChatClient @@ -38,16 +38,14 @@ public void Dispose() protected IChatClient InnerClient { get; } /// - public virtual Task GetResponseAsync(IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) - { - return InnerClient.GetResponseAsync(chatMessages, options, cancellationToken); - } + public virtual Task GetResponseAsync( + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) => + InnerClient.GetResponseAsync(messages, options, cancellationToken); /// - public virtual IAsyncEnumerable GetStreamingResponseAsync(IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) - { - return InnerClient.GetStreamingResponseAsync(chatMessages, options, cancellationToken); - } + public virtual IAsyncEnumerable GetStreamingResponseAsync( + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) => + InnerClient.GetStreamingResponseAsync(messages, options, cancellationToken); /// public virtual object? GetService(Type serviceType, object? serviceKey = null) diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/IChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/IChatClient.cs index 26a39f05105..0de18809bbc 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/IChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/IChatClient.cs @@ -13,43 +13,37 @@ namespace Microsoft.Extensions.AI; /// /// Unless otherwise specified, all members of are thread-safe for concurrent use. /// It is expected that all implementations of support being used by multiple requests concurrently. +/// Instances must not be disposed of while the instance is still in use. /// /// /// However, implementations of might mutate the arguments supplied to and -/// , such as by adding additional messages to the messages list or configuring the options -/// instance. Thus, consumers of the interface either should avoid using shared instances of these arguments for concurrent -/// invocations or should otherwise ensure by construction that no instances are used which might employ -/// such mutation. For example, the WithChatOptions method be provided with a callback that could mutate the supplied options -/// argument, and that should be avoided if using a singleton options instance. +/// , such as by configuring the options instance. Thus, consumers of the interface either +/// should avoid using shared instances of these arguments for concurrent invocations or should otherwise ensure by construction +/// that no instances are used which might employ such mutation. For example, the ConfigureOptions method is +/// provided with a callback that could mutate the supplied options argument, and that should be avoided if using a singleton options instance. /// /// public interface IChatClient : IDisposable { /// Sends chat messages and returns the response. - /// The chat content to send. - /// The chat options to configure the request. + /// The sequence of chat messages to send. + /// The chat options with which to configure the request. /// The to monitor for cancellation requests. The default is . /// The response messages generated by the client. - /// - /// The returned messages aren't added to . However, any intermediate messages generated implicitly - /// by the client, including any messages for roundtrips to the model as part of the implementation of this request, are included. - /// + /// is . Task GetResponseAsync( - IList chatMessages, + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default); /// Sends chat messages and streams the response. - /// The chat content to send. - /// The chat options to configure the request. + /// The sequence of chat messages to send. + /// The chat options with which to configure the request. /// The to monitor for cancellation requests. The default is . /// The response messages generated by the client. - /// - /// The returned messages aren't added to . However, any intermediate messages generated implicitly - /// by the client, including any messages for roundtrips to the model as part of the implementation of this request, are included. - /// + /// is . IAsyncEnumerable GetStreamingResponseAsync( - IList chatMessages, + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default); @@ -59,8 +53,9 @@ IAsyncEnumerable GetStreamingResponseAsync( /// The found object, otherwise . /// is . /// - /// The purpose of this method is to allow for the retrieval of strongly typed services that might be provided by the , - /// including itself or any services it might be wrapping. + /// The purpose of this method is to allow for the retrieval of strongly-typed services that might be provided by the , + /// including itself or any services it might be wrapping. For example, to access the for the instance, + /// may be used to request it. /// object? GetService(Type serviceType, object? serviceKey = null); } diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/RequiredChatToolMode.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/RequiredChatToolMode.cs index 74858dfe89b..91397e67602 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/RequiredChatToolMode.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/RequiredChatToolMode.cs @@ -26,6 +26,7 @@ public sealed class RequiredChatToolMode : ChatToolMode /// Initializes a new instance of the class that requires a specific function to be called. /// /// The name of the function that must be called. + /// is empty or composed entirely of whitespace. /// /// can be . However, it's preferable to use /// when any function can be selected. diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/AIContentExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/AIContentExtensions.cs index eb516e2a7c1..550a48ab6de 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/AIContentExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/AIContentExtensions.cs @@ -3,10 +3,11 @@ using System; using System.Collections.Generic; -#if !NET using System.Linq; -#else +#if NET using System.Runtime.CompilerServices; +#else +using System.Text; #endif namespace Microsoft.Extensions.AI; @@ -14,51 +15,102 @@ namespace Microsoft.Extensions.AI; /// Internal extensions for working with . internal static class AIContentExtensions { - /// Finds the first occurrence of a in the list. - public static T? FindFirst(this IList contents) - where T : AIContent + /// Concatenates the text of all instances in the list. + public static string ConcatText(this IEnumerable contents) { - int count = contents.Count; - for (int i = 0; i < count; i++) + if (contents is IList list) { - if (contents[i] is T t) + int count = list.Count; + switch (count) { - return t; + case 0: + return string.Empty; + + case 1: + return (list[0] as TextContent)?.Text ?? string.Empty; + + default: +#if NET + DefaultInterpolatedStringHandler builder = new(count, 0, null, stackalloc char[512]); + for (int i = 0; i < count; i++) + { + if (list[i] is TextContent text) + { + builder.AppendLiteral(text.Text); + } + } + + return builder.ToStringAndClear(); +#else + StringBuilder builder = new(); + for (int i = 0; i < count; i++) + { + if (list[i] is TextContent text) + { + builder.Append(text.Text); + } + } + + return builder.ToString(); +#endif } } - return null; + return string.Concat(contents.OfType()); } - /// Concatenates the text of all instances in the list. - public static string ConcatText(this IList contents) + /// Concatenates the of all instances in the list. + /// A newline separator is added between each non-empty piece of text. + public static string ConcatText(this IList messages) { - int count = contents.Count; + int count = messages.Count; switch (count) { case 0: - break; + return string.Empty; case 1: - return contents[0] is TextContent tc ? tc.Text : string.Empty; + return messages[0].Text; default: #if NET - DefaultInterpolatedStringHandler builder = new(0, 0, null, stackalloc char[512]); + DefaultInterpolatedStringHandler builder = new(count, 0, null, stackalloc char[512]); + bool needsSeparator = false; for (int i = 0; i < count; i++) { - if (contents[i] is TextContent text) + string text = messages[i].Text; + if (text.Length > 0) { - builder.AppendLiteral(text.Text); + if (needsSeparator) + { + builder.AppendLiteral(Environment.NewLine); + } + + builder.AppendLiteral(text); + + needsSeparator = true; } } return builder.ToStringAndClear(); #else - return string.Concat(contents.OfType()); + StringBuilder builder = new(); + for (int i = 0; i < count; i++) + { + string text = messages[i].Text; + if (text.Length > 0) + { + if (builder.Length > 0) + { + builder.AppendLine(); + } + + builder.Append(text); + } + } + + return builder.ToString(); #endif } - - return string.Empty; } } diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/FunctionCallContent.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/FunctionCallContent.cs index 88e0a207127..d19988b2b76 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/FunctionCallContent.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/FunctionCallContent.cs @@ -65,15 +65,19 @@ public FunctionCallContent(string callId, string name, IDictionaryThe function name. /// The parsing implementation converting the encoding to a dictionary of arguments. /// A new instance of containing the parse result. + /// is . + /// is . + /// is . + /// is . public static FunctionCallContent CreateFromParsedArguments( TEncoding encodedArguments, string callId, string name, Func?> argumentParser) { + _ = Throw.IfNull(encodedArguments); _ = Throw.IfNull(callId); _ = Throw.IfNull(name); - _ = Throw.IfNull(encodedArguments); _ = Throw.IfNull(argumentParser); IDictionary? arguments = null; diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/DelegatingEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/DelegatingEmbeddingGenerator.cs index f1a4c3aa7a2..e15c2981613 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/DelegatingEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/DelegatingEmbeddingGenerator.cs @@ -15,7 +15,7 @@ namespace Microsoft.Extensions.AI; /// The type of the input passed to the generator. /// The type of the embedding instance produced by the generator. /// -/// This type is recommended as a base type when building generators that can be chained in any order around an underlying . +/// This type is recommended as a base type when building generators that can be chained around an underlying . /// The default implementation simply passes each call to the inner generator instance. /// public class DelegatingEmbeddingGenerator : IEmbeddingGenerator diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs index d69952598dd..35d8260e406 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs @@ -189,15 +189,21 @@ public static async Task GenerateEmbeddingAsync( if (embeddings is null) { - throw new InvalidOperationException("Embedding generator returned a null collection of embeddings."); + Throw.InvalidOperationException("Embedding generator returned a null collection of embeddings."); } if (embeddings.Count != 1) { - throw new InvalidOperationException($"Expected the number of embeddings ({embeddings.Count}) to match the number of inputs (1)."); + Throw.InvalidOperationException($"Expected the number of embeddings ({embeddings.Count}) to match the number of inputs (1)."); } - return embeddings[0] ?? throw new InvalidOperationException("Embedding generator generated a null embedding."); + TEmbedding embedding = embeddings[0]; + if (embedding is null) + { + Throw.InvalidOperationException("Embedding generator generated a null embedding."); + } + + return embedding; } /// @@ -235,7 +241,7 @@ public static async Task GenerateEmbeddingAsync( var embeddings = await generator.GenerateAsync(values, options, cancellationToken).ConfigureAwait(false); if (embeddings.Count != inputsCount) { - throw new InvalidOperationException($"Expected the number of embeddings ({embeddings.Count}) to match the number of inputs ({inputsCount})."); + Throw.InvalidOperationException($"Expected the number of embeddings ({embeddings.Count}) to match the number of inputs ({inputsCount})."); } var results = new (TInput, TEmbedding)[embeddings.Count]; diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator.cs index c260708079c..59fcc9e2393 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator.cs @@ -15,23 +15,24 @@ namespace Microsoft.Extensions.AI; /// /// Unless otherwise specified, all members of are thread-safe for concurrent use. /// It is expected that all implementations of support being used by multiple requests concurrently. +/// Instances must not be disposed of while the instance is still in use. /// /// /// However, implementations of may mutate the arguments supplied to -/// , such as by adding additional values to the values list or configuring the options -/// instance. Thus, consumers of the interface either should avoid using shared instances of these arguments for concurrent -/// invocations or should otherwise ensure by construction that no instances -/// are used which might employ such mutation. +/// , such as by configuring the options instance. Thus, consumers of the interface either should +/// avoid using shared instances of these arguments for concurrent invocations or should otherwise ensure by construction that +/// no instances are used which might employ such mutation. /// /// public interface IEmbeddingGenerator : IDisposable where TEmbedding : Embedding { /// Generates embeddings for each of the supplied . - /// The collection of values for which to generate embeddings. - /// The embedding generation options to configure the request. + /// The sequence of values for which to generate embeddings. + /// The embedding generation options with which to configure the request. /// The to monitor for cancellation requests. The default is . /// The generated embeddings. + /// is . Task> GenerateAsync( IEnumerable values, EmbeddingGenerationOptions? options = null, @@ -45,6 +46,8 @@ Task> GenerateAsync( /// /// The purpose of this method is to allow for the retrieval of strongly typed services that might be provided by the /// , including itself or any services it might be wrapping. + /// For example, to access the for the instance, may + /// be used to request it. /// object? GetService(Type serviceType, object? serviceKey = null); } diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/README.md b/src/Libraries/Microsoft.Extensions.AI.Abstractions/README.md index b8a6cba944f..0d94cacc925 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/README.md +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/README.md @@ -10,7 +10,7 @@ From the command-line: dotnet add package Microsoft.Extensions.AI.Abstractions ``` -Or directly in the C# project file: +or directly in the C# project file: ```xml @@ -27,105 +27,24 @@ of the abstractions. ### `IChatClient` -The `IChatClient` interface defines a client abstraction responsible for interacting with AI services that provide "chat" capabilities. It defines methods for sending and receiving messages comprised of multi-modal content (text, images, audio, etc.), with responses being either as a complete result or streamed incrementally. Additionally, it allows for retrieving strongly-typed services that may be provided by the client or its underlying services. - -#### Sample Implementation +The `IChatClient` interface defines a client abstraction responsible for interacting with AI services that provide "chat" capabilities. It defines methods for sending and receiving messages comprised of multi-modal content (text, images, audio, etc.), with responses providing either a complete result or one streamed incrementally. Additionally, it allows for retrieving strongly-typed services that may be provided by the client or its underlying services. .NET libraries that provide clients for language models and services may provide an implementation of the `IChatClient` interface. Any consumers of the interface are then able to interoperate seamlessly with these models and services via the abstractions. -Here is a sample implementation of an `IChatClient` to show the general structure. - -```csharp -using System.Runtime.CompilerServices; -using Microsoft.Extensions.AI; - -public class SampleChatClient : IChatClient -{ - private readonly ChatClientMetadata _metadata; - - public SampleChatClient(Uri endpoint, string modelId) => - _metadata = new("SampleChatClient", endpoint, modelId); - - public async Task GetResponseAsync( - IList chatMessages, - ChatOptions? options = null, - CancellationToken cancellationToken = default) - { - // Simulate some operation. - await Task.Delay(300, cancellationToken); - - // Return a sample chat response randomly. - string[] responses = - [ - "This is the first sample response.", - "Here is another example of a response message.", - "This is yet another response message." - ]; - - return new(new ChatMessage() - { - Role = ChatRole.Assistant, - Text = responses[Random.Shared.Next(responses.Length)], - }); - } - - public async IAsyncEnumerable GetStreamingResponseAsync( - IList chatMessages, - ChatOptions? options = null, - [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - // Simulate streaming by yielding messages one by one. - string[] words = ["This ", "is ", "the ", "response ", "for ", "the ", "request."]; - foreach (string word in words) - { - // Simulate some operation. - await Task.Delay(100, cancellationToken); - - // Yield the next message in the response. - yield return new ChatResponseUpdate - { - Role = ChatRole.Assistant, - Text = word, - }; - } - } - - object? IChatClient.GetService(Type serviceType, object? serviceKey = null) => - serviceKey is not null ? null : - serviceType == typeof(ChatClientMetadata) ? _metadata : - serviceType?.IsInstanceOfType(this) is true ? this : - null; - - void IDisposable.Dispose() { } -} -``` - -As further examples, you can find other concrete implementations in the following packages (but many more such implementations for a large variety of services are available on NuGet): - -- [Microsoft.Extensions.AI.AzureAIInference](https://aka.ms/meai-azaiinference-nuget) -- [Microsoft.Extensions.AI.OpenAI](https://aka.ms/meai-openai-nuget) -- [Microsoft.Extensions.AI.Ollama](https://aka.ms/meai-ollama-nuget) - #### Requesting a Chat Response: `GetResponseAsync` With an instance of `IChatClient`, the `GetResponseAsync` method may be used to send a request and get a response. The request is composed of one or more messages, each of which is composed of one or more pieces of content. Accelerator methods exist to simplify common cases, such as constructing a request for a single piece of text content. ```csharp -using Microsoft.Extensions.AI; - -IChatClient client = new SampleChatClient(new Uri("http://coolsite.ai"), "my-custom-model"); - -var response = await client.GetResponseAsync("What is AI?"); +IChatClient client = ...; -Console.WriteLine(response.Message); +Console.WriteLine(await client.GetResponseAsync("What is AI?")); ``` -The core `GetResponseAsync` method on the `IChatClient` interface accepts a list of messages. This list represents the history of all messages that are part of the conversation. +The core `GetResponseAsync` method on the `IChatClient` interface accepts a list of messages. This list often represents the history of all messages that are part of the conversation. ```csharp -using Microsoft.Extensions.AI; - -IChatClient client = new SampleChatClient(new Uri("http://coolsite.ai"), "my-custom-model"); +IChatClient client = ...; Console.WriteLine(await client.GetResponseAsync( [ @@ -134,7 +53,10 @@ Console.WriteLine(await client.GetResponseAsync( ])); ``` -The `ChatResponse` that's returned from `GetResponseAsync` exposes a `ChatMessage` representing the response message. It may be added back into the history in order to provide this response back to the service in a subsequent request, e.g. +The `ChatResponse` that's returned from `GetResponseAsync` exposes a list of `ChatMessage` instances representing one or more messages generated as part of the operation. +In common cases, there is only one response message, but a variety of situations can result in their being multiple; the list is ordered, such that the last message in +the list represents the final message to the request. In order to provide all of those response messages back to the service in a subsequent request, the messages from +the response may be added back into the messages list. ```csharp List history = []; @@ -143,21 +65,20 @@ while (true) Console.Write("Q: "); history.Add(new(ChatRole.User, Console.ReadLine())); - ChatResponse response = await client.GetResponseAsync(history); - + var response = await client.GetResponseAsync(history); Console.WriteLine(response); - history.Add(response.Message); + + history.AddMessages(response); } ``` #### Requesting a Streaming Chat Response: `GetStreamingResponseAsync` -The inputs to `GetStreamingResponseAsync` are identical to those of `GetResponseAsync`. However, rather than returning the complete response as part of a `ChatResponse` object, the method returns an `IAsyncEnumerable`, providing a stream of updates that together form the single response. +The inputs to `GetStreamingResponseAsync` are identical to those of `GetResponseAsync`. However, rather than returning the complete response as part of a +`ChatResponse` object, the method returns an `IAsyncEnumerable`, providing a stream of updates that together form the single response. ```csharp -using Microsoft.Extensions.AI; - -IChatClient client = new SampleChatClient(new Uri("http://coolsite.ai"), "my-custom-model"); +IChatClient client = ...; await foreach (var update in client.GetStreamingResponseAsync("What is AI?")) { @@ -165,46 +86,45 @@ await foreach (var update in client.GetStreamingResponseAsync("What is AI?")) } ``` -Such a stream of response updates may be combined into a single response object via the `ToChatResponse` and `ToChatResponseAsync` helper methods, e.g. +As with `GetResponseAsync`, the updates from `IChatClient.GetStreamingResponseAsync` can be added back into the messages list. As the updates provided +are individual pieces of a response, helpers like `ToChatResponse` can be used to compose one or more updates back into a single `ChatResponse` instance. +Further helpers like `AddMessages` perform that same operation and then extract the composed messages from the response and add them into a list. ```csharp List history = []; -List updates = []; while (true) { Console.Write("Q: "); history.Add(new(ChatRole.User, Console.ReadLine())); - updates.Clear(); + List updates = []; await foreach (var update in client.GetStreamingResponseAsync(history)) { Console.Write(update); - updates.Add(update); } + Console.WriteLine(); - history.Add(updates.ToChatResponse().Message)); + history.AddMessages(updates); } ``` #### Tool Calling -Some models and services support the notion of tool calling, where requests may include information about tools (in particular .NET methods) that the model may request be invoked in order to gather additional information. Rather than sending back a response message that represents the final response to the input, the model sends back a request to invoke a given function with a given set of arguments; the client may then find and invoke the relevant function and send back the results to the model (along with all the rest of the history). The abstractions in Microsoft.Extensions.AI include representations for various forms of content that may be included in messages, and this includes representations for these function call requests and results. While it's possible for the consumer of the `IChatClient` to interact with this content directly, `Microsoft.Extensions.AI` supports automating these interactions. It provides an `AIFunction` that represents an invocable function along with metadata for describing the function to the AI model, along with an `AIFunctionFactory` for creating `AIFunction`s to represent .NET methods. It also provides a `FunctionInvokingChatClient` that both is an `IChatClient` and also wraps an `IChatClient`, enabling layering automatic function invocation capabilities around an arbitrary `IChatClient` implementation. +Some models and services support the notion of tool calling, where requests may include information about tools (in particular .NET methods) that the model may request be invoked in order to gather additional information. Rather than sending back a response message that represents the final response to the input, the model sends back a request to invoke a given function with a given set of arguments; the client may then find and invoke the relevant function and send back the results to the model (along with all the rest of the history). The abstractions in `Microsoft.Extensions.AI` include representations for various forms of content that may be included in messages, and this includes representations for these function call requests and results. While it's possible for the consumer of the `IChatClient` to interact with this content directly, `Microsoft.Extensions.AI` supports automating these interactions. It provides an `AIFunction` that represents an invocable function along with metadata for describing the function to the AI model, as well as an `AIFunctionFactory` for creating `AIFunction`s to represent .NET methods. It also provides a `FunctionInvokingChatClient` that both is an `IChatClient` and also wraps an `IChatClient`, enabling layering automatic function invocation capabilities around an arbitrary `IChatClient` implementation. ```csharp -using System.ComponentModel; using Microsoft.Extensions.AI; -[Description("Gets the current weather")] string GetCurrentWeather() => Random.Shared.NextDouble() > 0.5 ? "It's sunny" : "It's raining"; -IChatClient client = new ChatClientBuilder(new OllamaChatClient(new Uri("http://localhost:11434"), "llama3.1")) +IChatClient client = new OllamaChatClient(new Uri("http://localhost:11434"), "llama3.1") + .AsBuilder() .UseFunctionInvocation() .Build(); -var response = client.GetStreamingResponseAsync( - "Should I wear a rain coat?", - new() { Tools = [AIFunctionFactory.Create(GetCurrentWeather)] }); +ChatOptions options = new() { Tools = [AIFunctionFactory.Create(GetCurrentWeather)] }; +var response = client.GetStreamingResponseAsync("Should I wear a rain coat?", options); await foreach (var update in response) { Console.Write(update); @@ -221,7 +141,7 @@ using Microsoft.Extensions.Caching.Distributed; using Microsoft.Extensions.Caching.Memory; using Microsoft.Extensions.Options; -IChatClient client = new ChatClientBuilder(new SampleChatClient(new Uri("http://coolsite.ai"), "my-custom-model")) +IChatClient client = new ChatClientBuilder(new OllamaChatClient(new Uri("http://localhost:11434"), "llama3.1")) .UseDistributedCache(new MemoryDistributedCache(Options.Create(new MemoryDistributedCacheOptions()))) .Build(); @@ -252,11 +172,11 @@ var tracerProvider = OpenTelemetry.Sdk.CreateTracerProviderBuilder() .AddConsoleExporter() .Build(); -IChatClient client = new ChatClientBuilder(new SampleChatClient(new Uri("http://coolsite.ai"), "my-custom-model")) +IChatClient client = new ChatClientBuilder(new OllamaChatClient(new Uri("http://localhost:11434"), "llama3.1")) .UseOpenTelemetry(sourceName: sourceName, configure: c => c.EnableSensitiveData = true) .Build(); -Console.WriteLine((await client.GetResponseAsync("What is AI?")).Message); +Console.WriteLine(await client.GetResponseAsync("What is AI?")); ``` Alternatively, the `LoggingChatClient` and corresponding `UseLogging` method provide a simple way to write log entries to an `ILogger` for every request and response. @@ -269,7 +189,8 @@ Options may also be baked into an `IChatClient` via the `ConfigureOptions` exten ```csharp using Microsoft.Extensions.AI; -IChatClient client = new ChatClientBuilder(new OllamaChatClient(new Uri("http://localhost:11434"))) +IChatClient client = new OllamaChatClient(new Uri("http://localhost:11434")) + .AsBuilder() .ConfigureOptions(options => options.ModelId ??= "phi3") .Build(); @@ -335,23 +256,23 @@ using System.Threading.RateLimiting; public sealed class RateLimitingChatClient(IChatClient innerClient, RateLimiter rateLimiter) : DelegatingChatClient(innerClient) { public override async Task GetResponseAsync( - IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) { using var lease = await rateLimiter.AcquireAsync(permitCount: 1, cancellationToken).ConfigureAwait(false); if (!lease.IsAcquired) throw new InvalidOperationException("Unable to acquire lease."); - return await base.GetResponseAsync(chatMessages, options, cancellationToken).ConfigureAwait(false); + return await base.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false); } public override async IAsyncEnumerable GetStreamingResponseAsync( - IList chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + IEnumerable messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { using var lease = await rateLimiter.AcquireAsync(permitCount: 1, cancellationToken).ConfigureAwait(false); if (!lease.IsAcquired) throw new InvalidOperationException("Unable to acquire lease."); - await foreach (var update in base.GetStreamingResponseAsync(chatMessages, options, cancellationToken).ConfigureAwait(false)) + await foreach (var update in base.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false)) yield return update; } @@ -372,10 +293,10 @@ using Microsoft.Extensions.AI; using System.Threading.RateLimiting; var client = new RateLimitingChatClient( - new SampleChatClient(new Uri("http://localhost"), "test"), + new OllamaChatClient(new Uri("http://localhost:11434"), "llama3.1"), new ConcurrencyLimiter(new() { PermitLimit = 1, QueueLimit = int.MaxValue })); -await client.GetResponseAsync("What color is the sky?"); +Console.WriteLine(await client.GetResponseAsync("What color is the sky?")); ``` To make it easier to compose such components with others, the author of the component is recommended to create a "Use" extension method for registering this component into a pipeline, e.g. @@ -398,7 +319,7 @@ public static class RateLimitingChatClientExtensions The consumer can then easily use this in their pipeline, e.g. ```csharp -var client = new SampleChatClient(new Uri("http://localhost"), "test") +var client = new OllamaChatClient(new Uri("http://localhost:11434"), "llama3.1") .AsBuilder() .UseDistributedCache() .UseRateLimiting() @@ -412,16 +333,16 @@ need to do work before and after delegating to the next client in the pipeline. be used that accepts a delegate which is used for both `GetResponseAsync` and `GetStreamingResponseAsync`, reducing the boilderplate required: ```csharp RateLimiter rateLimiter = ...; -var client = new SampleChatClient(new Uri("http://localhost"), "test") +var client = new OllamaChatClient(new Uri("http://localhost:11434"), "llama3.1") .AsBuilder() .UseDistributedCache() - .Use(async (chatMessages, options, nextAsync, cancellationToken) => + .Use(async (messages, options, nextAsync, cancellationToken) => { using var lease = await rateLimiter.AcquireAsync(permitCount: 1, cancellationToken).ConfigureAwait(false); if (!lease.IsAcquired) throw new InvalidOperationException("Unable to acquire lease."); - await nextAsync(chatMessages, options, cancellationToken); + await nextAsync(messages, options, cancellationToken); }) .UseOpenTelemetry() .Build(); @@ -443,7 +364,7 @@ using Microsoft.Extensions.Hosting; // App Setup var builder = Host.CreateApplicationBuilder(); builder.Services.AddDistributedMemoryCache(); -builder.Services.AddChatClient(new SampleChatClient(new Uri("http://coolsite.ai"), "my-custom-model")) +builder.Services.AddChatClient(new OllamaChatClient(new Uri("http://localhost:11434"), "llama3.1")) .UseDistributedCache(); var host = builder.Build(); @@ -459,7 +380,8 @@ What instance and configuration is injected may differ based on the current need "Stateless" services require all relevant conversation history to sent back on every request, while "stateful" services keep track of the history and instead require only additional messages be sent with a request. The `IChatClient` interface is designed to handle both stateless and stateful AI services. -If you know you're working with a stateless service (currently the most common form), responses may be added back into a message history for sending back to the server. +When working with a stateless service, callers maintain a list of all messages, adding in all received response messages, and providing the list +back on subsequent interactions. ```csharp List history = []; while (true) @@ -467,23 +389,23 @@ while (true) Console.Write("Q: "); history.Add(new(ChatRole.User, Console.ReadLine())); - ChatResponse response = await client.GetResponseAsync(history); - + var response = await client.GetResponseAsync(history); Console.WriteLine(response); - history.Add(response.Message); + + history.AddMessages(response); } ``` -For stateful services, you may know ahead of time an identifier used for the relevant conversation. That identifier can be put into `ChatOptions.ChatThreadId`: +For stateful services, you may know ahead of time an identifier used for the relevant conversation. That identifier can be put into `ChatOptions.ChatThreadId`. +Usage then follows the same pattern, except there's no need to maintain a history manually. ```csharp ChatOptions options = new() { ChatThreadId = "my-conversation-id" }; while (true) { Console.Write("Q: "); + ChatMessage message = new(ChatRole.User, Console.ReadLine()); - ChatResponse response = await client.GetResponseAsync(Console.ReadLine(), options); - - Console.WriteLine(response); + Console.WriteLine(await client.GetResponseAsync(message, options)); } ``` @@ -494,10 +416,11 @@ ChatOptions options = new(); while (true) { Console.Write("Q: "); + ChatMessage message = new(ChatRole.User, Console.ReadLine()); - ChatResponse response = await client.GetResponseAsync(Console.ReadLine(), options); - + ChatResponse response = await client.GetResponseAsync(message, options); Console.WriteLine(response); + options.ChatThreadId = response.ChatThreadId; } ``` @@ -515,8 +438,8 @@ while (true) history.Add(new(ChatRole.User, Console.ReadLine())); ChatResponse response = await client.GetResponseAsync(history); - Console.WriteLine(response); + options.ChatThreadId = response.ChatThreadId; if (response.ChatThreadId is not null) { @@ -524,7 +447,7 @@ while (true) } else { - history.Add(response.Message); + history.AddMessages(response); } } ``` diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/UsageDetails.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/UsageDetails.cs index 7d4e7ddbea2..b3c62cb67e0 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/UsageDetails.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/UsageDetails.cs @@ -29,9 +29,12 @@ public class UsageDetails public AdditionalPropertiesDictionary? AdditionalCounts { get; set; } /// Adds usage data from another into this instance. + /// The source with which to augment this instance. + /// is . public void Add(UsageDetails usage) { _ = Throw.IfNull(usage); + InputTokenCount = NullableSum(InputTokenCount, usage.InputTokenCount); OutputTokenCount = NullableSum(OutputTokenCount, usage.OutputTokenCount); TotalTokenCount = NullableSum(TotalTokenCount, usage.TotalTokenCount); diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.cs index e8a962a0be8..1d0224f7c9f 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.cs @@ -55,6 +55,7 @@ public static partial class AIJsonUtilities /// The options used to extract the schema from the specified type. /// The options controlling schema inference. /// A JSON schema document encoded as a . + /// is . public static JsonElement CreateFunctionJsonSchema( MethodBase method, string? title = null, @@ -63,6 +64,7 @@ public static JsonElement CreateFunctionJsonSchema( AIJsonSchemaCreateOptions? inferenceOptions = null) { _ = Throw.IfNull(method); + serializerOptions ??= DefaultOptions; inferenceOptions ??= AIJsonSchemaCreateOptions.Default; title ??= method.Name; diff --git a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceChatClient.cs index 2f527612fab..db03a62f2a9 100644 --- a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceChatClient.cs @@ -39,9 +39,12 @@ public sealed class AzureAIInferenceChatClient : IChatClient /// Initializes a new instance of the class for the specified . /// The underlying client. /// The ID of the model to use. If null, it can be provided per request via . + /// is . + /// is empty or composed entirely of whitespace. public AzureAIInferenceChatClient(ChatCompletionsClient chatCompletionsClient, string? modelId = null) { _ = Throw.IfNull(chatCompletionsClient); + if (modelId is not null) { _ = Throw.IfNullOrWhitespace(modelId); @@ -81,30 +84,21 @@ public JsonSerializerOptions ToolCallJsonSerializerOptions /// public async Task GetResponseAsync( - IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) { - _ = Throw.IfNull(chatMessages); + _ = Throw.IfNull(messages); // Make the call. ChatCompletions response = (await _chatCompletionsClient.CompleteAsync( - ToAzureAIOptions(chatMessages, options), + ToAzureAIOptions(messages, options), cancellationToken: cancellationToken).ConfigureAwait(false)).Value; // Create the return message. - List returnMessages = []; - - // Populate its content from those in the response content. - ChatMessage message = new() + ChatMessage message = new(ToChatRole(response.Role), response.Content) { RawRepresentation = response, - Role = ToChatRole(response.Role), }; - if (response.Content is string content) - { - message.Text = content; - } - if (response.ToolCalls is { Count: > 0 } toolCalls) { foreach (var toolCall in toolCalls) @@ -119,8 +113,6 @@ public async Task GetResponseAsync( } } - returnMessages.Add(message); - UsageDetails? usage = null; if (response.Usage is CompletionsUsage completionsUsage) { @@ -133,7 +125,7 @@ public async Task GetResponseAsync( } // Wrap the content in a ChatResponse to return. - return new ChatResponse(returnMessages) + return new ChatResponse(message) { CreatedAt = response.Created, ModelId = response.Model, @@ -146,9 +138,9 @@ public async Task GetResponseAsync( /// public async IAsyncEnumerable GetStreamingResponseAsync( - IList chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + IEnumerable messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - _ = Throw.IfNull(chatMessages); + _ = Throw.IfNull(messages); Dictionary? functionCallInfos = null; ChatRole? streamedRole = default; @@ -159,7 +151,7 @@ public async IAsyncEnumerable GetStreamingResponseAsync( string lastCallId = string.Empty; // Process each update as it arrives - var updates = await _chatCompletionsClient.CompleteStreamingAsync(ToAzureAIOptions(chatMessages, options), cancellationToken).ConfigureAwait(false); + var updates = await _chatCompletionsClient.CompleteStreamingAsync(ToAzureAIOptions(messages, options), cancellationToken).ConfigureAwait(false); await foreach (StreamingChatCompletionsUpdate chatCompletionUpdate in updates.ConfigureAwait(false)) { // The role and finish reason may arrive during any update, but once they've arrived, the same value should be the same for all subsequent updates. @@ -289,7 +281,7 @@ private static ChatRole ToChatRole(global::Azure.AI.Inference.ChatRole role) => new(s); /// Converts an extensions options instance to an AzureAI options instance. - private ChatCompletionsOptions ToAzureAIOptions(IList chatContents, ChatOptions? options) + private ChatCompletionsOptions ToAzureAIOptions(IEnumerable chatContents, ChatOptions? options) { ChatCompletionsOptions result = new(ToAzureAIInferenceChatMessages(chatContents)) { @@ -417,7 +409,7 @@ private static ChatCompletionsToolDefinition ToAzureAIChatTool(AIFunction aiFunc } /// Converts an Extensions chat message enumerable to an AzureAI chat message enumerable. - private IEnumerable ToAzureAIInferenceChatMessages(IList inputs) + private IEnumerable ToAzureAIInferenceChatMessages(IEnumerable inputs) { // Maps all of the M.E.AI types to the corresponding AzureAI types. // Unrecognized or non-processable content is ignored. diff --git a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceEmbeddingGenerator.cs index 17bd4fa4662..c0f4b2f4636 100644 --- a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceEmbeddingGenerator.cs @@ -41,6 +41,9 @@ public sealed class AzureAIInferenceEmbeddingGenerator : /// Either this parameter or must provide a valid model ID. /// /// The number of dimensions to generate in each embedding. + /// is . + /// is empty or composed entirely of whitespace. + /// is not positive. public AzureAIInferenceEmbeddingGenerator( EmbeddingsClient embeddingsClient, string? modelId = null, int? dimensions = null) { diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/ChatConversationEvaluator.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/ChatConversationEvaluator.cs index 4113de75568..cbc904277ab 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/ChatConversationEvaluator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/ChatConversationEvaluator.cs @@ -1,8 +1,10 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System; using System.Collections.Generic; using System.Linq; +using System.Text; using System.Threading; using System.Threading.Tasks; using Microsoft.Shared.Diagnostics; @@ -35,13 +37,13 @@ public abstract class ChatConversationEvaluator : IEvaluator /// public async ValueTask EvaluateAsync( IEnumerable messages, - ChatMessage modelResponse, + ChatResponse modelResponse, ChatConfiguration? chatConfiguration = null, IEnumerable? additionalContext = null, CancellationToken cancellationToken = default) { - _ = Throw.IfNull(modelResponse, nameof(modelResponse)); - _ = Throw.IfNull(chatConfiguration, nameof(chatConfiguration)); + _ = Throw.IfNull(modelResponse); + _ = Throw.IfNull(chatConfiguration); EvaluationResult result = InitializeResult(); @@ -211,8 +213,8 @@ await PerformEvaluationAsync( ChatConfiguration chatConfiguration, CancellationToken cancellationToken) { - _ = Throw.IfNull(message, nameof(message)); - _ = Throw.IfNull(chatConfiguration, nameof(chatConfiguration)); + _ = Throw.IfNull(message); + _ = Throw.IfNull(chatConfiguration); IEvaluationTokenCounter? tokenCounter = chatConfiguration.TokenCounter; if (tokenCounter is null) @@ -249,12 +251,41 @@ await PerformEvaluationAsync( } } + /// + /// Renders the supplied to a string that can be included as part of the evaluation + /// prompt that this uses. + /// + /// + /// Chat response being evaluated and that is to be rendered as part of the evaluation prompt. + /// + /// A that can cancel the operation. + /// + /// A string representation of the supplied that can be included as part of the + /// evaluation prompt. + /// + /// + /// The default implementation uses to render + /// each message in the response. + /// + protected virtual async ValueTask RenderAsync(ChatResponse response, CancellationToken cancellationToken) + { + _ = Throw.IfNull(response); + + StringBuilder sb = new(); + foreach (ChatMessage message in response.Messages) + { + _ = sb.Append(await RenderAsync(message, cancellationToken).ConfigureAwait(false)); + } + + return sb.ToString(); + } + /// /// Renders the supplied to a string that can be included as part of the evaluation /// prompt that this uses. /// /// - /// A message that is part of the conversation history for the response being evaluated and that is to be rendered + /// Message that is part of the conversation history for the response being evaluated and that is to be rendered /// as part of the evaluation prompt. /// /// A that can cancel the operation. @@ -264,7 +295,7 @@ await PerformEvaluationAsync( /// protected virtual ValueTask RenderAsync(ChatMessage message, CancellationToken cancellationToken) { - _ = Throw.IfNull(message, nameof(message)); + _ = Throw.IfNull(message); string? author = message.AuthorName; string role = message.Role.Value; @@ -296,7 +327,7 @@ protected virtual ValueTask RenderAsync(ChatMessage message, Cancellatio /// The evaluation prompt. protected abstract ValueTask RenderEvaluationPromptAsync( ChatMessage? userRequest, - ChatMessage modelResponse, + ChatResponse modelResponse, IEnumerable? includedHistory, IEnumerable? additionalContext, CancellationToken cancellationToken); diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/CoherenceEvaluator.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/CoherenceEvaluator.cs index 8c31feb2dde..4122a063cf4 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/CoherenceEvaluator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/CoherenceEvaluator.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; +using Microsoft.Shared.Diagnostics; namespace Microsoft.Extensions.AI.Evaluation.Quality; @@ -31,11 +32,13 @@ public sealed class CoherenceEvaluator : SingleNumericMetricEvaluator /// protected override async ValueTask RenderEvaluationPromptAsync( ChatMessage? userRequest, - ChatMessage modelResponse, + ChatResponse modelResponse, IEnumerable? includedHistory, IEnumerable? additionalContext, CancellationToken cancellationToken) { + _ = Throw.IfNull(modelResponse); + string renderedModelResponse = await RenderAsync(modelResponse, cancellationToken).ConfigureAwait(false); string renderedUserRequest = diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/EquivalenceEvaluator.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/EquivalenceEvaluator.cs index ed482688e0c..5926d260374 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/EquivalenceEvaluator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/EquivalenceEvaluator.cs @@ -6,6 +6,7 @@ using System.Linq; using System.Threading; using System.Threading.Tasks; +using Microsoft.Shared.Diagnostics; namespace Microsoft.Extensions.AI.Evaluation.Quality; @@ -35,11 +36,13 @@ public sealed class EquivalenceEvaluator : SingleNumericMetricEvaluator /// protected override async ValueTask RenderEvaluationPromptAsync( ChatMessage? userRequest, - ChatMessage modelResponse, + ChatResponse modelResponse, IEnumerable? includedHistory, IEnumerable? additionalContext, CancellationToken cancellationToken) { + _ = Throw.IfNull(modelResponse); + string renderedModelResponse = await RenderAsync(modelResponse, cancellationToken).ConfigureAwait(false); string renderedUserRequest = diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/FluencyEvaluator.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/FluencyEvaluator.cs index 8c11cf0f0c0..d08a30a31b2 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/FluencyEvaluator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/FluencyEvaluator.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; +using Microsoft.Shared.Diagnostics; namespace Microsoft.Extensions.AI.Evaluation.Quality; @@ -31,11 +32,13 @@ public sealed class FluencyEvaluator : SingleNumericMetricEvaluator /// protected override async ValueTask RenderEvaluationPromptAsync( ChatMessage? userRequest, - ChatMessage modelResponse, + ChatResponse modelResponse, IEnumerable? includedHistory, IEnumerable? additionalContext, CancellationToken cancellationToken) { + _ = Throw.IfNull(modelResponse); + string renderedModelResponse = await RenderAsync(modelResponse, cancellationToken).ConfigureAwait(false); string renderedUserRequest = diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/GroundednessEvaluator.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/GroundednessEvaluator.cs index ddb3d522a44..cbb66657a68 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/GroundednessEvaluator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/GroundednessEvaluator.cs @@ -6,6 +6,7 @@ using System.Text; using System.Threading; using System.Threading.Tasks; +using Microsoft.Shared.Diagnostics; namespace Microsoft.Extensions.AI.Evaluation.Quality; @@ -35,11 +36,13 @@ public sealed class GroundednessEvaluator : SingleNumericMetricEvaluator /// protected override async ValueTask RenderEvaluationPromptAsync( ChatMessage? userRequest, - ChatMessage modelResponse, + ChatResponse modelResponse, IEnumerable? includedHistory, IEnumerable? additionalContext, CancellationToken cancellationToken) { + _ = Throw.IfNull(modelResponse); + string renderedModelResponse = await RenderAsync(modelResponse, cancellationToken).ConfigureAwait(false); string renderedUserRequest = diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/RelevanceTruthAndCompletenessEvaluator.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/RelevanceTruthAndCompletenessEvaluator.cs index 4fdccf03be9..419feb45743 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/RelevanceTruthAndCompletenessEvaluator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/RelevanceTruthAndCompletenessEvaluator.cs @@ -12,6 +12,7 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI.Evaluation.Quality.Utilities; +using Microsoft.Shared.Diagnostics; namespace Microsoft.Extensions.AI.Evaluation.Quality; @@ -75,11 +76,13 @@ protected override EvaluationResult InitializeResult() /// protected override async ValueTask RenderEvaluationPromptAsync( ChatMessage? userRequest, - ChatMessage modelResponse, + ChatResponse modelResponse, IEnumerable? includedHistory, IEnumerable? additionalContext, CancellationToken cancellationToken) { + _ = Throw.IfNull(modelResponse); + string renderedModelResponse = await RenderAsync(modelResponse, cancellationToken).ConfigureAwait(false); string renderedUserRequest = @@ -125,10 +128,10 @@ await chatConfiguration.ChatClient.GetResponseAsync( _chatOptions, cancellationToken: cancellationToken).ConfigureAwait(false); - string? evaluationResponseText = evaluationResponse.Message.Text?.Trim(); + string evaluationResponseText = evaluationResponse.Text.Trim(); Rating rating; - if (string.IsNullOrWhiteSpace(evaluationResponseText)) + if (string.IsNullOrEmpty(evaluationResponseText)) { rating = Rating.Inconclusive; result.AddDiagnosticToAllMetrics( @@ -145,13 +148,13 @@ await chatConfiguration.ChatClient.GetResponseAsync( { try { - string? repairedJson = + string repairedJson = await JsonOutputFixer.RepairJsonAsync( chatConfiguration, evaluationResponseText!, cancellationToken).ConfigureAwait(false); - if (string.IsNullOrWhiteSpace(repairedJson)) + if (string.IsNullOrEmpty(repairedJson)) { rating = Rating.Inconclusive; result.AddDiagnosticToAllMetrics( diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/SingleNumericMetricEvaluator.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/SingleNumericMetricEvaluator.cs index 8b9367dbf32..437dde3eb1e 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/SingleNumericMetricEvaluator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/SingleNumericMetricEvaluator.cs @@ -62,8 +62,8 @@ protected sealed override async ValueTask PerformEvaluationAsync( EvaluationResult result, CancellationToken cancellationToken) { - _ = Throw.IfNull(chatConfiguration, nameof(chatConfiguration)); - _ = Throw.IfNull(result, nameof(result)); + _ = Throw.IfNull(chatConfiguration); + _ = Throw.IfNull(result); ChatResponse evaluationResponse = await chatConfiguration.ChatClient.GetResponseAsync( @@ -71,11 +71,11 @@ await chatConfiguration.ChatClient.GetResponseAsync( _chatOptions, cancellationToken: cancellationToken).ConfigureAwait(false); - string? evaluationResponseText = evaluationResponse.Message.Text?.Trim(); + string evaluationResponseText = evaluationResponse.Text.Trim(); NumericMetric metric = result.Get(MetricName); - if (string.IsNullOrWhiteSpace(evaluationResponseText)) + if (string.IsNullOrEmpty(evaluationResponseText)) { metric.AddDiagnostic( EvaluationDiagnostic.Error( diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/Utilities/JsonOutputFixer.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/Utilities/JsonOutputFixer.cs index e6b10dedb84..b50d69bcebd 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/Utilities/JsonOutputFixer.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/Utilities/JsonOutputFixer.cs @@ -32,7 +32,7 @@ internal static ReadOnlySpan TrimMarkdownDelimiters(string json) return trimmed; } - internal static async ValueTask RepairJsonAsync( + internal static async ValueTask RepairJsonAsync( ChatConfiguration chatConfig, string json, CancellationToken cancellationToken) @@ -74,6 +74,6 @@ await chatConfig.ChatClient.GetResponseAsync( chatOptions, cancellationToken: cancellationToken).ConfigureAwait(false); - return response.Message.Text?.Trim(); + return response.Text.Trim(); } } diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting.Azure/Storage/AzureStorageResultStore.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting.Azure/Storage/AzureStorageResultStore.cs index 1ff06c07467..fe1a3d91a9c 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting.Azure/Storage/AzureStorageResultStore.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting.Azure/Storage/AzureStorageResultStore.cs @@ -156,7 +156,7 @@ public async ValueTask WriteResultsAsync( IEnumerable results, CancellationToken cancellationToken = default) { - _ = Throw.IfNull(results, nameof(results)); + _ = Throw.IfNull(results); foreach (ScenarioRunResult result in results) { diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/ScenarioRun.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/ScenarioRun.cs index 8dc189767f2..3793afff05b 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/ScenarioRun.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/ScenarioRun.cs @@ -126,7 +126,7 @@ internal ScenarioRun( /// An containing one or more s. public async ValueTask EvaluateAsync( IEnumerable messages, - ChatMessage modelResponse, + ChatResponse modelResponse, IEnumerable? additionalContext = null, CancellationToken cancellationToken = default) { diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/ScenarioRunExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/ScenarioRunExtensions.cs index 3c9a8fd5d44..3b723a2d258 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/ScenarioRunExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/ScenarioRunExtensions.cs @@ -87,7 +87,36 @@ public static ValueTask EvaluateAsync( IEnumerable? additionalContext = null, CancellationToken cancellationToken = default) { - _ = Throw.IfNull(scenarioRun, nameof(scenarioRun)); + _ = Throw.IfNull(scenarioRun); + + return scenarioRun.EvaluateAsync( + messages: [], + new ChatResponse(modelResponse), + additionalContext, + cancellationToken); + } + + /// + /// Evaluates the supplied and returns an + /// containing one or more s. + /// + /// The of which this evaluation is a part. + /// The response that is to be evaluated. + /// + /// Additional contextual information that the s included in this + /// may need to accurately evaluate the supplied . + /// + /// + /// A that can cancel the evaluation operation. + /// + /// An containing one or more s. + public static ValueTask EvaluateAsync( + this ScenarioRun scenarioRun, + ChatResponse modelResponse, + IEnumerable? additionalContext = null, + CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(scenarioRun); return scenarioRun.EvaluateAsync( messages: [], @@ -121,7 +150,41 @@ public static ValueTask EvaluateAsync( IEnumerable? additionalContext = null, CancellationToken cancellationToken = default) { - _ = Throw.IfNull(scenarioRun, nameof(scenarioRun)); + _ = Throw.IfNull(scenarioRun); + + return scenarioRun.EvaluateAsync( + messages: [userRequest], + new ChatResponse(modelResponse), + additionalContext, + cancellationToken); + } + + /// + /// Evaluates the supplied and returns an + /// containing one or more s. + /// + /// The of which this evaluation is a part. + /// + /// The request that produced the that is to be evaluated. + /// + /// The response that is to be evaluated. + /// + /// Additional contextual information (beyond that which is available in ) that the + /// s included in this may need to accurately evaluate the + /// supplied . + /// + /// + /// A that can cancel the evaluation operation. + /// + /// An containing one or more s. + public static ValueTask EvaluateAsync( + this ScenarioRun scenarioRun, + ChatMessage userRequest, + ChatResponse modelResponse, + IEnumerable? additionalContext = null, + CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(scenarioRun); return scenarioRun.EvaluateAsync( messages: [userRequest], diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/ScenarioRunResult.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/ScenarioRunResult.cs index 22d9ff0167e..e1a4102e42c 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/ScenarioRunResult.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/ScenarioRunResult.cs @@ -17,7 +17,7 @@ namespace Microsoft.Extensions.AI.Evaluation.Reporting; /// Represents the results of a single execution of a particular iteration of a particular scenario under evaluation. /// In other words, represents the results of evaluating a /// and includes the that is produced when -/// +/// /// is invoked. /// /// @@ -44,7 +44,7 @@ public sealed class ScenarioRunResult( string executionName, DateTime creationTime, IList messages, - ChatMessage modelResponse, + ChatResponse modelResponse, EvaluationResult evaluationResult) { /// @@ -68,7 +68,7 @@ public ScenarioRunResult( string executionName, DateTime creationTime, IEnumerable messages, - ChatMessage modelResponse, + ChatResponse modelResponse, EvaluationResult evaluationResult) : this( scenarioName, @@ -115,7 +115,7 @@ public ScenarioRunResult( /// /// Gets or sets the response being evaluated in this . /// - public ChatMessage ModelResponse { get; set; } = modelResponse; + public ChatResponse ModelResponse { get; set; } = modelResponse; /// /// Gets or sets the for the corresponding to @@ -123,7 +123,7 @@ public ScenarioRunResult( /// /// /// This is the same that is returned when - /// + /// /// is invoked. /// public EvaluationResult EvaluationResult { get; set; } = evaluationResult; diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/ScenarioRunResultExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/ScenarioRunResultExtensions.cs index 8b82a7336cf..ecc3dcb80e8 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/ScenarioRunResultExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/ScenarioRunResultExtensions.cs @@ -30,7 +30,7 @@ public static bool ContainsDiagnostics( this ScenarioRunResult result, Func? predicate = null) { - _ = Throw.IfNull(result, nameof(result)); + _ = Throw.IfNull(result); return result.EvaluationResult.ContainsDiagnostics(predicate); } diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/Storage/DiskBasedResultStore.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/Storage/DiskBasedResultStore.cs index 3ab62df05f8..422bcab2fb2 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/Storage/DiskBasedResultStore.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/Storage/DiskBasedResultStore.cs @@ -83,7 +83,7 @@ public async ValueTask WriteResultsAsync( IEnumerable results, CancellationToken cancellationToken = default) { - _ = Throw.IfNull(results, nameof(results)); + _ = Throw.IfNull(results); foreach (ScenarioRunResult result in results) { diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation/CompositeEvaluator.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation/CompositeEvaluator.cs index af14851ff92..7dc544c66c8 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation/CompositeEvaluator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation/CompositeEvaluator.cs @@ -42,7 +42,7 @@ public CompositeEvaluator(params IEvaluator[] evaluators) /// An enumeration of s that are to be composed. public CompositeEvaluator(IEnumerable evaluators) { - _ = Throw.IfNull(evaluators, nameof(evaluators)); + _ = Throw.IfNull(evaluators); var metricNames = new HashSet(); @@ -102,7 +102,7 @@ public CompositeEvaluator(IEnumerable evaluators) /// An containing one or more s. public async ValueTask EvaluateAsync( IEnumerable messages, - ChatMessage modelResponse, + ChatResponse modelResponse, ChatConfiguration? chatConfiguration = null, IEnumerable? additionalContext = null, CancellationToken cancellationToken = default) @@ -127,7 +127,7 @@ public async ValueTask EvaluateAsync( private IAsyncEnumerable EvaluateAndStreamResultsAsync( IEnumerable messages, - ChatMessage modelResponse, + ChatResponse modelResponse, ChatConfiguration? chatConfiguration = null, IEnumerable? additionalContext = null, CancellationToken cancellationToken = default) diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation/EvaluationMetricExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation/EvaluationMetricExtensions.cs index 9af7a5a2427..9b6f5e05104 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation/EvaluationMetricExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation/EvaluationMetricExtensions.cs @@ -31,7 +31,7 @@ public static bool ContainsDiagnostics( this EvaluationMetric metric, Func? predicate = null) { - _ = Throw.IfNull(metric, nameof(metric)); + _ = Throw.IfNull(metric); return predicate is null ? metric.Diagnostics.Any() : metric.Diagnostics.Any(predicate); } diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation/EvaluationResult.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation/EvaluationResult.cs index 668422d349e..a49cfeb8463 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation/EvaluationResult.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation/EvaluationResult.cs @@ -50,7 +50,7 @@ public EvaluationResult(IDictionary metrics) /// public EvaluationResult(IEnumerable metrics) { - _ = Throw.IfNull(metrics, nameof(metrics)); + _ = Throw.IfNull(metrics); var metricsDictionary = new Dictionary(); diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation/EvaluationResultExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation/EvaluationResultExtensions.cs index 18b7181c7aa..30305327c8d 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation/EvaluationResultExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation/EvaluationResultExtensions.cs @@ -22,7 +22,7 @@ public static class EvaluationResultExtensions /// The that is to be added. public static void AddDiagnosticToAllMetrics(this EvaluationResult result, EvaluationDiagnostic diagnostic) { - _ = Throw.IfNull(result, nameof(result)); + _ = Throw.IfNull(result); foreach (EvaluationMetric metric in result.Metrics.Values) { @@ -49,7 +49,7 @@ public static bool ContainsDiagnostics( this EvaluationResult result, Func? predicate = null) { - _ = Throw.IfNull(result, nameof(result)); + _ = Throw.IfNull(result); return result.Metrics.Values.Any(m => m.ContainsDiagnostics(predicate)); } @@ -69,8 +69,8 @@ public static void Interpret( this EvaluationResult result, Func interpretationProvider) { - _ = Throw.IfNull(result, nameof(result)); - _ = Throw.IfNull(interpretationProvider, nameof(interpretationProvider)); + _ = Throw.IfNull(result); + _ = Throw.IfNull(interpretationProvider); foreach (EvaluationMetric metric in result.Metrics.Values) { diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation/EvaluatorExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation/EvaluatorExtensions.cs index efda72a5c39..cfef4121af4 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation/EvaluatorExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation/EvaluatorExtensions.cs @@ -133,7 +133,52 @@ public static ValueTask EvaluateAsync( IEnumerable? additionalContext = null, CancellationToken cancellationToken = default) { - _ = Throw.IfNull(evaluator, nameof(evaluator)); + _ = Throw.IfNull(evaluator); + + return evaluator.EvaluateAsync( + messages: [], + new ChatResponse(modelResponse), + chatConfiguration, + additionalContext, + cancellationToken); + } + + /// + /// Evaluates the supplied and returns an + /// containing one or more s. + /// + /// + /// + /// The s of the s contained in the returned + /// should match . + /// + /// + /// Also note that must not be omitted if the evaluation is performed using an + /// AI model. + /// + /// + /// The that should perform the evaluation. + /// The response that is to be evaluated. + /// + /// A that specifies the and the + /// that should be used if the evaluation is performed using an AI model. + /// + /// + /// Additional contextual information that the may need to accurately evaluate the + /// supplied . + /// + /// + /// A that can cancel the evaluation operation. + /// + /// An containing one or more s. + public static ValueTask EvaluateAsync( + this IEvaluator evaluator, + ChatResponse modelResponse, + ChatConfiguration? chatConfiguration = null, + IEnumerable? additionalContext = null, + CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(evaluator); return evaluator.EvaluateAsync( messages: [], @@ -182,7 +227,56 @@ public static ValueTask EvaluateAsync( IEnumerable? additionalContext = null, CancellationToken cancellationToken = default) { - _ = Throw.IfNull(evaluator, nameof(evaluator)); + _ = Throw.IfNull(evaluator); + + return evaluator.EvaluateAsync( + messages: [userRequest], + new ChatResponse(modelResponse), + chatConfiguration, + additionalContext, + cancellationToken); + } + + /// + /// Evaluates the supplied and returns an + /// containing one or more s. + /// + /// + /// + /// The s of the s contained in the returned + /// should match . + /// + /// + /// Also note that must not be omitted if the evaluation is performed using an + /// AI model. + /// + /// + /// The that should perform the evaluation. + /// + /// The request that produced the that is to be evaluated. + /// + /// The response that is to be evaluated. + /// + /// A that specifies the and the + /// that should be used if the evaluation is performed using an AI model. + /// + /// + /// Additional contextual information (beyond that which is available in ) that the + /// may need to accurately evaluate the supplied . + /// + /// + /// A that can cancel the evaluation operation. + /// + /// An containing one or more s. + public static ValueTask EvaluateAsync( + this IEvaluator evaluator, + ChatMessage userRequest, + ChatResponse modelResponse, + ChatConfiguration? chatConfiguration = null, + IEnumerable? additionalContext = null, + CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(evaluator); return evaluator.EvaluateAsync( messages: [userRequest], diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation/IEvaluator.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation/IEvaluator.cs index d30e4b92df7..9528d4132d3 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation/IEvaluator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation/IEvaluator.cs @@ -50,7 +50,7 @@ public interface IEvaluator /// An containing one or more s. ValueTask EvaluateAsync( IEnumerable messages, - ChatMessage modelResponse, + ChatResponse modelResponse, ChatConfiguration? chatConfiguration = null, IEnumerable? additionalContext = null, CancellationToken cancellationToken = default); diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation/TokenizerExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation/TokenizerExtensions.cs index a9ef5e0c508..681d69ed1e1 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation/TokenizerExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation/TokenizerExtensions.cs @@ -37,7 +37,7 @@ public int CountTokens(string content) /// public static IEvaluationTokenCounter ToTokenCounter(this Tokenizer tokenizer, int inputTokenLimit) { - _ = Throw.IfNull(tokenizer, nameof(tokenizer)); + _ = Throw.IfNull(tokenizer); return new TokenCounter(tokenizer, inputTokenLimit); } diff --git a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs index d3f45358d10..ed1448c8b69 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs @@ -56,6 +56,8 @@ public OllamaChatClient(string endpoint, string? modelId = null, HttpClient? htt /// Either this parameter or must provide a valid model ID. /// /// An instance to use for HTTP operations. + /// is . + /// is empty or composed entirely of whitespace. public OllamaChatClient(Uri endpoint, string? modelId = null, HttpClient? httpClient = null) { _ = Throw.IfNull(endpoint); @@ -78,13 +80,14 @@ public JsonSerializerOptions ToolCallJsonSerializerOptions } /// - public async Task GetResponseAsync(IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + public async Task GetResponseAsync( + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) { - _ = Throw.IfNull(chatMessages); + _ = Throw.IfNull(messages); using var httpResponse = await _httpClient.PostAsJsonAsync( _apiChatEndpoint, - ToOllamaChatRequest(chatMessages, options, stream: false), + ToOllamaChatRequest(messages, options, stream: false), JsonContext.Default.OllamaChatRequest, cancellationToken).ConfigureAwait(false); @@ -102,7 +105,7 @@ public async Task GetResponseAsync(IList chatMessages throw new InvalidOperationException($"Ollama error: {response.Error}"); } - return new([FromOllamaMessage(response.Message!)]) + return new(FromOllamaMessage(response.Message!)) { CreatedAt = DateTimeOffset.TryParse(response.CreatedAt, CultureInfo.InvariantCulture, DateTimeStyles.None, out DateTimeOffset createdAt) ? createdAt : null, FinishReason = ToFinishReason(response), @@ -114,13 +117,13 @@ public async Task GetResponseAsync(IList chatMessages /// public async IAsyncEnumerable GetStreamingResponseAsync( - IList chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + IEnumerable messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - _ = Throw.IfNull(chatMessages); + _ = Throw.IfNull(messages); using HttpRequestMessage request = new(HttpMethod.Post, _apiChatEndpoint) { - Content = JsonContent.Create(ToOllamaChatRequest(chatMessages, options, stream: true), JsonContext.Default.OllamaChatRequest) + Content = JsonContent.Create(ToOllamaChatRequest(messages, options, stream: true), JsonContext.Default.OllamaChatRequest) }; using var httpResponse = await _httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cancellationToken).ConfigureAwait(false); @@ -290,12 +293,12 @@ private static FunctionCallContent ToFunctionCallContent(OllamaFunctionToolCall } } - private OllamaChatRequest ToOllamaChatRequest(IList chatMessages, ChatOptions? options, bool stream) + private OllamaChatRequest ToOllamaChatRequest(IEnumerable messages, ChatOptions? options, bool stream) { OllamaChatRequest request = new() { Format = ToOllamaChatResponseFormat(options?.ResponseFormat), - Messages = chatMessages.SelectMany(ToOllamaChatRequestMessages).ToArray(), + Messages = messages.SelectMany(ToOllamaChatRequestMessages).ToArray(), Model = options?.ModelId ?? _metadata.ModelId ?? string.Empty, Stream = stream, Tools = options?.ToolMode is not NoneChatToolMode && options?.Tools is { Count: > 0 } tools ? tools.OfType().Select(ToOllamaTool) : null, diff --git a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaEmbeddingGenerator.cs index 3d869f3f278..6056753dd26 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaEmbeddingGenerator.cs @@ -45,6 +45,8 @@ public OllamaEmbeddingGenerator(string endpoint, string? modelId = null, HttpCli /// Either this parameter or must provide a valid model ID. /// /// An instance to use for HTTP operations. + /// is . + /// is empty or composed entirely of whitespace. public OllamaEmbeddingGenerator(Uri endpoint, string? modelId = null, HttpClient? httpClient = null) { _ = Throw.IfNull(endpoint); diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIAssistantClient.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIAssistantClient.cs index 7b68ce5e15e..1e5afb6d529 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIAssistantClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIAssistantClient.cs @@ -70,21 +70,21 @@ public OpenAIAssistantClient(AssistantClient assistantClient, string assistantId /// public Task GetResponseAsync( - IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) => - GetStreamingResponseAsync(chatMessages, options, cancellationToken).ToChatResponseAsync(coalesceContent: true, cancellationToken); + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) => + GetStreamingResponseAsync(messages, options, cancellationToken).ToChatResponseAsync(cancellationToken); /// public async IAsyncEnumerable GetStreamingResponseAsync( - IList chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + IEnumerable messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - // Extract necessary state from chatMessages and options. - (RunCreationOptions runOptions, List? toolResults) = CreateRunOptions(chatMessages, options); + // Extract necessary state from messages and options. + (RunCreationOptions runOptions, List? toolResults) = CreateRunOptions(messages, options); // Get the thread ID. string? threadId = options?.ChatThreadId ?? _threadId; if (threadId is null && toolResults is not null) { - Throw.ArgumentException(nameof(chatMessages), "No thread ID was provided, but chat messages includes tool results."); + Throw.ArgumentException(nameof(messages), "No thread ID was provided, but chat messages includes tool results."); } // Get the updates to process from the assistant. If we have any tool results, this means submitting those and ignoring @@ -112,17 +112,17 @@ public async IAsyncEnumerable GetStreamingResponseAsync( } // Process each update. + string? responseId = null; await foreach (var update in updates.ConfigureAwait(false)) { switch (update) { case MessageContentUpdate mcu: - yield return new() + yield return new(mcu.Role == MessageRole.User ? ChatRole.User : ChatRole.Assistant, mcu.Text) { ChatThreadId = threadId, RawRepresentation = mcu, - Role = mcu.Role == MessageRole.User ? ChatRole.User : ChatRole.Assistant, - Text = mcu.Text, + ResponseId = responseId, }; break; @@ -132,6 +132,7 @@ public async IAsyncEnumerable GetStreamingResponseAsync( case RunUpdate ru: threadId ??= ru.Value.ThreadId; + responseId ??= ru.Value.Id; ChatResponseUpdate ruUpdate = new() { @@ -140,7 +141,7 @@ public async IAsyncEnumerable GetStreamingResponseAsync( CreatedAt = ru.Value.CreatedAt, ModelId = ru.Value.Model, RawRepresentation = ru, - ResponseId = ru.Value.Id, + ResponseId = responseId, Role = ChatRole.Assistant, }; @@ -176,9 +177,10 @@ void IDisposable.Dispose() } /// Adds the provided messages to the thread and returns the options to use for the request. - private static (RunCreationOptions RunOptions, List? ToolResults) CreateRunOptions(IList chatMessages, ChatOptions? options) + private static (RunCreationOptions RunOptions, List? ToolResults) CreateRunOptions( + IEnumerable messages, ChatOptions? options) { - _ = Throw.IfNull(chatMessages); + _ = Throw.IfNull(messages); RunCreationOptions runOptions = new(); @@ -273,7 +275,7 @@ strictObj is bool strictValue ? // Handle ChatMessages. System messages are turned into additional instructions. StringBuilder? instructions = null; List? functionResults = null; - foreach (var chatMessage in chatMessages) + foreach (var chatMessage in messages) { List messageContents = []; diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs index ba584cc1734..7852a87c2e1 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs @@ -39,6 +39,8 @@ public sealed class OpenAIChatClient : IChatClient /// Initializes a new instance of the class for the specified . /// The underlying client. /// The model to use. + /// is . + /// is empty or composed entirely of whitespace. public OpenAIChatClient(OpenAIClient openAIClient, string modelId) { _ = Throw.IfNull(openAIClient); @@ -59,6 +61,7 @@ public OpenAIChatClient(OpenAIClient openAIClient, string modelId) /// Initializes a new instance of the class for the specified . /// The underlying client. + /// is . public OpenAIChatClient(ChatClient chatClient) { _ = Throw.IfNull(chatClient); @@ -100,11 +103,11 @@ public JsonSerializerOptions ToolCallJsonSerializerOptions /// public async Task GetResponseAsync( - IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) { - _ = Throw.IfNull(chatMessages); + _ = Throw.IfNull(messages); - var openAIChatMessages = OpenAIModelMappers.ToOpenAIChatMessages(chatMessages, ToolCallJsonSerializerOptions); + var openAIChatMessages = OpenAIModelMappers.ToOpenAIChatMessages(messages, ToolCallJsonSerializerOptions); var openAIOptions = OpenAIModelMappers.ToOpenAIOptions(options); // Make the call to OpenAI. @@ -115,11 +118,11 @@ public async Task GetResponseAsync( /// public IAsyncEnumerable GetStreamingResponseAsync( - IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) { - _ = Throw.IfNull(chatMessages); + _ = Throw.IfNull(messages); - var openAIChatMessages = OpenAIModelMappers.ToOpenAIChatMessages(chatMessages, ToolCallJsonSerializerOptions); + var openAIChatMessages = OpenAIModelMappers.ToOpenAIChatMessages(messages, ToolCallJsonSerializerOptions); var openAIOptions = OpenAIModelMappers.ToOpenAIOptions(options); // Make the call to OpenAI. diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIEmbeddingGenerator.cs index 55c887ba108..7cf0be18fb0 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIEmbeddingGenerator.cs @@ -38,6 +38,9 @@ public sealed class OpenAIEmbeddingGenerator : IEmbeddingGeneratorThe underlying client. /// The model to use. /// The number of dimensions to generate in each embedding. + /// is . + /// is empty or composed entirely of whitespace. + /// is not positive. public OpenAIEmbeddingGenerator( OpenAIClient openAIClient, string modelId, int? dimensions = null) { @@ -66,6 +69,8 @@ public OpenAIEmbeddingGenerator( /// Initializes a new instance of the class. /// The underlying client. /// The number of dimensions to generate in each embedding. + /// is . + /// is not positive. public OpenAIEmbeddingGenerator(EmbeddingClient embeddingClient, int? dimensions = null) { _ = Throw.IfNull(embeddingClient); diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatCompletion.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatCompletion.cs index f5c21be3678..59727d38f00 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatCompletion.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatCompletion.cs @@ -30,23 +30,25 @@ public static ChatCompletion ToOpenAIChatCompletion(ChatResponse response, JsonS { _ = Throw.IfNull(response); - if (response.Choices.Count > 1) - { - throw new NotSupportedException("Creating OpenAI ChatCompletion models with multiple choices is currently not supported."); - } - List? toolCalls = null; - foreach (AIContent content in response.Message.Contents) + ChatRole? role = null; + List allContents = []; + foreach (ChatMessage message in response.Messages) { - if (content is FunctionCallContent callRequest) + role = message.Role; + foreach (AIContent content in message.Contents) { - toolCalls ??= []; - toolCalls.Add(ChatToolCall.CreateFunctionToolCall( - callRequest.CallId, - callRequest.Name, - new(JsonSerializer.SerializeToUtf8Bytes( - callRequest.Arguments, - options.GetTypeInfo(typeof(IDictionary)))))); + allContents.Add(content); + if (content is FunctionCallContent callRequest) + { + toolCalls ??= []; + toolCalls.Add(ChatToolCall.CreateFunctionToolCall( + callRequest.CallId, + callRequest.Name, + new(JsonSerializer.SerializeToUtf8Bytes( + callRequest.Arguments, + options.GetTypeInfo(typeof(IDictionary)))))); + } } } @@ -60,9 +62,9 @@ public static ChatCompletion ToOpenAIChatCompletion(ChatResponse response, JsonS id: response.ResponseId ?? CreateCompletionId(), model: response.ModelId, createdAt: response.CreatedAt ?? DateTimeOffset.UtcNow, - role: ToOpenAIChatRole(response.Message.Role).Value, + role: ToOpenAIChatRole(role) ?? ChatMessageRole.Assistant, finishReason: ToOpenAIFinishReason(response.FinishReason), - content: new(ToOpenAIChatContent(response.Message.Contents)), + content: new(ToOpenAIChatContent(allContents)), toolCalls: toolCalls, refusal: response.AdditionalProperties.GetValueOrDefault(nameof(ChatCompletion.Refusal)), contentTokenLogProbabilities: response.AdditionalProperties.GetValueOrDefault>(nameof(ChatCompletion.ContentTokenLogProbabilities)), @@ -138,7 +140,7 @@ public static ChatResponse FromOpenAIChatCompletion(ChatCompletion openAIComplet } // Wrap the content in a ChatResponse to return. - var response = new ChatResponse([returnMessage]) + var response = new ChatResponse(returnMessage) { CreatedAt = openAICompletion.CreatedAt, FinishReason = FromOpenAIFinishReason(openAICompletion.FinishReason), diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIRealtimeExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIRealtimeExtensions.cs index 8a652a71766..d74505e64f8 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIRealtimeExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIRealtimeExtensions.cs @@ -23,6 +23,7 @@ public static class OpenAIRealtimeExtensions /// it can be used with . /// /// A that can be used with . + /// is . public static ConversationFunctionTool ToConversationFunctionTool(this AIFunction aiFunction) { _ = Throw.IfNull(aiFunction); @@ -53,6 +54,9 @@ public static ConversationFunctionTool ToConversationFunctionTool(this AIFunctio /// An optional that controls JSON handling. /// An optional . /// A that represents the completion of processing, including invoking any asynchronous tools. + /// is . + /// is . + /// is . public static async Task HandleToolCallsAsync( this RealtimeConversationSession session, ConversationUpdate update, diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAISerializationHelpers.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAISerializationHelpers.cs index 659db4ed3bd..e736d110650 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAISerializationHelpers.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAISerializationHelpers.cs @@ -25,6 +25,7 @@ public static class OpenAISerializationHelpers /// The stream containing a message using the OpenAI wire format. /// A token used to cancel the operation. /// The deserialized list of chat messages and chat options. + /// is . public static async Task DeserializeChatCompletionRequestAsync( Stream stream, CancellationToken cancellationToken = default) { @@ -43,6 +44,8 @@ public static async Task DeserializeChatCompletionR /// The governing function call content serialization. /// A token used to cancel the serialization operation. /// A task tracking the serialization operation. + /// is . + /// is . public static async Task SerializeAsync( Stream stream, ChatResponse response, @@ -66,6 +69,8 @@ public static async Task SerializeAsync( /// The governing function call content serialization. /// A token used to cancel the serialization operation. /// A task tracking the serialization operation. + /// is . + /// is . public static Task SerializeStreamingAsync( Stream stream, IAsyncEnumerable updates, diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/AnonymousDelegatingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/AnonymousDelegatingChatClient.cs index 8193e841536..dbc3114ec25 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/AnonymousDelegatingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/AnonymousDelegatingChatClient.cs @@ -17,7 +17,7 @@ namespace Microsoft.Extensions.AI; internal sealed class AnonymousDelegatingChatClient : DelegatingChatClient { /// The delegate to use as the implementation of . - private readonly Func, ChatOptions?, IChatClient, CancellationToken, Task>? _getResponseFunc; + private readonly Func, ChatOptions?, IChatClient, CancellationToken, Task>? _getResponseFunc; /// The delegate to use as the implementation of . /// @@ -25,10 +25,10 @@ internal sealed class AnonymousDelegatingChatClient : DelegatingChatClient /// will be invoked with the same arguments as the method itself, along with a reference to the inner client. /// When , will delegate directly to the inner client. /// - private readonly Func, ChatOptions?, IChatClient, CancellationToken, IAsyncEnumerable>? _getStreamingResponseFunc; + private readonly Func, ChatOptions?, IChatClient, CancellationToken, IAsyncEnumerable>? _getStreamingResponseFunc; /// The delegate to use as the implementation of both and . - private readonly Func, ChatOptions?, Func, ChatOptions?, CancellationToken, Task>, CancellationToken, Task>? _sharedFunc; + private readonly Func, ChatOptions?, Func, ChatOptions?, CancellationToken, Task>, CancellationToken, Task>? _sharedFunc; /// /// Initializes a new instance of the class. @@ -47,7 +47,7 @@ internal sealed class AnonymousDelegatingChatClient : DelegatingChatClient /// is . public AnonymousDelegatingChatClient( IChatClient innerClient, - Func, ChatOptions?, Func, ChatOptions?, CancellationToken, Task>, CancellationToken, Task> sharedFunc) + Func, ChatOptions?, Func, ChatOptions?, CancellationToken, Task>, CancellationToken, Task> sharedFunc) : base(innerClient) { _ = Throw.IfNull(sharedFunc); @@ -73,8 +73,8 @@ public AnonymousDelegatingChatClient( /// Both and are . public AnonymousDelegatingChatClient( IChatClient innerClient, - Func, ChatOptions?, IChatClient, CancellationToken, Task>? getResponseFunc, - Func, ChatOptions?, IChatClient, CancellationToken, IAsyncEnumerable>? getStreamingResponseFunc) + Func, ChatOptions?, IChatClient, CancellationToken, Task>? getResponseFunc, + Func, ChatOptions?, IChatClient, CancellationToken, IAsyncEnumerable>? getStreamingResponseFunc) : base(innerClient) { ThrowIfBothDelegatesNull(getResponseFunc, getStreamingResponseFunc); @@ -85,25 +85,26 @@ public AnonymousDelegatingChatClient( /// public override Task GetResponseAsync( - IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) { - _ = Throw.IfNull(chatMessages); + _ = Throw.IfNull(messages); if (_sharedFunc is not null) { - return GetResponseViaSharedAsync(chatMessages, options, cancellationToken); + return GetResponseViaSharedAsync(messages, options, cancellationToken); - async Task GetResponseViaSharedAsync(IList chatMessages, ChatOptions? options, CancellationToken cancellationToken) + async Task GetResponseViaSharedAsync( + IEnumerable messages, ChatOptions? options, CancellationToken cancellationToken) { ChatResponse? response = null; - await _sharedFunc(chatMessages, options, async (chatMessages, options, cancellationToken) => + await _sharedFunc(messages, options, async (messages, options, cancellationToken) => { - response = await InnerClient.GetResponseAsync(chatMessages, options, cancellationToken).ConfigureAwait(false); + response = await InnerClient.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false); }, cancellationToken).ConfigureAwait(false); if (response is null) { - throw new InvalidOperationException("The wrapper completed successfully without producing a ChatResponse."); + Throw.InvalidOperationException("The wrapper completed successfully without producing a ChatResponse."); } return response; @@ -111,21 +112,21 @@ await _sharedFunc(chatMessages, options, async (chatMessages, options, cancellat } else if (_getResponseFunc is not null) { - return _getResponseFunc(chatMessages, options, InnerClient, cancellationToken); + return _getResponseFunc(messages, options, InnerClient, cancellationToken); } else { Debug.Assert(_getStreamingResponseFunc is not null, "Expected non-null streaming delegate."); - return _getStreamingResponseFunc!(chatMessages, options, InnerClient, cancellationToken) - .ToChatResponseAsync(coalesceContent: true, cancellationToken); + return _getStreamingResponseFunc!(messages, options, InnerClient, cancellationToken) + .ToChatResponseAsync(cancellationToken); } } /// public override IAsyncEnumerable GetStreamingResponseAsync( - IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) { - _ = Throw.IfNull(chatMessages); + _ = Throw.IfNull(messages); if (_sharedFunc is not null) { @@ -138,9 +139,9 @@ public override IAsyncEnumerable GetStreamingResponseAsync( Exception? error = null; try { - await _sharedFunc(chatMessages, options, async (chatMessages, options, cancellationToken) => + await _sharedFunc(messages, options, async (messages, options, cancellationToken) => { - await foreach (var update in InnerClient.GetStreamingResponseAsync(chatMessages, options, cancellationToken).ConfigureAwait(false)) + await foreach (var update in InnerClient.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false)) { await updates.Writer.WriteAsync(update, cancellationToken).ConfigureAwait(false); } @@ -161,12 +162,12 @@ await _sharedFunc(chatMessages, options, async (chatMessages, options, cancellat } else if (_getStreamingResponseFunc is not null) { - return _getStreamingResponseFunc(chatMessages, options, InnerClient, cancellationToken); + return _getStreamingResponseFunc(messages, options, InnerClient, cancellationToken); } else { Debug.Assert(_getResponseFunc is not null, "Expected non-null non-streaming delegate."); - return GetStreamingResponseAsyncViaGetResponseAsync(_getResponseFunc!(chatMessages, options, InnerClient, cancellationToken)); + return GetStreamingResponseAsyncViaGetResponseAsync(_getResponseFunc!(messages, options, InnerClient, cancellationToken)); static async IAsyncEnumerable GetStreamingResponseAsyncViaGetResponseAsync(Task task) { diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs index 79f41d1790e..7d7b2b58403 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs @@ -45,18 +45,19 @@ protected CachingChatClient(IChatClient innerClient) public bool CoalesceStreamingUpdates { get; set; } = true; /// - public override async Task GetResponseAsync(IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + public override async Task GetResponseAsync( + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) { - _ = Throw.IfNull(chatMessages); + _ = Throw.IfNull(messages); // We're only storing the final result, not the in-flight task, so that we can avoid caching failures // or having problems when one of the callers cancels but others don't. This has the drawback that // concurrent callers might trigger duplicate requests, but that's acceptable. - var cacheKey = GetCacheKey(_boxedFalse, chatMessages, options); + var cacheKey = GetCacheKey(_boxedFalse, messages, options); if (await ReadCacheAsync(cacheKey, cancellationToken).ConfigureAwait(false) is not { } result) { - result = await base.GetResponseAsync(chatMessages, options, cancellationToken).ConfigureAwait(false); + result = await base.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false); await WriteCacheAsync(cacheKey, result, cancellationToken).ConfigureAwait(false); } @@ -65,9 +66,9 @@ public override async Task GetResponseAsync(IList cha /// public override async IAsyncEnumerable GetStreamingResponseAsync( - IList chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + IEnumerable messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - _ = Throw.IfNull(chatMessages); + _ = Throw.IfNull(messages); if (CoalesceStreamingUpdates) { @@ -75,7 +76,7 @@ public override async IAsyncEnumerable GetStreamingResponseA // we make a streaming request, yielding those results, but then convert those into a non-streaming // result and cache it. When we get a cache hit, we yield the non-streaming result as a streaming one. - var cacheKey = GetCacheKey(_boxedTrue, chatMessages, options); + var cacheKey = GetCacheKey(_boxedTrue, messages, options); if (await ReadCacheAsync(cacheKey, cancellationToken).ConfigureAwait(false) is { } chatResponse) { // Yield all of the cached items. @@ -88,7 +89,7 @@ public override async IAsyncEnumerable GetStreamingResponseA { // Yield and store all of the items. List capturedItems = []; - await foreach (var chunk in base.GetStreamingResponseAsync(chatMessages, options, cancellationToken).ConfigureAwait(false)) + await foreach (var chunk in base.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false)) { capturedItems.Add(chunk); yield return chunk; @@ -100,12 +101,14 @@ public override async IAsyncEnumerable GetStreamingResponseA } else { - var cacheKey = GetCacheKey(_boxedTrue, chatMessages, options); + var cacheKey = GetCacheKey(_boxedTrue, messages, options); if (await ReadCacheStreamingAsync(cacheKey, cancellationToken).ConfigureAwait(false) is { } existingChunks) { // Yield all of the cached items. + string? chatThreadId = null; foreach (var chunk in existingChunks) { + chatThreadId ??= chunk.ChatThreadId; yield return chunk; } } @@ -113,7 +116,7 @@ public override async IAsyncEnumerable GetStreamingResponseA { // Yield and store all of the items. List capturedItems = []; - await foreach (var chunk in base.GetStreamingResponseAsync(chatMessages, options, cancellationToken).ConfigureAwait(false)) + await foreach (var chunk in base.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false)) { capturedItems.Add(chunk); yield return chunk; @@ -132,39 +135,45 @@ public override async IAsyncEnumerable GetStreamingResponseA /// /// Returns a previously cached , if available. - /// This is used when there is a call to . + /// This is used when there is a call to . /// /// The cache key. /// The to monitor for cancellation requests. /// The previously cached data, if available, otherwise . + /// is . protected abstract Task ReadCacheAsync(string key, CancellationToken cancellationToken); /// /// Returns a previously cached list of values, if available. - /// This is used when there is a call to . + /// This is used when there is a call to . /// /// The cache key. /// The to monitor for cancellation requests. /// The previously cached data, if available, otherwise . + /// is . protected abstract Task?> ReadCacheStreamingAsync(string key, CancellationToken cancellationToken); /// /// Stores a in the underlying cache. - /// This is used when there is a call to . + /// This is used when there is a call to . /// /// The cache key. /// The to be stored. /// The to monitor for cancellation requests. /// A representing the completion of the operation. + /// is . + /// is . protected abstract Task WriteCacheAsync(string key, ChatResponse value, CancellationToken cancellationToken); /// /// Stores a list of values in the underlying cache. - /// This is used when there is a call to . + /// This is used when there is a call to . /// /// The cache key. /// The to be stored. /// The to monitor for cancellation requests. /// A representing the completion of the operation. + /// is . + /// is . protected abstract Task WriteCacheStreamingAsync(string key, IReadOnlyList value, CancellationToken cancellationToken); } diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilder.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilder.cs index ecd6d04914b..8789810b601 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilder.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilder.cs @@ -19,6 +19,7 @@ public sealed class ChatClientBuilder /// Initializes a new instance of the class. /// The inner that represents the underlying backend. + /// is . public ChatClientBuilder(IChatClient innerClient) { _ = Throw.IfNull(innerClient); @@ -48,10 +49,13 @@ public IChatClient Build(IServiceProvider? services = null) { for (var i = _clientFactories.Count - 1; i >= 0; i--) { - chatClient = _clientFactories[i](chatClient, services) ?? - throw new InvalidOperationException( + chatClient = _clientFactories[i](chatClient, services); + if (chatClient is null) + { + Throw.InvalidOperationException( $"The {nameof(ChatClientBuilder)} entry at index {i} returned null. " + $"Ensure that the callbacks passed to {nameof(Use)} return non-null {nameof(IChatClient)} instances."); + } } } @@ -61,6 +65,7 @@ public IChatClient Build(IServiceProvider? services = null) /// Adds a factory for an intermediate chat client to the chat client pipeline. /// The client factory function. /// The updated instance. + /// is . public ChatClientBuilder Use(Func clientFactory) { _ = Throw.IfNull(clientFactory); @@ -71,6 +76,7 @@ public ChatClientBuilder Use(Func clientFactory) /// Adds a factory for an intermediate chat client to the chat client pipeline. /// The client factory function. /// The updated instance. + /// is . public ChatClientBuilder Use(Func clientFactory) { _ = Throw.IfNull(clientFactory); @@ -96,7 +102,7 @@ public ChatClientBuilder Use(Func cl /// need to interact with the results of the operation, which will come from the inner client. /// /// is . - public ChatClientBuilder Use(Func, ChatOptions?, Func, ChatOptions?, CancellationToken, Task>, CancellationToken, Task> sharedFunc) + public ChatClientBuilder Use(Func, ChatOptions?, Func, ChatOptions?, CancellationToken, Task>, CancellationToken, Task> sharedFunc) { _ = Throw.IfNull(sharedFunc); @@ -130,8 +136,8 @@ public ChatClientBuilder Use(Func, ChatOptions?, Func /// Both and are . public ChatClientBuilder Use( - Func, ChatOptions?, IChatClient, CancellationToken, Task>? getResponseFunc, - Func, ChatOptions?, IChatClient, CancellationToken, IAsyncEnumerable>? getStreamingResponseFunc) + Func, ChatOptions?, IChatClient, CancellationToken, Task>? getResponseFunc, + Func, ChatOptions?, IChatClient, CancellationToken, IAsyncEnumerable>? getStreamingResponseFunc) { AnonymousDelegatingChatClient.ThrowIfBothDelegatesNull(getResponseFunc, getStreamingResponseFunc); diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilderChatClientExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilderChatClientExtensions.cs index 87983bf2367..b4e1e7f280f 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilderChatClientExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilderChatClientExtensions.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System; using Microsoft.Extensions.AI; using Microsoft.Shared.Diagnostics; @@ -16,6 +17,7 @@ public static class ChatClientBuilderChatClientExtensions /// This method is equivalent to using the constructor directly, /// specifying as the inner client. /// + /// is . public static ChatClientBuilder AsBuilder(this IChatClient innerClient) { _ = Throw.IfNull(innerClient); diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilderServiceCollectionExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilderServiceCollectionExtensions.cs index c1be6406d1a..d1e6761f317 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilderServiceCollectionExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilderServiceCollectionExtensions.cs @@ -16,11 +16,18 @@ public static class ChatClientBuilderServiceCollectionExtensions /// The service lifetime for the client. Defaults to . /// A that can be used to build a pipeline around the inner client. /// The client is registered as a singleton service. + /// is . + /// is . public static ChatClientBuilder AddChatClient( this IServiceCollection serviceCollection, IChatClient innerClient, ServiceLifetime lifetime = ServiceLifetime.Singleton) - => AddChatClient(serviceCollection, _ => innerClient, lifetime); + { + _ = Throw.IfNull(serviceCollection); + _ = Throw.IfNull(innerClient); + + return AddChatClient(serviceCollection, _ => innerClient, lifetime); + } /// Registers a singleton in the . /// The to which the client should be added. @@ -28,6 +35,8 @@ public static ChatClientBuilder AddChatClient( /// The service lifetime for the client. Defaults to . /// A that can be used to build a pipeline around the inner client. /// The client is registered as a singleton service. + /// is . + /// is . public static ChatClientBuilder AddChatClient( this IServiceCollection serviceCollection, Func innerClientFactory, @@ -48,12 +57,19 @@ public static ChatClientBuilder AddChatClient( /// The service lifetime for the client. Defaults to . /// A that can be used to build a pipeline around the inner client. /// The client is registered as a scoped service. + /// is . + /// is . public static ChatClientBuilder AddKeyedChatClient( this IServiceCollection serviceCollection, object? serviceKey, IChatClient innerClient, ServiceLifetime lifetime = ServiceLifetime.Singleton) - => AddKeyedChatClient(serviceCollection, serviceKey, _ => innerClient, lifetime); + { + _ = Throw.IfNull(serviceCollection); + _ = Throw.IfNull(innerClient); + + return AddKeyedChatClient(serviceCollection, serviceKey, _ => innerClient, lifetime); + } /// Registers a keyed singleton in the . /// The to which the client should be added. @@ -62,6 +78,8 @@ public static ChatClientBuilder AddKeyedChatClient( /// The service lifetime for the client. Defaults to . /// A that can be used to build a pipeline around the inner client. /// The client is registered as a scoped service. + /// is . + /// is . public static ChatClientBuilder AddKeyedChatClient( this IServiceCollection serviceCollection, object? serviceKey, @@ -69,7 +87,6 @@ public static ChatClientBuilder AddKeyedChatClient( ServiceLifetime lifetime = ServiceLifetime.Singleton) { _ = Throw.IfNull(serviceCollection); - _ = Throw.IfNull(serviceKey); _ = Throw.IfNull(innerClientFactory); var builder = new ChatClientBuilder(innerClientFactory); diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientStructuredOutputExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientStructuredOutputExtensions.cs index 778150f1ac1..59bd70eefc6 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientStructuredOutputExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientStructuredOutputExtensions.cs @@ -11,6 +11,8 @@ using System.Threading.Tasks; using Microsoft.Shared.Diagnostics; +#pragma warning disable SA1118 // Parameter should not span multiple lines + namespace Microsoft.Extensions.AI; /// @@ -27,7 +29,7 @@ public static class ChatClientStructuredOutputExtensions /// Sends chat messages, requesting a response matching the type . /// The . - /// The chat content to send. + /// The chat content to send. /// The chat options to configure the request. /// /// Optionally specifies whether to set a JSON schema on the . @@ -36,18 +38,14 @@ public static class ChatClientStructuredOutputExtensions /// /// The to monitor for cancellation requests. The default is . /// The response messages generated by the client. - /// - /// The returned messages will not have been added to . However, any intermediate messages generated implicitly - /// by the client, including any messages for roundtrips to the model as part of the implementation of this request, will be included. - /// /// The type of structured output to request. public static Task> GetResponseAsync( this IChatClient chatClient, - IList chatMessages, + IEnumerable messages, ChatOptions? options = null, bool? useNativeJsonSchema = null, CancellationToken cancellationToken = default) => - GetResponseAsync(chatClient, chatMessages, AIJsonUtilities.DefaultOptions, options, useNativeJsonSchema, cancellationToken); + GetResponseAsync(chatClient, messages, AIJsonUtilities.DefaultOptions, options, useNativeJsonSchema, cancellationToken); /// Sends a user chat text message, requesting a response matching the type . /// The . @@ -135,7 +133,7 @@ public static Task> GetResponseAsync( /// Sends chat messages, requesting a response matching the type . /// The . - /// The chat content to send. + /// The chat content to send. /// The JSON serialization options to use. /// The chat options to configure the request. /// @@ -145,21 +143,20 @@ public static Task> GetResponseAsync( /// /// The to monitor for cancellation requests. The default is . /// The response messages generated by the client. - /// - /// The returned messages will not have been added to . However, any intermediate messages generated implicitly - /// by the client, including any messages for roundtrips to the model as part of the implementation of this request, will be included. - /// /// The type of structured output to request. + /// is . + /// is . + /// is . public static async Task> GetResponseAsync( this IChatClient chatClient, - IList chatMessages, + IEnumerable messages, JsonSerializerOptions serializerOptions, ChatOptions? options = null, bool? useNativeJsonSchema = null, CancellationToken cancellationToken = default) { _ = Throw.IfNull(chatClient); - _ = Throw.IfNull(chatMessages); + _ = Throw.IfNull(messages); _ = Throw.IfNull(serializerOptions); serializerOptions.MakeReadOnly(); @@ -193,7 +190,7 @@ public static async Task> GetResponseAsync( } ChatMessage? promptAugmentation = null; - options = (options ?? new()).Clone(); + options = options is not null ? options.Clone() : new(); // Currently there's no way for the inner IChatClient to specify whether structured output // is supported, so we always default to false. In the future, some mechanism of declaring @@ -212,30 +209,18 @@ public static async Task> GetResponseAsync( options.ResponseFormat = ChatResponseFormat.Json; // When not using native structured output, augment the chat messages with a schema prompt -#pragma warning disable SA1118 // Parameter should not span multiple lines promptAugmentation = new ChatMessage(ChatRole.User, $$""" Respond with a JSON value conforming to the following schema: ``` {{schema}} ``` """); -#pragma warning restore SA1118 // Parameter should not span multiple lines - chatMessages.Add(promptAugmentation); + messages = [.. messages, promptAugmentation]; } - try - { - var result = await chatClient.GetResponseAsync(chatMessages, options, cancellationToken).ConfigureAwait(false); - return new ChatResponse(result, serializerOptions) { IsWrappedInObject = isWrappedInObject }; - } - finally - { - if (promptAugmentation is not null) - { - _ = chatMessages.Remove(promptAugmentation); - } - } + var result = await chatClient.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false); + return new ChatResponse(result, serializerOptions) { IsWrappedInObject = isWrappedInObject }; } private static bool SchemaRepresentsObject(JsonElement schemaElement) diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatResponse{T}.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatResponse{T}.cs index e78a0acf1f5..2a9fca23fae 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatResponse{T}.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatResponse{T}.cs @@ -3,7 +3,6 @@ using System; using System.Buffers; -using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Text; using System.Text.Json; @@ -17,7 +16,7 @@ namespace Microsoft.Extensions.AI; /// /// Language models are not guaranteed to honor the requested schema. If the model's output is not /// parseable as the expected type, then will return . -/// You can access the underlying JSON response on the property. +/// You can access the underlying JSON response on the property. /// public class ChatResponse : ChatResponse { @@ -31,10 +30,11 @@ public class ChatResponse : ChatResponse /// The unstructured that is being wrapped. /// The to use when deserializing the result. public ChatResponse(ChatResponse response, JsonSerializerOptions serializerOptions) - : base(Throw.IfNull(response).Choices) + : base(Throw.IfNull(response).Messages) { _serializerOptions = Throw.IfNull(serializerOptions); AdditionalProperties = response.AdditionalProperties; + ChatThreadId = response.ChatThreadId; CreatedAt = response.CreatedAt; FinishReason = response.FinishReason; ModelId = response.ModelId; @@ -114,13 +114,6 @@ public bool TryGetResult([NotNullWhen(true)] out T? result) /// internal bool IsWrappedInObject { get; set; } - private string? GetResultAsJson() - { - var choice = Choices.Count == 1 ? Choices[0] : null; - var content = choice?.Contents.Count == 1 ? choice.Contents[0] : null; - return (content as TextContent)?.Text; - } - private T? GetResultCore(out FailureReason? failureReason) { if (_hasDeserializedResult) @@ -129,7 +122,7 @@ public bool TryGetResult([NotNullWhen(true)] out T? result) return _deserializedResult; } - var json = GetResultAsJson(); + var json = Text; if (string.IsNullOrEmpty(json)) { failureReason = FailureReason.ResultDidNotContainJson; diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClient.cs index 551441139a7..5a5dfea06c3 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClient.cs @@ -34,16 +34,15 @@ public ConfigureOptionsChatClient(IChatClient innerClient, Action c } /// - public override async Task GetResponseAsync(IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) - { - return await base.GetResponseAsync(chatMessages, Configure(options), cancellationToken).ConfigureAwait(false); - } + public override async Task GetResponseAsync( + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) => + await base.GetResponseAsync(messages, Configure(options), cancellationToken).ConfigureAwait(false); /// public override async IAsyncEnumerable GetStreamingResponseAsync( - IList chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + IEnumerable messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - await foreach (var update in base.GetStreamingResponseAsync(chatMessages, Configure(options), cancellationToken).ConfigureAwait(false)) + await foreach (var update in base.GetStreamingResponseAsync(messages, Configure(options), cancellationToken).ConfigureAwait(false)) { yield return update; } diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClientBuilderExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClientBuilderExtensions.cs index ea990d09a85..d76b2ba1a2e 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClientBuilderExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClientBuilderExtensions.cs @@ -25,6 +25,8 @@ public static class ConfigureOptionsChatClientBuilderExtensions /// of the caller-supplied instance if one was supplied. /// /// The . + /// is . + /// is . public static ChatClientBuilder ConfigureOptions( this ChatClientBuilder builder, Action configure) { diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/DistributedCachingChatClientBuilderExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/DistributedCachingChatClientBuilderExtensions.cs index 6396459c09c..6a9474b751d 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/DistributedCachingChatClientBuilderExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/DistributedCachingChatClientBuilderExtensions.cs @@ -22,6 +22,7 @@ public static class DistributedCachingChatClientBuilderExtensions /// /// An optional callback that can be used to configure the instance. /// The provided as . + /// is . public static ChatClientBuilder UseDistributedCache(this ChatClientBuilder builder, IDistributedCache? storage = null, Action? configure = null) { _ = Throw.IfNull(builder); diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvocationContext.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvocationContext.cs index 690af275761..8dca904ccd0 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvocationContext.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvocationContext.cs @@ -18,7 +18,7 @@ public sealed class FunctionInvocationContext private static readonly AIFunction _nopFunction = AIFunctionFactory.Create(() => { }, nameof(FunctionInvocationContext)); /// The chat contents associated with the operation that initiated this function call request. - private IList _chatMessages = Array.Empty(); + private IList _messages = Array.Empty(); /// The AI function to be invoked. private AIFunction _function = _nopFunction; @@ -39,12 +39,15 @@ public FunctionCallContent CallContent } /// Gets or sets the chat contents associated with the operation that initiated this function call request. - public IList ChatMessages + public IList Messages { - get => _chatMessages; - set => _chatMessages = Throw.IfNull(value); + get => _messages; + set => _messages = Throw.IfNull(value); } + /// Gets or sets the chat options associated with the operation that initiated this function call request. + public ChatOptions? Options { get; set; } + /// Gets or sets the AI function to be invoked. public AIFunction Function { diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs index a64ebf7d61d..67e6249298c 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Runtime.CompilerServices; using System.Threading; @@ -11,8 +12,10 @@ using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; using Microsoft.Shared.Diagnostics; +using static Microsoft.Extensions.AI.OpenTelemetryConsts.GenAI; #pragma warning disable CA2213 // Disposable fields should be disposed +#pragma warning disable EA0002 // Use 'System.TimeProvider' to make the code easier to test namespace Microsoft.Extensions.AI; @@ -23,8 +26,10 @@ namespace Microsoft.Extensions.AI; /// /// /// When this client receives a in a chat response, it responds -/// by calling the corresponding defined in , -/// producing a . +/// by calling the corresponding defined in , +/// producing a that it sends back to the inner client. This loop +/// is repeated until there are no more function calls to make, or until another stop condition is met, +/// such as hitting . /// /// /// The provided implementation of is thread-safe for concurrent use so long as the @@ -139,40 +144,6 @@ public static FunctionInvocationContext? CurrentContext /// public bool AllowConcurrentInvocation { get; set; } - /// - /// Gets or sets a value indicating whether to keep intermediate function calling request - /// and response messages in the chat history. - /// - /// - /// if intermediate messages persist in the list provided - /// to and by the caller. - /// if intermediate messages are removed prior to completing the operation. - /// The default value is . - /// - /// - /// - /// When the inner returns to the - /// , the adds - /// those messages to the list of messages, along with instances - /// it creates with the results of invoking the requested functions. The resulting augmented - /// list of messages is then passed to the inner client in order to send the results back. - /// By default, those messages persist in the list provided to - /// and by the caller, such that those - /// messages are available to the caller. Set to avoid including - /// those messages in the caller-provided . - /// - /// - /// Changing the value of this property while the client is in use might result in inconsistencies - /// as to whether function calling messages are kept during an in-flight request. - /// - /// - /// If the underlying responds with - /// set to a non- value, this property may be ignored and behave as if it is - /// , with any such intermediate messages not stored in the messages list. - /// - /// - public bool KeepFunctionCallingContent { get; set; } = true; - /// /// Gets or sets the maximum number of iterations per request. /// @@ -209,214 +180,169 @@ public int? MaximumIterationsPerRequest } /// - public override async Task GetResponseAsync(IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + public override async Task GetResponseAsync( + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) { - _ = Throw.IfNull(chatMessages); + _ = Throw.IfNull(messages); // A single request into this GetResponseAsync may result in multiple requests to the inner client. // Create an activity to group them together for better observability. using Activity? activity = _activitySource?.StartActivity(nameof(FunctionInvokingChatClient)); - ChatResponse? response = null; - UsageDetails? totalUsage = null; - IList originalChatMessages = chatMessages; - try - { - for (int iteration = 0; ; iteration++) - { - // Make the call to the handler. - response = await base.GetResponseAsync(chatMessages, options, cancellationToken).ConfigureAwait(false); + IEnumerable originalMessages = messages; // the original messages, tracked for the rare case where we need to know what was originally provided + List? augmentedHistory = null; // the actual history of messages sent on turns other than the first + ChatResponse? response = null; // the response from the inner client, which is possibly modified and then eventually returned + List? responseMessages = null; // tracked list of messages, across multiple turns, to be used for the final response + UsageDetails? totalUsage = null; // tracked usage across all turns, to be used for the final response + List? functionCallContents = null; // function call contents that need responding to in the current turn + bool lastIterationHadThreadId = false; // whether the last iteration's response had a ChatThreadId set - // Aggregate usage data over all calls - if (response.Usage is not null) - { - totalUsage ??= new(); - totalUsage.Add(response.Usage); - } + for (int iteration = 0; ; iteration++) + { + functionCallContents?.Clear(); - // If there are no tools to call, or for any other reason we should stop, return the response. - if (options is null - || options.Tools is not { Count: > 0 } - || response.Choices.Count == 0 - || (MaximumIterationsPerRequest is { } maxIterations && iteration >= maxIterations)) - { - break; - } + // Make the call to the inner client. + response = await base.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false); + if (response is null) + { + Throw.InvalidOperationException($"The inner {nameof(IChatClient)} returned a null {nameof(ChatResponse)}."); + } - // If there's more than one choice, we don't know which one to add to chat history, or which - // of their function calls to process. This should not happen except if the developer has - // explicitly requested multiple choices. We fail aggressively to avoid cases where a developer - // doesn't realize this and is wasting their budget requesting extra choices we'd never use. - if (response.Choices.Count > 1) - { - ThrowForMultipleChoices(); - } + // Any function call work to do? If yes, ensure we're tracking that work in functionCallContents. + bool requiresFunctionInvocation = + options?.Tools is { Count: > 0 } && + (!MaximumIterationsPerRequest.HasValue || iteration < MaximumIterationsPerRequest.GetValueOrDefault()) && + CopyFunctionCalls(response.Messages, ref functionCallContents); - // Extract any function call contents on the first choice. If there are none, we're done. - // We don't have any way to express a preference to use a different choice, since this - // is a niche case especially with function calling. - FunctionCallContent[] functionCallContents = response.Message.Contents.OfType().ToArray(); - if (functionCallContents.Length == 0) - { - break; - } + // In a common case where we make a request and there's no function calling work required, + // fast path out by just returning the original response. + if (iteration == 0 && !requiresFunctionInvocation) + { + return response; + } - // Update the chat history. If the underlying client is tracking the state, then we want to avoid re-sending - // what we already sent as well as this response message, so create a new list to store the response message(s). - if (response.ChatThreadId is not null) + // Track aggregatable details from the response, including all of the response messages and usage details. + (responseMessages ??= []).AddRange(response.Messages); + if (response.Usage is not null) + { + if (totalUsage is not null) { - if (chatMessages == originalChatMessages) - { - chatMessages = []; - } - else - { - chatMessages.Clear(); - } + totalUsage.Add(response.Usage); } else { - // Otherwise, we need to add the response message to the history we're sending back. However, if the caller - // doesn't want the intermediate messages, create a new list that we mutate instead of mutating the original. - if (!KeepFunctionCallingContent) - { - // Create a new list that will include the message with the function call contents. - if (chatMessages == originalChatMessages) - { - chatMessages = [.. chatMessages]; - } - - // We want to include any non-functional calling content, if there is any, - // in the caller's list so that they don't lose out on actual content. - // This can happen but is relatively rare. - if (response.Message.Contents.Any(c => c is not FunctionCallContent)) - { - var clone = response.Message.Clone(); - clone.Contents = clone.Contents.Where(c => c is not FunctionCallContent).ToList(); - originalChatMessages.Add(clone); - } - } - - // Add the original response message into the history. - chatMessages.Add(response.Message); + totalUsage = response.Usage; } + } - // Add the responses from the function calls into the history. - var modeAndMessages = await ProcessFunctionCallsAsync(chatMessages, options, functionCallContents, iteration, cancellationToken).ConfigureAwait(false); - if (UpdateOptionsForMode(modeAndMessages.Mode, ref options, response.ChatThreadId)) - { - // Terminate - return response; - } + // If there are no tools to call, or for any other reason we should stop, we're done. + // Break out of the loop and allow the handling at the end to configure the response + // with aggregated data from previous requests. + if (!requiresFunctionInvocation) + { + break; } - return response; - } - finally - { - if (response is not null) + // Prepare the history for the next iteration. + FixupHistories(originalMessages, ref messages, ref augmentedHistory, response, responseMessages, ref lastIterationHadThreadId); + + // Add the responses from the function calls into the augmented history and also into the tracked + // list of response messages. + var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options!, functionCallContents!, iteration, cancellationToken).ConfigureAwait(false); + responseMessages.AddRange(modeAndMessages.MessagesAdded); + + if (UpdateOptionsForMode(modeAndMessages.Mode, ref options!, response.ChatThreadId)) { - response.Usage = totalUsage; + // Terminate + break; } } + + Debug.Assert(responseMessages is not null, "Expected to only be here if we have response messages."); + response.Messages = responseMessages!; + response.Usage = totalUsage; + + return response; } /// public override async IAsyncEnumerable GetStreamingResponseAsync( - IList chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + IEnumerable messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - _ = Throw.IfNull(chatMessages); + _ = Throw.IfNull(messages); // A single request into this GetStreamingResponseAsync may result in multiple requests to the inner client. // Create an activity to group them together for better observability. using Activity? activity = _activitySource?.StartActivity(nameof(FunctionInvokingChatClient)); - List functionCallContents = []; - int? choice; - IList originalChatMessages = chatMessages; + IEnumerable originalMessages = messages; // the original messages, tracked for the rare case where we need to know what was originally provided + List? augmentedHistory = null; // the actual history of messages sent on turns other than the first + List? functionCallContents = null; // function call contents that need responding to in the current turn + List? responseMessages = null; // tracked list of messages, across multiple turns, to be used in fallback cases to reconstitute history + bool lastIterationHadThreadId = false; // whether the last iteration's response had a ChatThreadId set + List updates = []; // updates from the current response + for (int iteration = 0; ; iteration++) { - choice = null; - string? chatThreadId = null; - functionCallContents.Clear(); - await foreach (var update in base.GetStreamingResponseAsync(chatMessages, options, cancellationToken).ConfigureAwait(false)) + updates.Clear(); + functionCallContents?.Clear(); + + await foreach (var update in base.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false)) { - // We're going to emit all ChatResponseUpdates upstream, even ones that contain function call - // content, because a given ChatResponseUpdate can contain other content/metadata. But if we - // yield the function calls, and the consumer adds all the content into a message that's then - // added into history, they'll end up with function call contents that aren't directly paired - // with function result contents, which may cause issues for some models when the history is - // later sent again. We thus remove the FunctionCallContent instances from the updates before - // yielding them, tracking those FunctionCallContents separately so they can be processed and - // added to the chat history. - - // Find all the FCCs. We need to track these separately in order to be able to process them later. - int preFccCount = functionCallContents.Count; - functionCallContents.AddRange(update.Contents.OfType()); - - // If there were any, remove them from the update. We do this before yielding the update so - // that we're not modifying an instance already provided back to the caller. - int addedFccs = functionCallContents.Count - preFccCount; - if (addedFccs > 0) + if (update is null) { - update.Contents = addedFccs == update.Contents.Count ? - [] : update.Contents.Where(c => c is not FunctionCallContent).ToList(); + Throw.InvalidOperationException($"The inner {nameof(IChatClient)} streamed a null {nameof(ChatResponseUpdate)}."); } - // Only one choice is allowed with automatic function calling. - if (choice is null) - { - choice = update.ChoiceIndex; - } - else if (choice != update.ChoiceIndex) - { - ThrowForMultipleChoices(); - } + updates.Add(update); - chatThreadId ??= update.ChatThreadId; + _ = CopyFunctionCalls(update.Contents, ref functionCallContents); yield return update; Activity.Current = activity; // workaround for https://github.com/dotnet/runtime/issues/47802 } // If there are no tools to call, or for any other reason we should stop, return the response. - if (options is null - || options.Tools is not { Count: > 0 } - || (MaximumIterationsPerRequest is { } maxIterations && iteration >= maxIterations) - || functionCallContents is not { Count: > 0 }) + if (functionCallContents is not { Count: > 0 } || + options?.Tools is not { Count: > 0 } || + (MaximumIterationsPerRequest is { } maxIterations && iteration >= maxIterations)) { break; } - // Update the chat history. If the underlying client is tracking the state, then we want to avoid re-sending - // what we already sent as well as this response message, so create a new list to store the response message(s). - if (chatThreadId is not null) - { - if (chatMessages == originalChatMessages) - { - chatMessages = []; - } - else - { - chatMessages.Clear(); - } - } - else + // Reconsistitue a response from the response updates. + var response = updates.ToChatResponse(); + (responseMessages ??= []).AddRange(response.Messages); + + // Prepare the history for the next iteration. + FixupHistories(originalMessages, ref messages, ref augmentedHistory, response, responseMessages, ref lastIterationHadThreadId); + + // Process all of the functions, adding their results into the history. + var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options, functionCallContents, iteration, cancellationToken).ConfigureAwait(false); + responseMessages.AddRange(modeAndMessages.MessagesAdded); + + // Stream any generated function results. This mirrors what's done for GetResponseAsync, where the returned messages + // includes all activitys, including generated function results. + string toolResponseId = Guid.NewGuid().ToString("N"); + foreach (var message in modeAndMessages.MessagesAdded) { - // Otherwise, we need to add the response message to the history we're sending back. However, if the caller - // doesn't want the intermediate messages, create a new list that we mutate instead of mutating the original. - if (chatMessages == originalChatMessages && !KeepFunctionCallingContent) + var toolResultUpdate = new ChatResponseUpdate { - chatMessages = [.. chatMessages]; - } + AdditionalProperties = message.AdditionalProperties, + AuthorName = message.AuthorName, + ChatThreadId = response.ChatThreadId, + CreatedAt = DateTimeOffset.UtcNow, + Contents = message.Contents, + RawRepresentation = message.RawRepresentation, + ResponseId = toolResponseId, + Role = message.Role, + }; - // Add a manufactured response message containing the function call contents to the chat history. - chatMessages.Add(new(ChatRole.Assistant, [.. functionCallContents])); + yield return toolResultUpdate; + Activity.Current = activity; // workaround for https://github.com/dotnet/runtime/issues/47802 } - // Process all of the functions, adding their results into the history. - var modeAndMessages = await ProcessFunctionCallsAsync(chatMessages, options, functionCallContents, iteration, cancellationToken).ConfigureAwait(false); - if (UpdateOptionsForMode(modeAndMessages.Mode, ref options, chatThreadId)) + if (UpdateOptionsForMode(modeAndMessages.Mode, ref options, response.ChatThreadId)) { // Terminate yield break; @@ -424,14 +350,98 @@ public override async IAsyncEnumerable GetStreamingResponseA } } - /// Throws an exception when multiple choices are received. - private static void ThrowForMultipleChoices() + /// Prepares the various chat message lists after a response from the inner client and before invoking functions. + /// The original messages provided by the caller. + /// The messages reference passed to the inner client. + /// The augmented history containing all the messages to be sent. + /// The most recent response being handled. + /// A list of all response messages received up until this point. + /// Whether the previous iteration's response had a thread id. + private static void FixupHistories( + IEnumerable originalMessages, + ref IEnumerable messages, + [NotNull] ref List? augmentedHistory, + ChatResponse response, + List allTurnsResponseMessages, + ref bool lastIterationHadThreadId) { - // If there's more than one choice, we don't know which one to add to chat history, or which - // of their function calls to process. This should not happen except if the developer has - // explicitly requested multiple choices. We fail aggressively to avoid cases where a developer - // doesn't realize this and is wasting their budget requesting extra choices we'd never use. - throw new InvalidOperationException("Automatic function call invocation only accepts a single choice, but multiple choices were received."); + // We're now going to need to augment the history with function result contents. + // That means we need a separate list to store the augmented history. + if (response.ChatThreadId is not null) + { + // The response indicates the inner client is tracking the history, so we don't want to send + // anything we've already sent or received. + if (augmentedHistory is not null) + { + augmentedHistory.Clear(); + } + else + { + augmentedHistory = []; + } + + lastIterationHadThreadId = true; + } + else if (lastIterationHadThreadId) + { + // In the very rare case where the inner client returned a response with a thread ID but then + // returned a subsequent response without one, we want to reconstitue the full history. To do that, + // we can populate the history with the original chat messages and then all of the response + // messages up until this point, which includes the most recent ones. + augmentedHistory ??= []; + augmentedHistory.Clear(); + augmentedHistory.AddRange(originalMessages); + augmentedHistory.AddRange(allTurnsResponseMessages); + + lastIterationHadThreadId = false; + } + else + { + // If augmentedHistory is already non-null, then we've already populated it with everything up + // until this point (except for the most recent response). If it's null, we need to seed it with + // the chat history provided by the caller. + augmentedHistory ??= originalMessages.ToList(); + + // Now add the most recent response messages. + augmentedHistory.AddMessages(response); + + lastIterationHadThreadId = false; + } + + // Use the augmented history as the new set of messages to send. + messages = augmentedHistory; + } + + /// Copies any from to . + private static bool CopyFunctionCalls( + IList messages, [NotNullWhen(true)] ref List? functionCalls) + { + bool any = false; + int count = messages.Count; + for (int i = 0; i < count; i++) + { + any |= CopyFunctionCalls(messages[i].Contents, ref functionCalls); + } + + return any; + } + + /// Copies any from to . + private static bool CopyFunctionCalls( + IList content, [NotNullWhen(true)] ref List? functionCalls) + { + bool any = false; + int count = content.Count; + for (int i = 0; i < count; i++) + { + if (content[i] is FunctionCallContent functionCall) + { + (functionCalls ??= []).Add(functionCall); + any = true; + } + } + + return any; } /// Updates for the response. @@ -445,10 +455,7 @@ private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions opti // as otherwise we'll be in an infinite loop. options = options.Clone(); options.ToolMode = null; - if (chatThreadId is not null) - { - options.ChatThreadId = chatThreadId; - } + options.ChatThreadId = chatThreadId; break; @@ -457,10 +464,7 @@ private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions opti options = options.Clone(); options.Tools = null; options.ToolMode = null; - if (chatThreadId is not null) - { - options.ChatThreadId = chatThreadId; - } + options.ChatThreadId = chatThreadId; break; @@ -471,7 +475,7 @@ private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions opti default: // As with the other modes, ensure we've propagated the chat thread ID to the options. // We only need to clone the options if we're actually mutating it. - if (chatThreadId is not null && options.ChatThreadId != chatThreadId) + if (options.ChatThreadId != chatThreadId) { options = options.Clone(); options.ChatThreadId = chatThreadId; @@ -486,26 +490,30 @@ private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions opti /// /// Processes the function calls in the list. /// - /// The current chat contents, inclusive of the function call contents being processed. + /// The current chat contents, inclusive of the function call contents being processed. /// The options used for the response being processed. /// The function call contents representing the functions to be invoked. /// The iteration number of how many roundtrips have been made to the inner client. /// The to monitor for cancellation requests. /// A value indicating how the caller should proceed. private async Task<(ContinueMode Mode, IList MessagesAdded)> ProcessFunctionCallsAsync( - IList chatMessages, ChatOptions options, IReadOnlyList functionCallContents, int iteration, CancellationToken cancellationToken) + List messages, ChatOptions options, List functionCallContents, int iteration, CancellationToken cancellationToken) { // We must add a response for every tool call, regardless of whether we successfully executed it or not. // If we successfully execute it, we'll add the result. If we don't, we'll add an error. - int functionCount = functionCallContents.Count; - Debug.Assert(functionCount > 0, $"Expecteded {nameof(functionCount)} to be > 0, got {functionCount}."); + Debug.Assert(functionCallContents.Count > 0, "Expecteded at least one function call."); // Process all functions. If there's more than one and concurrent invocation is enabled, do so in parallel. - if (functionCount == 1) + if (functionCallContents.Count == 1) { - FunctionInvocationResult result = await ProcessFunctionCallAsync(chatMessages, options, functionCallContents[0], iteration, 0, 1, cancellationToken).ConfigureAwait(false); - IList added = AddResponseMessages(chatMessages, [result]); + FunctionInvocationResult result = await ProcessFunctionCallAsync( + messages, options, functionCallContents, iteration, 0, cancellationToken).ConfigureAwait(false); + + IList added = CreateResponseMessages([result]); + ThrowIfNoFunctionResultsAdded(added); + + messages.AddRange(added); return (result.ContinueMode, added); } else @@ -516,21 +524,29 @@ private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions opti { // Schedule the invocation of every function. results = await Task.WhenAll( - from i in Enumerable.Range(0, functionCount) - select Task.Run(() => ProcessFunctionCallAsync(chatMessages, options, functionCallContents[i], iteration, i, functionCount, cancellationToken))).ConfigureAwait(false); + from i in Enumerable.Range(0, functionCallContents.Count) + select Task.Run(() => ProcessFunctionCallAsync( + messages, options, functionCallContents, + iteration, i, cancellationToken))).ConfigureAwait(false); } else { // Invoke each function serially. - results = new FunctionInvocationResult[functionCount]; - for (int i = 0; i < functionCount; i++) + results = new FunctionInvocationResult[functionCallContents.Count]; + for (int i = 0; i < results.Length; i++) { - results[i] = await ProcessFunctionCallAsync(chatMessages, options, functionCallContents[i], iteration, i, functionCount, cancellationToken).ConfigureAwait(false); + results[i] = await ProcessFunctionCallAsync( + messages, options, functionCallContents, + iteration, i, cancellationToken).ConfigureAwait(false); } } ContinueMode continueMode = ContinueMode.Continue; - IList added = AddResponseMessages(chatMessages, results); + + IList added = CreateResponseMessages(results); + ThrowIfNoFunctionResultsAdded(added); + + messages.AddRange(added); foreach (FunctionInvocationResult fir in results) { if (fir.ContinueMode > continueMode) @@ -543,19 +559,31 @@ from i in Enumerable.Range(0, functionCount) } } - /// Processes the function call described in . - /// The current chat contents, inclusive of the function call contents being processed. + /// + /// Throws an exception if doesn't create any messages. + /// + private void ThrowIfNoFunctionResultsAdded(IList? messages) + { + if (messages is null || messages.Count == 0) + { + Throw.InvalidOperationException($"{GetType().Name}.{nameof(CreateResponseMessages)} returned null or an empty collection of messages."); + } + } + + /// Processes the function call described in []. + /// The current chat contents, inclusive of the function call contents being processed. /// The options used for the response being processed. - /// The function call content representing the function to be invoked. + /// The function call contents representing all the functions being invoked. /// The iteration number of how many roundtrips have been made to the inner client. - /// The 0-based index of the function being called out of total functions. - /// The number of function call requests made, of which this is one. + /// The 0-based index of the function being called out of . /// The to monitor for cancellation requests. /// A value indicating how the caller should proceed. private async Task ProcessFunctionCallAsync( - IList chatMessages, ChatOptions options, FunctionCallContent callContent, - int iteration, int functionCallIndex, int totalFunctionCount, CancellationToken cancellationToken) + List messages, ChatOptions options, List callContents, + int iteration, int functionCallIndex, CancellationToken cancellationToken) { + var callContent = callContents[functionCallIndex]; + // Look up the AIFunction for the function call. If the requested function isn't available, send back an error. AIFunction? function = options.Tools!.OfType().FirstOrDefault(t => t.Name == callContent.Name); if (function is null) @@ -565,12 +593,13 @@ private async Task ProcessFunctionCallAsync( FunctionInvocationContext context = new() { - ChatMessages = chatMessages, + Messages = messages, + Options = options, CallContent = callContent, Function = function, Iteration = iteration, FunctionCallIndex = functionCallIndex, - FunctionCount = totalFunctionCount, + FunctionCount = callContents.Count, }; object? result; @@ -610,23 +639,19 @@ internal enum ContinueMode Terminate = 2, } - /// Adds one or more response messages for function invocation results. - /// The chat to which to add the one or more response messages. + /// Creates one or more response messages for function invocation results. /// Information about the function call invocations and results. - /// A list of all chat messages added to . - protected virtual IList AddResponseMessages(IList chatMessages, ReadOnlySpan results) + /// A list of all chat messages created from . + protected virtual IList CreateResponseMessages( + ReadOnlySpan results) { - _ = Throw.IfNull(chatMessages); - - var contents = new AIContent[results.Length]; + var contents = new List(results.Length); for (int i = 0; i < results.Length; i++) { - contents[i] = CreateFunctionResultContent(results[i]); + contents.Add(CreateFunctionResultContent(results[i])); } - ChatMessage message = new(ChatRole.Tool, contents); - chatMessages.Add(message); - return [message]; + return [new(ChatRole.Tool, contents)]; FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult result) { @@ -664,6 +689,7 @@ FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult resul /// /// The to monitor for cancellation requests. The default is . /// The result of the function invocation, or if the function invocation returned . + /// is . protected virtual async Task InvokeFunctionAsync(FunctionInvocationContext context, CancellationToken cancellationToken) { _ = Throw.IfNull(context); diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClientBuilderExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClientBuilderExtensions.cs index 0d2d6f8bc9b..f2a60718ea9 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClientBuilderExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClientBuilderExtensions.cs @@ -21,6 +21,7 @@ public static class FunctionInvokingChatClientBuilderExtensions /// An optional to use to create a logger for logging function invocations. /// An optional callback that can be used to configure the instance. /// The supplied . + /// is . public static ChatClientBuilder UseFunctionInvocation( this ChatClientBuilder builder, ILoggerFactory? loggerFactory = null, diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/LoggingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/LoggingChatClient.cs index b8e15718b78..51ca5a8f6d1 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/LoggingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/LoggingChatClient.cs @@ -44,13 +44,13 @@ public JsonSerializerOptions JsonSerializerOptions /// public override async Task GetResponseAsync( - IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) { if (_logger.IsEnabled(LogLevel.Debug)) { if (_logger.IsEnabled(LogLevel.Trace)) { - LogInvokedSensitive(nameof(GetResponseAsync), AsJson(chatMessages), AsJson(options), AsJson(this.GetService())); + LogInvokedSensitive(nameof(GetResponseAsync), AsJson(messages), AsJson(options), AsJson(this.GetService())); } else { @@ -60,7 +60,7 @@ public override async Task GetResponseAsync( try { - var response = await base.GetResponseAsync(chatMessages, options, cancellationToken).ConfigureAwait(false); + var response = await base.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false); if (_logger.IsEnabled(LogLevel.Debug)) { @@ -90,13 +90,13 @@ public override async Task GetResponseAsync( /// public override async IAsyncEnumerable GetStreamingResponseAsync( - IList chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + IEnumerable messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { if (_logger.IsEnabled(LogLevel.Debug)) { if (_logger.IsEnabled(LogLevel.Trace)) { - LogInvokedSensitive(nameof(GetStreamingResponseAsync), AsJson(chatMessages), AsJson(options), AsJson(this.GetService())); + LogInvokedSensitive(nameof(GetStreamingResponseAsync), AsJson(messages), AsJson(options), AsJson(this.GetService())); } else { @@ -107,7 +107,7 @@ public override async IAsyncEnumerable GetStreamingResponseA IAsyncEnumerator e; try { - e = base.GetStreamingResponseAsync(chatMessages, options, cancellationToken).GetAsyncEnumerator(cancellationToken); + e = base.GetStreamingResponseAsync(messages, options, cancellationToken).GetAsyncEnumerator(cancellationToken); } catch (OperationCanceledException) { @@ -173,8 +173,8 @@ public override async IAsyncEnumerable GetStreamingResponseA [LoggerMessage(LogLevel.Debug, "{MethodName} invoked.")] private partial void LogInvoked(string methodName); - [LoggerMessage(LogLevel.Trace, "{MethodName} invoked: {ChatMessages}. Options: {ChatOptions}. Metadata: {ChatClientMetadata}.")] - private partial void LogInvokedSensitive(string methodName, string chatMessages, string chatOptions, string chatClientMetadata); + [LoggerMessage(LogLevel.Trace, "{MethodName} invoked: {Messages}. Options: {ChatOptions}. Metadata: {ChatClientMetadata}.")] + private partial void LogInvokedSensitive(string methodName, string messages, string chatOptions, string chatClientMetadata); [LoggerMessage(LogLevel.Debug, "{MethodName} completed.")] private partial void LogCompleted(string methodName); diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/LoggingChatClientBuilderExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/LoggingChatClientBuilderExtensions.cs index 6ae8d176e5e..d34716ed886 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/LoggingChatClientBuilderExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/LoggingChatClientBuilderExtensions.cs @@ -20,6 +20,7 @@ public static class LoggingChatClientBuilderExtensions /// /// An optional callback that can be used to configure the instance. /// The . + /// is . public static ChatClientBuilder UseLogging( this ChatClientBuilder builder, ILoggerFactory? loggerFactory = null, diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClient.cs index 1ae5f83b4b2..b3e4f8b86bf 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClient.cs @@ -121,22 +121,23 @@ protected override void Dispose(bool disposing) base.GetService(serviceType, serviceKey); /// - public override async Task GetResponseAsync(IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + public override async Task GetResponseAsync( + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) { - _ = Throw.IfNull(chatMessages); + _ = Throw.IfNull(messages); _jsonSerializerOptions.MakeReadOnly(); using Activity? activity = CreateAndConfigureActivity(options); Stopwatch? stopwatch = _operationDurationHistogram.Enabled ? Stopwatch.StartNew() : null; string? requestModelId = options?.ModelId ?? _modelId; - LogChatMessages(chatMessages); + LogChatMessages(messages); ChatResponse? response = null; Exception? error = null; try { - response = await base.GetResponseAsync(chatMessages, options, cancellationToken).ConfigureAwait(false); + response = await base.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false); return response; } catch (Exception ex) @@ -152,21 +153,21 @@ public override async Task GetResponseAsync(IList cha /// public override async IAsyncEnumerable GetStreamingResponseAsync( - IList chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + IEnumerable messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - _ = Throw.IfNull(chatMessages); + _ = Throw.IfNull(messages); _jsonSerializerOptions.MakeReadOnly(); using Activity? activity = CreateAndConfigureActivity(options); Stopwatch? stopwatch = _operationDurationHistogram.Enabled ? Stopwatch.StartNew() : null; string? requestModelId = options?.ModelId ?? _modelId; - LogChatMessages(chatMessages); + LogChatMessages(messages); IAsyncEnumerable updates; try { - updates = base.GetStreamingResponseAsync(chatMessages, options, cancellationToken); + updates = base.GetStreamingResponseAsync(messages, options, cancellationToken); } catch (Exception ex) { @@ -446,7 +447,7 @@ private void LogChatMessages(IEnumerable messages) if (message.Role == ChatRole.Assistant) { Log(new(1, OpenTelemetryConsts.GenAI.Assistant.Message), - JsonSerializer.Serialize(CreateAssistantEvent(message), OtelContext.Default.AssistantEvent)); + JsonSerializer.Serialize(CreateAssistantEvent(message.Contents), OtelContext.Default.AssistantEvent)); } else if (message.Role == ChatRole.Tool) { @@ -468,7 +469,7 @@ private void LogChatMessages(IEnumerable messages) JsonSerializer.Serialize(new() { Role = message.Role != ChatRole.System && message.Role != ChatRole.User && !string.IsNullOrWhiteSpace(message.Role.Value) ? message.Role.Value : null, - Content = GetMessageContent(message), + Content = GetMessageContent(message.Contents), }, OtelContext.Default.SystemOrUserEvent)); } } @@ -482,16 +483,12 @@ private void LogChatResponse(ChatResponse response) } EventId id = new(1, OpenTelemetryConsts.GenAI.Choice); - int choiceCount = response.Choices.Count; - for (int choiceIndex = 0; choiceIndex < choiceCount; choiceIndex++) + Log(id, JsonSerializer.Serialize(new() { - Log(id, JsonSerializer.Serialize(new() - { - FinishReason = response.FinishReason?.Value ?? "error", - Index = choiceIndex, - Message = CreateAssistantEvent(response.Choices[choiceIndex]), - }, OtelContext.Default.ChoiceEvent)); - } + FinishReason = response.FinishReason?.Value ?? "error", + Index = 0, + Message = CreateAssistantEvent(response.Messages is { Count: 1 } ? response.Messages[0].Contents : response.Messages.SelectMany(m => m.Contents)), + }, OtelContext.Default.ChoiceEvent)); } private void Log(EventId id, [StringSyntax(StringSyntaxAttribute.Json)] string eventBodyJson) @@ -509,9 +506,9 @@ private void Log(EventId id, [StringSyntax(StringSyntaxAttribute.Json)] string e _logger.Log(EventLogLevel, id, tags, null, (_, __) => eventBodyJson); } - private AssistantEvent CreateAssistantEvent(ChatMessage message) + private AssistantEvent CreateAssistantEvent(IEnumerable contents) { - var toolCalls = message.Contents.OfType().Select(fc => new ToolCall + var toolCalls = contents.OfType().Select(fc => new ToolCall { Id = fc.CallId, Function = new() @@ -525,16 +522,16 @@ private AssistantEvent CreateAssistantEvent(ChatMessage message) return new() { - Content = GetMessageContent(message), + Content = GetMessageContent(contents), ToolCalls = toolCalls.Length > 0 ? toolCalls : null, }; } - private string? GetMessageContent(ChatMessage message) + private string? GetMessageContent(IEnumerable contents) { if (EnableSensitiveData) { - string content = string.Concat(message.Contents.OfType()); + string content = string.Concat(contents.OfType()); if (content.Length > 0) { return content; diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/CachingEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/CachingEmbeddingGenerator.cs index 688e4b2353d..43a983d7fd4 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Embeddings/CachingEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/CachingEmbeddingGenerator.cs @@ -51,7 +51,7 @@ public override async Task> GenerateAsync( var generated = await base.GenerateAsync(valuesList, options, cancellationToken).ConfigureAwait(false); if (generated.Count != 1) { - throw new InvalidOperationException($"Expected exactly one embedding to be generated, but received {generated.Count}."); + Throw.InvalidOperationException($"Expected exactly one embedding to be generated, but received {generated.Count}."); } await WriteCacheAsync(cacheKey, generated[0], cancellationToken).ConfigureAwait(false); diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/ConfigureOptionsEmbeddingGeneratorBuilderExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/ConfigureOptionsEmbeddingGeneratorBuilderExtensions.cs index 51f1804c2df..73867e4b2f7 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Embeddings/ConfigureOptionsEmbeddingGeneratorBuilderExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/ConfigureOptionsEmbeddingGeneratorBuilderExtensions.cs @@ -28,6 +28,8 @@ public static class ConfigureOptionsEmbeddingGeneratorBuilderExtensions /// of the caller-supplied instance if one was supplied. /// /// The . + /// is . + /// is . public static EmbeddingGeneratorBuilder ConfigureOptions( this EmbeddingGeneratorBuilder builder, Action configure) diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGenerator.cs index bd801911257..d6c20ffb2f5 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGenerator.cs @@ -30,6 +30,7 @@ public class DistributedCachingEmbeddingGenerator : CachingE /// Initializes a new instance of the class. /// The underlying . /// A instance that will be used as the backing store for the cache. + /// is . public DistributedCachingEmbeddingGenerator(IEmbeddingGenerator innerGenerator, IDistributedCache storage) : base(innerGenerator) { @@ -39,6 +40,7 @@ public DistributedCachingEmbeddingGenerator(IEmbeddingGeneratorGets or sets JSON serialization options to use when serializing cache data. + /// is . public JsonSerializerOptions JsonSerializerOptions { get => _jsonSerializerOptions; diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGeneratorBuilderExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGeneratorBuilderExtensions.cs index 7d42407d930..c2bbdbd1ded 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGeneratorBuilderExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGeneratorBuilderExtensions.cs @@ -25,6 +25,7 @@ public static class DistributedCachingEmbeddingGeneratorBuilderExtensions /// /// An optional callback that can be used to configure the instance. /// The provided as . + /// is . public static EmbeddingGeneratorBuilder UseDistributedCache( this EmbeddingGeneratorBuilder builder, IDistributedCache? storage = null, diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilder.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilder.cs index 1baa64d2a20..dcb33d37c3c 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilder.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilder.cs @@ -22,6 +22,7 @@ public sealed class EmbeddingGeneratorBuilder /// Initializes a new instance of the class. /// The inner that represents the underlying backend. + /// is . public EmbeddingGeneratorBuilder(IEmbeddingGenerator innerGenerator) { _ = Throw.IfNull(innerGenerator); @@ -53,10 +54,13 @@ public IEmbeddingGenerator Build(IServiceProvider? services { for (var i = _generatorFactories.Count - 1; i >= 0; i--) { - embeddingGenerator = _generatorFactories[i](embeddingGenerator, services) ?? - throw new InvalidOperationException( + embeddingGenerator = _generatorFactories[i](embeddingGenerator, services); + if (embeddingGenerator is null) + { + Throw.InvalidOperationException( $"The {nameof(IEmbeddingGenerator)} entry at index {i} returned null. " + $"Ensure that the callbacks passed to {nameof(Use)} return non-null {nameof(IEmbeddingGenerator)} instances."); + } } } @@ -66,6 +70,7 @@ public IEmbeddingGenerator Build(IServiceProvider? services /// Adds a factory for an intermediate embedding generator to the embedding generator pipeline. /// The generator factory function. /// The updated instance. + /// is . public EmbeddingGeneratorBuilder Use(Func, IEmbeddingGenerator> generatorFactory) { _ = Throw.IfNull(generatorFactory); @@ -76,6 +81,7 @@ public EmbeddingGeneratorBuilder Use(FuncAdds a factory for an intermediate embedding generator to the embedding generator pipeline. /// The generator factory function. /// The updated instance. + /// is . public EmbeddingGeneratorBuilder Use( Func, IServiceProvider, IEmbeddingGenerator> generatorFactory) { diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilderEmbeddingGeneratorExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilderEmbeddingGeneratorExtensions.cs index 73784f56916..84d4815cb23 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilderEmbeddingGeneratorExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilderEmbeddingGeneratorExtensions.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System; using Microsoft.Extensions.AI; using Microsoft.Shared.Diagnostics; @@ -22,6 +23,7 @@ public static class EmbeddingGeneratorBuilderEmbeddingGeneratorExtensions /// This method is equivalent to using the /// constructor directly, specifying as the inner generator. /// + /// is . public static EmbeddingGeneratorBuilder AsBuilder( this IEmbeddingGenerator innerGenerator) where TEmbedding : Embedding diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilderServiceCollectionExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilderServiceCollectionExtensions.cs index 2000e71cf03..b84e8ac6e60 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilderServiceCollectionExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilderServiceCollectionExtensions.cs @@ -18,12 +18,19 @@ public static class EmbeddingGeneratorBuilderServiceCollectionExtensions /// The service lifetime for the client. Defaults to . /// An that can be used to build a pipeline around the inner generator. /// The generator is registered as a singleton service. + /// is . + /// is . public static EmbeddingGeneratorBuilder AddEmbeddingGenerator( this IServiceCollection serviceCollection, IEmbeddingGenerator innerGenerator, ServiceLifetime lifetime = ServiceLifetime.Singleton) where TEmbedding : Embedding - => AddEmbeddingGenerator(serviceCollection, _ => innerGenerator, lifetime); + { + _ = Throw.IfNull(serviceCollection); + _ = Throw.IfNull(innerGenerator); + + return AddEmbeddingGenerator(serviceCollection, _ => innerGenerator, lifetime); + } /// Registers a singleton embedding generator in the . /// The type from which embeddings will be generated. @@ -33,6 +40,8 @@ public static EmbeddingGeneratorBuilder AddEmbeddingGenerato /// The service lifetime for the client. Defaults to . /// An that can be used to build a pipeline around the inner generator. /// The generator is registered as a singleton service. + /// is . + /// is . public static EmbeddingGeneratorBuilder AddEmbeddingGenerator( this IServiceCollection serviceCollection, Func> innerGeneratorFactory, @@ -56,13 +65,20 @@ public static EmbeddingGeneratorBuilder AddEmbeddingGenerato /// The service lifetime for the client. Defaults to . /// An that can be used to build a pipeline around the inner generator. /// The generator is registered as a singleton service. + /// is . + /// is . public static EmbeddingGeneratorBuilder AddKeyedEmbeddingGenerator( this IServiceCollection serviceCollection, object? serviceKey, IEmbeddingGenerator innerGenerator, ServiceLifetime lifetime = ServiceLifetime.Singleton) where TEmbedding : Embedding - => AddKeyedEmbeddingGenerator(serviceCollection, serviceKey, _ => innerGenerator, lifetime); + { + _ = Throw.IfNull(serviceCollection); + _ = Throw.IfNull(innerGenerator); + + return AddKeyedEmbeddingGenerator(serviceCollection, serviceKey, _ => innerGenerator, lifetime); + } /// Registers a keyed singleton embedding generator in the . /// The type from which embeddings will be generated. @@ -73,6 +89,8 @@ public static EmbeddingGeneratorBuilder AddKeyedEmbeddingGen /// The service lifetime for the client. Defaults to . /// An that can be used to build a pipeline around the inner generator. /// The generator is registered as a singleton service. + /// is . + /// is . public static EmbeddingGeneratorBuilder AddKeyedEmbeddingGenerator( this IServiceCollection serviceCollection, object? serviceKey, @@ -81,7 +99,6 @@ public static EmbeddingGeneratorBuilder AddKeyedEmbeddingGen where TEmbedding : Embedding { _ = Throw.IfNull(serviceCollection); - _ = Throw.IfNull(serviceKey); _ = Throw.IfNull(innerGeneratorFactory); var builder = new EmbeddingGeneratorBuilder(innerGeneratorFactory); diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/LoggingEmbeddingGeneratorBuilderExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/LoggingEmbeddingGeneratorBuilderExtensions.cs index 52fb7dd1ca3..eb472fb1e0e 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Embeddings/LoggingEmbeddingGeneratorBuilderExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/LoggingEmbeddingGeneratorBuilderExtensions.cs @@ -22,6 +22,7 @@ public static class LoggingEmbeddingGeneratorBuilderExtensions /// /// An optional callback that can be used to configure the instance. /// The . + /// is . public static EmbeddingGeneratorBuilder UseLogging( this EmbeddingGeneratorBuilder builder, ILoggerFactory? loggerFactory = null, diff --git a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs index d8be8e9f128..4d16ac6ae6b 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs @@ -39,6 +39,7 @@ public static partial class AIFunctionFactory /// it is round-tripped through JSON, serializing the object as JSON and then deserializing it to the expected type. /// /// + /// is . public static AIFunction Create(Delegate method, AIFunctionFactoryOptions? options) { _ = Throw.IfNull(method); @@ -61,6 +62,7 @@ public static AIFunction Create(Delegate method, AIFunctionFactoryOptions? optio /// round-tripped through JSON, serializing the object as JSON and then deserializing it to the expected type. /// /// + /// is . public static AIFunction Create(Delegate method, string? name = null, string? description = null, JsonSerializerOptions? serializerOptions = null) { _ = Throw.IfNull(method); @@ -98,6 +100,7 @@ public static AIFunction Create(Delegate method, string? name = null, string? de /// it is round-tripped through JSON, serializing the object as JSON and then deserializing it to the expected type. /// /// + /// is . public static AIFunction Create(MethodInfo method, object? target, AIFunctionFactoryOptions? options) { _ = Throw.IfNull(method); @@ -126,6 +129,7 @@ public static AIFunction Create(MethodInfo method, object? target, AIFunctionFac /// round-tripped through JSON, serializing the object as JSON and then deserializing it to the expected type. /// /// + /// is . public static AIFunction Create(MethodInfo method, object? target, string? name = null, string? description = null, JsonSerializerOptions? serializerOptions = null) { _ = Throw.IfNull(method); diff --git a/src/ProjectTemplates/Microsoft.Extensions.AI.Templates/src/ChatWithCustomData/ChatWithCustomData.Web-CSharp/Services/Ingestion/IngestionCacheDbContext.cs b/src/ProjectTemplates/Microsoft.Extensions.AI.Templates/src/ChatWithCustomData/ChatWithCustomData.Web-CSharp/Services/Ingestion/IngestionCacheDbContext.cs index aeaf4ccd52d..78842253abe 100644 --- a/src/ProjectTemplates/Microsoft.Extensions.AI.Templates/src/ChatWithCustomData/ChatWithCustomData.Web-CSharp/Services/Ingestion/IngestionCacheDbContext.cs +++ b/src/ProjectTemplates/Microsoft.Extensions.AI.Templates/src/ChatWithCustomData/ChatWithCustomData.Web-CSharp/Services/Ingestion/IngestionCacheDbContext.cs @@ -34,7 +34,7 @@ public class IngestedDocument public required string Id { get; set; } public required string SourceId { get; set; } public required string Version { get; set; } - public List Records { get; set; } = new(); + public List Records { get; set; } = []; } public class IngestedRecord diff --git a/src/ProjectTemplates/Microsoft.Extensions.AI.Templates/src/ChatWithCustomData/ChatWithCustomData.Web-CSharp/Services/JsonVectorStore.cs b/src/ProjectTemplates/Microsoft.Extensions.AI.Templates/src/ChatWithCustomData/ChatWithCustomData.Web-CSharp/Services/JsonVectorStore.cs index 9dba51c7692..cb787c3bbef 100644 --- a/src/ProjectTemplates/Microsoft.Extensions.AI.Templates/src/ChatWithCustomData/ChatWithCustomData.Web-CSharp/Services/JsonVectorStore.cs +++ b/src/ProjectTemplates/Microsoft.Extensions.AI.Templates/src/ChatWithCustomData/ChatWithCustomData.Web-CSharp/Services/JsonVectorStore.cs @@ -51,7 +51,7 @@ public Task CollectionExistsAsync(CancellationToken cancellationToken = de public async Task CreateCollectionAsync(CancellationToken cancellationToken = default) { - _records = new(); + _records = []; await WriteToDiskAsync(cancellationToken); } diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientExtensionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientExtensionsTests.cs index 5a95f2b3fd0..c74c50813f4 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientExtensionsTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientExtensionsTests.cs @@ -32,7 +32,7 @@ public void GetService_ValidService_Returned() { using var client = new TestChatClient { - GetServiceCallback = (Type serviceType, object? serviceKey) => + GetServiceCallback = (serviceType, serviceKey) => { if (serviceType == typeof(string)) { @@ -100,15 +100,15 @@ public void GetStreamingResponseAsync_InvalidArgs_Throws() [Fact] public async Task GetResponseAsync_CreatesTextMessageAsync() { - var expectedResponse = new ChatResponse([new ChatMessage()]); + var expectedResponse = new ChatResponse(); var expectedOptions = new ChatOptions(); using var cts = new CancellationTokenSource(); using TestChatClient client = new() { - GetResponseAsyncCallback = (chatMessages, options, cancellationToken) => + GetResponseAsyncCallback = (messages, options, cancellationToken) => { - ChatMessage m = Assert.Single(chatMessages); + ChatMessage m = Assert.Single(messages); Assert.Equal(ChatRole.User, m.Role); Assert.Equal("hello", m.Text); @@ -133,9 +133,9 @@ public async Task GetStreamingResponseAsync_CreatesTextMessageAsync() using TestChatClient client = new() { - GetStreamingResponseAsyncCallback = (chatMessages, options, cancellationToken) => + GetStreamingResponseAsyncCallback = (messages, options, cancellationToken) => { - ChatMessage m = Assert.Single(chatMessages); + ChatMessage m = Assert.Single(messages); Assert.Equal(ChatRole.User, m.Role); Assert.Equal("hello", m.Text); @@ -143,7 +143,7 @@ public async Task GetStreamingResponseAsync_CreatesTextMessageAsync() Assert.Equal(cts.Token, cancellationToken); - return YieldAsync([new ChatResponseUpdate { Text = "world" }]); + return YieldAsync([new ChatResponseUpdate(ChatRole.Assistant, "world")]); }, }; diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatMessageTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatMessageTests.cs index b67fb1de4a5..7174d2a70c8 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatMessageTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatMessageTests.cs @@ -6,6 +6,7 @@ using System.Linq; using System.Text.Json; using Xunit; +using static System.Net.Mime.MediaTypeNames; namespace Microsoft.Extensions.AI; @@ -18,7 +19,7 @@ public void Constructor_Parameterless_PropsDefaulted() Assert.Null(message.AuthorName); Assert.Empty(message.Contents); Assert.Equal(ChatRole.User, message.Role); - Assert.Null(message.Text); + Assert.Empty(message.Text); Assert.NotNull(message.Contents); Assert.Same(message.Contents, message.Contents); Assert.Empty(message.Contents); @@ -55,9 +56,25 @@ public void Constructor_RoleString_PropsRoundtrip(string? text) } [Fact] - public void Constructor_RoleList_InvalidArgs_Throws() + public void Constructor_NullEmptyArgs_Valid() { - Assert.Throws("contents", () => new ChatMessage(ChatRole.User, (IList)null!)); + ChatMessage message; + + message = new(); + Assert.Empty(message.Text); + Assert.Empty(message.Contents); + + message = new(ChatRole.User, (string?)null); + Assert.Empty(message.Text); + Assert.Empty(message.Contents); + + message = new(ChatRole.User, (IList?)null); + Assert.Empty(message.Text); + Assert.Empty(message.Contents); + + message = new ChatMessage(ChatRole.User, Array.Empty()); + Assert.Empty(message.Text); + Assert.Empty(message.Contents); } [Theory] @@ -80,7 +97,7 @@ public void Constructor_RoleList_PropsRoundtrip(int messageCount) if (messageCount == 0) { Assert.Empty(message.Contents); - Assert.Null(message.Text); + Assert.Empty(message.Text); } else { @@ -91,7 +108,7 @@ public void Constructor_RoleList_PropsRoundtrip(int messageCount) Assert.Equal($"text-{i}", tc.Text); } - Assert.Equal("text-0", message.Text); + Assert.Equal(string.Concat(Enumerable.Range(0, messageCount).Select(i => $"text-{i}")), message.Text); Assert.Equal(string.Concat(Enumerable.Range(0, messageCount).Select(i => $"text-{i}")), message.ToString()); } @@ -120,7 +137,7 @@ public void AuthorName_InvalidArg_UsesNull(string? authorName) } [Fact] - public void Text_GetSet_UsesFirstTextContent() + public void Text_ConcatsAllTextContent() { ChatMessage message = new(ChatRole.User, [ @@ -134,57 +151,15 @@ public void Text_GetSet_UsesFirstTextContent() TextContent textContent = Assert.IsType(message.Contents[3]); Assert.Equal("text-1", textContent.Text); - Assert.Equal("text-1", message.Text); + Assert.Equal("text-1text-2", message.Text); Assert.Equal("text-1text-2", message.ToString()); - message.Text = "text-3"; - Assert.Equal("text-3", message.Text); - Assert.Equal("text-3", message.Text); - Assert.Same(textContent, message.Contents[3]); + ((TextContent)message.Contents[3]).Text = "text-3"; + Assert.Equal("text-3", textContent.Text); + Assert.Equal("text-3text-2", message.Text); Assert.Equal("text-3text-2", message.ToString()); } - [Fact] - public void Text_Set_AddsTextMessageToEmptyList() - { - ChatMessage message = new(ChatRole.User, []); - Assert.Empty(message.Contents); - - message.Text = "text-1"; - Assert.Equal("text-1", message.Text); - - Assert.Single(message.Contents); - TextContent textContent = Assert.IsType(message.Contents[0]); - Assert.Equal("text-1", textContent.Text); - } - - [Fact] - public void Text_Set_AddsTextMessageToListWithNoText() - { - ChatMessage message = new(ChatRole.User, - [ - new DataContent("http://localhost/audio"), - new DataContent("http://localhost/image"), - new FunctionCallContent("callId1", "fc1"), - ]); - Assert.Equal(3, message.Contents.Count); - - message.Text = "text-1"; - Assert.Equal("text-1", message.Text); - Assert.Equal(4, message.Contents.Count); - - message.Text = "text-2"; - Assert.Equal("text-2", message.Text); - Assert.Equal(4, message.Contents.Count); - - message.Contents.RemoveAt(3); - Assert.Equal(3, message.Contents.Count); - - message.Text = "text-3"; - Assert.Equal("text-3", message.Text); - Assert.Equal(4, message.Contents.Count); - } - [Fact] public void Contents_InitializesToList() { @@ -282,12 +257,13 @@ public void ItCanBeSerializeAndDeserialized() ]; // Act - var chatMessageJson = JsonSerializer.Serialize(new ChatMessage(ChatRole.User, contents: items) + var chatMessage = new ChatMessage(ChatRole.User, contents: items) { - Text = "content-1-override", // Override the content of the first text content item that has the "content-1" content AuthorName = "Fred", AdditionalProperties = new() { ["message-metadata-key-1"] = "message-metadata-value-1" }, - }, TestJsonSerializerContext.Default.Options); + }; + ((TextContent)chatMessage.Contents[0]).Text = "content-1-override"; // Override the content of the first text content item that has the "content-1" content + var chatMessageJson = JsonSerializer.Serialize(chatMessage, TestJsonSerializerContext.Default.Options); var deserializedMessage = JsonSerializer.Deserialize(chatMessageJson, TestJsonSerializerContext.Default.Options)!; diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseTests.cs index e222b6d5215..ee719ee5647 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseTests.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text.Json; using Xunit; @@ -11,80 +12,42 @@ namespace Microsoft.Extensions.AI; public class ChatResponseTests { [Fact] - public void Constructor_InvalidArgs_Throws() + public void Constructor_NullEmptyArgs_Valid() { - Assert.Throws("message", () => new ChatResponse((ChatMessage)null!)); - Assert.Throws("choices", () => new ChatResponse((IList)null!)); - } - - [Fact] - public void Constructor_Message_Roundtrips() - { - ChatMessage message = new(); - - ChatResponse response = new(message); - Assert.Same(message, response.Message); - Assert.Same(message, Assert.Single(response.Choices)); - } - - [Fact] - public void Constructor_Choices_Roundtrips() - { - List messages = - [ - new ChatMessage(), - new ChatMessage(), - new ChatMessage(), - ]; - - ChatResponse response = new(messages); - Assert.Same(messages, response.Choices); - Assert.Equal(3, messages.Count); - } - - [Fact] - public void Message_EmptyChoices_Throws() - { - ChatResponse response = new([]); + ChatResponse response; - Assert.Empty(response.Choices); - Assert.Throws(() => response.Message); - } - - [Fact] - public void Message_SingleChoice_Returned() - { - ChatMessage message = new(); - ChatResponse response = new([message]); + response = new(); + Assert.Empty(response.Messages); + Assert.Empty(response.Text); - Assert.Same(message, response.Message); - Assert.Same(message, response.Choices[0]); - } - - [Fact] - public void Message_MultipleChoices_ReturnsFirst() - { - ChatMessage first = new(); - ChatResponse response = new([ - first, - new ChatMessage(), - ]); + response = new((IList?)null); + Assert.Empty(response.Messages); + Assert.Empty(response.Text); - Assert.Same(first, response.Message); - Assert.Same(first, response.Choices[0]); + Assert.Throws("message", () => new ChatResponse((ChatMessage)null!)); } [Fact] - public void Choices_SetNull_Throws() + public void Constructor_Messages_Roundtrips() { - ChatResponse response = new([]); - Assert.Throws("value", () => response.Choices = null!); + ChatResponse response = new(); + Assert.NotNull(response.Messages); + Assert.Same(response.Messages, response.Messages); + + List messages = []; + response = new(messages); + Assert.Same(messages, response.Messages); + + messages = []; + Assert.NotSame(messages, response.Messages); + response.Messages = messages; + Assert.Same(messages, response.Messages); } [Fact] public void Properties_Roundtrip() { - ChatResponse response = new([]); + ChatResponse response = new(); Assert.Null(response.ResponseId); response.ResponseId = "id"; @@ -116,22 +79,12 @@ public void Properties_Roundtrip() AdditionalPropertiesDictionary additionalProps = []; response.AdditionalProperties = additionalProps; Assert.Same(additionalProps, response.AdditionalProperties); - - List newChoices = [new ChatMessage(), new ChatMessage()]; - response.Choices = newChoices; - Assert.Same(newChoices, response.Choices); } [Fact] public void JsonSerialization_Roundtrips() { - ChatResponse original = new( - [ - new ChatMessage(ChatRole.Assistant, "Choice1"), - new ChatMessage(ChatRole.Assistant, "Choice2"), - new ChatMessage(ChatRole.Assistant, "Choice3"), - new ChatMessage(ChatRole.Assistant, "Choice4"), - ]) + ChatResponse original = new(new ChatMessage(ChatRole.Assistant, "the message")) { ResponseId = "id", ModelId = "modelId", @@ -147,13 +100,8 @@ public void JsonSerialization_Roundtrips() ChatResponse? result = JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.ChatResponse); Assert.NotNull(result); - Assert.Equal(4, result.Choices.Count); - - for (int i = 0; i < original.Choices.Count; i++) - { - Assert.Equal(ChatRole.Assistant, result.Choices[i].Role); - Assert.Equal($"Choice{i + 1}", result.Choices[i].Text); - } + Assert.Equal(ChatRole.Assistant, result.Messages.Single().Role); + Assert.Equal("the message", result.Messages.Single().Text); Assert.Equal("id", result.ResponseId); Assert.Equal("modelId", result.ModelId); @@ -169,41 +117,15 @@ public void JsonSerialization_Roundtrips() } [Fact] - public void ToString_OneChoice_OutputsChatMessageToString() + public void ToString_OutputsText() { - ChatResponse response = new( - [ - new ChatMessage(ChatRole.Assistant, "This is a test." + Environment.NewLine + "It's multiple lines.") - ]); + ChatResponse response = new(new ChatMessage(ChatRole.Assistant, $"This is a test.{Environment.NewLine}It's multiple lines.")); - Assert.Equal(response.Choices[0].Text, response.ToString()); + Assert.Equal(response.Text, response.ToString()); } [Fact] - public void ToString_MultipleChoices_OutputsAllChoicesWithPrefix() - { - ChatResponse response = new( - [ - new ChatMessage(ChatRole.Assistant, "This is a test." + Environment.NewLine + "It's multiple lines."), - new ChatMessage(ChatRole.Assistant, "So is" + Environment.NewLine + " this."), - new ChatMessage(ChatRole.Assistant, "And this."), - ]); - - Assert.Equal( - "Choice 0:" + Environment.NewLine + - response.Choices[0] + Environment.NewLine + Environment.NewLine + - - "Choice 1:" + Environment.NewLine + - response.Choices[1] + Environment.NewLine + Environment.NewLine + - - "Choice 2:" + Environment.NewLine + - response.Choices[2], - - response.ToString()); - } - - [Fact] - public void ToChatResponseUpdates_SingleChoice() + public void ToChatResponseUpdates() { ChatResponse response = new(new ChatMessage(new ChatRole("customRole"), "Text")) { @@ -230,68 +152,4 @@ public void ToChatResponseUpdates_SingleChoice() Assert.Equal("value1", update1.AdditionalProperties?["key1"]); Assert.Equal(42, update1.AdditionalProperties?["key2"]); } - - [Fact] - public void ToChatResponseUpdates_MultiChoice() - { - ChatResponse response = new( - [ - new ChatMessage(ChatRole.Assistant, - [ - new TextContent("Hello, "), - new DataContent("http://localhost/image.png", mediaType: "image/png"), - new TextContent("world!"), - ]) - { - AdditionalProperties = new() { ["choice1Key"] = "choice1Value" }, - }, - - new ChatMessage(ChatRole.System, - [ - new FunctionCallContent("call123", "name"), - new FunctionResultContent("call123", 42), - ]) - { - AdditionalProperties = new() { ["choice2Key"] = "choice2Value" }, - }, - ]) - { - ResponseId = "12345", - ModelId = "someModel", - FinishReason = ChatFinishReason.ContentFilter, - CreatedAt = new DateTimeOffset(2024, 11, 10, 9, 20, 0, TimeSpan.Zero), - AdditionalProperties = new() { ["key1"] = "value1", ["key2"] = 42 }, - Usage = new UsageDetails { TotalTokenCount = 123 }, - }; - - ChatResponseUpdate[] updates = response.ToChatResponseUpdates(); - Assert.NotNull(updates); - Assert.Equal(3, updates.Length); - - ChatResponseUpdate update0 = updates[0]; - Assert.Equal("12345", update0.ResponseId); - Assert.Equal("someModel", update0.ModelId); - Assert.Equal(ChatFinishReason.ContentFilter, update0.FinishReason); - Assert.Equal(new DateTimeOffset(2024, 11, 10, 9, 20, 0, TimeSpan.Zero), update0.CreatedAt); - Assert.Equal("assistant", update0.Role?.Value); - Assert.Equal("Hello, ", Assert.IsType(update0.Contents[0]).Text); - Assert.Equal("image/png", Assert.IsType(update0.Contents[1]).MediaType); - Assert.Equal("world!", Assert.IsType(update0.Contents[2]).Text); - Assert.Equal("choice1Value", update0.AdditionalProperties?["choice1Key"]); - - ChatResponseUpdate update1 = updates[1]; - Assert.Equal("12345", update1.ResponseId); - Assert.Equal("someModel", update1.ModelId); - Assert.Equal(ChatFinishReason.ContentFilter, update1.FinishReason); - Assert.Equal(new DateTimeOffset(2024, 11, 10, 9, 20, 0, TimeSpan.Zero), update1.CreatedAt); - Assert.Equal("system", update1.Role?.Value); - Assert.IsType(update1.Contents[0]); - Assert.IsType(update1.Contents[1]); - Assert.Equal("choice2Value", update1.AdditionalProperties?["choice2Key"]); - - ChatResponseUpdate update2 = updates[2]; - Assert.Equal("value1", update2.AdditionalProperties?["key1"]); - Assert.Equal(42, update2.AdditionalProperties?["key2"]); - Assert.Equal(123, Assert.IsType(Assert.Single(update2.Contents)).Details.TotalTokenCount); - } } diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseUpdateExtensionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseUpdateExtensionsTests.cs index fea25191aff..454c3c3cad3 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseUpdateExtensionsTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseUpdateExtensionsTests.cs @@ -21,44 +21,24 @@ public void InvalidArgs_Throws() Assert.Throws("updates", () => ((List)null!).ToChatResponse()); } - public static IEnumerable ToChatResponse_SuccessfullyCreatesResponse_MemberData() - { - foreach (bool useAsync in new[] { false, true }) - { - foreach (bool? coalesceContent in new bool?[] { null, false, true }) - { - yield return new object?[] { useAsync, coalesceContent }; - } - } - } - [Theory] - [MemberData(nameof(ToChatResponse_SuccessfullyCreatesResponse_MemberData))] - public async Task ToChatResponse_SuccessfullyCreatesResponse(bool useAsync, bool? coalesceContent) + [InlineData(false)] + [InlineData(true)] + public async Task ToChatResponse_SuccessfullyCreatesResponse(bool useAsync) { ChatResponseUpdate[] updates = [ - new() { ChoiceIndex = 0, Text = "Hello", ResponseId = "12345", CreatedAt = new DateTimeOffset(1, 2, 3, 4, 5, 6, TimeSpan.Zero), ModelId = "model123" }, - new() { ChoiceIndex = 1, Text = "Hey", ResponseId = "12345", CreatedAt = new DateTimeOffset(1, 2, 3, 4, 5, 6, TimeSpan.Zero), ModelId = "model124" }, - - new() { ChoiceIndex = 0, Text = ", ", AuthorName = "Someone", Role = ChatRole.User, AdditionalProperties = new() { ["a"] = "b" } }, - new() { ChoiceIndex = 1, Text = ", ", AuthorName = "Else", Role = ChatRole.System, ChatThreadId = "123", AdditionalProperties = new() { ["g"] = "h" } }, + new(ChatRole.Assistant, "Hello") { ResponseId = "12345", CreatedAt = new DateTimeOffset(1, 2, 3, 4, 5, 6, TimeSpan.Zero), ModelId = "model123" }, + new(new("human"), ", ") { AuthorName = "Someone", AdditionalProperties = new() { ["a"] = "b" } }, + new(null, "world!") { CreatedAt = new DateTimeOffset(2, 2, 3, 4, 5, 6, TimeSpan.Zero), ChatThreadId = "123", AdditionalProperties = new() { ["c"] = "d" } }, - new() { ChoiceIndex = 0, Text = "world!", CreatedAt = new DateTimeOffset(2, 2, 3, 4, 5, 6, TimeSpan.Zero), AdditionalProperties = new() { ["c"] = "d" } }, - new() { ChoiceIndex = 1, Text = "you!", Role = ChatRole.Tool, CreatedAt = new DateTimeOffset(3, 2, 3, 4, 5, 6, TimeSpan.Zero), AdditionalProperties = new() { ["e"] = "f", ["i"] = 42 } }, - - new() { ChoiceIndex = 0, Contents = new[] { new UsageContent(new() { InputTokenCount = 1, OutputTokenCount = 2 }) } }, - new() { ChoiceIndex = 3, Contents = new[] { new UsageContent(new() { InputTokenCount = 4, OutputTokenCount = 5 }) } }, + new() { Contents = [new UsageContent(new() { InputTokenCount = 1, OutputTokenCount = 2 })] }, + new() { Contents = [new UsageContent(new() { InputTokenCount = 4, OutputTokenCount = 5 })] }, ]; - ChatResponse response = (coalesceContent is bool, useAsync) switch - { - (false, false) => updates.ToChatResponse(), - (false, true) => await YieldAsync(updates).ToChatResponseAsync(), - - (true, false) => updates.ToChatResponse(coalesceContent.GetValueOrDefault()), - (true, true) => await YieldAsync(updates).ToChatResponseAsync(coalesceContent.GetValueOrDefault()), - }; + ChatResponse response = useAsync ? + updates.ToChatResponse() : + await YieldAsync(updates).ToChatResponseAsync(); Assert.NotNull(response); Assert.NotNull(response.Usage); @@ -66,54 +46,22 @@ public async Task ToChatResponse_SuccessfullyCreatesResponse(bool useAsync, bool Assert.Equal(7, response.Usage.OutputTokenCount); Assert.Equal("12345", response.ResponseId); - Assert.Equal(new DateTimeOffset(1, 2, 3, 4, 5, 6, TimeSpan.Zero), response.CreatedAt); + Assert.Equal(new DateTimeOffset(2, 2, 3, 4, 5, 6, TimeSpan.Zero), response.CreatedAt); Assert.Equal("model123", response.ModelId); Assert.Equal("123", response.ChatThreadId); - Assert.Equal(3, response.Choices.Count); - - ChatMessage message = response.Choices[0]; - Assert.Equal(ChatRole.User, message.Role); + ChatMessage message = response.Messages.Last(); + Assert.Equal(new ChatRole("human"), message.Role); Assert.Equal("Someone", message.AuthorName); - Assert.NotNull(message.AdditionalProperties); - Assert.Equal(2, message.AdditionalProperties.Count); - Assert.Equal("b", message.AdditionalProperties["a"]); - Assert.Equal("d", message.AdditionalProperties["c"]); - - message = response.Choices[1]; - Assert.Equal(ChatRole.System, message.Role); - Assert.Equal("Else", message.AuthorName); - Assert.NotNull(message.AdditionalProperties); - Assert.Equal(3, message.AdditionalProperties.Count); - Assert.Equal("h", message.AdditionalProperties["g"]); - Assert.Equal("f", message.AdditionalProperties["e"]); - Assert.Equal(42, message.AdditionalProperties["i"]); - - message = response.Choices[2]; - Assert.Equal(ChatRole.Assistant, message.Role); - Assert.Null(message.AuthorName); Assert.Null(message.AdditionalProperties); - Assert.Empty(message.Contents); - - if (coalesceContent is null or true) - { - Assert.Equal("Hello, world!", response.Choices[0].Text); - Assert.Equal("Hey, you!", response.Choices[1].Text); - Assert.Null(response.Choices[2].Text); - } - else - { - Assert.Equal("Hello", response.Choices[0].Contents[0].ToString()); - Assert.Equal(", ", response.Choices[0].Contents[1].ToString()); - Assert.Equal("world!", response.Choices[0].Contents[2].ToString()); - Assert.Equal("Hey", response.Choices[1].Contents[0].ToString()); - Assert.Equal(", ", response.Choices[1].Contents[1].ToString()); - Assert.Equal("you!", response.Choices[1].Contents[2].ToString()); + Assert.NotNull(response.AdditionalProperties); + Assert.Equal(2, response.AdditionalProperties.Count); + Assert.Equal("b", response.AdditionalProperties["a"]); + Assert.Equal("d", response.AdditionalProperties["c"]); - Assert.Null(response.Choices[2].Text); - } + Assert.Equal("Hello, world!", response.Text); } public static IEnumerable ToChatResponse_Coalescing_VariousSequenceAndGapLengths_MemberData() @@ -155,7 +103,7 @@ public async Task ToChatResponse_Coalescing_VariousSequenceAndGapLengths(bool us for (int i = 0; i < sequenceLength; i++) { string text = $"{(char)('A' + sequenceNum)}{i}"; - updates.Add(new() { Text = text }); + updates.Add(new(null, text)); sb.Append(text); } @@ -181,9 +129,11 @@ void AddGap() } ChatResponse response = useAsync ? await YieldAsync(updates).ToChatResponseAsync() : updates.ToChatResponse(); - Assert.Single(response.Choices); + Assert.NotNull(response); + + ChatMessage message = response.Messages.Single(); + Assert.NotNull(message); - ChatMessage message = response.Message; Assert.Equal(expected.Count + (gapLength * ((numSequences - 1) + (gapBeginningEnd ? 2 : 0))), message.Contents.Count); TextContent[] contents = message.Contents.OfType().ToArray(); @@ -199,8 +149,8 @@ public async Task ToChatResponse_UsageContentExtractedFromContents() { ChatResponseUpdate[] updates = { - new() { Text = "Hello, " }, - new() { Text = "world!" }, + new(null, "Hello, "), + new(null, "world!"), new() { Contents = [new UsageContent(new() { TotalTokenCount = 42 })] }, }; @@ -211,7 +161,7 @@ public async Task ToChatResponse_UsageContentExtractedFromContents() Assert.NotNull(response.Usage); Assert.Equal(42, response.Usage.TotalTokenCount); - Assert.Equal("Hello, world!", Assert.IsType(Assert.Single(response.Message.Contents)).Text); + Assert.Equal("Hello, world!", Assert.IsType(Assert.Single(Assert.Single(response.Messages).Contents)).Text); } private static async IAsyncEnumerable YieldAsync(IEnumerable updates) diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseUpdateTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseUpdateTests.cs index be4108f8148..7e5ff6b1e84 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseUpdateTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseUpdateTests.cs @@ -16,14 +16,13 @@ public void Constructor_PropsDefaulted() ChatResponseUpdate update = new(); Assert.Null(update.AuthorName); Assert.Null(update.Role); - Assert.Null(update.Text); + Assert.Empty(update.Text); Assert.Empty(update.Contents); Assert.Null(update.RawRepresentation); Assert.Null(update.AdditionalProperties); Assert.Null(update.ResponseId); Assert.Null(update.CreatedAt); Assert.Null(update.FinishReason); - Assert.Equal(0, update.ChoiceIndex); Assert.Equal(string.Empty, update.ToString()); } @@ -52,9 +51,7 @@ public void Properties_Roundtrip() Assert.NotNull(update.Contents); Assert.Empty(update.Contents); - Assert.Null(update.Text); - update.Text = "text"; - Assert.Equal("text", update.Text); + Assert.Empty(update.Text); Assert.Null(update.RawRepresentation); object raw = new(); @@ -74,17 +71,13 @@ public void Properties_Roundtrip() update.CreatedAt = new DateTimeOffset(2022, 1, 1, 0, 0, 0, TimeSpan.Zero); Assert.Equal(new DateTimeOffset(2022, 1, 1, 0, 0, 0, TimeSpan.Zero), update.CreatedAt); - Assert.Equal(0, update.ChoiceIndex); - update.ChoiceIndex = 42; - Assert.Equal(42, update.ChoiceIndex); - Assert.Null(update.FinishReason); update.FinishReason = ChatFinishReason.ContentFilter; Assert.Equal(ChatFinishReason.ContentFilter, update.FinishReason); } [Fact] - public void Text_GetSet_UsesFirstTextContent() + public void Text_Get_UsesAllTextContent() { ChatResponseUpdate update = new() { @@ -102,63 +95,15 @@ public void Text_GetSet_UsesFirstTextContent() TextContent textContent = Assert.IsType(update.Contents[3]); Assert.Equal("text-1", textContent.Text); - Assert.Equal("text-1", update.Text); + Assert.Equal("text-1text-2", update.Text); Assert.Equal("text-1text-2", update.ToString()); - update.Text = "text-3"; - Assert.Equal("text-3", update.Text); - Assert.Equal("text-3", update.Text); + ((TextContent)update.Contents[3]).Text = "text-3"; + Assert.Equal("text-3text-2", update.Text); Assert.Same(textContent, update.Contents[3]); Assert.Equal("text-3text-2", update.ToString()); } - [Fact] - public void Text_Set_AddsTextMessageToEmptyList() - { - ChatResponseUpdate update = new() - { - Role = ChatRole.User, - }; - Assert.Empty(update.Contents); - - update.Text = "text-1"; - Assert.Equal("text-1", update.Text); - - Assert.Single(update.Contents); - TextContent textContent = Assert.IsType(update.Contents[0]); - Assert.Equal("text-1", textContent.Text); - } - - [Fact] - public void Text_Set_AddsTextMessageToListWithNoText() - { - ChatResponseUpdate update = new() - { - Contents = - [ - new DataContent("http://localhost/audio"), - new DataContent("http://localhost/image"), - new FunctionCallContent("callId1", "fc1"), - ] - }; - Assert.Equal(3, update.Contents.Count); - - update.Text = "text-1"; - Assert.Equal("text-1", update.Text); - Assert.Equal(4, update.Contents.Count); - - update.Text = "text-2"; - Assert.Equal("text-2", update.Text); - Assert.Equal(4, update.Contents.Count); - - update.Contents.RemoveAt(3); - Assert.Equal(3, update.Contents.Count); - - update.Text = "text-3"; - Assert.Equal("text-3", update.Text); - Assert.Equal(4, update.Contents.Count); - } - [Fact] public void JsonSerialization_Roundtrips() { @@ -179,7 +124,6 @@ public void JsonSerialization_Roundtrips() CreatedAt = new DateTimeOffset(2022, 1, 1, 0, 0, 0, TimeSpan.Zero), FinishReason = ChatFinishReason.ContentFilter, AdditionalProperties = new() { ["key"] = "value" }, - ChoiceIndex = 42, }; string json = JsonSerializer.Serialize(original, TestJsonSerializerContext.Default.ChatResponseUpdate); @@ -209,7 +153,6 @@ public void JsonSerialization_Roundtrips() Assert.Equal("id", result.ResponseId); Assert.Equal(new DateTimeOffset(2022, 1, 1, 0, 0, 0, TimeSpan.Zero), result.CreatedAt); Assert.Equal(ChatFinishReason.ContentFilter, result.FinishReason); - Assert.Equal(42, result.ChoiceIndex); Assert.NotNull(result.AdditionalProperties); Assert.Single(result.AdditionalProperties); diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/DelegatingChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/DelegatingChatClientTests.cs index d7d265018b0..bab36d7f91a 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/DelegatingChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/DelegatingChatClientTests.cs @@ -25,7 +25,7 @@ public async Task ChatAsyncDefaultsToInnerClientAsync() var expectedChatOptions = new ChatOptions(); var expectedCancellationToken = CancellationToken.None; var expectedResult = new TaskCompletionSource(); - var expectedResponse = new ChatResponse([]); + var expectedResponse = new ChatResponse(); using var inner = new TestChatClient { GetResponseAsyncCallback = (chatContents, options, cancellationToken) => @@ -58,8 +58,8 @@ public async Task ChatStreamingAsyncDefaultsToInnerClientAsync() var expectedCancellationToken = CancellationToken.None; ChatResponseUpdate[] expectedResults = [ - new() { Role = ChatRole.User, Text = "Message 1" }, - new() { Role = ChatRole.User, Text = "Message 2" } + new(ChatRole.User, "Message 1"), + new(ChatRole.User, "Message 2") ]; using var inner = new TestChatClient diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGeneratorExtensionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGeneratorExtensionsTests.cs index 8a61fbb0786..fe4af33cf23 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGeneratorExtensionsTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGeneratorExtensionsTests.cs @@ -33,7 +33,7 @@ public void GetService_ValidService_Returned() { using IEmbeddingGenerator> generator = new TestEmbeddingGenerator { - GetServiceCallback = (Type serviceType, object? serviceKey) => + GetServiceCallback = (serviceType, serviceKey) => { if (serviceType == typeof(string)) { diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestChatClient.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestChatClient.cs index 226612bcff4..95f89a79141 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestChatClient.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestChatClient.cs @@ -17,20 +17,20 @@ public TestChatClient() public IServiceProvider? Services { get; set; } - public Func, ChatOptions?, CancellationToken, Task>? GetResponseAsyncCallback { get; set; } + public Func, ChatOptions?, CancellationToken, Task>? GetResponseAsyncCallback { get; set; } - public Func, ChatOptions?, CancellationToken, IAsyncEnumerable>? GetStreamingResponseAsyncCallback { get; set; } + public Func, ChatOptions?, CancellationToken, IAsyncEnumerable>? GetStreamingResponseAsyncCallback { get; set; } public Func GetServiceCallback { get; set; } private object? DefaultGetServiceCallback(Type serviceType, object? serviceKey) => serviceType is not null && serviceKey is null && serviceType.IsInstanceOfType(this) ? this : null; - public Task GetResponseAsync(IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) - => GetResponseAsyncCallback!.Invoke(chatMessages, options, cancellationToken); + public Task GetResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) + => GetResponseAsyncCallback!.Invoke(messages, options, cancellationToken); - public IAsyncEnumerable GetStreamingResponseAsync(IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) - => GetStreamingResponseAsyncCallback!.Invoke(chatMessages, options, cancellationToken); + public IAsyncEnumerable GetStreamingResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) + => GetStreamingResponseAsyncCallback!.Invoke(messages, options, cancellationToken); public object? GetService(Type serviceType, object? serviceKey = null) => GetServiceCallback(serviceType, serviceKey); diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Utilities/AIJsonUtilitiesTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Utilities/AIJsonUtilitiesTests.cs index 77209e0146c..d828365d8b5 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Utilities/AIJsonUtilitiesTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Utilities/AIJsonUtilitiesTests.cs @@ -387,9 +387,9 @@ public static void AddAIContentType_ReadOnlyJsonSerializerOptions_ThrowsInvalidO public static void AddAIContentType_NonAIContent_ThrowsArgumentException() { JsonSerializerOptions options = new(); - Assert.Throws(() => options.AddAIContentType(typeof(int), "discriminator")); - Assert.Throws(() => options.AddAIContentType(typeof(object), "discriminator")); - Assert.Throws(() => options.AddAIContentType(typeof(ChatMessage), "discriminator")); + Assert.Throws("contentType", () => options.AddAIContentType(typeof(int), "discriminator")); + Assert.Throws("contentType", () => options.AddAIContentType(typeof(object), "discriminator")); + Assert.Throws("contentType", () => options.AddAIContentType(typeof(ChatMessage), "discriminator")); } [Fact] @@ -415,11 +415,11 @@ public static void AddAIContentType_ConflictingIdentifier_ThrowsInvalidOperation public static void AddAIContentType_NullArguments_ThrowsArgumentNullException() { JsonSerializerOptions options = new(); - Assert.Throws(() => ((JsonSerializerOptions)null!).AddAIContentType("discriminator")); - Assert.Throws(() => ((JsonSerializerOptions)null!).AddAIContentType(typeof(DerivedAIContent), "discriminator")); - Assert.Throws(() => options.AddAIContentType(null!)); - Assert.Throws(() => options.AddAIContentType(typeof(DerivedAIContent), null!)); - Assert.Throws(() => options.AddAIContentType(null!, "discriminator")); + Assert.Throws("options", () => ((JsonSerializerOptions)null!).AddAIContentType("discriminator")); + Assert.Throws("options", () => ((JsonSerializerOptions)null!).AddAIContentType(typeof(DerivedAIContent), "discriminator")); + Assert.Throws("typeDiscriminatorId", () => options.AddAIContentType(null!)); + Assert.Throws("typeDiscriminatorId", () => options.AddAIContentType(typeof(DerivedAIContent), null!)); + Assert.Throws("contentType", () => options.AddAIContentType(null!, "discriminator")); } [Fact] diff --git a/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceChatClientTests.cs index c86dbd756b5..b8a68c913ed 100644 --- a/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceChatClientTests.cs @@ -152,11 +152,11 @@ public async Task BasicRequestResponse_NonStreaming(bool multiContent) using HttpClient httpClient = new(handler); using IChatClient client = CreateChatClient(httpClient, "gpt-4o-mini"); - List chatMessages = multiContent ? + List messages = multiContent ? [new ChatMessage(ChatRole.User, "hello".Select(c => (AIContent)new TextContent(c.ToString())).ToList())] : [new ChatMessage(ChatRole.User, "hello")]; - var response = await client.GetResponseAsync(chatMessages, new() + var response = await client.GetResponseAsync(messages, new() { MaxOutputTokens = 10, Temperature = 0.5f, @@ -164,9 +164,9 @@ [new ChatMessage(ChatRole.User, "hello".Select(c => (AIContent)new TextContent(c Assert.NotNull(response); Assert.Equal("chatcmpl-ADx3PvAnCwJg0woha4pYsBTi3ZpOI", response.ResponseId); - Assert.Equal("Hello! How can I assist you today?", response.Message.Text); - Assert.Single(response.Message.Contents); - Assert.Equal(ChatRole.Assistant, response.Message.Role); + Assert.Equal("Hello! How can I assist you today?", response.Text); + Assert.Single(response.Messages.Single().Contents); + Assert.Equal(ChatRole.Assistant, response.Messages.Single().Role); Assert.Equal("gpt-4o-mini-2024-07-18", response.ModelId); Assert.Equal(DateTimeOffset.FromUnixTimeSeconds(1_727_888_631), response.CreatedAt); Assert.Equal(ChatFinishReason.Stop, response.FinishReason); @@ -224,12 +224,12 @@ public async Task BasicRequestResponse_Streaming(bool multiContent) using HttpClient httpClient = new(handler); using IChatClient client = CreateChatClient(httpClient, "gpt-4o-mini"); - List chatMessages = multiContent ? + List messages = multiContent ? [new ChatMessage(ChatRole.User, "hello".Select(c => (AIContent)new TextContent(c.ToString())).ToList())] : [new ChatMessage(ChatRole.User, "hello")]; List updates = []; - await foreach (var update in client.GetStreamingResponseAsync(chatMessages, new() + await foreach (var update in client.GetStreamingResponseAsync(messages, new() { MaxOutputTokens = 20, Temperature = 0.5f, @@ -551,9 +551,9 @@ public async Task MultipleMessages_NonStreaming() Assert.NotNull(response); Assert.Equal("chatcmpl-ADyV17bXeSm5rzUx3n46O7m3M0o3P", response.ResponseId); - Assert.Equal("I’m doing well, thank you! What’s on your mind today?", response.Message.Text); - Assert.Single(response.Message.Contents); - Assert.Equal(ChatRole.Assistant, response.Message.Role); + Assert.Equal("I’m doing well, thank you! What’s on your mind today?", response.Text); + Assert.Single(response.Messages.Single().Contents); + Assert.Equal(ChatRole.Assistant, response.Messages.Single().Role); Assert.Equal("gpt-4o-mini-2024-07-18", response.ModelId); Assert.Equal(DateTimeOffset.FromUnixTimeSeconds(1_727_894_187), response.CreatedAt); Assert.Equal(ChatFinishReason.Stop, response.FinishReason); @@ -685,9 +685,9 @@ public async Task NullAssistantText_ContentEmpty_NonStreaming() Assert.NotNull(response); Assert.Equal("chatcmpl-ADyV17bXeSm5rzUx3n46O7m3M0o3P", response.ResponseId); - Assert.Equal("Hello.", response.Message.Text); - Assert.Single(response.Message.Contents); - Assert.Equal(ChatRole.Assistant, response.Message.Role); + Assert.Equal("Hello.", response.Text); + Assert.Single(response.Messages.Single().Contents); + Assert.Equal(ChatRole.Assistant, response.Messages.Single().Role); Assert.Equal("gpt-4o-mini-2024-07-18", response.ModelId); Assert.Equal(DateTimeOffset.FromUnixTimeSeconds(1_727_894_187), response.CreatedAt); Assert.Equal(ChatFinishReason.Stop, response.FinishReason); @@ -801,9 +801,9 @@ public async Task FunctionCallContent_NonStreaming(ChatToolMode mode) }); Assert.NotNull(response); - Assert.Null(response.Message.Text); + Assert.Empty(response.Text); Assert.Equal("gpt-4o-mini-2024-07-18", response.ModelId); - Assert.Equal(ChatRole.Assistant, response.Message.Role); + Assert.Equal(ChatRole.Assistant, response.Messages.Single().Role); Assert.Equal(DateTimeOffset.FromUnixTimeSeconds(1_727_894_702), response.CreatedAt); Assert.Equal(ChatFinishReason.ToolCalls, response.FinishReason); Assert.NotNull(response.Usage); @@ -811,9 +811,8 @@ public async Task FunctionCallContent_NonStreaming(ChatToolMode mode) Assert.Equal(16, response.Usage.OutputTokenCount); Assert.Equal(77, response.Usage.TotalTokenCount); - Assert.Single(response.Choices); - Assert.Single(response.Message.Contents); - FunctionCallContent fcc = Assert.IsType(response.Message.Contents[0]); + Assert.Single(response.Messages.Single().Contents); + FunctionCallContent fcc = Assert.IsType(response.Messages.Single().Contents[0]); Assert.Equal("GetPersonAge", fcc.Name); AssertExtensions.EqualFunctionCallParameters(new Dictionary { ["personName"] = "Alice" }, fcc.Arguments); } diff --git a/test/Libraries/Microsoft.Extensions.AI.Evaluation.Integration.Tests/AdditionalContextTests.cs b/test/Libraries/Microsoft.Extensions.AI.Evaluation.Integration.Tests/AdditionalContextTests.cs index f5ffd922816..cbc78ef6642 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Evaluation.Integration.Tests/AdditionalContextTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Evaluation.Integration.Tests/AdditionalContextTests.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; +using System.Linq; using System.Threading.Tasks; using FluentAssertions; using FluentAssertions.Execution; @@ -59,7 +60,7 @@ await _reportingConfiguration.CreateScenarioRunAsync( messages.Add(promptMessage); ChatResponse response = await chatClient.GetResponseAsync(messages, _chatOptions); - ChatMessage responseMessage = response.Message; + ChatMessage responseMessage = response.Messages.Single(); Assert.NotNull(responseMessage.Text); EvaluationResult result = @@ -94,7 +95,7 @@ await _reportingConfiguration.CreateScenarioRunAsync( messages.Add(promptMessage); ChatResponse response = await chatClient.GetResponseAsync(messages, _chatOptions); - ChatMessage responseMessage = response.Message; + ChatMessage responseMessage = response.Messages.Single(); Assert.NotNull(responseMessage.Text); var baselineResponseForEquivalenceEvaluator = diff --git a/test/Libraries/Microsoft.Extensions.AI.Evaluation.Integration.Tests/EndToEndTests.cs b/test/Libraries/Microsoft.Extensions.AI.Evaluation.Integration.Tests/EndToEndTests.cs index 4062a0c4fda..dbfdebc529c 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Evaluation.Integration.Tests/EndToEndTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Evaluation.Integration.Tests/EndToEndTests.cs @@ -71,7 +71,7 @@ await _reportingConfiguration.CreateScenarioRunAsync( messages.Add(promptMessage); ChatResponse response = await chatClient.GetResponseAsync(messages, _chatOptions); - ChatMessage responseMessage = response.Message; + ChatMessage responseMessage = response.Messages.Single(); Assert.NotNull(responseMessage.Text); EvaluationResult result = await scenarioRun.EvaluateAsync(promptMessage, responseMessage); @@ -122,7 +122,7 @@ await _reportingConfiguration.CreateScenarioRunAsync( messages.Add(promptMessage); ChatResponse response = await chatClient.GetResponseAsync(messages, _chatOptions); - ChatMessage responseMessage = response.Message; + ChatMessage responseMessage = response.Messages.Single(); Assert.NotNull(responseMessage.Text); EvaluationResult result = await scenarioRun.EvaluateAsync(promptMessage, responseMessage); diff --git a/test/Libraries/Microsoft.Extensions.AI.Evaluation.Integration.Tests/RelevanceTruthAndCompletenessEvaluatorTests.cs b/test/Libraries/Microsoft.Extensions.AI.Evaluation.Integration.Tests/RelevanceTruthAndCompletenessEvaluatorTests.cs index eac1f5ea228..8b479ea57cf 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Evaluation.Integration.Tests/RelevanceTruthAndCompletenessEvaluatorTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Evaluation.Integration.Tests/RelevanceTruthAndCompletenessEvaluatorTests.cs @@ -68,7 +68,7 @@ await _reportingConfigurationWithoutReasoning.CreateScenarioRunAsync( messages.Add(promptMessage); ChatResponse response = await chatClient.GetResponseAsync(messages, _chatOptions); - ChatMessage responseMessage = response.Message; + ChatMessage responseMessage = response.Messages.Single(); Assert.NotNull(responseMessage.Text); EvaluationResult result = await scenarioRun.EvaluateAsync(promptMessage, responseMessage); @@ -101,7 +101,7 @@ await _reportingConfigurationWithReasoning.CreateScenarioRunAsync( messages.Add(promptMessage); ChatResponse response = await chatClient.GetResponseAsync(messages, _chatOptions); - ChatMessage responseMessage = response.Message; + ChatMessage responseMessage = response.Messages.Single(); Assert.NotNull(responseMessage.Text); EvaluationResult result = await scenarioRun.EvaluateAsync(promptMessage, responseMessage); diff --git a/test/Libraries/Microsoft.Extensions.AI.Evaluation.Integration.Tests/TestEvaluator.cs b/test/Libraries/Microsoft.Extensions.AI.Evaluation.Integration.Tests/TestEvaluator.cs index 8584ada0853..f853b3ac030 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Evaluation.Integration.Tests/TestEvaluator.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Evaluation.Integration.Tests/TestEvaluator.cs @@ -25,7 +25,7 @@ private ValueTask GetResultAsync() => async ValueTask IEvaluator.EvaluateAsync( IEnumerable messages, - ChatMessage modelResponse, + ChatResponse modelResponse, ChatConfiguration? chatConfiguration, IEnumerable? additionalContext, CancellationToken cancellationToken) diff --git a/test/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting.Tests/ResultStoreTester.cs b/test/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting.Tests/ResultStoreTester.cs index f68eb15380e..7ed0f31f3c5 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting.Tests/ResultStoreTester.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting.Tests/ResultStoreTester.cs @@ -32,7 +32,7 @@ private static ScenarioRunResult CreateTestResult(string scenarioName, string it executionName: executionName, creationTime: DateTime.UtcNow, messages: [new ChatMessage(ChatRole.User, "User prompt")], - modelResponse: new ChatMessage(ChatRole.Assistant, "LLM response"), + modelResponse: new ChatResponse(new ChatMessage(ChatRole.Assistant, "LLM response")), evaluationResult: new EvaluationResult(booleanMetric, numericMetric, stringMetric)); } diff --git a/test/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting.Tests/ScenarioRunResultTests.cs b/test/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting.Tests/ScenarioRunResultTests.cs index b522f797675..9418a5db359 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting.Tests/ScenarioRunResultTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting.Tests/ScenarioRunResultTests.cs @@ -36,7 +36,7 @@ public void SerializeScenarioRunResult() executionName: "Test Execution", creationTime: DateTime.UtcNow, messages: [new ChatMessage(ChatRole.User, "prompt")], - modelResponse: new ChatMessage(ChatRole.Assistant, "response"), + modelResponse: new ChatResponse(new ChatMessage(ChatRole.Assistant, "response")), evaluationResult: new EvaluationResult(booleanMetric, numericMetric, stringMetric, metricWithNoValue)); string json = JsonSerializer.Serialize(entry, SerializerContext.Default.ScenarioRunResult); @@ -48,7 +48,7 @@ public void SerializeScenarioRunResult() Assert.Equal(entry.ExecutionName, deserialized.ExecutionName); Assert.Equal(entry.CreationTime, deserialized.CreationTime); Assert.True(entry.Messages.SequenceEqual(deserialized.Messages, ChatMessageComparer.Instance)); - Assert.Equal(entry.ModelResponse, deserialized.ModelResponse, ChatMessageComparer.Instance); + Assert.Equal(entry.ModelResponse, deserialized.ModelResponse, ChatResponseComparer.Instance); ValidateEquivalence(entry.EvaluationResult, deserialized.EvaluationResult); } @@ -75,7 +75,7 @@ public void SerializeDatasetCompact() executionName: "Test Execution", creationTime: DateTime.UtcNow, messages: [new ChatMessage(ChatRole.User, "prompt")], - modelResponse: new ChatMessage(ChatRole.Assistant, "response"), + modelResponse: new ChatResponse(new ChatMessage(ChatRole.Assistant, "response")), evaluationResult: new EvaluationResult(booleanMetric, numericMetric, stringMetric, metricWithNoValue)); var dataset = new Dataset([entry], createdAt: DateTime.UtcNow, generatorVersion: "1.2.3.4"); @@ -89,7 +89,7 @@ public void SerializeDatasetCompact() Assert.Equal(entry.ExecutionName, deserialized.ScenarioRunResults[0].ExecutionName); Assert.Equal(entry.CreationTime, deserialized.ScenarioRunResults[0].CreationTime); Assert.True(entry.Messages.SequenceEqual(deserialized.ScenarioRunResults[0].Messages, ChatMessageComparer.Instance)); - Assert.Equal(entry.ModelResponse, deserialized.ScenarioRunResults[0].ModelResponse, ChatMessageComparer.Instance); + Assert.Equal(entry.ModelResponse, deserialized.ScenarioRunResults[0].ModelResponse, ChatResponseComparer.Instance); Assert.Single(deserialized.ScenarioRunResults); Assert.Equal(dataset.CreatedAt, deserialized.CreatedAt); @@ -155,7 +155,20 @@ public bool Equals(ChatMessage? x, ChatMessage? y) => x?.AuthorName == y?.AuthorName && x?.Role == y?.Role && x?.Text == y?.Text; public int GetHashCode(ChatMessage obj) - => obj.GetHashCode(); + => obj.Text.GetHashCode(); + } + + private class ChatResponseComparer : IEqualityComparer + { + public static ChatResponseComparer Instance { get; } = new ChatResponseComparer(); + + public bool Equals(ChatResponse? x, ChatResponse? y) + => + x is null ? y is null : + y is not null && x.Messages.SequenceEqual(y.Messages, ChatMessageComparer.Instance); + + public int GetHashCode(ChatResponse obj) + => obj.Text.GetHashCode(); } private class DiagnosticComparer : IEqualityComparer diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/CallCountingChatClient.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/CallCountingChatClient.cs index 853815ff033..c0045dc0f82 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/CallCountingChatClient.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/CallCountingChatClient.cs @@ -17,17 +17,17 @@ internal sealed class CallCountingChatClient(IChatClient innerClient) : Delegati public int CallCount => _callCount; public override Task GetResponseAsync( - IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) { Interlocked.Increment(ref _callCount); - return base.GetResponseAsync(chatMessages, options, cancellationToken); + return base.GetResponseAsync(messages, options, cancellationToken); } public override IAsyncEnumerable GetStreamingResponseAsync( - IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) { Interlocked.Increment(ref _callCount); - return base.GetStreamingResponseAsync(chatMessages, options, cancellationToken); + return base.GetStreamingResponseAsync(messages, options, cancellationToken); } } diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs index eaf4834e60d..55b840eea5f 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs @@ -21,6 +21,7 @@ #pragma warning disable CA2000 // Dispose objects before losing scope #pragma warning disable CA2214 // Do not call overridable methods in constructors +#pragma warning disable CA2249 // Consider using 'string.Contains' instead of 'string.IndexOf' namespace Microsoft.Extensions.AI; @@ -48,7 +49,7 @@ public virtual async Task GetResponseAsync_SingleRequestMessage() var response = await _chatClient.GetResponseAsync("What's the biggest animal?"); - Assert.Contains("whale", response.Message.Text, StringComparison.OrdinalIgnoreCase); + Assert.Contains("whale", response.Text, StringComparison.OrdinalIgnoreCase); } [ConditionalFact] @@ -65,13 +66,12 @@ public virtual async Task GetResponseAsync_MultipleRequestMessages() new(ChatRole.User, "What continent are they each in?"), ]); - Assert.Single(response.Choices); - Assert.Contains("America", response.Message.Text); - Assert.Contains("Asia", response.Message.Text); + Assert.Contains("America", response.Text); + Assert.Contains("Asia", response.Text); } [ConditionalFact] - public virtual async Task GetStreamingResponseAsync_SingleStreamingResponseChoice() + public virtual async Task GetStreamingResponseAsync() { SkipIfNotEnabled(); @@ -89,9 +89,6 @@ public virtual async Task GetStreamingResponseAsync_SingleStreamingResponseChoic string responseText = sb.ToString(); Assert.Contains("one small step", responseText, StringComparison.OrdinalIgnoreCase); Assert.Contains("one giant leap", responseText, StringComparison.OrdinalIgnoreCase); - - // The input list is left unaugmented. - Assert.Single(chatHistory); } [ConditionalFact] @@ -101,7 +98,6 @@ public virtual async Task GetResponseAsync_UsageDataAvailable() var response = await _chatClient.GetResponseAsync("Explain in 10 words how AI works"); - Assert.Single(response.Choices); Assert.True(response.Usage?.InputTokenCount > 1); Assert.True(response.Usage?.OutputTokenCount > 1); Assert.Equal(response.Usage?.InputTokenCount + response.Usage?.OutputTokenCount, response.Usage?.TotalTokenCount); @@ -151,8 +147,7 @@ public virtual async Task MultiModal_DescribeImage() ], new() { ModelId = GetModel_MultiModal_DescribeImage() }); - Assert.Single(response.Choices); - Assert.True(response.Message.Text?.IndexOf("net", StringComparison.OrdinalIgnoreCase) >= 0, response.Message.Text); + Assert.True(response.Text.IndexOf("net", StringComparison.OrdinalIgnoreCase) >= 0, response.Text); } [ConditionalFact] @@ -182,8 +177,7 @@ public virtual async Task FunctionInvocation_AutomaticallyInvokeFunction_Paramet Tools = [AIFunctionFactory.Create(() => secretNumber, "GetSecretNumber")] }); - Assert.Single(response.Choices); - Assert.Contains(secretNumber.ToString(), response.Message.Text); + Assert.Contains(secretNumber.ToString(), response.Text); // If the underlying IChatClient provides usage data, function invocation should aggregate the // usage data across all calls to produce a single Usage value on the final response @@ -208,8 +202,7 @@ public virtual async Task FunctionInvocation_AutomaticallyInvokeFunction_WithPar Tools = [AIFunctionFactory.Create((int a, int b) => a * b, "SecretComputation")] }); - Assert.Single(response.Choices); - Assert.Contains("3528", response.Message.Text); + Assert.Contains("3528", response.Text); } [ConditionalFact] @@ -261,8 +254,8 @@ public virtual async Task FunctionInvocation_SupportsMultipleParallelRequests() }); Assert.True( - Regex.IsMatch(response.Message.Text ?? "", @"\b(3|three)\b", RegexOptions.IgnoreCase), - $"Doesn't contain three: {response.Message.Text}"); + Regex.IsMatch(response.Text ?? "", @"\b(3|three)\b", RegexOptions.IgnoreCase), + $"Doesn't contain three: {response.Text}"); } [ConditionalFact] @@ -285,7 +278,6 @@ public virtual async Task FunctionInvocation_RequireAny() ToolMode = ChatToolMode.RequireAny, }); - Assert.Single(response.Choices); Assert.True(callCount >= 1); } @@ -317,10 +309,9 @@ public virtual async Task Caching_OutputVariesWithoutCaching() var message = new ChatMessage(ChatRole.User, "Pick a random number, uniformly distributed between 1 and 1000000"); var firstResponse = await _chatClient.GetResponseAsync([message]); - Assert.Single(firstResponse.Choices); var secondResponse = await _chatClient.GetResponseAsync([message]); - Assert.NotEqual(firstResponse.Message.Text, secondResponse.Message.Text); + Assert.NotEqual(firstResponse.Text, secondResponse.Text); } [ConditionalFact] @@ -334,19 +325,18 @@ public virtual async Task Caching_SamePromptResultsInCacheHit_NonStreaming() var message = new ChatMessage(ChatRole.User, "Pick a random number, uniformly distributed between 1 and 1000000"); var firstResponse = await chatClient.GetResponseAsync([message]); - Assert.Single(firstResponse.Choices); // No matter what it said before, we should see identical output due to caching for (int i = 0; i < 3; i++) { var secondResponse = await chatClient.GetResponseAsync([message]); - Assert.Equal(firstResponse.Message.Text, secondResponse.Message.Text); + Assert.Equal(firstResponse.Messages.Select(m => m.Text), secondResponse.Messages.Select(m => m.Text)); } // ... but if the conversation differs, we should see different output - message.Text += "!"; + ((TextContent)message.Contents[0]).Text += "!"; var thirdResponse = await chatClient.GetResponseAsync([message]); - Assert.NotEqual(firstResponse.Message.Text, thirdResponse.Message.Text); + Assert.NotEqual(firstResponse.Messages, thirdResponse.Messages); } [ConditionalFact] @@ -378,7 +368,7 @@ public virtual async Task Caching_SamePromptResultsInCacheHit_Streaming() } // ... but if the conversation differs, we should see different output - message.Text += "!"; + ((TextContent)message.Contents[0]).Text += "!"; StringBuilder third = new(); await foreach (var update in chatClient.GetStreamingResponseAsync([message])) { @@ -412,14 +402,14 @@ public virtual async Task Caching_BeforeFunctionInvocation_AvoidsExtraCalls() var llmCallCount = chatClient.GetService(); var message = new ChatMessage(ChatRole.User, "What is the temperature?"); var response = await chatClient.GetResponseAsync([message]); - Assert.Contains("101", response.Message.Text); + Assert.Contains("101", response.Text); // First LLM call tells us to call the function, second deals with the result Assert.Equal(2, llmCallCount!.CallCount); // Second call doesn't execute the function or call the LLM, but rather just returns the cached result var secondResponse = await chatClient.GetResponseAsync([message]); - Assert.Equal(response.Message.Text, secondResponse.Message.Text); + Assert.Equal(response.Text, secondResponse.Text); Assert.Equal(1, functionCallCount); Assert.Equal(2, llmCallCount!.CallCount); } @@ -451,7 +441,7 @@ public virtual async Task Caching_AfterFunctionInvocation_FunctionOutputUnchange var llmCallCount = chatClient.GetService(); var message = new ChatMessage(ChatRole.User, "What is the temperature?"); var response = await chatClient.GetResponseAsync([message]); - Assert.Contains("58", response.Message.Text); + Assert.Contains("58", response.Text); // First LLM call tells us to call the function, second deals with the result Assert.Equal(1, functionCallCount); @@ -459,7 +449,7 @@ public virtual async Task Caching_AfterFunctionInvocation_FunctionOutputUnchange // Second time, the calls to the LLM don't happen, but the function is called again var secondResponse = await chatClient.GetResponseAsync([message]); - Assert.Equal(response.Message.Text, secondResponse.Message.Text); + Assert.Equal(response.Text, secondResponse.Text); Assert.Equal(2, functionCallCount); Assert.Equal(2, llmCallCount!.CallCount); } @@ -491,7 +481,7 @@ public virtual async Task Caching_AfterFunctionInvocation_FunctionOutputChangedA var llmCallCount = chatClient.GetService(); var message = new ChatMessage(ChatRole.User, "What is the temperature?"); var response = await chatClient.GetResponseAsync([message]); - Assert.Contains("81", response.Message.Text); + Assert.Contains("81", response.Text); // First LLM call tells us to call the function, second deals with the result Assert.Equal(1, functionCallCount); @@ -500,7 +490,7 @@ public virtual async Task Caching_AfterFunctionInvocation_FunctionOutputChangedA // Second time, the first call to the LLM don't happen, but the function is called again, // and since its output now differs, we no longer hit the cache so the second LLM call does happen var secondResponse = await chatClient.GetResponseAsync([message]); - Assert.Contains("82", secondResponse.Message.Text); + Assert.Contains("82", secondResponse.Text); Assert.Equal(2, functionCallCount); Assert.Equal(3, llmCallCount!.CallCount); } diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/PromptBasedFunctionCallingChatClient.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/PromptBasedFunctionCallingChatClient.cs index 1cf786bb288..1b8b90f4a3a 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/PromptBasedFunctionCallingChatClient.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/PromptBasedFunctionCallingChatClient.cs @@ -7,6 +7,7 @@ using System.ComponentModel; using System.Linq; using System.Reflection; +using System.Runtime.CompilerServices; using System.Text.Json; using System.Text.Json.Serialization; using System.Threading; @@ -18,8 +19,8 @@ namespace Microsoft.Extensions.AI; -// This isn't a feature we're planning to ship, but demonstrates how custom clients can -// layer in non-trivial functionality. In this case we're able to upgrade non-function-calling models +// Demonstrates how custom clients can layer in non-trivial functionality. +// In this case we're able to upgrade non-function-calling models // to behaving as if they do support function calling. // // In practice: @@ -39,13 +40,16 @@ internal sealed class PromptBasedFunctionCallingChatClient(IChatClient innerClie DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, }; - public override async Task GetResponseAsync(IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + public override async Task GetResponseAsync( + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) { + List chatMessageList = [.. messages]; + // Our goal is to convert tools into a prompt describing them, then to detect tool calls in the // response and convert those into FunctionCallContent. if (options?.Tools is { Count: > 0 }) { - AddOrUpdateToolPrompt(chatMessages, options.Tools); + AddOrUpdateToolPrompt(chatMessageList, options.Tools); options = options.Clone(); options.Tools = null; @@ -58,7 +62,7 @@ public override async Task GetResponseAsync(IList cha // Since the point of this client is to avoid relying on the underlying model having // native tool call support, we have to replace any "tool" or "toolcall" messages with // "user" or "assistant" ones. - foreach (var message in chatMessages) + foreach (var message in chatMessageList) { for (var itemIndex = 0; itemIndex < message.Contents.Count; itemIndex++) { @@ -80,12 +84,12 @@ public override async Task GetResponseAsync(IList cha } } - var result = await base.GetResponseAsync(chatMessages, options, cancellationToken); + var result = await base.GetResponseAsync(chatMessageList, options, cancellationToken); - if (result.Choices.FirstOrDefault()?.Text is { } content && content.IndexOf("", StringComparison.Ordinal) is int startPos + if (result.Text is { } content && content.IndexOf("", StringComparison.Ordinal) is int startPos && startPos >= 0) { - var message = result.Choices.First(); + var message = result.Messages.Last(); var contentItem = message.Contents.SingleOrDefault(); content = content.Substring(startPos); @@ -131,6 +135,16 @@ public override async Task GetResponseAsync(IList cha return result; } + public override async IAsyncEnumerable GetStreamingResponseAsync( + IEnumerable messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + var response = await GetResponseAsync(messages, options, cancellationToken); + foreach (var update in response.ToChatResponseUpdates()) + { + yield return update; + } + } + private static void ParseArguments(IDictionary arguments) { // This is a simple implementation. A more robust answer is to use other schema information given by @@ -151,17 +165,17 @@ private static void ParseArguments(IDictionary arguments) } } - private static void AddOrUpdateToolPrompt(IList chatMessages, IList tools) + private static void AddOrUpdateToolPrompt(List messages, IList tools) { - var existingToolPrompt = chatMessages.FirstOrDefault(c => c.Text?.StartsWith(MessageIntro, StringComparison.Ordinal) is true); + var existingToolPrompt = messages.FirstOrDefault(c => c.Text.StartsWith(MessageIntro, StringComparison.Ordinal) is true); if (existingToolPrompt is null) { existingToolPrompt = new ChatMessage(ChatRole.System, (string?)null); - chatMessages.Insert(0, existingToolPrompt); + messages.Insert(0, existingToolPrompt); } var toolDescriptorsJson = JsonSerializer.Serialize(tools.OfType().Select(ToToolDescriptor), _jsonOptions); - existingToolPrompt.Text = $$""" + existingToolPrompt.Contents.OfType().First().Text = $$""" {{MessageIntro}} For each function call, return a JSON object with the function name and arguments within XML tags diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ReducingChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ReducingChatClientTests.cs index eeccd609a93..99533e56f53 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ReducingChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ReducingChatClientTests.cs @@ -29,11 +29,11 @@ public async Task Reduction_LimitsMessagesBasedOnTokenLimit() { GetResponseAsyncCallback = (messages, options, cancellationToken) => { - Assert.Equal(2, messages.Count); + Assert.Equal(2, messages.Count()); Assert.Collection(messages, m => Assert.StartsWith("Golden retrievers are quite active", m.Text, StringComparison.Ordinal), m => Assert.StartsWith("Are they good with kids?", m.Text, StringComparison.Ordinal)); - return Task.FromResult(new ChatResponse([])); + return Task.FromResult(new ChatResponse()); } }; @@ -61,69 +61,57 @@ public async Task Reduction_LimitsMessagesBasedOnTokenLimit() public sealed class ReducingChatClient : DelegatingChatClient { private readonly IChatReducer _reducer; - private readonly bool _inPlace; /// Initializes a new instance of the class. /// The inner client. /// The reducer to be used by this instance. - /// - /// true if the should perform any modifications directly on the supplied list of messages; - /// false if it should instead create a new list when reduction is necessary. - /// - public ReducingChatClient(IChatClient innerClient, IChatReducer reducer, bool inPlace = false) + public ReducingChatClient(IChatClient innerClient, IChatReducer reducer) : base(innerClient) { _reducer = Throw.IfNull(reducer); - _inPlace = inPlace; } /// public override async Task GetResponseAsync( - IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) { - chatMessages = await GetChatMessagesToPropagate(chatMessages, cancellationToken).ConfigureAwait(false); + messages = await _reducer.ReduceAsync(messages, cancellationToken).ConfigureAwait(false); - return await base.GetResponseAsync(chatMessages, options, cancellationToken).ConfigureAwait(false); + return await base.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false); } /// public override async IAsyncEnumerable GetStreamingResponseAsync( - IList chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + IEnumerable messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - chatMessages = await GetChatMessagesToPropagate(chatMessages, cancellationToken).ConfigureAwait(false); + messages = await _reducer.ReduceAsync(messages, cancellationToken).ConfigureAwait(false); - await foreach (var update in base.GetStreamingResponseAsync(chatMessages, options, cancellationToken).ConfigureAwait(false)) + await foreach (var update in base.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false)) { yield return update; } } - - /// Runs the reducer and gets the chat message list to forward to the inner client. - private async Task> GetChatMessagesToPropagate(IList chatMessages, CancellationToken cancellationToken) => - await _reducer.ReduceAsync(chatMessages, _inPlace, cancellationToken).ConfigureAwait(false) ?? - chatMessages; } /// Represents a reducer capable of shrinking the size of a list of chat messages. public interface IChatReducer { /// Reduces the size of a list of chat messages. - /// The messages. - /// true if the reducer should modify the provided list; false if a new list should be returned. + /// The messages. /// The to monitor for cancellation requests. The default is . /// The new list of messages, or null if no reduction need be performed or was true. - Task?> ReduceAsync(IList chatMessages, bool inPlace, CancellationToken cancellationToken); + Task> ReduceAsync(IEnumerable messages, CancellationToken cancellationToken); } /// Provides extensions for configuring instances. public static class ReducingChatClientExtensions { - public static ChatClientBuilder UseChatReducer(this ChatClientBuilder builder, IChatReducer reducer, bool inPlace = false) + public static ChatClientBuilder UseChatReducer(this ChatClientBuilder builder, IChatReducer reducer) { _ = Throw.IfNull(builder); _ = Throw.IfNull(reducer); - return builder.Use(innerClient => new ReducingChatClient(innerClient, reducer, inPlace)); + return builder.Use(innerClient => new ReducingChatClient(innerClient, reducer)); } } @@ -139,51 +127,29 @@ public TokenCountingChatReducer(Tokenizer tokenizer, int tokenLimit) _tokenLimit = Throw.IfLessThan(tokenLimit, 1); } - public async Task?> ReduceAsync(IList chatMessages, bool inPlace, CancellationToken cancellationToken) + public async Task> ReduceAsync( + IEnumerable messages, CancellationToken cancellationToken) { - _ = Throw.IfNull(chatMessages); + _ = Throw.IfNull(messages); + + List list = messages.ToList(); - if (chatMessages.Count > 1) + if (list.Count > 1) { - int totalCount = CountTokens(chatMessages[chatMessages.Count - 1]); + int totalCount = CountTokens(list[list.Count - 1]); - if (inPlace) - { - for (int i = chatMessages.Count - 2; i >= 0; i--) - { - totalCount += CountTokens(chatMessages[i]); - if (totalCount > _tokenLimit) - { - if (chatMessages is List list) - { - list.RemoveRange(0, i + 1); - } - else - { - for (int j = i; j >= 0; j--) - { - chatMessages.RemoveAt(j); - } - } - - break; - } - } - } - else + for (int i = list.Count - 2; i >= 0; i--) { - for (int i = chatMessages.Count - 2; i >= 0; i--) + totalCount += CountTokens(list[i]); + if (totalCount > _tokenLimit) { - totalCount += CountTokens(chatMessages[i]); - if (totalCount > _tokenLimit) - { - return chatMessages.Skip(i + 1).ToList(); - } + list.RemoveRange(0, i + 1); + break; } } } - return null; + return list; } private int CountTokens(ChatMessage message) diff --git a/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientIntegrationTests.cs b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientIntegrationTests.cs index 09328dd8ce6..83e84e49f5b 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientIntegrationTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientIntegrationTests.cs @@ -47,8 +47,7 @@ public async Task PromptBasedFunctionCalling_NoArgs() Seed = 0, }); - Assert.Single(response.Choices); - Assert.Contains(secretNumber.ToString(), response.Message.Text); + Assert.Contains(secretNumber.ToString(), response.Text); } [ConditionalFact] @@ -82,8 +81,7 @@ public async Task PromptBasedFunctionCalling_WithArgs() Seed = 0, }); - Assert.Single(response.Choices); - Assert.Contains("999", response.Message.Text); + Assert.Contains("999", response.Text); Assert.False(didCallIrrelevantTool); } @@ -108,10 +106,10 @@ public async Task InvalidModelParameter_ThrowsInvalidOperationException() private sealed class AssertNoToolsDefinedChatClient(IChatClient innerClient) : DelegatingChatClient(innerClient) { public override Task GetResponseAsync( - IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) { Assert.Null(options?.Tools); - return base.GetResponseAsync(chatMessages, options, cancellationToken); + return base.GetResponseAsync(messages, options, cancellationToken); } } } diff --git a/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientTests.cs index 439ca29a3ec..8f7499aa272 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientTests.cs @@ -118,9 +118,9 @@ public async Task BasicRequestResponse_NonStreaming() }); Assert.NotNull(response); - Assert.Equal("Hello! How are you today? Is there something", response.Message.Text); - Assert.Single(response.Message.Contents); - Assert.Equal(ChatRole.Assistant, response.Message.Role); + Assert.Equal("Hello! How are you today? Is there something", response.Text); + Assert.Single(response.Messages.Single().Contents); + Assert.Equal(ChatRole.Assistant, response.Messages.Single().Role); Assert.Equal("llama3.1", response.ModelId); Assert.Equal(DateTimeOffset.Parse("2024-10-01T15:46:10.5248793Z"), response.CreatedAt); Assert.Equal(ChatFinishReason.Length, response.FinishReason); @@ -281,9 +281,9 @@ public async Task MultipleMessages_NonStreaming() but I'm functioning properly and ready to help with any questions or tasks you may have! How about we chat about something in particular or just shoot the breeze ? Your choice! """), - VerbatimHttpHandler.RemoveWhiteSpace(response.Message.Text)); - Assert.Single(response.Message.Contents); - Assert.Equal(ChatRole.Assistant, response.Message.Role); + VerbatimHttpHandler.RemoveWhiteSpace(response.Text)); + Assert.Single(response.Messages.Single().Contents); + Assert.Equal(ChatRole.Assistant, response.Messages.Single().Role); Assert.Equal("llama3.1", response.ModelId); Assert.Equal(DateTimeOffset.Parse("2024-10-01T17:18:46.308987Z"), response.CreatedAt); Assert.Equal(ChatFinishReason.Stop, response.FinishReason); @@ -370,9 +370,9 @@ public async Task FunctionCallContent_NonStreaming() }); Assert.NotNull(response); - Assert.Null(response.Message.Text); + Assert.Empty(response.Text); Assert.Equal("llama3.1", response.ModelId); - Assert.Equal(ChatRole.Assistant, response.Message.Role); + Assert.Equal(ChatRole.Assistant, response.Messages.Single().Role); Assert.Equal(DateTimeOffset.Parse("2024-10-01T18:48:30.2669578Z"), response.CreatedAt); Assert.Equal(ChatFinishReason.Stop, response.FinishReason); Assert.NotNull(response.Usage); @@ -380,9 +380,8 @@ public async Task FunctionCallContent_NonStreaming() Assert.Equal(19, response.Usage.OutputTokenCount); Assert.Equal(189, response.Usage.TotalTokenCount); - Assert.Single(response.Choices); - Assert.Single(response.Message.Contents); - FunctionCallContent fcc = Assert.IsType(response.Message.Contents[0]); + Assert.Single(response.Messages.Single().Contents); + FunctionCallContent fcc = Assert.IsType(response.Messages.Single().Contents[0]); Assert.Equal("GetPersonAge", fcc.Name); AssertExtensions.EqualFunctionCallParameters(new Dictionary { ["personName"] = "Alice" }, fcc.Arguments); } @@ -468,9 +467,9 @@ public async Task FunctionResultContent_NonStreaming() }); Assert.NotNull(response); - Assert.Equal("Alice is 42 years old.", response.Message.Text); + Assert.Equal("Alice is 42 years old.", response.Text); Assert.Equal("llama3.1", response.ModelId); - Assert.Equal(ChatRole.Assistant, response.Message.Role); + Assert.Equal(ChatRole.Assistant, response.Messages.Single().Role); Assert.Equal(DateTimeOffset.Parse("2024-10-01T20:57:20.157266Z"), response.CreatedAt); Assert.Equal(ChatFinishReason.Stop, response.FinishReason); Assert.NotNull(response.Usage); diff --git a/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIChatClientTests.cs index 789f0abeb63..8cd53c55766 100644 --- a/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIChatClientTests.cs @@ -195,9 +195,9 @@ public async Task BasicRequestResponse_NonStreaming() Assert.NotNull(response); Assert.Equal("chatcmpl-ADx3PvAnCwJg0woha4pYsBTi3ZpOI", response.ResponseId); - Assert.Equal("Hello! How can I assist you today?", response.Message.Text); - Assert.Single(response.Message.Contents); - Assert.Equal(ChatRole.Assistant, response.Message.Role); + Assert.Equal("Hello! How can I assist you today?", response.Text); + Assert.Single(response.Messages.Single().Contents); + Assert.Equal(ChatRole.Assistant, response.Messages.Single().Role); Assert.Equal("gpt-4o-mini-2024-07-18", response.ModelId); Assert.Equal(DateTimeOffset.FromUnixTimeSeconds(1_727_888_631), response.CreatedAt); Assert.Equal(ChatFinishReason.Stop, response.FinishReason); @@ -455,9 +455,9 @@ public async Task MultipleMessages_NonStreaming() Assert.NotNull(response); Assert.Equal("chatcmpl-ADyV17bXeSm5rzUx3n46O7m3M0o3P", response.ResponseId); - Assert.Equal("I’m doing well, thank you! What’s on your mind today?", response.Message.Text); - Assert.Single(response.Message.Contents); - Assert.Equal(ChatRole.Assistant, response.Message.Role); + Assert.Equal("I’m doing well, thank you! What’s on your mind today?", response.Text); + Assert.Single(response.Messages.Single().Contents); + Assert.Equal(ChatRole.Assistant, response.Messages.Single().Role); Assert.Equal("gpt-4o-mini-2024-07-18", response.ModelId); Assert.Equal(DateTimeOffset.FromUnixTimeSeconds(1_727_894_187), response.CreatedAt); Assert.Equal(ChatFinishReason.Stop, response.FinishReason); @@ -555,9 +555,9 @@ public async Task MultiPartSystemMessage_NonStreaming() Assert.NotNull(response); Assert.Equal("chatcmpl-ADyV17bXeSm5rzUx3n46O7m3M0o3P", response.ResponseId); - Assert.Equal("Hi! It's so good to hear from you!", response.Message.Text); - Assert.Single(response.Message.Contents); - Assert.Equal(ChatRole.Assistant, response.Message.Role); + Assert.Equal("Hi! It's so good to hear from you!", response.Text); + Assert.Single(response.Messages.Single().Contents); + Assert.Equal(ChatRole.Assistant, response.Messages.Single().Role); Assert.Equal("gpt-4o-mini-2024-07-18", response.ModelId); Assert.Equal(DateTimeOffset.FromUnixTimeSeconds(1_727_894_187), response.CreatedAt); Assert.Equal(ChatFinishReason.Stop, response.FinishReason); @@ -656,9 +656,9 @@ public async Task EmptyAssistantMessage_NonStreaming() Assert.NotNull(response); Assert.Equal("chatcmpl-ADyV17bXeSm5rzUx3n46O7m3M0o3P", response.ResponseId); - Assert.Equal("I’m doing well, thank you! What’s on your mind today?", response.Message.Text); - Assert.Single(response.Message.Contents); - Assert.Equal(ChatRole.Assistant, response.Message.Role); + Assert.Equal("I’m doing well, thank you! What’s on your mind today?", response.Text); + Assert.Single(response.Messages.Single().Contents); + Assert.Equal(ChatRole.Assistant, response.Messages.Single().Role); Assert.Equal("gpt-4o-mini-2024-07-18", response.ModelId); Assert.Equal(DateTimeOffset.FromUnixTimeSeconds(1_727_894_187), response.CreatedAt); Assert.Equal(ChatFinishReason.Stop, response.FinishReason); @@ -771,9 +771,9 @@ public async Task FunctionCallContent_NonStreaming() }); Assert.NotNull(response); - Assert.Null(response.Message.Text); + Assert.Empty(response.Text); Assert.Equal("gpt-4o-mini-2024-07-18", response.ModelId); - Assert.Equal(ChatRole.Assistant, response.Message.Role); + Assert.Equal(ChatRole.Assistant, response.Messages.Single().Role); Assert.Equal(DateTimeOffset.FromUnixTimeSeconds(1_727_894_702), response.CreatedAt); Assert.Equal(ChatFinishReason.ToolCalls, response.FinishReason); Assert.NotNull(response.Usage); @@ -791,9 +791,8 @@ public async Task FunctionCallContent_NonStreaming() { "OutputTokenDetails.RejectedPredictionTokenCount", 0 }, }, response.Usage.AdditionalCounts); - Assert.Single(response.Choices); - Assert.Single(response.Message.Contents); - FunctionCallContent fcc = Assert.IsType(response.Message.Contents[0]); + Assert.Single(response.Messages.Single().Contents); + FunctionCallContent fcc = Assert.IsType(response.Messages.Single().Contents[0]); Assert.Equal("GetPersonAge", fcc.Name); AssertExtensions.EqualFunctionCallParameters(new Dictionary { ["personName"] = "Alice" }, fcc.Arguments); @@ -1033,9 +1032,9 @@ public async Task AssistantMessageWithBothToolsAndContent_NonStreaming() Assert.NotNull(response); Assert.Equal("chatcmpl-ADyV17bXeSm5rzUx3n46O7m3M0o3P", response.ResponseId); - Assert.Equal("I’m doing well, thank you! What’s on your mind today?", response.Message.Text); - Assert.Single(response.Message.Contents); - Assert.Equal(ChatRole.Assistant, response.Message.Role); + Assert.Equal("I’m doing well, thank you! What’s on your mind today?", response.Text); + Assert.Single(response.Messages.Single().Contents); + Assert.Equal(ChatRole.Assistant, response.Messages.Single().Role); Assert.Equal("gpt-4o-mini-2024-07-18", response.ModelId); Assert.Equal(DateTimeOffset.FromUnixTimeSeconds(1_727_894_187), response.CreatedAt); Assert.Equal(ChatFinishReason.Stop, response.FinishReason); diff --git a/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAISerializationTests.cs b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAISerializationTests.cs index 51947ae0c8e..752e44dc388 100644 --- a/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAISerializationTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAISerializationTests.cs @@ -460,7 +460,7 @@ public static async Task RequestDeserialization_ToolChatMessage() } [Fact] - public static async Task SerializeResponse_SingleChoice() + public static async Task SerializeResponse() { ChatMessage message = new() { @@ -558,28 +558,6 @@ public static async Task SerializeResponse_SingleChoice() """, result); } - [Fact] - public static async Task SerializeResponse_ManyChoices_ThrowsNotSupportedException() - { - ChatMessage message1 = new() - { - Role = ChatRole.Assistant, - Text = "Hello! How can I assist you today?", - }; - - ChatMessage message2 = new() - { - Role = ChatRole.Assistant, - Text = "Hey there! How can I help?", - }; - - ChatResponse response = new([message1, message2]); - - using MemoryStream stream = new(); - var ex = await Assert.ThrowsAsync(() => OpenAISerializationHelpers.SerializeAsync(stream, response)); - Assert.Contains("multiple choices", ex.Message); - } - [Fact] public static async Task SerializeStreamingResponse() { diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ChatClientStructuredOutputExtensionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ChatClientStructuredOutputExtensionsTests.cs index ef1ac9718d0..0b8aca0785e 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ChatClientStructuredOutputExtensionsTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ChatClientStructuredOutputExtensionsTests.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.ComponentModel; +using System.Linq; using System.Text.Json; using System.Text.RegularExpressions; using System.Threading.Tasks; @@ -84,7 +85,7 @@ public async Task WrapsNonObjectValuesInDataProperty() { GetResponseAsyncCallback = (messages, options, cancellationToken) => { - var suppliedSchemaMatch = Regex.Match(messages[1].Text!, "```(.*?)```", RegexOptions.Singleline); + var suppliedSchemaMatch = Regex.Match(messages.Last().Text!, "```(.*?)```", RegexOptions.Singleline); Assert.True(suppliedSchemaMatch.Success); Assert.Equal(""" { diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ConfigureOptionsChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ConfigureOptionsChatClientTests.cs index d1ae1c21ebe..0cddb58e006 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ConfigureOptionsChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ConfigureOptionsChatClientTests.cs @@ -34,7 +34,7 @@ public async Task ConfigureOptions_ReturnedInstancePassedToNextClient(bool nullP { ChatOptions? providedOptions = nullProvidedOptions ? null : new() { ModelId = "test" }; ChatOptions? returnedOptions = null; - ChatResponse expectedResponse = new(Array.Empty()); + ChatResponse expectedResponse = new(); var expectedUpdates = Enumerable.Range(0, 3).Select(i => new ChatResponseUpdate()).ToArray(); using CancellationTokenSource cts = new(); diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs index b87b866c50f..611d3b9f45b 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs @@ -39,8 +39,8 @@ public async Task CachesSuccessResultsAsync() // Verify that all the expected properties will round-trip through the cache, // even if this involves serialization - var expectedResponse = new ChatResponse([ - new(new ChatRole("fakeRole"), "This is some content") + var expectedResponse = new ChatResponse( + new ChatMessage(new ChatRole("fakeRole"), "This is some content") { AdditionalProperties = new() { ["a"] = "b" }, Contents = [new FunctionCallContent("someCallId", "functionName", new Dictionary @@ -52,8 +52,7 @@ public async Task CachesSuccessResultsAsync() ["arg5"] = false, ["arg6"] = null })] - } - ]) + }) { ResponseId = "someId", Usage = new() @@ -111,7 +110,7 @@ public async Task AllowsConcurrentCallsAsync() { innerCallCount++; await completionTcs.Task; - return new ChatResponse([new(ChatRole.Assistant, "Hello")]); + return new ChatResponse(new ChatMessage(ChatRole.Assistant, "Hello")); } }; using var outer = new DistributedCachingChatClient(testClient, _storage) @@ -128,13 +127,13 @@ public async Task AllowsConcurrentCallsAsync() Assert.False(result1.IsCompleted); Assert.False(result2.IsCompleted); completionTcs.SetResult(true); - Assert.Equal("Hello", (await result1).Message.Text); - Assert.Equal("Hello", (await result2).Message.Text); + Assert.Equal("Hello", (await result1).Text); + Assert.Equal("Hello", (await result2).Text); // Act 2: Subsequent calls after completion are resolved from the cache var result3 = outer.GetResponseAsync("some input"); Assert.Equal(2, innerCallCount); - Assert.Equal("Hello", (await result3).Message.Text); + Assert.Equal("Hello", (await result3).Text); } [Fact] @@ -185,7 +184,7 @@ public async Task DoesNotCacheCanceledResultsAsync() await resolutionTcs.Task; } - return new ChatResponse([new(ChatRole.Assistant, "A good result")]); + return new ChatResponse(new ChatMessage(ChatRole.Assistant, "A good result")); } }; using var outer = new DistributedCachingChatClient(testClient, _storage) @@ -205,7 +204,7 @@ public async Task DoesNotCacheCanceledResultsAsync() // Act/Assert: Second call can succeed var result2 = await outer.GetResponseAsync([input]); Assert.Equal(2, innerCallCount); - Assert.Equal("A good result", result2.Message.Text); + Assert.Equal("A good result", result2.Text); } [Fact] @@ -217,13 +216,6 @@ public async Task StreamingCachesSuccessResultsAsync() // even if this involves serialization List actualUpdate = [ - new() - { - Role = new ChatRole("fakeRole1"), - ChoiceIndex = 1, - AdditionalProperties = new() { ["a"] = "b" }, - Contents = [new TextContent("Chunk1")] - }, new() { Role = new ChatRole("fakeRole2"), @@ -243,13 +235,6 @@ public async Task StreamingCachesSuccessResultsAsync() Contents = [new FunctionCallContent("someCallId", "someFn", new Dictionary { ["arg1"] = "value1" })], }, new() - { - Role = new ChatRole("fakeRole1"), - ChoiceIndex = 1, - AdditionalProperties = new() { ["a"] = "b" }, - Contents = [new TextContent("Chunk1")] - }, - new() { Contents = [new UsageContent(new() { InputTokenCount = 123, OutputTokenCount = 456, TotalTokenCount = 99999 })], }, @@ -295,12 +280,12 @@ public async Task StreamingCoalescesConsecutiveTextChunksAsync(bool? coalesce) // Arrange List expectedResponse = [ - new() { Role = ChatRole.Assistant, Text = "This" }, - new() { Role = ChatRole.Assistant, Text = " becomes one chunk" }, + new(ChatRole.Assistant, "This"), + new(ChatRole.Assistant, " becomes one chunk"), new() { Role = ChatRole.Assistant, Contents = [new FunctionCallContent("callId1", "separator")] }, - new() { Role = ChatRole.Assistant, Text = "... and this" }, - new() { Role = ChatRole.Assistant, Text = " becomes another" }, - new() { Role = ChatRole.Assistant, Text = " one." }, + new(ChatRole.Assistant, "... and this"), + new(ChatRole.Assistant, " becomes another"), + new(ChatRole.Assistant, " one."), ]; using var testClient = new TestChatClient @@ -416,7 +401,7 @@ public async Task StreamingAllowsConcurrentCallsAsync() var completionTcs = new TaskCompletionSource(); List expectedResponse = [ - new() { Role = ChatRole.Assistant, Text = "Chunk 1" }, + new(ChatRole.Assistant, "Chunk 1"), ]; using var testClient = new TestChatClient { @@ -464,7 +449,7 @@ public async Task StreamingDoesNotCacheExceptionResultsAsync() innerCallCount++; return ToAsyncEnumerableAsync(Task.CompletedTask, [ - () => new() { Role = ChatRole.Assistant, Text = "Chunk 1" }, + () => new(ChatRole.Assistant, "Chunk 1"), () => throw new InvalidTimeZoneException("some failure"), ]); } @@ -503,7 +488,7 @@ public async Task StreamingDoesNotCacheCanceledResultsAsync() innerCallCount++; return ToAsyncEnumerableAsync( innerCallCount == 1 ? completionTcs.Task : Task.CompletedTask, - [() => new() { Role = ChatRole.Assistant, Text = "A good result" }]); + [() => new(ChatRole.Assistant, "A good result")]); } }; using var outer = new DistributedCachingChatClient(testClient, _storage) @@ -539,7 +524,7 @@ public async Task CacheKeyVariesByChatOptionsAsync() { innerCallCount++; await Task.Yield(); - return new([new(ChatRole.Assistant, options!.AdditionalProperties!["someKey"]!.ToString())]); + return new(new ChatMessage(ChatRole.Assistant, options!.AdditionalProperties!["someKey"]!.ToString())); } }; using var outer = new DistributedCachingChatClient(testClient, _storage) @@ -559,8 +544,8 @@ public async Task CacheKeyVariesByChatOptionsAsync() // Assert: Same result Assert.Equal(1, innerCallCount); - Assert.Equal("value 1", result1.Message.Text); - Assert.Equal("value 1", result2.Message.Text); + Assert.Equal("value 1", result1.Text); + Assert.Equal("value 1", result2.Text); // Act: Call with two different ChatOptions that have different values var result3 = await outer.GetResponseAsync([], new ChatOptions @@ -574,8 +559,8 @@ public async Task CacheKeyVariesByChatOptionsAsync() // Assert: Different results Assert.Equal(2, innerCallCount); - Assert.Equal("value 1", result3.Message.Text); - Assert.Equal("value 2", result4.Message.Text); + Assert.Equal("value 1", result3.Text); + Assert.Equal("value 2", result4.Text); } [Fact] @@ -590,7 +575,7 @@ public async Task SubclassCanOverrideCacheKeyToVaryByChatOptionsAsync() { innerCallCount++; await Task.Yield(); - return new([new(ChatRole.Assistant, options!.AdditionalProperties!["someKey"]!.ToString())]); + return new(new ChatMessage(ChatRole.Assistant, options!.AdditionalProperties!["someKey"]!.ToString())); } }; using var outer = new CachingChatClientWithCustomKey(testClient, _storage) @@ -610,21 +595,20 @@ public async Task SubclassCanOverrideCacheKeyToVaryByChatOptionsAsync() // Assert: Different results Assert.Equal(2, innerCallCount); - Assert.Equal("value 1", result1.Message.Text); - Assert.Equal("value 2", result2.Message.Text); + Assert.Equal("value 1", result1.Text); + Assert.Equal("value 2", result2.Text); } [Fact] public async Task CanCacheCustomContentTypesAsync() { // Arrange - var expectedResponse = new ChatResponse([ - new(new ChatRole("fakeRole"), + var expectedResponse = new ChatResponse( + new ChatMessage(new ChatRole("fakeRole"), [ new CustomAIContent1("Hello", DateTime.Now), new CustomAIContent2("Goodbye", 42), - ]) - ]); + ])); var serializerOptions = new JsonSerializerOptions(TestJsonSerializerContext.Default.Options); serializerOptions.TypeInfoResolver = serializerOptions.TypeInfoResolver!.WithAddedModifier(typeInfo => @@ -663,8 +647,8 @@ public async Task CanCacheCustomContentTypesAsync() // Assert Assert.Equal(1, innerCallCount); AssertResponsesEqual(expectedResponse, result2); - Assert.NotSame(result2.Message.Contents[0], expectedResponse.Message.Contents[0]); - Assert.NotSame(result2.Message.Contents[1], expectedResponse.Message.Contents[1]); + Assert.NotSame(result2.Messages.Last().Contents[0], expectedResponse.Messages.Last().Contents[0]); + Assert.NotSame(result2.Messages.Last().Contents[1], expectedResponse.Messages.Last().Contents[1]); } [Fact] @@ -678,8 +662,8 @@ public async Task CanResolveIDistributedCacheFromDI() { GetResponseAsyncCallback = delegate { - return Task.FromResult(new ChatResponse([ - new(ChatRole.Assistant, [new TextContent("Hey")])])); + return Task.FromResult(new ChatResponse( + new ChatMessage(ChatRole.Assistant, [new TextContent("Hey")]))); } }; using var outer = testClient @@ -739,33 +723,31 @@ private static void AssertResponsesEqual(ChatResponse expected, ChatResponse act Assert.Equal( JsonSerializer.Serialize(expected.AdditionalProperties, TestJsonSerializerContext.Default.Options), JsonSerializer.Serialize(actual.AdditionalProperties, TestJsonSerializerContext.Default.Options)); - Assert.Equal(expected.Choices.Count, actual.Choices.Count); - for (var i = 0; i < expected.Choices.Count; i++) + ChatMessage expectedMessage = expected.Messages.Last(); + ChatMessage actualMessage = actual.Messages.Last(); + Assert.IsType(expectedMessage.GetType(), actualMessage); + Assert.Equal(expectedMessage.Role, actualMessage.Role); + Assert.Equal(expectedMessage.Text, actualMessage.Text); + Assert.Equal(expectedMessage.Contents.Count, actualMessage.Contents.Count); + + for (var itemIndex = 0; itemIndex < expectedMessage.Contents.Count; itemIndex++) { - Assert.IsType(expected.Choices[i].GetType(), actual.Choices[i]); - Assert.Equal(expected.Choices[i].Role, actual.Choices[i].Role); - Assert.Equal(expected.Choices[i].Text, actual.Choices[i].Text); - Assert.Equal(expected.Choices[i].Contents.Count, actual.Choices[i].Contents.Count); + var expectedItem = expectedMessage.Contents[itemIndex]; + var actualItem = actualMessage.Contents[itemIndex]; + Assert.IsType(expectedItem.GetType(), actualItem); - for (var itemIndex = 0; itemIndex < expected.Choices[i].Contents.Count; itemIndex++) + if (expectedItem is FunctionCallContent expectedFcc) { - var expectedItem = expected.Choices[i].Contents[itemIndex]; - var actualItem = actual.Choices[i].Contents[itemIndex]; - Assert.IsType(expectedItem.GetType(), actualItem); - - if (expectedItem is FunctionCallContent expectedFcc) - { - var actualFcc = (FunctionCallContent)actualItem; - Assert.Equal(expectedFcc.Name, actualFcc.Name); - Assert.Equal(expectedFcc.CallId, actualFcc.CallId); - - // The correct JSON-round-tripping of AIContent/AIContent is not - // the responsibility of CachingChatClient, so not testing that here. - Assert.Equal( - JsonSerializer.Serialize(expectedFcc.Arguments, TestJsonSerializerContext.Default.Options), - JsonSerializer.Serialize(actualFcc.Arguments, TestJsonSerializerContext.Default.Options)); - } + var actualFcc = (FunctionCallContent)actualItem; + Assert.Equal(expectedFcc.Name, actualFcc.Name); + Assert.Equal(expectedFcc.CallId, actualFcc.CallId); + + // The correct JSON-round-tripping of AIContent/AIContent is not + // the responsibility of CachingChatClient, so not testing that here. + Assert.Equal( + JsonSerializer.Serialize(expectedFcc.Arguments, TestJsonSerializerContext.Default.Options), + JsonSerializer.Serialize(actualFcc.Arguments, TestJsonSerializerContext.Default.Options)); } } } @@ -780,7 +762,6 @@ private static async Task AssertResponsesEqualAsync(IReadOnlyList("value", () => ctx.CallContent = null!); - Assert.Throws("value", () => ctx.ChatMessages = null!); + Assert.Throws("value", () => ctx.Messages = null!); Assert.Throws("value", () => ctx.Function = null!); } @@ -44,9 +44,9 @@ public void Properties_Roundtrip() { FunctionInvocationContext ctx = new(); - List chatMessages = []; - ctx.ChatMessages = chatMessages; - Assert.Same(chatMessages, ctx.ChatMessages); + List messages = []; + ctx.Messages = messages; + Assert.Same(messages, ctx.Messages); AIFunction function = AIFunctionFactory.Create(() => { }, nameof(Properties_Roundtrip)); ctx.Function = function; diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs index dfba437c98f..c2b6b067c8a 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs @@ -19,9 +19,6 @@ namespace Microsoft.Extensions.AI; public class FunctionInvokingChatClientTests { - private readonly Func _keepMessagesConfigure = - b => b.Use(client => new FunctionInvokingChatClient(client) { KeepFunctionCallingContent = true }); - [Fact] public void InvalidArgs_Throws() { @@ -37,7 +34,6 @@ public void Ctor_HasExpectedDefaults() Assert.False(client.AllowConcurrentInvocation); Assert.False(client.IncludeDetailedErrors); - Assert.True(client.KeepFunctionCallingContent); Assert.Null(client.MaximumIterationsPerRequest); Assert.False(client.RetryOnError); } @@ -67,9 +63,9 @@ public async Task SupportsSingleFunctionCallPerRequestAsync() new ChatMessage(ChatRole.Assistant, "world"), ]; - await InvokeAndAssertAsync(options, plan, configurePipeline: _keepMessagesConfigure); + await InvokeAndAssertAsync(options, plan); - await InvokeAndAssertStreamingAsync(options, plan, configurePipeline: _keepMessagesConfigure); + await InvokeAndAssertStreamingAsync(options, plan); } [Theory] @@ -115,7 +111,7 @@ public async Task SupportsMultipleFunctionCallsPerRequestAsync(bool concurrentIn ]; Func configure = b => b.Use( - s => new FunctionInvokingChatClient(s) { AllowConcurrentInvocation = concurrentInvocation, KeepFunctionCallingContent = true }); + s => new FunctionInvokingChatClient(s) { AllowConcurrentInvocation = concurrentInvocation }); await InvokeAndAssertAsync(options, plan, configurePipeline: configure); @@ -156,7 +152,7 @@ public async Task ParallelFunctionCallsMayBeInvokedConcurrentlyAsync() ]; Func configure = b => b.Use( - s => new FunctionInvokingChatClient(s) { AllowConcurrentInvocation = true, KeepFunctionCallingContent = true }); + s => new FunctionInvokingChatClient(s) { AllowConcurrentInvocation = true }); await InvokeAndAssertAsync(options, plan, configurePipeline: configure); @@ -199,68 +195,13 @@ public async Task ConcurrentInvocationOfParallelCallsDisabledByDefaultAsync() new ChatMessage(ChatRole.Assistant, "done"), ]; - await InvokeAndAssertAsync(options, plan, configurePipeline: _keepMessagesConfigure); + await InvokeAndAssertAsync(options, plan); - await InvokeAndAssertStreamingAsync(options, plan, configurePipeline: _keepMessagesConfigure); + await InvokeAndAssertStreamingAsync(options, plan); } - [Theory] - [InlineData(false)] - [InlineData(true)] - public async Task RemovesFunctionCallingMessagesWhenRequestedAsync(bool keepFunctionCallingMessages) - { - var options = new ChatOptions - { - Tools = - [ - AIFunctionFactory.Create(() => "Result 1", "Func1"), - AIFunctionFactory.Create((int i) => $"Result 2: {i}", "Func2"), - AIFunctionFactory.Create((int i) => { }, "VoidReturn"), - ] - }; - - List plan = - [ - new ChatMessage(ChatRole.User, "hello"), - new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]), - new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", result: "Result 1")]), - new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func2", arguments: new Dictionary { { "i", 42 } })]), - new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", result: "Result 2: 42")]), - new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId3", "VoidReturn", arguments: new Dictionary { { "i", 43 } })]), - new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId3", result: "Success: Function completed.")]), - new ChatMessage(ChatRole.Assistant, "world"), - ]; - - List? expected = keepFunctionCallingMessages ? null : - [ - new ChatMessage(ChatRole.User, "hello"), - new ChatMessage(ChatRole.Assistant, "world") - ]; - - Func configure = b => b.Use( - client => new FunctionInvokingChatClient(client) { KeepFunctionCallingContent = keepFunctionCallingMessages }); - - Validate(await InvokeAndAssertAsync(options, plan, expected, configure)); - Validate(await InvokeAndAssertStreamingAsync(options, plan, expected, configure)); - - void Validate(List finalChat) - { - IEnumerable content = finalChat.SelectMany(m => m.Contents); - if (keepFunctionCallingMessages) - { - Assert.Contains(content, c => c is FunctionCallContent or FunctionResultContent); - } - else - { - Assert.All(content, c => Assert.False(c is FunctionCallContent or FunctionResultContent)); - } - } - } - - [Theory] - [InlineData(false)] - [InlineData(true)] - public async Task KeepsFunctionCallingContentWhenRequestedAsync(bool keepFunctionCallingMessages) + [Fact] + public async Task KeepsFunctionCallingContent() { var options = new ChatOptions { @@ -284,45 +225,15 @@ public async Task KeepsFunctionCallingContentWhenRequestedAsync(bool keepFunctio new ChatMessage(ChatRole.Assistant, "world"), ]; - Func configure = b => b.Use( - client => new FunctionInvokingChatClient(client) { KeepFunctionCallingContent = keepFunctionCallingMessages }); - #pragma warning disable SA1005, S125 - Validate(await InvokeAndAssertAsync(options, plan, keepFunctionCallingMessages ? null : - [ - new ChatMessage(ChatRole.User, "hello"), - new ChatMessage(ChatRole.Assistant, [new TextContent("extra"), new TextContent("stuff")]), - new ChatMessage(ChatRole.Assistant, "more"), - new ChatMessage(ChatRole.Assistant, "world"), - ], configure)); + Validate(await InvokeAndAssertAsync(options, plan)); - Validate(await InvokeAndAssertStreamingAsync(options, plan, keepFunctionCallingMessages ? - [ - new ChatMessage(ChatRole.User, "hello"), - new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]), - new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", result: "Result 1")]), - new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func2", arguments: new Dictionary { { "i", 42 } })]), - new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", result: "Result 2: 42")]), - new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId3", "VoidReturn", arguments: new Dictionary { { "i", 43 } })]), - new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId3", result: "Success: Function completed.")]), - new ChatMessage(ChatRole.Assistant, "extrastuffmoreworld"), - ] : - [ - new ChatMessage(ChatRole.User, "hello"), - new ChatMessage(ChatRole.Assistant, "extrastuffmoreworld"), - ], configure)); + Validate(await InvokeAndAssertStreamingAsync(options, plan)); - void Validate(List finalChat) + static void Validate(List finalChat) { IEnumerable content = finalChat.SelectMany(m => m.Contents); - if (keepFunctionCallingMessages) - { - Assert.Contains(content, c => c is FunctionCallContent or FunctionResultContent); - } - else - { - Assert.All(content, c => Assert.False(c is FunctionCallContent or FunctionResultContent)); - } + Assert.Contains(content, c => c is FunctionCallContent or FunctionResultContent); } } @@ -348,51 +259,13 @@ public async Task ExceptionDetailsOnlyReportedWhenRequestedAsync(bool detailedEr ]; Func configure = b => b.Use( - s => new FunctionInvokingChatClient(s) { IncludeDetailedErrors = detailedErrors, KeepFunctionCallingContent = true }); + s => new FunctionInvokingChatClient(s) { IncludeDetailedErrors = detailedErrors }); await InvokeAndAssertAsync(options, plan, configurePipeline: configure); await InvokeAndAssertStreamingAsync(options, plan, configurePipeline: configure); } - [Fact] - public async Task RejectsMultipleChoicesAsync() - { - var func1 = AIFunctionFactory.Create(() => "Some result 1", "Func1"); - var func2 = AIFunctionFactory.Create(() => "Some result 2", "Func2"); - - var expected = new ChatResponse( - [ - new(ChatRole.Assistant, [new FunctionCallContent("callId1", func1.Name)]), - new(ChatRole.Assistant, [new FunctionCallContent("callId2", func2.Name)]), - ]); - - using var innerClient = new TestChatClient - { - GetResponseAsyncCallback = async (chatContents, options, cancellationToken) => - { - await Task.Yield(); - return expected; - }, - GetStreamingResponseAsyncCallback = (chatContents, options, cancellationToken) => - YieldAsync(expected.ToChatResponseUpdates()), - }; - - IChatClient service = innerClient.AsBuilder().UseFunctionInvocation().Build(); - - List chat = [new ChatMessage(ChatRole.User, "hello")]; - ChatOptions options = new() { Tools = [func1, func2] }; - - Validate(await Assert.ThrowsAsync(() => service.GetResponseAsync(chat, options))); - Validate(await Assert.ThrowsAsync(() => service.GetStreamingResponseAsync(chat, options).ToChatResponseAsync())); - - void Validate(Exception ex) - { - Assert.Contains("only accepts a single choice", ex.Message); - Assert.Single(chat); // It didn't add anything to the chat history - } - } - [Theory] [InlineData(LogLevel.Trace)] [InlineData(LogLevel.Debug)] @@ -413,10 +286,7 @@ public async Task FunctionInvocationsLogged(LogLevel level) }; Func configure = b => - b.Use((c, services) => new FunctionInvokingChatClient(c, services.GetRequiredService>()) - { - KeepFunctionCallingContent = true, - }); + b.Use((c, services) => new FunctionInvokingChatClient(c, services.GetRequiredService>())); await InvokeAsync(services => InvokeAndAssertAsync(options, plan, configurePipeline: configure, services: services)); @@ -472,10 +342,7 @@ public async Task FunctionInvocationTrackedWithActivity(bool enableTelemetry) }; Func configure = b => b.Use(c => - new FunctionInvokingChatClient(new OpenTelemetryChatClient(c, sourceName: sourceName)) - { - KeepFunctionCallingContent = true, - }); + new FunctionInvokingChatClient(new OpenTelemetryChatClient(c, sourceName: sourceName))); await InvokeAsync(() => InvokeAndAssertAsync(options, plan, configurePipeline: configure)); @@ -532,38 +399,84 @@ public async Task SupportsConsecutiveStreamingUpdatesWithFunctionCalls() GetStreamingResponseAsyncCallback = (chatContents, chatOptions, cancellationToken) => { // If the conversation is just starting, issue two consecutive updates with function calls - // Otherwise just end the conversation - return chatContents.Last().Text == "Hello" - ? YieldAsync( - new ChatResponseUpdate { Contents = [new FunctionCallContent("callId1", "Func1", new Dictionary { ["text"] = "Input 1" })] }, - new ChatResponseUpdate { Contents = [new FunctionCallContent("callId2", "Func1", new Dictionary { ["text"] = "Input 2" })] }) - : YieldAsync( - new ChatResponseUpdate { Contents = [new TextContent("OK bye")] }); + // Otherwise just end the conversation. + List updates; + string responseId = Guid.NewGuid().ToString("N"); + if (chatContents.Last().Text == "Hello") + { + updates = + [ + new() { Contents = [new FunctionCallContent("callId1", "Func1", new Dictionary { ["text"] = "Input 1" })] }, + new() { Contents = [new FunctionCallContent("callId2", "Func1", new Dictionary { ["text"] = "Input 2" })] } + ]; + } + else + { + updates = [new() { Contents = [new TextContent("OK bye")] }]; + } + + foreach (var update in updates) + { + update.ResponseId = responseId; + } + + return YieldAsync(updates); } }; - using var client = new FunctionInvokingChatClient(innerClient) { KeepFunctionCallingContent = true }; + using var client = new FunctionInvokingChatClient(innerClient); - var updates = new List(); - await foreach (var update in client.GetStreamingResponseAsync(messages, options, CancellationToken.None)) - { - updates.Add(update); - } + var response = await client.GetStreamingResponseAsync(messages, options, CancellationToken.None).ToChatResponseAsync(); - // Message history should now include the FCCs and FRCs - Assert.Collection(messages, - m => Assert.Equal("Hello", Assert.IsType(Assert.Single(m.Contents)).Text), + // The returned message should include the FCCs and FRCs. + Assert.Collection(response.Messages, m => Assert.Collection(m.Contents, c => Assert.Equal("Input 1", Assert.IsType(c).Arguments!["text"]), c => Assert.Equal("Input 2", Assert.IsType(c).Arguments!["text"])), m => Assert.Collection(m.Contents, c => Assert.Equal("Result for Input 1", Assert.IsType(c).Result?.ToString()), - c => Assert.Equal("Result for Input 2", Assert.IsType(c).Result?.ToString()))); + c => Assert.Equal("Result for Input 2", Assert.IsType(c).Result?.ToString())), + m => Assert.Equal("OK bye", Assert.IsType(Assert.Single(m.Contents)).Text)); + } - // The returned updates should *not* include the FCCs and FRCs - var allUpdateContents = updates.SelectMany(updates => updates.Contents).ToList(); - var singleUpdateContent = Assert.IsType(Assert.Single(allUpdateContents)); - Assert.Equal("OK bye", singleUpdateContent.Text); + [Fact] + public async Task AllResponseMessagesReturned() + { + var options = new ChatOptions + { + Tools = [AIFunctionFactory.Create(() => "doesn't matter", "Func1")] + }; + + var messages = new List + { + new(ChatRole.User, "Hello"), + }; + + using var innerClient = new TestChatClient + { + GetResponseAsyncCallback = async (chatContents, chatOptions, cancellationToken) => + { + await Task.Yield(); + + ChatMessage message = chatContents.Count() is 1 or 3 ? + new(ChatRole.Assistant, [new FunctionCallContent($"callId{chatContents.Count()}", "Func1")]) : + new(ChatRole.Assistant, "The answer is 42."); + + return new(message); + } + }; + + using var client = new FunctionInvokingChatClient(innerClient); + + ChatResponse response = await client.GetResponseAsync(messages, options); + + Assert.Equal(5, response.Messages.Count); + Assert.Equal("The answer is 42.", response.Text); + Assert.IsType(Assert.Single(response.Messages[0].Contents)); + Assert.IsType(Assert.Single(response.Messages[1].Contents)); + Assert.IsType(Assert.Single(response.Messages[2].Contents)); + Assert.IsType(Assert.Single(response.Messages[3].Contents)); + Assert.IsType(Assert.Single(response.Messages[4].Contents)); } [Fact] @@ -610,21 +523,9 @@ public async Task CanAccesssFunctionInvocationContextFromFunctionCall() new ChatMessage(ChatRole.Assistant, "world"), ]; - await InvokeAsync(() => InvokeAndAssertAsync(options, plan, expected: [ - .. planBeforeTermination, - - // The last message is the one returned by the chat client - // This message's content should contain the last function call before the termination - new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func1", new Dictionary { ["i"] = 42 })]), - ], configurePipeline: _keepMessagesConfigure)); - - await InvokeAsync(() => InvokeAndAssertStreamingAsync(options, plan, expected: [ - .. planBeforeTermination, + await InvokeAsync(() => InvokeAndAssertAsync(options, plan, planBeforeTermination)); - // The last message is the one returned by the chat client - // When streaming, function call content is removed from this message - new ChatMessage(ChatRole.Assistant, []), - ], configurePipeline: _keepMessagesConfigure)); + await InvokeAsync(() => InvokeAndAssertStreamingAsync(options, plan, planBeforeTermination)); // The current context should be null outside the async call stack for the function invocation Assert.Null(FunctionInvokingChatClient.CurrentContext); @@ -633,7 +534,7 @@ async Task InvokeAsync(Func>> work) { invocationContexts.Clear(); - var chatMessages = await work(); + var messages = await work(); Assert.Collection(invocationContexts, c => AssertInvocationContext(c, iteration: 0, terminate: false), @@ -642,7 +543,8 @@ async Task InvokeAsync(Func>> work) void AssertInvocationContext(FunctionInvocationContext context, int iteration, bool terminate) { Assert.NotNull(context); - Assert.Same(chatMessages, context.ChatMessages); + Assert.Equal(messages.Count, context.Messages.Count); + Assert.Equal(string.Concat(messages), string.Concat(context.Messages)); Assert.Same(function, context.Function); Assert.Equal("Func1", context.CallContent.Name); Assert.Equal(0, context.FunctionCallIndex); @@ -663,7 +565,7 @@ public async Task PropagatesResponseChatThreadIdToOptions() int iteration = 0; - Func, ChatOptions?, CancellationToken, ChatResponse> callback = + Func, ChatOptions?, CancellationToken, ChatResponse> callback = (chatContents, chatOptions, cancellationToken) => { iteration++; @@ -728,17 +630,20 @@ private static async Task> InvokeAndAssertAsync( var usage = CreateRandomUsage(); expectedTotalTokenCounts += usage.InputTokenCount!.Value; - return new ChatResponse(new ChatMessage(ChatRole.Assistant, [.. plan[contents.Count].Contents])) { Usage = usage }; + + var message = new ChatMessage(ChatRole.Assistant, [.. plan[contents.Count()].Contents]); + return new ChatResponse(message) { Usage = usage, ResponseId = Guid.NewGuid().ToString("N") }; } }; IChatClient service = configurePipeline(innerClient.AsBuilder()).Build(services); var result = await service.GetResponseAsync(chat, options, cts.Token); - chat.Add(result.Message); + Assert.NotNull(result); + + chat.AddRange(result.Messages); expected ??= plan; - Assert.NotNull(result); Assert.Equal(expected.Count, chat.Count); for (int i = 0; i < expected.Count; i++) { @@ -817,17 +722,19 @@ private static async Task> InvokeAndAssertStreamingAsync( { Assert.Equal(cts.Token, actualCancellationToken); - return YieldAsync(new ChatResponse(new ChatMessage(ChatRole.Assistant, [.. plan[contents.Count].Contents])).ToChatResponseUpdates()); + ChatMessage message = new(ChatRole.Assistant, [.. plan[contents.Count()].Contents]); + return YieldAsync(new ChatResponse(message) { ResponseId = Guid.NewGuid().ToString("N") }.ToChatResponseUpdates()); } }; IChatClient service = configurePipeline(innerClient.AsBuilder()).Build(services); var result = await service.GetStreamingResponseAsync(chat, options, cts.Token).ToChatResponseAsync(); - chat.Add(result.Message); + Assert.NotNull(result); + + chat.AddRange(result.Messages); expected ??= plan; - Assert.NotNull(result); Assert.Equal(expected.Count, chat.Count); for (int i = 0; i < expected.Count; i++) { @@ -863,7 +770,7 @@ private static async Task> InvokeAndAssertStreamingAsync( return chat; } - private static async IAsyncEnumerable YieldAsync(params T[] items) + private static async IAsyncEnumerable YieldAsync(params IEnumerable items) { await Task.Yield(); foreach (var item in items) diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/LoggingChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/LoggingChatClientTests.cs index 721768a5e08..51638d1a252 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/LoggingChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/LoggingChatClientTests.cs @@ -56,7 +56,7 @@ public async Task GetResponseAsync_LogsResponseInvocationAndCompletion(LogLevel { GetResponseAsyncCallback = (messages, options, cancellationToken) => { - return Task.FromResult(new ChatResponse([new(ChatRole.Assistant, "blue whale")])); + return Task.FromResult(new ChatResponse(new ChatMessage(ChatRole.Assistant, "blue whale"))); }, }; @@ -105,8 +105,8 @@ public async Task GetResponseStreamingStreamAsync_LogsUpdateReceived(LogLevel le static async IAsyncEnumerable GetUpdatesAsync() { await Task.Yield(); - yield return new ChatResponseUpdate { Role = ChatRole.Assistant, Text = "blue " }; - yield return new ChatResponseUpdate { Role = ChatRole.Assistant, Text = "whale" }; + yield return new(ChatRole.Assistant, "blue "); + yield return new(ChatRole.Assistant, "whale"); } using IChatClient client = innerClient diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/OpenTelemetryChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/OpenTelemetryChatClientTests.cs index bccba4cc65d..37ae545c04c 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/OpenTelemetryChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/OpenTelemetryChatClientTests.cs @@ -62,17 +62,15 @@ public async Task ExpectedInformationLogged_Async(bool enableSensitiveData, bool }; async static IAsyncEnumerable CallbackAsync( - IList messages, ChatOptions? options, [EnumeratorCancellation] CancellationToken cancellationToken) + IEnumerable messages, ChatOptions? options, [EnumeratorCancellation] CancellationToken cancellationToken) { await Task.Yield(); foreach (string text in new[] { "The ", "blue ", "whale,", " ", "", "I", " think." }) { await Task.Yield(); - yield return new ChatResponseUpdate + yield return new ChatResponseUpdate(ChatRole.Assistant, text) { - Role = ChatRole.Assistant, - Text = text, ResponseId = "id123", }; } @@ -107,7 +105,7 @@ async static IAsyncEnumerable CallbackAsync( }) .Build(); - List chatMessages = + List messages = [ new(ChatRole.System, "You are a close friend."), new(ChatRole.User, "Hey!"), @@ -138,14 +136,14 @@ async static IAsyncEnumerable CallbackAsync( if (streaming) { - await foreach (var update in chatClient.GetStreamingResponseAsync(chatMessages, options)) + await foreach (var update in chatClient.GetStreamingResponseAsync(messages, options)) { await Task.Yield(); } } else { - await chatClient.GetResponseAsync(chatMessages, options); + await chatClient.GetResponseAsync(messages, options); } var activity = Assert.Single(activities); diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/UseDelegateChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/UseDelegateChatClientTests.cs index 3f94a47b7bd..18ad0c08bbd 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/UseDelegateChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/UseDelegateChatClientTests.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; @@ -32,43 +33,43 @@ public async Task Shared_ContextPropagated() using IChatClient innerClient = new TestChatClient { - GetResponseAsyncCallback = (chatMessages, options, cancellationToken) => + GetResponseAsyncCallback = (messages, options, cancellationToken) => { - Assert.Same(expectedMessages, chatMessages); + Assert.Same(expectedMessages, messages); Assert.Same(expectedOptions, options); Assert.Equal(expectedCts.Token, cancellationToken); Assert.Equal(42, asyncLocal.Value); return Task.FromResult(new ChatResponse(new ChatMessage(ChatRole.Assistant, "hello"))); }, - GetStreamingResponseAsyncCallback = (chatMessages, options, cancellationToken) => + GetStreamingResponseAsyncCallback = (messages, options, cancellationToken) => { - Assert.Same(expectedMessages, chatMessages); + Assert.Same(expectedMessages, messages); Assert.Same(expectedOptions, options); Assert.Equal(expectedCts.Token, cancellationToken); Assert.Equal(42, asyncLocal.Value); - return YieldUpdates(new ChatResponseUpdate { Text = "world" }); + return YieldUpdates(new ChatResponseUpdate(null, "world")); }, }; using IChatClient client = new ChatClientBuilder(innerClient) - .Use(async (chatMessages, options, next, cancellationToken) => + .Use(async (messages, options, next, cancellationToken) => { - Assert.Same(expectedMessages, chatMessages); + Assert.Same(expectedMessages, messages); Assert.Same(expectedOptions, options); Assert.Equal(expectedCts.Token, cancellationToken); asyncLocal.Value = 42; - await next(chatMessages, options, cancellationToken); + await next(messages, options, cancellationToken); }) .Build(); Assert.Equal(0, asyncLocal.Value); ChatResponse response = await client.GetResponseAsync(expectedMessages, expectedOptions, expectedCts.Token); - Assert.Equal("hello", response.Message.Text); + Assert.Equal("hello", response.Text); Assert.Equal(0, asyncLocal.Value); response = await client.GetStreamingResponseAsync(expectedMessages, expectedOptions, expectedCts.Token).ToChatResponseAsync(); - Assert.Equal("world", response.Message.Text); + Assert.Equal("world", response.Text); } [Fact] @@ -81,9 +82,9 @@ public async Task GetResponseFunc_ContextPropagated() using IChatClient innerClient = new TestChatClient { - GetResponseAsyncCallback = (chatMessages, options, cancellationToken) => + GetResponseAsyncCallback = (messages, options, cancellationToken) => { - Assert.Same(expectedMessages, chatMessages); + Assert.Same(expectedMessages, messages); Assert.Same(expectedOptions, options); Assert.Equal(expectedCts.Token, cancellationToken); Assert.Equal(42, asyncLocal.Value); @@ -92,14 +93,14 @@ public async Task GetResponseFunc_ContextPropagated() }; using IChatClient client = new ChatClientBuilder(innerClient) - .Use(async (chatMessages, options, innerClient, cancellationToken) => + .Use(async (messages, options, innerClient, cancellationToken) => { - Assert.Same(expectedMessages, chatMessages); + Assert.Same(expectedMessages, messages); Assert.Same(expectedOptions, options); Assert.Equal(expectedCts.Token, cancellationToken); asyncLocal.Value = 42; - var cc = await innerClient.GetResponseAsync(chatMessages, options, cancellationToken); - cc.Choices[0].Text += " world"; + var cc = await innerClient.GetResponseAsync(messages, options, cancellationToken); + cc.Messages.SelectMany(c => c.Contents).OfType().Last().Text += " world"; return cc; }, null) .Build(); @@ -107,10 +108,10 @@ public async Task GetResponseFunc_ContextPropagated() Assert.Equal(0, asyncLocal.Value); ChatResponse response = await client.GetResponseAsync(expectedMessages, expectedOptions, expectedCts.Token); - Assert.Equal("hello world", response.Message.Text); + Assert.Equal("hello world", response.Text); response = await client.GetStreamingResponseAsync(expectedMessages, expectedOptions, expectedCts.Token).ToChatResponseAsync(); - Assert.Equal("hello world", response.Message.Text); + Assert.Equal("hello world", response.Text); } [Fact] @@ -123,34 +124,34 @@ public async Task GetStreamingResponseFunc_ContextPropagated() using IChatClient innerClient = new TestChatClient { - GetStreamingResponseAsyncCallback = (chatMessages, options, cancellationToken) => + GetStreamingResponseAsyncCallback = (messages, options, cancellationToken) => { - Assert.Same(expectedMessages, chatMessages); + Assert.Same(expectedMessages, messages); Assert.Same(expectedOptions, options); Assert.Equal(expectedCts.Token, cancellationToken); Assert.Equal(42, asyncLocal.Value); - return YieldUpdates(new ChatResponseUpdate { Text = "hello" }); + return YieldUpdates(new ChatResponseUpdate(null, "hello")); }, }; using IChatClient client = new ChatClientBuilder(innerClient) - .Use(null, (chatMessages, options, innerClient, cancellationToken) => + .Use(null, (messages, options, innerClient, cancellationToken) => { - Assert.Same(expectedMessages, chatMessages); + Assert.Same(expectedMessages, messages); Assert.Same(expectedOptions, options); Assert.Equal(expectedCts.Token, cancellationToken); asyncLocal.Value = 42; - return Impl(chatMessages, options, innerClient, cancellationToken); + return Impl(messages, options, innerClient, cancellationToken); static async IAsyncEnumerable Impl( - IList chatMessages, ChatOptions? options, IChatClient innerClient, [EnumeratorCancellation] CancellationToken cancellationToken) + IEnumerable messages, ChatOptions? options, IChatClient innerClient, [EnumeratorCancellation] CancellationToken cancellationToken) { - await foreach (var update in innerClient.GetStreamingResponseAsync(chatMessages, options, cancellationToken)) + await foreach (var update in innerClient.GetStreamingResponseAsync(messages, options, cancellationToken)) { yield return update; } - yield return new() { Text = " world" }; + yield return new(null, " world"); } }) .Build(); @@ -158,10 +159,10 @@ static async IAsyncEnumerable Impl( Assert.Equal(0, asyncLocal.Value); ChatResponse response = await client.GetResponseAsync(expectedMessages, expectedOptions, expectedCts.Token); - Assert.Equal("hello world", response.Message.Text); + Assert.Equal("hello world", response.Text); response = await client.GetStreamingResponseAsync(expectedMessages, expectedOptions, expectedCts.Token).ToChatResponseAsync(); - Assert.Equal("hello world", response.Message.Text); + Assert.Equal("hello world", response.Text); } [Fact] @@ -174,54 +175,54 @@ public async Task BothGetResponseAndGetStreamingResponseFuncs_ContextPropagated( using IChatClient innerClient = new TestChatClient { - GetResponseAsyncCallback = (chatMessages, options, cancellationToken) => + GetResponseAsyncCallback = (messages, options, cancellationToken) => { - Assert.Same(expectedMessages, chatMessages); + Assert.Same(expectedMessages, messages); Assert.Same(expectedOptions, options); Assert.Equal(expectedCts.Token, cancellationToken); Assert.Equal(42, asyncLocal.Value); return Task.FromResult(new ChatResponse(new ChatMessage(ChatRole.Assistant, "non-streaming hello"))); }, - GetStreamingResponseAsyncCallback = (chatMessages, options, cancellationToken) => + GetStreamingResponseAsyncCallback = (messages, options, cancellationToken) => { - Assert.Same(expectedMessages, chatMessages); + Assert.Same(expectedMessages, messages); Assert.Same(expectedOptions, options); Assert.Equal(expectedCts.Token, cancellationToken); Assert.Equal(42, asyncLocal.Value); - return YieldUpdates(new ChatResponseUpdate { Text = "streaming hello" }); + return YieldUpdates(new ChatResponseUpdate(null, "streaming hello")); }, }; using IChatClient client = new ChatClientBuilder(innerClient) .Use( - async (chatMessages, options, innerClient, cancellationToken) => + async (messages, options, innerClient, cancellationToken) => { - Assert.Same(expectedMessages, chatMessages); + Assert.Same(expectedMessages, messages); Assert.Same(expectedOptions, options); Assert.Equal(expectedCts.Token, cancellationToken); asyncLocal.Value = 42; - var cc = await innerClient.GetResponseAsync(chatMessages, options, cancellationToken); - cc.Choices[0].Text += " world (non-streaming)"; + var cc = await innerClient.GetResponseAsync(messages, options, cancellationToken); + cc.Messages.SelectMany(c => c.Contents).OfType().Last().Text += " world (non-streaming)"; return cc; }, - (chatMessages, options, innerClient, cancellationToken) => + (messages, options, innerClient, cancellationToken) => { - Assert.Same(expectedMessages, chatMessages); + Assert.Same(expectedMessages, messages); Assert.Same(expectedOptions, options); Assert.Equal(expectedCts.Token, cancellationToken); asyncLocal.Value = 42; - return Impl(chatMessages, options, innerClient, cancellationToken); + return Impl(messages, options, innerClient, cancellationToken); static async IAsyncEnumerable Impl( - IList chatMessages, ChatOptions? options, IChatClient innerClient, [EnumeratorCancellation] CancellationToken cancellationToken) + IEnumerable messages, ChatOptions? options, IChatClient innerClient, [EnumeratorCancellation] CancellationToken cancellationToken) { - await foreach (var update in innerClient.GetStreamingResponseAsync(chatMessages, options, cancellationToken)) + await foreach (var update in innerClient.GetStreamingResponseAsync(messages, options, cancellationToken)) { yield return update; } - yield return new() { Text = " world (streaming)" }; + yield return new(null, " world (streaming)"); } }) .Build(); @@ -229,10 +230,10 @@ static async IAsyncEnumerable Impl( Assert.Equal(0, asyncLocal.Value); ChatResponse response = await client.GetResponseAsync(expectedMessages, expectedOptions, expectedCts.Token); - Assert.Equal("non-streaming hello world (non-streaming)", response.Message.Text); + Assert.Equal("non-streaming hello world (non-streaming)", response.Text); response = await client.GetStreamingResponseAsync(expectedMessages, expectedOptions, expectedCts.Token).ToChatResponseAsync(); - Assert.Equal("streaming hello world (streaming)", response.Message.Text); + Assert.Equal("streaming hello world (streaming)", response.Text); } private static async IAsyncEnumerable YieldUpdates(params ChatResponseUpdate[] updates) diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/UseDelegateEmbeddingGeneratorTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/UseDelegateEmbeddingGeneratorTests.cs index 1109cbc581a..e71e6d9461c 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/UseDelegateEmbeddingGeneratorTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/UseDelegateEmbeddingGeneratorTests.cs @@ -24,7 +24,7 @@ public void InvalidArgs_Throws() [Fact] public async Task GenerateFunc_ContextPropagated() { - GeneratedEmbeddings> expectedEmbeddings = new(); + GeneratedEmbeddings> expectedEmbeddings = []; IList expectedValues = ["hello"]; EmbeddingGenerationOptions expectedOptions = new(); using CancellationTokenSource expectedCts = new(); diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs index a196823c5c5..3d94420063c 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs @@ -120,7 +120,7 @@ public void Metadata_DerivedFromLambda() Assert.Empty(func.Description); Assert.Same(dotnetFunc.Method, func.UnderlyingMethod); - Func dotnetFunc2 = (string a) => a + " " + a; + Func dotnetFunc2 = a => a + " " + a; func = AIFunctionFactory.Create(dotnetFunc2); Assert.Contains("Metadata_DerivedFromLambda", func.Name); Assert.Empty(func.Description); From 3ebaddf7584fb06a0c9f47dafda63890aad8a9d6 Mon Sep 17 00:00:00 2001 From: Igor Velikorossov Date: Mon, 10 Mar 2025 07:15:12 +0000 Subject: [PATCH 4/7] Merged PR 48215: ME.AI.Eval retain preview branding #### AI description (iteration 1) #### PR Classification Code configuration update to retain preview branding. #### PR Summary This pull request updates project files to retain the preview branding by suppressing the final package version. - `Microsoft.Extensions.AI.Evaluation.Console.csproj`: Added `SuppressFinalPackageVersion` property. - `Microsoft.Extensions.AI.Evaluation.Quality.csproj`: Added `SuppressFinalPackageVersion` property. - `Microsoft.Extensions.AI.Evaluation.Reporting.Azure.csproj`: Added `SuppressFinalPackageVersion` property. - `Microsoft.Extensions.AI.Evaluation.Reporting.csproj`: Added `SuppressFinalPackageVersion` property. - `Microsoft.Extensions.AI.Evaluation.csproj`: Added `SuppressFinalPackageVersion` property. --- .../Microsoft.Extensions.AI.Evaluation.Console.csproj | 1 + .../Microsoft.Extensions.AI.Evaluation.Quality.csproj | 1 + .../Microsoft.Extensions.AI.Evaluation.Reporting.Azure.csproj | 1 + .../CSharp/Microsoft.Extensions.AI.Evaluation.Reporting.csproj | 1 + .../Microsoft.Extensions.AI.Evaluation.csproj | 1 + 5 files changed, 5 insertions(+) diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Console/Microsoft.Extensions.AI.Evaluation.Console.csproj b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Console/Microsoft.Extensions.AI.Evaluation.Console.csproj index 0d9a6469640..c78f8b2d160 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Console/Microsoft.Extensions.AI.Evaluation.Console.csproj +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Console/Microsoft.Extensions.AI.Evaluation.Console.csproj @@ -14,6 +14,7 @@ AIEval preview + true true false 8 diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/Microsoft.Extensions.AI.Evaluation.Quality.csproj b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/Microsoft.Extensions.AI.Evaluation.Quality.csproj index 3d9a48e1a5e..f73422364bc 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/Microsoft.Extensions.AI.Evaluation.Quality.csproj +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Quality/Microsoft.Extensions.AI.Evaluation.Quality.csproj @@ -9,6 +9,7 @@ AIEval preview + true true false diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting.Azure/Microsoft.Extensions.AI.Evaluation.Reporting.Azure.csproj b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting.Azure/Microsoft.Extensions.AI.Evaluation.Reporting.Azure.csproj index f705add750e..b1dbb80dd38 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting.Azure/Microsoft.Extensions.AI.Evaluation.Reporting.Azure.csproj +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting.Azure/Microsoft.Extensions.AI.Evaluation.Reporting.Azure.csproj @@ -11,6 +11,7 @@ AIEval preview + true true false 88 diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/Microsoft.Extensions.AI.Evaluation.Reporting.csproj b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/Microsoft.Extensions.AI.Evaluation.Reporting.csproj index 878378d633e..1b24a8d5887 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/Microsoft.Extensions.AI.Evaluation.Reporting.csproj +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/Microsoft.Extensions.AI.Evaluation.Reporting.csproj @@ -18,6 +18,7 @@ AIEval preview + true true false 66 diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation/Microsoft.Extensions.AI.Evaluation.csproj b/src/Libraries/Microsoft.Extensions.AI.Evaluation/Microsoft.Extensions.AI.Evaluation.csproj index 0123cae0f0f..cd65f271c5c 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation/Microsoft.Extensions.AI.Evaluation.csproj +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation/Microsoft.Extensions.AI.Evaluation.csproj @@ -9,6 +9,7 @@ AIEval preview + true true false 56 From 6883e997ea86ead8637cb6c7437aefcd113763f0 Mon Sep 17 00:00:00 2001 From: Shyam Namboodiripad Date: Mon, 10 Mar 2025 20:33:48 +0000 Subject: [PATCH 5/7] Merged PR 48250: [internal/release/9.3] Fix report generation The .NET side code for the `ScenarioRunResult` was recently changed (#https://github.com/dotnet/extensions/pull/5998) to include `ChatResponse` (which can contain multiple `ChatMessage`s) in place of a single `ChatMessage`. Unfortunately, we missed updating the TypeScript reporting code to account for this. This change fixes the problem by updating the deserialization code in TypeScript to match what .NET code serializes. Cherry-picked from commit `41bbedd0` (https://github.com/dotnet/extensions/pull/6061) --- .../TypeScript/components/EvalTypes.d.ts | 6 +++++- .../TypeScript/components/Summary.ts | 4 ++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/TypeScript/components/EvalTypes.d.ts b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/TypeScript/components/EvalTypes.d.ts index 20a3df81b21..1055df330df 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/TypeScript/components/EvalTypes.d.ts +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/TypeScript/components/EvalTypes.d.ts @@ -13,10 +13,14 @@ type ScenarioRunResult = { executionName: string; creationTime?: string; messages: ChatMessage[]; - modelResponse: ChatMessage; + modelResponse: ChatResponse; evaluationResult: EvaluationResult; }; +type ChatResponse = { + messages: ChatMessage[]; +} + type ChatMessage = { authorName?: string; role: string; diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/TypeScript/components/Summary.ts b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/TypeScript/components/Summary.ts index a6df92b36da..8cef12ce4f1 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/TypeScript/components/Summary.ts +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/TypeScript/components/Summary.ts @@ -145,7 +145,7 @@ const isTextContent = (content: AIContent): content is TextContent => { return (content as TextContent).text !== undefined; }; -export const getPromptDetails = (messages: ChatMessage[], modelResponse?: ChatMessage): {history:string, response: string}=> { +export const getPromptDetails = (messages: ChatMessage[], modelResponse?: ChatResponse): {history:string, response: string}=> { let history: string = ""; if (messages.length === 1) { history = messages[0].contents.map(c => (c as TextContent).text).join("\n"); @@ -163,7 +163,7 @@ export const getPromptDetails = (messages: ChatMessage[], modelResponse?: ChatMe history = historyItems.join("\n\n"); } - const response: string = modelResponse?.contents.map(c => (c as TextContent).text).join("\n") ?? ""; + const response: string = modelResponse?.messages.map(m => m.contents.map(c => (c as TextContent).text).join("\n") ?? "").join("\n") ?? ""; return { history, response }; }; \ No newline at end of file From 5ebe9495fd52e48d0eaa4c33b465aadcd98b3651 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Mon, 10 Mar 2025 18:03:12 -0400 Subject: [PATCH 6/7] Address M.E.VectorData feedback for IEmbeddingGenerator (#6058) * Move GetService down to a non-generic IEmbeddingGenerator interface * Separate UriContent from DataContent * Address feedback --- .../Contents/AIContent.cs | 1 + .../Contents/DataContent.cs | 127 +++++++++-------- .../Contents/DataUriParser.cs | 44 ++++-- .../Contents/UriContent.cs | 92 +++++++++++++ .../EmbeddingGeneratorExtensions.cs | 57 +------- .../Embeddings/IEmbeddingGenerator.cs | 33 +---- .../IEmbeddingGenerator{TInput,TEmbedding}.cs | 40 ++++++ .../AzureAIInferenceChatClient.cs | 48 +++---- .../AzureAIInferenceEmbeddingGenerator.cs | 2 +- .../OllamaChatClient.cs | 4 +- .../OllamaEmbeddingGenerator.cs | 2 +- .../OpenAIAssistantClient.cs | 2 +- .../OpenAIEmbeddingGenerator.cs | 2 +- .../OpenAIModelMapper.ChatCompletion.cs | 9 +- .../OpenAIModelMapper.ChatMessage.cs | 28 ++-- ...ratorBuilderServiceCollectionExtensions.cs | 4 + .../Embeddings/LoggingEmbeddingGenerator.cs | 2 +- .../OpenTelemetryEmbeddingGenerator.cs | 2 +- .../ChatCompletion/ChatMessageTests.cs | 10 +- .../ChatResponseUpdateExtensionsTests.cs | 2 +- .../ChatCompletion/ChatResponseUpdateTests.cs | 12 +- .../Contents/DataContentTests.cs | 114 ++++++--------- .../Contents/UriContentTests.cs | 130 ++++++++++++++++++ .../EmbeddingGeneratorExtensionsTests.cs | 13 -- .../AzureAIInferenceChatClientTests.cs | 2 +- ...atClientStructuredOutputExtensionsTests.cs | 2 +- .../DependencyInjectionPatterns.cs | 50 ++++++- 27 files changed, 529 insertions(+), 305 deletions(-) create mode 100644 src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/UriContent.cs create mode 100644 src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator{TInput,TEmbedding}.cs create mode 100644 test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/UriContentTests.cs diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/AIContent.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/AIContent.cs index 6895d4c1e42..6562b7bcc42 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/AIContent.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/AIContent.cs @@ -11,6 +11,7 @@ namespace Microsoft.Extensions.AI; [JsonDerivedType(typeof(FunctionCallContent), typeDiscriminator: "functionCall")] [JsonDerivedType(typeof(FunctionResultContent), typeDiscriminator: "functionResult")] [JsonDerivedType(typeof(TextContent), typeDiscriminator: "text")] +[JsonDerivedType(typeof(UriContent), typeDiscriminator: "uri")] [JsonDerivedType(typeof(UsageContent), typeDiscriminator: "usage")] public class AIContent { diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/DataContent.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/DataContent.cs index 041d33a9704..dc0c5db9289 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/DataContent.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/DataContent.cs @@ -8,17 +8,17 @@ using Microsoft.Shared.Diagnostics; #pragma warning disable S3996 // URI properties should not be strings +#pragma warning disable CA1054 // URI-like parameters should not be strings #pragma warning disable CA1056 // URI-like properties should not be strings namespace Microsoft.Extensions.AI; /// -/// Represents data content, such as an image or audio. +/// Represents binary content with an associated media type (also known as MIME type). /// /// /// -/// The represented content may either be the actual bytes stored in this instance, or it may -/// be a URI that references the location of the content. +/// The content represents in-memory data. For references to data at a remote URI, use instead. /// /// /// always returns a valid URI string, even if the instance was constructed from @@ -32,20 +32,27 @@ public class DataContent : AIContent // Ideally DataContent would be based in terms of Uri. However, Uri has a length limitation that makes it prohibitive // for the kinds of data URIs necessary to support here. As such, this type is based in strings. + /// Parsed data URI information. + private readonly DataUriParser.DataUri? _dataUri; + /// The string-based representation of the URI, including any data in the instance. private string? _uri; /// The data, lazily initialized if the data is provided in a data URI. private ReadOnlyMemory? _data; - /// Parsed data URI information. - private DataUriParser.DataUri? _dataUri; - /// /// Initializes a new instance of the class. /// - /// The URI of the content. This can be a data URI. - /// The media type (also known as MIME type) represented by the content. + /// The data URI containing the content. + /// + /// The media type (also known as MIME type) represented by the content. If not provided, + /// it must be provided as part of the . + /// + /// is . + /// is not a data URI. + /// did not contain a media type and was not supplied. + /// is an invalid media type. public DataContent(Uri uri, string? mediaType = null) : this(Throw.IfNull(uri).ToString(), mediaType) { @@ -54,42 +61,48 @@ public DataContent(Uri uri, string? mediaType = null) /// /// Initializes a new instance of the class. /// - /// The URI of the content. This can be a data URI. + /// The data URI containing the content. /// The media type (also known as MIME type) represented by the content. + /// is . + /// is not a data URI. + /// did not contain a media type and was not supplied. + /// is an invalid media type. [JsonConstructor] public DataContent([StringSyntax(StringSyntaxAttribute.Uri)] string uri, string? mediaType = null) { _uri = Throw.IfNullOrWhitespace(uri); - ValidateMediaType(ref mediaType); - MediaType = mediaType; - - if (uri.StartsWith(DataUriParser.Scheme, StringComparison.OrdinalIgnoreCase)) + if (!uri.StartsWith(DataUriParser.Scheme, StringComparison.OrdinalIgnoreCase)) { - _dataUri = DataUriParser.Parse(uri.AsMemory()); + Throw.ArgumentException(nameof(uri), "The provided URI is not a data URI."); + } - // If the data URI contains a media type that's different from a non-null media type - // explicitly provided, prefer the one explicitly provided as an override. - if (MediaType is not null) - { - if (MediaType != _dataUri.MediaType) - { - // Extract the bytes from the data URI and null out the uri. - // Then we'll lazily recreate it later if needed based on the updated media type. - _data = _dataUri.ToByteArray(); - _dataUri = null; - _uri = null; - } - } - else + _dataUri = DataUriParser.Parse(uri.AsMemory()); + + if (mediaType is null) + { + mediaType = _dataUri.MediaType; + if (mediaType is null) { - MediaType = _dataUri.MediaType; + Throw.ArgumentNullException(nameof(mediaType), $"{nameof(uri)} did not contain a media type, and {nameof(mediaType)} was not provided."); } } - else if (!System.Uri.TryCreate(uri, UriKind.Absolute, out _)) + else { - throw new UriFormatException("The URI is not well-formed."); + if (mediaType != _dataUri.MediaType) + { + // If the data URI contains a media type that's different from a non-null media type + // explicitly provided, prefer the one explicitly provided as an override. + + // Extract the bytes from the data URI and null out the uri. + // Then we'll lazily recreate it later if needed based on the updated media type. + _data = _dataUri.ToByteArray(); + _dataUri = null; + _uri = null; + } } + + MediaType = DataUriParser.ThrowIfInvalidMediaType(mediaType); } /// @@ -97,32 +110,29 @@ public DataContent([StringSyntax(StringSyntaxAttribute.Uri)] string uri, string? /// /// The byte contents. /// The media type (also known as MIME type) represented by the content. - public DataContent(ReadOnlyMemory data, string? mediaType = null) + /// is null. + /// is empty or composed entirely of whitespace. + public DataContent(ReadOnlyMemory data, string mediaType) { - ValidateMediaType(ref mediaType); - MediaType = mediaType; + MediaType = DataUriParser.ThrowIfInvalidMediaType(mediaType); _data = data; } /// - /// Determines whether the has the specified prefix. + /// Determines whether the 's top-level type matches the specified . /// - /// The media type prefix. - /// if the has the specified prefix, otherwise . - public bool MediaTypeStartsWith(string prefix) - => MediaType?.StartsWith(prefix, StringComparison.OrdinalIgnoreCase) is true; - - /// Sets to null if it's empty or composed entirely of whitespace. - private static void ValidateMediaType(ref string? mediaType) - { - if (!DataUriParser.IsValidMediaType(mediaType.AsSpan(), ref mediaType)) - { - Throw.ArgumentException(nameof(mediaType), "Invalid media type."); - } - } + /// The type to compare against . + /// if the type portion of matches the specified value; otherwise, false. + /// + /// A media type is primarily composed of two parts, a "type" and a "subtype", separated by a slash ("/"). + /// The type portion is also referred to as the "top-level type"; for example, + /// "image/png" has a top-level type of "image". compares + /// the specified against the type portion of . + /// + public bool HasTopLevelMediaType(string topLevelType) => DataUriParser.HasTopLevelMediaType(MediaType, topLevelType); - /// Gets the URI for this . + /// Gets the data URI for this . /// /// The returned URI is always a valid URI string, even if the instance was constructed from a /// or from a . In the case of a , this property returns a data URI containing @@ -137,8 +147,8 @@ public string Uri { if (_dataUri is null) { - Debug.Assert(Data is not null, "Expected Data to be initialized."); - _uri = string.Concat("data:", MediaType, ";base64,", Convert.ToBase64String(Data.GetValueOrDefault() + Debug.Assert(_data is not null, "Expected _data to be initialized."); + _uri = string.Concat("data:", MediaType, ";base64,", Convert.ToBase64String(_data.GetValueOrDefault() #if NET .Span)); #else @@ -167,10 +177,9 @@ public string Uri /// If the media type was explicitly specified, this property returns that value. /// If the media type was not explicitly specified, but a data URI was supplied and that data URI contained a non-default /// media type, that media type is returned. - /// Otherwise, this property returns null. /// - [JsonPropertyOrder(1)] - public string? MediaType { get; private set; } + [JsonIgnore] + public string MediaType { get; } /// Gets the data represented by this instance. /// @@ -181,16 +190,18 @@ public string Uri /// no attempt is made to retrieve the data from that URI. /// [JsonIgnore] - public ReadOnlyMemory? Data + public ReadOnlyMemory Data { get { - if (_dataUri is not null) + if (_data is null) { - _data ??= _dataUri.ToByteArray(); + Debug.Assert(_dataUri is not null, "Expected dataUri to be initialized."); + _data = _dataUri!.ToByteArray(); } - return _data; + Debug.Assert(_data is not null, "Expected data to be initialized."); + return _data.GetValueOrDefault(); } } diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/DataUriParser.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/DataUriParser.cs index 5cb33d1a55c..cff25e9c30b 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/DataUriParser.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/DataUriParser.cs @@ -5,10 +5,14 @@ #if NET8_0_OR_GREATER using System.Buffers.Text; #endif -using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; using System.Net; using System.Net.Http.Headers; +using System.Runtime.CompilerServices; using System.Text; +using Microsoft.Shared.Diagnostics; + +#pragma warning disable CA1307 // Specify StringComparison for clarity namespace Microsoft.Extensions.AI; @@ -55,8 +59,9 @@ public static DataUri Parse(ReadOnlyMemory dataUri) } // Validate the media type, if present. + ReadOnlySpan span = metadata.Span.Trim(); string? mediaType = null; - if (!IsValidMediaType(metadata.Span.Trim(), ref mediaType)) + if (!span.IsEmpty && !IsValidMediaType(span, ref mediaType)) { throw new UriFormatException("Invalid data URI format: the media type is not a valid."); } @@ -64,20 +69,25 @@ public static DataUri Parse(ReadOnlyMemory dataUri) return new DataUri(data, isBase64, mediaType); } - /// Validates that a media type is valid, and if successful, ensures we have it as a string. - public static bool IsValidMediaType(ReadOnlySpan mediaTypeSpan, ref string? mediaType) + public static string ThrowIfInvalidMediaType( + string mediaType, [CallerArgumentExpression(nameof(mediaType))] string parameterName = "") { - Debug.Assert( - mediaType is null || mediaTypeSpan.Equals(mediaType.AsSpan(), StringComparison.Ordinal), - "mediaType string should either be null or the same as the span"); + _ = Throw.IfNullOrWhitespace(mediaType, parameterName); - // If the media type is empty or all whitespace, normalize it to null. - if (mediaTypeSpan.IsWhiteSpace()) + if (!IsValidMediaType(mediaType)) { - mediaType = null; - return true; + Throw.ArgumentException(parameterName, $"An invalid media type was specified: '{mediaType}'"); } + return mediaType; + } + + public static bool IsValidMediaType(string mediaType) => + IsValidMediaType(mediaType.AsSpan(), ref mediaType); + + /// Validates that a media type is valid, and if successful, ensures we have it as a string. + public static bool IsValidMediaType(ReadOnlySpan mediaTypeSpan, [NotNull] ref string? mediaType) + { // For common media types, we can avoid both allocating a string for the span and avoid parsing overheads. string? knownType = mediaTypeSpan switch { @@ -108,7 +118,7 @@ public static bool IsValidMediaType(ReadOnlySpan mediaTypeSpan, ref string }; if (knownType is not null) { - mediaType ??= knownType; + mediaType = knownType; return true; } @@ -117,6 +127,16 @@ public static bool IsValidMediaType(ReadOnlySpan mediaTypeSpan, ref string return MediaTypeHeaderValue.TryParse(mediaType, out _); } + public static bool HasTopLevelMediaType(string mediaType, string topLevelMediaType) + { + int slashIndex = mediaType.IndexOf('/'); + + ReadOnlySpan span = slashIndex < 0 ? mediaType.AsSpan() : mediaType.AsSpan(0, slashIndex); + span = span.Trim(); + + return span.Equals(topLevelMediaType.AsSpan(), StringComparison.OrdinalIgnoreCase); + } + /// Test whether the value is a base64 string without whitespace. private static bool IsValidBase64Data(ReadOnlySpan value) { diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/UriContent.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/UriContent.cs new file mode 100644 index 00000000000..7beaa40efdf --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/UriContent.cs @@ -0,0 +1,92 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Diagnostics; +using System.Text.Json.Serialization; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// +/// Represents a URL, typically to hosted content such as an image, audio, or video. +/// +/// +/// This class is intended for use with HTTP or HTTPS URIs that reference hosted content. +/// For data URIs, use instead. +/// +[DebuggerDisplay("{DebuggerDisplay,nq}")] +public class UriContent : AIContent +{ + /// The URI represented. + private Uri _uri; + + /// The MIME type of the data at the referenced URI. + private string _mediaType; + + /// Initializes a new instance of the class. + /// The URI to the represented content. + /// The media type (also known as MIME type) represented by the content. + /// is . + /// is . + /// is an invalid media type. + /// is an invalid URL. + /// + /// A media type must be specified, so that consumers know what to do with the content. + /// If an exact media type is not known, but the category (e.g. image) is known, a wildcard + /// may be used (e.g. "image/*"). + /// + public UriContent(string uri, string mediaType) + : this(new Uri(Throw.IfNull(uri)), mediaType) + { + } + + /// Initializes a new instance of the class. + /// The URI to the represented content. + /// The media type (also known as MIME type) represented by the content. + /// is . + /// is . + /// is an invalid media type. + /// + /// A media type must be specified, so that consumers know what to do with the content. + /// If an exact media type is not known, but the category (e.g. image) is known, a wildcard + /// may be used (e.g. "image/*"). + /// + [JsonConstructor] + public UriContent(Uri uri, string mediaType) + { + _uri = Throw.IfNull(uri); + _mediaType = DataUriParser.ThrowIfInvalidMediaType(mediaType); + } + + /// Gets or sets the for this content. + public Uri Uri + { + get => _uri; + set => _uri = Throw.IfNull(value); + } + + /// Gets or sets the media type (also known as MIME type) for this content. + public string MediaType + { + get => _mediaType; + set => _mediaType = DataUriParser.ThrowIfInvalidMediaType(value); + } + + /// + /// Determines whether the 's top-level type matches the specified . + /// + /// The type to compare against . + /// if the type portion of matches the specified value; otherwise, false. + /// + /// A media type is primarily composed of two parts, a "type" and a "subtype", separated by a slash ("/"). + /// The type portion is also referred to as the "top-level type"; for example, + /// "image/png" has a top-level type of "image". compares + /// the specified against the type portion of . + /// + public bool HasTopLevelMediaType(string topLevelType) => DataUriParser.HasTopLevelMediaType(MediaType, topLevelType); + + /// Gets a string representing this instance to display in the debugger. + [DebuggerBrowsable(DebuggerBrowsableState.Never)] + private string DebuggerDisplay => $"Uri = {_uri}"; +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs index 35d8260e406..d8ed6967d71 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs @@ -17,8 +17,6 @@ namespace Microsoft.Extensions.AI; public static class EmbeddingGeneratorExtensions { /// Asks the for an object of type . - /// The type from which embeddings will be generated. - /// The numeric type of the embedding data. /// The type of the object to be retrieved. /// The generator. /// An optional key that can be used to help identify the target service. @@ -28,9 +26,8 @@ public static class EmbeddingGeneratorExtensions /// The purpose of this method is to allow for the retrieval of strongly typed services that may be provided by the /// , including itself or any services it might be wrapping. /// - public static TService? GetService( - this IEmbeddingGenerator generator, object? serviceKey = null) - where TEmbedding : Embedding + public static TService? GetService( + this IEmbeddingGenerator generator, object? serviceKey = null) { _ = Throw.IfNull(generator); @@ -41,8 +38,6 @@ public static class EmbeddingGeneratorExtensions /// Asks the for an object of the specified type /// and throws an exception if one isn't available. /// - /// The type from which embeddings will be generated. - /// The numeric type of the embedding data. /// The generator. /// The type of object being requested. /// An optional key that can be used to help identify the target service. @@ -54,9 +49,8 @@ public static class EmbeddingGeneratorExtensions /// The purpose of this method is to allow for the retrieval of services that are required to be provided by the /// , including itself or any services it might be wrapping. /// - public static object GetRequiredService( - this IEmbeddingGenerator generator, Type serviceType, object? serviceKey = null) - where TEmbedding : Embedding + public static object GetRequiredService( + this IEmbeddingGenerator generator, Type serviceType, object? serviceKey = null) { _ = Throw.IfNull(generator); _ = Throw.IfNull(serviceType); @@ -70,8 +64,6 @@ public static object GetRequiredService( /// Asks the for an object of type /// and throws an exception if one isn't available. /// - /// The type from which embeddings will be generated. - /// The numeric type of the embedding data. /// The type of the object to be retrieved. /// The generator. /// An optional key that can be used to help identify the target service. @@ -82,9 +74,8 @@ public static object GetRequiredService( /// The purpose of this method is to allow for the retrieval of strongly typed services that are required to be provided by the /// , including itself or any services it might be wrapping. /// - public static TService GetRequiredService( - this IEmbeddingGenerator generator, object? serviceKey = null) - where TEmbedding : Embedding + public static TService GetRequiredService( + this IEmbeddingGenerator generator, object? serviceKey = null) { _ = Throw.IfNull(generator); @@ -96,42 +87,6 @@ public static TService GetRequiredService( return service; } - // The following overloads exist purely to work around the lack of partial generic type inference. - // Given an IEmbeddingGenerator generator, to call GetService with TService, you still need - // to re-specify both TInput and TEmbedding, e.g. generator.GetService, TService>. - // The case of string/Embedding is by far the most common case today, so this overload exists as an - // accelerator to allow it to be written simply as generator.GetService. - - /// Asks the for an object of type . - /// The type of the object to be retrieved. - /// The generator. - /// An optional key that can be used to help identify the target service. - /// The found object, otherwise . - /// is . - /// - /// The purpose of this method is to allow for the retrieval of strongly typed services that may be provided by the - /// , including itself or any services it might be wrapping. - /// - public static TService? GetService(this IEmbeddingGenerator> generator, object? serviceKey = null) => - GetService, TService>(generator, serviceKey); - - /// - /// Asks the for an object of type - /// and throws an exception if one isn't available. - /// - /// The type of the object to be retrieved. - /// The generator. - /// An optional key that can be used to help identify the target service. - /// The found object. - /// is . - /// No service of the requested type for the specified key is available. - /// - /// The purpose of this method is to allow for the retrieval of strongly typed services that may be provided by the - /// , including itself or any services it might be wrapping. - /// - public static TService GetRequiredService(this IEmbeddingGenerator> generator, object? serviceKey = null) => - GetRequiredService, TService>(generator, serviceKey); - /// Generates an embedding vector from the specified . /// The type from which embeddings will be generated. /// The numeric type of the embedding data. diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator.cs index 59fcc9e2393..4f8174b6874 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator.cs @@ -2,42 +2,17 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; -using System.Collections.Generic; -using System.Threading; -using System.Threading.Tasks; namespace Microsoft.Extensions.AI; /// Represents a generator of embeddings. -/// The type from which embeddings will be generated. -/// The type of embeddings to generate. /// -/// -/// Unless otherwise specified, all members of are thread-safe for concurrent use. -/// It is expected that all implementations of support being used by multiple requests concurrently. -/// Instances must not be disposed of while the instance is still in use. -/// -/// -/// However, implementations of may mutate the arguments supplied to -/// , such as by configuring the options instance. Thus, consumers of the interface either should -/// avoid using shared instances of these arguments for concurrent invocations or should otherwise ensure by construction that -/// no instances are used which might employ such mutation. -/// +/// This base interface is used to allow for embedding generators to be stored in a non-generic manner. +/// To use the generator to create embeddings, instances typed as this base interface first need to be +/// cast to the generic interface . /// -public interface IEmbeddingGenerator : IDisposable - where TEmbedding : Embedding +public interface IEmbeddingGenerator : IDisposable { - /// Generates embeddings for each of the supplied . - /// The sequence of values for which to generate embeddings. - /// The embedding generation options with which to configure the request. - /// The to monitor for cancellation requests. The default is . - /// The generated embeddings. - /// is . - Task> GenerateAsync( - IEnumerable values, - EmbeddingGenerationOptions? options = null, - CancellationToken cancellationToken = default); - /// Asks the for an object of the specified type . /// The type of object being requested. /// An optional key that can be used to help identify the target service. diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator{TInput,TEmbedding}.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator{TInput,TEmbedding}.cs new file mode 100644 index 00000000000..ff3910ae737 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator{TInput,TEmbedding}.cs @@ -0,0 +1,40 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.Extensions.AI; + +/// Represents a generator of embeddings. +/// The type from which embeddings will be generated. +/// The type of embeddings to generate. +/// +/// +/// Unless otherwise specified, all members of are thread-safe for concurrent use. +/// It is expected that all implementations of support being used by multiple requests concurrently. +/// Instances must not be disposed of while the instance is still in use. +/// +/// +/// However, implementations of may mutate the arguments supplied to +/// , such as by configuring the options instance. Thus, consumers of the interface either should +/// avoid using shared instances of these arguments for concurrent invocations or should otherwise ensure by construction that +/// no instances are used which might employ such mutation. +/// +/// +public interface IEmbeddingGenerator : IEmbeddingGenerator + where TEmbedding : Embedding +{ + /// Generates embeddings for each of the supplied . + /// The sequence of values for which to generate embeddings. + /// The embedding generation options with which to configure the request. + /// The to monitor for cancellation requests. The default is . + /// The generated embeddings. + /// is . + Task> GenerateAsync( + IEnumerable values, + EmbeddingGenerationOptions? options = null, + CancellationToken cancellationToken = default); +} diff --git a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceChatClient.cs index db03a62f2a9..ed2cc991e8c 100644 --- a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceChatClient.cs @@ -490,42 +490,34 @@ private static List GetContentParts(IList con parts.Add(new ChatMessageTextContentItem(textContent.Text)); break; - case DataContent dataContent when dataContent.MediaTypeStartsWith("image/"): - if (dataContent.Data.HasValue) - { - parts.Add(new ChatMessageImageContentItem(BinaryData.FromBytes(dataContent.Data.Value), dataContent.MediaType)); - } - else if (dataContent.Uri is string uri) - { - parts.Add(new ChatMessageImageContentItem(new Uri(uri))); - } + case UriContent uriContent when uriContent.HasTopLevelMediaType("image"): + parts.Add(new ChatMessageImageContentItem(uriContent.Uri)); + break; + case DataContent dataContent when dataContent.HasTopLevelMediaType("image"): + parts.Add(new ChatMessageImageContentItem(BinaryData.FromBytes(dataContent.Data), dataContent.MediaType)); break; - case DataContent dataContent when dataContent.MediaTypeStartsWith("audio/"): - if (dataContent.Data.HasValue) - { - AudioContentFormat format; - if (dataContent.MediaTypeStartsWith("audio/mpeg")) - { - format = AudioContentFormat.Mp3; - } - else if (dataContent.MediaTypeStartsWith("audio/wav")) - { - format = AudioContentFormat.Wav; - } - else - { - break; - } + case UriContent uriContent when uriContent.HasTopLevelMediaType("audio"): + parts.Add(new ChatMessageAudioContentItem(uriContent.Uri)); + break; - parts.Add(new ChatMessageAudioContentItem(BinaryData.FromBytes(dataContent.Data.Value), format)); + case DataContent dataContent when dataContent.HasTopLevelMediaType("audio"): + AudioContentFormat format; + if (dataContent.MediaType.Equals("audio/mpeg", StringComparison.OrdinalIgnoreCase)) + { + format = AudioContentFormat.Mp3; + } + else if (dataContent.MediaType.Equals("audio/wav", StringComparison.OrdinalIgnoreCase)) + { + format = AudioContentFormat.Wav; } - else if (dataContent.Uri is string uri) + else { - parts.Add(new ChatMessageAudioContentItem(new Uri(uri))); + break; } + parts.Add(new ChatMessageAudioContentItem(BinaryData.FromBytes(dataContent.Data), format)); break; } } diff --git a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceEmbeddingGenerator.cs index c0f4b2f4636..5cadc200869 100644 --- a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceEmbeddingGenerator.cs @@ -73,7 +73,7 @@ public AzureAIInferenceEmbeddingGenerator( } /// - object? IEmbeddingGenerator>.GetService(Type serviceType, object? serviceKey) + object? IEmbeddingGenerator.GetService(Type serviceType, object? serviceKey) { _ = Throw.IfNull(serviceType); diff --git a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs index ed1448c8b69..0af538b9802 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs @@ -392,10 +392,10 @@ private IEnumerable ToOllamaChatRequestMessages(ChatMe OllamaChatRequestMessage? currentTextMessage = null; foreach (var item in content.Contents) { - if (item is DataContent dataContent && dataContent.MediaTypeStartsWith("image/") && dataContent.Data.HasValue) + if (item is DataContent dataContent && dataContent.HasTopLevelMediaType("image")) { IList images = currentTextMessage?.Images ?? []; - images.Add(Convert.ToBase64String(dataContent.Data.Value + images.Add(Convert.ToBase64String(dataContent.Data #if NET .Span)); #else diff --git a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaEmbeddingGenerator.cs index 6056753dd26..0b63491ddc2 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaEmbeddingGenerator.cs @@ -61,7 +61,7 @@ public OllamaEmbeddingGenerator(Uri endpoint, string? modelId = null, HttpClient } /// - object? IEmbeddingGenerator>.GetService(Type serviceType, object? serviceKey) + object? IEmbeddingGenerator.GetService(Type serviceType, object? serviceKey) { _ = Throw.IfNull(serviceType); diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIAssistantClient.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIAssistantClient.cs index 1e5afb6d529..9aaad72ec3b 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIAssistantClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIAssistantClient.cs @@ -299,7 +299,7 @@ strictObj is bool strictValue ? messageContents.Add(MessageContent.FromText(tc.Text)); break; - case DataContent dc when dc.MediaTypeStartsWith("image/"): + case DataContent dc when dc.HasTopLevelMediaType("image"): messageContents.Add(MessageContent.FromImageUri(new(dc.Uri))); break; diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIEmbeddingGenerator.cs index 7cf0be18fb0..8ae8a32b898 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIEmbeddingGenerator.cs @@ -125,7 +125,7 @@ void IDisposable.Dispose() } /// - object? IEmbeddingGenerator>.GetService(Type serviceType, object? serviceKey) + object? IEmbeddingGenerator.GetService(Type serviceType, object? serviceKey) { _ = Throw.IfNull(serviceType); diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatCompletion.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatCompletion.cs index 59727d38f00..fdee45ea96d 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatCompletion.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatCompletion.cs @@ -566,15 +566,14 @@ private static ChatRole FromOpenAIChatRole(ChatMessageRole role) => } else if (contentPart.Kind == ChatMessageContentPartKind.Image) { - DataContent? imageContent; - aiContent = imageContent = - contentPart.ImageUri is not null ? new DataContent(contentPart.ImageUri, contentPart.ImageBytesMediaType) : + aiContent = + contentPart.ImageUri is not null ? new UriContent(contentPart.ImageUri, "image/*") : contentPart.ImageBytes is not null ? new DataContent(contentPart.ImageBytes.ToMemory(), contentPart.ImageBytesMediaType) : null; - if (imageContent is not null && contentPart.ImageDetailLevel?.ToString() is string detail) + if (aiContent is not null && contentPart.ImageDetailLevel?.ToString() is string detail) { - (imageContent.AdditionalProperties ??= [])[nameof(contentPart.ImageDetailLevel)] = detail; + (aiContent.AdditionalProperties ??= [])[nameof(contentPart.ImageDetailLevel)] = detail; } } diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatMessage.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatMessage.cs index c051c208f1e..8d9195b0953 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatMessage.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatMessage.cs @@ -205,11 +205,11 @@ private static List FromOpenAIChatContent(IList ToOpenAIChatContent(IList parts.Add(ChatMessageContentPart.CreateTextPart(textContent.Text)); break; - case DataContent dataContent when dataContent.MediaTypeStartsWith("image/"): - if (dataContent.Data.HasValue) - { - parts.Add(ChatMessageContentPart.CreateImagePart(BinaryData.FromBytes(dataContent.Data.Value), dataContent.MediaType)); - } - else if (dataContent.Uri is string uri) - { - parts.Add(ChatMessageContentPart.CreateImagePart(new Uri(uri))); - } + case UriContent uriContent when uriContent.HasTopLevelMediaType("image"): + parts.Add(ChatMessageContentPart.CreateImagePart(uriContent.Uri)); + break; + case DataContent dataContent when dataContent.HasTopLevelMediaType("image"): + parts.Add(ChatMessageContentPart.CreateImagePart(BinaryData.FromBytes(dataContent.Data), dataContent.MediaType)); break; - case DataContent dataContent when dataContent.MediaTypeStartsWith("audio/") && dataContent.Data.HasValue: - var audioData = BinaryData.FromBytes(dataContent.Data.Value); - if (dataContent.MediaTypeStartsWith("audio/mpeg")) + case DataContent dataContent when dataContent.HasTopLevelMediaType("audio"): + var audioData = BinaryData.FromBytes(dataContent.Data); + if (dataContent.MediaType.Equals("audio/mpeg", StringComparison.OrdinalIgnoreCase)) { parts.Add(ChatMessageContentPart.CreateInputAudioPart(audioData, ChatInputAudioFormat.Mp3)); } - else if (dataContent.MediaTypeStartsWith("audio/wav")) + else if (dataContent.MediaType.Equals("audio/wav", StringComparison.OrdinalIgnoreCase)) { parts.Add(ChatMessageContentPart.CreateInputAudioPart(audioData, ChatInputAudioFormat.Wav)); } diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilderServiceCollectionExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilderServiceCollectionExtensions.cs index b84e8ac6e60..ebc7e3d26af 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilderServiceCollectionExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilderServiceCollectionExtensions.cs @@ -53,6 +53,8 @@ public static EmbeddingGeneratorBuilder AddEmbeddingGenerato var builder = new EmbeddingGeneratorBuilder(innerGeneratorFactory); serviceCollection.Add(new ServiceDescriptor(typeof(IEmbeddingGenerator), builder.Build, lifetime)); + serviceCollection.Add(new ServiceDescriptor(typeof(IEmbeddingGenerator), + static services => services.GetRequiredService>(), lifetime)); return builder; } @@ -103,6 +105,8 @@ public static EmbeddingGeneratorBuilder AddKeyedEmbeddingGen var builder = new EmbeddingGeneratorBuilder(innerGeneratorFactory); serviceCollection.Add(new ServiceDescriptor(typeof(IEmbeddingGenerator), serviceKey, factory: (services, serviceKey) => builder.Build(services), lifetime)); + serviceCollection.Add(new ServiceDescriptor(typeof(IEmbeddingGenerator), serviceKey, + static (services, serviceKey) => services.GetRequiredKeyedService>(serviceKey), lifetime)); return builder; } } diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/LoggingEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/LoggingEmbeddingGenerator.cs index 24770df1052..90553ca5411 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Embeddings/LoggingEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/LoggingEmbeddingGenerator.cs @@ -52,7 +52,7 @@ public override async Task> GenerateAsync(IEnume { if (_logger.IsEnabled(LogLevel.Trace)) { - LogInvokedSensitive(AsJson(values), AsJson(options), AsJson(this.GetService())); + LogInvokedSensitive(AsJson(values), AsJson(options), AsJson(this.GetService())); } else { diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/OpenTelemetryEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/OpenTelemetryEmbeddingGenerator.cs index 3fd92a103aa..26ead720a1c 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Embeddings/OpenTelemetryEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/OpenTelemetryEmbeddingGenerator.cs @@ -50,7 +50,7 @@ public OpenTelemetryEmbeddingGenerator(IEmbeddingGenerator i { Debug.Assert(innerGenerator is not null, "Should have been validated by the base ctor."); - if (innerGenerator!.GetService() is EmbeddingGeneratorMetadata metadata) + if (innerGenerator!.GetService() is EmbeddingGeneratorMetadata metadata) { _system = metadata.ProviderName; _modelId = metadata.ModelId; diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatMessageTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatMessageTests.cs index 7174d2a70c8..c449f064255 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatMessageTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatMessageTests.cs @@ -141,8 +141,8 @@ public void Text_ConcatsAllTextContent() { ChatMessage message = new(ChatRole.User, [ - new DataContent("http://localhost/audio"), - new DataContent("http://localhost/image"), + new DataContent("data:text/image;base64,aGVsbG8="), + new DataContent("data:text/plain;base64,aGVsbG8="), new FunctionCallContent("callId1", "fc1"), new TextContent("text-1"), new TextContent("text-2"), @@ -240,7 +240,7 @@ public void ItCanBeSerializeAndDeserialized() { AdditionalProperties = new() { ["metadata-key-1"] = "metadata-value-1" } }, - new DataContent(new Uri("https://fake-random-test-host:123"), "mime-type/2") + new DataContent(new Uri("data:text/plain;base64,aGVsbG8="), "mime-type/2") { AdditionalProperties = new() { ["metadata-key-2"] = "metadata-value-2" } }, @@ -286,7 +286,7 @@ public void ItCanBeSerializeAndDeserialized() var dataContent = deserializedMessage.Contents[1] as DataContent; Assert.NotNull(dataContent); - Assert.Equal("https://fake-random-test-host:123/", dataContent.Uri); + Assert.Equal("data:mime-type/2;base64,aGVsbG8=", dataContent.Uri); Assert.Equal("mime-type/2", dataContent.MediaType); Assert.NotNull(dataContent.AdditionalProperties); Assert.Single(dataContent.AdditionalProperties); @@ -294,7 +294,7 @@ public void ItCanBeSerializeAndDeserialized() dataContent = deserializedMessage.Contents[2] as DataContent; Assert.NotNull(dataContent); - Assert.True(dataContent.Data!.Value.Span.SequenceEqual(new BinaryData(new[] { 1, 2, 3 }, TestJsonSerializerContext.Default.Options))); + Assert.True(dataContent.Data.Span.SequenceEqual(new BinaryData(new[] { 1, 2, 3 }, TestJsonSerializerContext.Default.Options))); Assert.Equal("mime-type/3", dataContent.MediaType); Assert.NotNull(dataContent.AdditionalProperties); Assert.Single(dataContent.AdditionalProperties); diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseUpdateExtensionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseUpdateExtensionsTests.cs index 454c3c3cad3..00e074ab276 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseUpdateExtensionsTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseUpdateExtensionsTests.cs @@ -124,7 +124,7 @@ void AddGap() { for (int i = 0; i < gapLength; i++) { - updates.Add(new() { Contents = [new DataContent("https://uri", mediaType: "image/png")] }); + updates.Add(new() { Contents = [new DataContent("data:image/png;base64,aGVsbG8=")] }); } } diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseUpdateTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseUpdateTests.cs index 7e5ff6b1e84..cc406929aa1 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseUpdateTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseUpdateTests.cs @@ -84,8 +84,8 @@ public void Text_Get_UsesAllTextContent() Role = ChatRole.User, Contents = [ - new DataContent("http://localhost/audio"), - new DataContent("http://localhost/image"), + new DataContent("data:image/audio;base64,aGVsbG8="), + new DataContent("data:image/image;base64,aGVsbG8="), new FunctionCallContent("callId1", "fc1"), new TextContent("text-1"), new TextContent("text-2"), @@ -114,9 +114,9 @@ public void JsonSerialization_Roundtrips() Contents = [ new TextContent("text-1"), - new DataContent("http://localhost/image"), + new DataContent("data:image/png;base64,aGVsbG8="), new FunctionCallContent("callId1", "fc1"), - new DataContent("data"u8.ToArray()), + new DataContent("data"u8.ToArray(), "text/plain"), new TextContent("text-2"), ], RawRepresentation = new object(), @@ -137,13 +137,13 @@ public void JsonSerialization_Roundtrips() Assert.Equal("text-1", ((TextContent)result.Contents[0]).Text); Assert.IsType(result.Contents[1]); - Assert.Equal("http://localhost/image", ((DataContent)result.Contents[1]).Uri); + Assert.Equal("data:image/png;base64,aGVsbG8=", ((DataContent)result.Contents[1]).Uri); Assert.IsType(result.Contents[2]); Assert.Equal("fc1", ((FunctionCallContent)result.Contents[2]).Name); Assert.IsType(result.Contents[3]); - Assert.Equal("data"u8.ToArray(), ((DataContent)result.Contents[3]).Data?.ToArray()); + Assert.Equal("data"u8.ToArray(), ((DataContent)result.Contents[3]).Data.ToArray()); Assert.IsType(result.Contents[4]); Assert.Equal("text-2", ((TextContent)result.Contents[4]).Text); diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/DataContentTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/DataContentTests.cs index dfa28373d48..83f09c66889 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/DataContentTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/DataContentTests.cs @@ -13,10 +13,16 @@ public sealed class DataContentTests // Invalid URI [InlineData("", typeof(ArgumentException))] - [InlineData("invalid", typeof(UriFormatException))] + [InlineData("invalid", typeof(ArgumentException))] + [InlineData("data", typeof(ArgumentException))] + + // Not a data URI + [InlineData("http://localhost/blah.png", typeof(ArgumentException))] + [InlineData("https://localhost/blah.png", typeof(ArgumentException))] + [InlineData("ftp://localhost/blah.png", typeof(ArgumentException))] + [InlineData("a://localhost/blah.png", typeof(ArgumentException))] // Format errors - [InlineData("data", typeof(UriFormatException))] // data missing colon [InlineData("data:", typeof(UriFormatException))] // data missing comma [InlineData("data:something,", typeof(UriFormatException))] // mime type without subtype [InlineData("data:something;else,data", typeof(UriFormatException))] // mime type without subtype @@ -48,7 +54,7 @@ public void Ctor_InvalidUri_Throws(string path, Type exception) [InlineData("type/subtype;key=value;another=")] public void Ctor_InvalidMediaType_Throws(string type) { - Assert.Throws("mediaType", () => new DataContent("http://localhost/test", type)); + Assert.Throws("mediaType", () => new DataContent("data:image/png;base64,aGVsbG8=", type)); } [Theory] @@ -58,7 +64,7 @@ public void Ctor_InvalidMediaType_Throws(string type) [InlineData("type/subtype;key=value;another=value;yet_another=value")] public void Ctor_ValidMediaType_Roundtrips(string mediaType) { - var content = new DataContent("http://localhost/test", mediaType); + var content = new DataContent("data:image/png;base64,aGVsbG8=", mediaType); Assert.Equal(mediaType, content.MediaType); content = new DataContent("data:,", mediaType); @@ -82,43 +88,25 @@ public void Ctor_NoMediaType_Roundtrips() { DataContent content; - foreach (string url in new[] { "http://localhost/test", "about:something", "file://c:\\path" }) - { - content = new DataContent(url); - Assert.Equal(url, content.Uri); - Assert.Null(content.MediaType); - Assert.Null(content.Data); - } - - content = new DataContent("data:,something"); - Assert.Equal("data:,something", content.Uri); - Assert.Null(content.MediaType); - Assert.Equal("something"u8.ToArray(), content.Data!.Value.ToArray()); - - content = new DataContent("data:,Hello+%3C%3E"); - Assert.Equal("data:,Hello+%3C%3E", content.Uri); - Assert.Null(content.MediaType); - Assert.Equal("Hello <>"u8.ToArray(), content.Data!.Value.ToArray()); + content = new DataContent("data:image/png;base64,aGVsbG8="); + Assert.Equal("data:image/png;base64,aGVsbG8=", content.Uri); + Assert.Equal("image/png", content.MediaType); + + content = new DataContent(new Uri("data:image/png;base64,aGVsbG8=")); + Assert.Equal("data:image/png;base64,aGVsbG8=", content.Uri); + Assert.Equal("image/png", content.MediaType); } [Fact] public void Serialize_MatchesExpectedJson() { Assert.Equal( - """{"uri":"data:,"}""", - JsonSerializer.Serialize(new DataContent("data:,"), TestJsonSerializerContext.Default.Options)); - - Assert.Equal( - """{"uri":"http://localhost/"}""", - JsonSerializer.Serialize(new DataContent(new Uri("http://localhost/")), TestJsonSerializerContext.Default.Options)); - - Assert.Equal( - """{"uri":"data:application/octet-stream;base64,AQIDBA==","mediaType":"application/octet-stream"}""", + """{"uri":"data:application/octet-stream;base64,AQIDBA=="}""", JsonSerializer.Serialize(new DataContent( uri: "data:application/octet-stream;base64,AQIDBA=="), TestJsonSerializerContext.Default.Options)); Assert.Equal( - """{"uri":"data:application/octet-stream;base64,AQIDBA==","mediaType":"application/octet-stream"}""", + """{"uri":"data:application/octet-stream;base64,AQIDBA=="}""", JsonSerializer.Serialize(new DataContent( new ReadOnlyMemory([0x01, 0x02, 0x03, 0x04]), "application/octet-stream"), TestJsonSerializerContext.Default.Options)); @@ -136,53 +124,43 @@ public void Deserialize_MissingUriString_Throws(string json) public void Deserialize_MatchesExpectedData() { // Data + MimeType only - var content = JsonSerializer.Deserialize("""{"mediaType":"application/octet-stream","uri":"data:;base64,AQIDBA=="}""", TestJsonSerializerContext.Default.Options)!; + var content = JsonSerializer.Deserialize("""{"uri":"data:application/octet-stream;base64,AQIDBA=="}""", TestJsonSerializerContext.Default.Options)!; Assert.Equal("data:application/octet-stream;base64,AQIDBA==", content.Uri); - Assert.NotNull(content.Data); - Assert.Equal([0x01, 0x02, 0x03, 0x04], content.Data.Value.ToArray()); + Assert.Equal([0x01, 0x02, 0x03, 0x04], content.Data.ToArray()); Assert.Equal("application/octet-stream", content.MediaType); // Uri referenced content-only - content = JsonSerializer.Deserialize("""{"mediaType":"application/octet-stream","uri":"http://localhost/"}""", TestJsonSerializerContext.Default.Options)!; + content = JsonSerializer.Deserialize("""{"uri":"data:application/octet-stream;base64,AQIDBA=="}""", TestJsonSerializerContext.Default.Options)!; - Assert.Null(content.Data); - Assert.Equal("http://localhost/", content.Uri); + Assert.Equal("data:application/octet-stream;base64,AQIDBA==", content.Uri); Assert.Equal("application/octet-stream", content.MediaType); // Using extra metadata content = JsonSerializer.Deserialize(""" { - "uri": "data:;base64,AQIDBA==", + "uri": "data:audio/wav;base64,AQIDBA==", "modelId": "gpt-4", "additionalProperties": { "key": "value" - }, - "mediaType": "text/plain" + } } """, TestJsonSerializerContext.Default.Options)!; - Assert.Equal("data:text/plain;base64,AQIDBA==", content.Uri); - Assert.NotNull(content.Data); - Assert.Equal([0x01, 0x02, 0x03, 0x04], content.Data.Value.ToArray()); - Assert.Equal("text/plain", content.MediaType); + Assert.Equal("data:audio/wav;base64,AQIDBA==", content.Uri); + Assert.Equal([0x01, 0x02, 0x03, 0x04], content.Data.ToArray()); + Assert.Equal("audio/wav", content.MediaType); Assert.Equal("value", content.AdditionalProperties!["key"]!.ToString()); } [Theory] [InlineData( - """{"uri": "data:;base64,AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=","mediaType": "text/plain"}""", - """{"uri":"data:text/plain;base64,AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=","mediaType":"text/plain"}""")] - [InlineData( - """{"uri": "data:text/plain;base64,AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=","mediaType": "text/plain"}""", - """{"uri":"data:text/plain;base64,AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=","mediaType":"text/plain"}""")] + """{"uri": "data:text/plain;base64,AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8="}""", + """{"uri":"data:text/plain;base64,AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8="}""")] [InlineData( // Does not support non-readable content """{"uri": "data:text/plain;base64,AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=", "unexpected": true}""", - """{"uri":"data:text/plain;base64,AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=","mediaType":"text/plain"}""")] - [InlineData( // Uri comes before mimetype - """{"mediaType": "text/plain", "uri": "http://localhost/" }""", - """{"uri":"http://localhost/","mediaType":"text/plain"}""")] + """{"uri":"data:text/plain;base64,AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8="}""")] public void Serialize_Deserialize_Roundtrips(string serialized, string expectedToString) { var content = JsonSerializer.Deserialize(serialized, TestJsonSerializerContext.Default.Options)!; @@ -222,30 +200,28 @@ public void MediaType_Roundtrips(string mediaType) } [Theory] - [InlineData("image/gif", "image/")] + [InlineData("image/gif", "image")] [InlineData("IMAGE/JPEG", "image")] - [InlineData("image/vnd.microsoft.icon", "ima")] - [InlineData("image/svg+xml", "IMAGE/")] + [InlineData("image/vnd.microsoft.icon", "imAge")] + [InlineData("image/svg+xml", "IMAGE")] [InlineData("image/nonexistentimagemimetype", "IMAGE")] - [InlineData("audio/mpeg", "aUdIo/")] - [InlineData("application/json", "")] - [InlineData("application/pdf", "application/pdf")] - public void HasMediaTypePrefix_ReturnsTrue(string? mediaType, string prefix) + [InlineData("audio/mpeg", "aUdIo")] + public void HasMediaTypePrefix_ReturnsTrue(string mediaType, string prefix) { - var content = new DataContent("http://localhost/image.png", mediaType); - Assert.True(content.MediaTypeStartsWith(prefix)); + var content = new DataContent("data:application/octet-stream;base64,AQIDBA==", mediaType); + Assert.True(content.HasTopLevelMediaType(prefix)); } [Theory] - [InlineData("audio/mpeg", "image/")] + [InlineData("audio/mpeg", "audio/")] + [InlineData("audio/mpeg", "image")] + [InlineData("audio/mpeg", "audio/mpeg")] [InlineData("text/css", "text/csv")] + [InlineData("text/css", "/csv")] [InlineData("application/json", "application/json!")] - [InlineData("", "")] // The media type will get normalized to null - [InlineData(null, "image/")] - [InlineData(null, "")] - public void HasMediaTypePrefix_ReturnsFalse(string? mediaType, string prefix) + public void HasMediaTypePrefix_ReturnsFalse(string mediaType, string prefix) { - var content = new DataContent("http://localhost/image.png", mediaType); - Assert.False(content.MediaTypeStartsWith(prefix)); + var content = new DataContent("data:application/octet-stream;base64,AQIDBA==", mediaType); + Assert.False(content.HasTopLevelMediaType(prefix)); } } diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/UriContentTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/UriContentTests.cs new file mode 100644 index 00000000000..8b4e8c6665d --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/UriContentTests.cs @@ -0,0 +1,130 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Text.Json; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public sealed class UriContentTests +{ + [Fact] + public void Ctor_InvalidUriMediaType_Throws() + { + Assert.Throws("uri", () => new UriContent((string)null!, "image/png")); + Assert.Throws("uri", () => new UriContent((Uri)null!, "image/png")); + Assert.Throws(() => new UriContent("notauri", "image/png")); + + Assert.Throws("mediaType", () => new UriContent("data:image/png;base64,aGVsbG8=", null!)); + Assert.Throws("mediaType", () => new UriContent("data:image/png;base64,aGVsbG8=", "")); + Assert.Throws("mediaType", () => new UriContent("data:image/png;base64,aGVsbG8=", "image")); + + Assert.Throws("mediaType", () => new UriContent(new Uri("data:image/png;base64,aGVsbG8="), null!)); + Assert.Throws("mediaType", () => new UriContent(new Uri("data:image/png;base64,aGVsbG8="), "")); + Assert.Throws("mediaType", () => new UriContent(new Uri("data:image/png;base64,aGVsbG8="), "audio")); + + UriContent c = new("http://localhost/something", "image/png"); + Assert.Throws("value", () => c.Uri = null!); + } + + [Theory] + [InlineData("type")] + [InlineData("type//subtype")] + [InlineData("type/subtype/")] + [InlineData("type/subtype;key=")] + [InlineData("type/subtype;=value")] + [InlineData("type/subtype;key=value;another=")] + public void Ctor_InvalidMediaType_Throws(string type) + { + Assert.Throws("mediaType", () => new UriContent("http://localhost/something", type)); + + UriContent c = new("http://localhost/something", "image/png"); + Assert.Throws("value", () => c.MediaType = type); + Assert.Throws("value", () => c.MediaType = null!); + } + + [Theory] + [InlineData("type/subtype")] + [InlineData("type/subtype;key=value")] + [InlineData("type/subtype;key=value;another=value")] + [InlineData("type/subtype;key=value;another=value;yet_another=value")] + public void Ctor_ValidMediaType_Roundtrips(string mediaType) + { + var content = new UriContent("http://localhost/something", mediaType); + Assert.Equal(mediaType, content.MediaType); + + content.MediaType = "image/png"; + Assert.Equal("image/png", content.MediaType); + + content.MediaType = mediaType; + Assert.Equal(mediaType, content.MediaType); + } + + [Fact] + public void Serialize_MatchesExpectedJson() + { + Assert.Equal( + """{"uri":"http://localhost/something","mediaType":"image/png"}""", + JsonSerializer.Serialize( + new UriContent("http://localhost/something", "image/png"), + TestJsonSerializerContext.Default.Options)); + } + + [Theory] + [InlineData("application/json")] + [InlineData("application/octet-stream")] + [InlineData("application/pdf")] + [InlineData("application/xml")] + [InlineData("audio/mpeg")] + [InlineData("audio/ogg")] + [InlineData("audio/wav")] + [InlineData("image/apng")] + [InlineData("image/avif")] + [InlineData("image/bmp")] + [InlineData("image/gif")] + [InlineData("image/jpeg")] + [InlineData("image/png")] + [InlineData("image/svg+xml")] + [InlineData("image/tiff")] + [InlineData("image/webp")] + [InlineData("text/css")] + [InlineData("text/csv")] + [InlineData("text/html")] + [InlineData("text/javascript")] + [InlineData("text/plain")] + [InlineData("text/plain;charset=UTF-8")] + [InlineData("text/xml")] + [InlineData("custom/mediatypethatdoesntexists")] + public void MediaType_Roundtrips(string mediaType) + { + UriContent c = new("http://localhost", mediaType); + Assert.Equal(mediaType, c.MediaType); + } + + [Theory] + [InlineData("image/gif", "image")] + [InlineData("IMAGE/JPEG", "image")] + [InlineData("image/vnd.microsoft.icon", "imAge")] + [InlineData("image/svg+xml", "IMAGE")] + [InlineData("image/nonexistentimagemimetype", "IMAGE")] + [InlineData("audio/mpeg", "aUdIo")] + public void HasMediaTypePrefix_ReturnsTrue(string mediaType, string prefix) + { + var content = new UriContent("http://localhost", mediaType); + Assert.True(content.HasTopLevelMediaType(prefix)); + } + + [Theory] + [InlineData("audio/mpeg", "audio/")] + [InlineData("audio/mpeg", "image")] + [InlineData("audio/mpeg", "audio/mpeg")] + [InlineData("text/css", "text/csv")] + [InlineData("text/css", "/csv")] + [InlineData("application/json", "application/json!")] + public void HasMediaTypePrefix_ReturnsFalse(string mediaType, string prefix) + { + var content = new UriContent("http://localhost", mediaType); + Assert.False(content.HasTopLevelMediaType(prefix)); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGeneratorExtensionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGeneratorExtensionsTests.cs index fe4af33cf23..dfe970b23ca 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGeneratorExtensionsTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGeneratorExtensionsTests.cs @@ -14,15 +14,12 @@ public class EmbeddingGeneratorExtensionsTests public void GetService_InvalidArgs_Throws() { Assert.Throws("generator", () => EmbeddingGeneratorExtensions.GetService(null!)); - Assert.Throws("generator", () => EmbeddingGeneratorExtensions.GetService, object>(null!)); } [Fact] public void GetRequiredService_InvalidArgs_Throws() { Assert.Throws("generator", () => EmbeddingGeneratorExtensions.GetRequiredService(null!)); - Assert.Throws("generator", () => EmbeddingGeneratorExtensions.GetRequiredService>(null!, typeof(string))); - Assert.Throws("generator", () => EmbeddingGeneratorExtensions.GetRequiredService, object>(null!)); using var generator = new TestEmbeddingGenerator(); Assert.Throws("serviceType", () => generator.GetRequiredService(null!)); @@ -51,41 +48,31 @@ public void GetService_ValidService_Returned() Assert.Equal("null key", generator.GetService(typeof(string))); Assert.Equal("null key", generator.GetService()); - Assert.Equal("null key", generator.GetService, string>()); Assert.Equal("non-null key", generator.GetService(typeof(string), "key")); Assert.Equal("non-null key", generator.GetService("key")); - Assert.Equal("non-null key", generator.GetService, string>("key")); Assert.Null(generator.GetService(typeof(object))); Assert.Null(generator.GetService()); - Assert.Null(generator.GetService, object>()); Assert.Null(generator.GetService(typeof(object), "key")); Assert.Null(generator.GetService("key")); - Assert.Null(generator.GetService, object>("key")); Assert.Null(generator.GetService()); - Assert.Null(generator.GetService, int?>()); Assert.Equal("null key", generator.GetRequiredService(typeof(string))); Assert.Equal("null key", generator.GetRequiredService()); - Assert.Equal("null key", generator.GetRequiredService, string>()); Assert.Equal("non-null key", generator.GetRequiredService(typeof(string), "key")); Assert.Equal("non-null key", generator.GetRequiredService("key")); - Assert.Equal("non-null key", generator.GetRequiredService, string>("key")); Assert.Throws(() => generator.GetRequiredService(typeof(object))); Assert.Throws(() => generator.GetRequiredService()); - Assert.Throws(() => generator.GetRequiredService, object>()); Assert.Throws(() => generator.GetRequiredService(typeof(object), "key")); Assert.Throws(() => generator.GetRequiredService("key")); - Assert.Throws(() => generator.GetRequiredService, object>("key")); Assert.Throws(() => generator.GetRequiredService()); - Assert.Throws(() => generator.GetRequiredService, int?>()); } [Fact] diff --git a/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceChatClientTests.cs index b8a68c913ed..d0167c8778b 100644 --- a/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceChatClientTests.cs @@ -615,7 +615,7 @@ public async Task MultipleContent_NonStreaming() Assert.NotNull(await client.GetResponseAsync([new(ChatRole.User, [ new TextContent("Describe this picture."), - new DataContent("http://dot.net/someimage.png", mediaType: "image/png"), + new UriContent("http://dot.net/someimage.png", mediaType: "image/*"), ])])); } diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ChatClientStructuredOutputExtensionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ChatClientStructuredOutputExtensionsTests.cs index 0b8aca0785e..0a499ab644d 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ChatClientStructuredOutputExtensionsTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ChatClientStructuredOutputExtensionsTests.cs @@ -149,7 +149,7 @@ public async Task FailureUsage_NullJson() [Fact] public async Task FailureUsage_NoJsonInResponse() { - var expectedResponse = new ChatResponse(new ChatMessage(ChatRole.Assistant, [new DataContent("https://example.com")])); + var expectedResponse = new ChatResponse(new ChatMessage(ChatRole.Assistant, [new UriContent("https://example.com", "image/*")])); using var client = new TestChatClient { GetResponseAsyncCallback = (messages, options, cancellationToken) => Task.FromResult(expectedResponse), diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DependencyInjectionPatterns.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DependencyInjectionPatterns.cs index c2f288165cb..3ff20afaad1 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DependencyInjectionPatterns.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DependencyInjectionPatterns.cs @@ -167,7 +167,8 @@ public void AddEmbeddingGenerator_RegistersExpectedLifetime(ServiceLifetime? lif ? sc.AddEmbeddingGenerator(services => new TestEmbeddingGenerator(), lifetime.Value) : sc.AddEmbeddingGenerator(services => new TestEmbeddingGenerator()); - ServiceDescriptor sd = Assert.Single(sc); + Assert.Equal(2, sc.Count); + ServiceDescriptor sd = sc[0]; Assert.Equal(typeof(IEmbeddingGenerator>), sd.ServiceType); Assert.False(sd.IsKeyedService); Assert.Null(sd.ImplementationInstance); @@ -176,6 +177,28 @@ public void AddEmbeddingGenerator_RegistersExpectedLifetime(ServiceLifetime? lif Assert.Equal(expectedLifetime, sd.Lifetime); } + [Theory] + [InlineData(null)] + [InlineData(ServiceLifetime.Singleton)] + [InlineData(ServiceLifetime.Scoped)] + [InlineData(ServiceLifetime.Transient)] + public void AddEmbeddingGenerator_RegistersNonGeneric(ServiceLifetime? lifetime) + { + ServiceCollection sc = new(); + ServiceLifetime expectedLifetime = lifetime ?? ServiceLifetime.Singleton; + var builder = lifetime.HasValue + ? sc.AddEmbeddingGenerator(services => new TestEmbeddingGenerator(), lifetime.Value) + : sc.AddEmbeddingGenerator(services => new TestEmbeddingGenerator()); + IServiceProvider sp = sc.BuildServiceProvider(); + + IEmbeddingGenerator>? g = sp.GetService>>(); + IEmbeddingGenerator? ng = sp.GetService(); + + Assert.NotNull(g); + Assert.NotNull(ng); + Assert.Equal(lifetime != ServiceLifetime.Transient, ReferenceEquals(g, ng)); + } + [Theory] [InlineData(null)] [InlineData(ServiceLifetime.Singleton)] @@ -189,7 +212,8 @@ public void AddKeyedEmbeddingGenerator_RegistersExpectedLifetime(ServiceLifetime ? sc.AddKeyedEmbeddingGenerator("key", services => new TestEmbeddingGenerator(), lifetime.Value) : sc.AddKeyedEmbeddingGenerator("key", services => new TestEmbeddingGenerator()); - ServiceDescriptor sd = Assert.Single(sc); + Assert.Equal(2, sc.Count); + ServiceDescriptor sd = sc[0]; Assert.Equal(typeof(IEmbeddingGenerator>), sd.ServiceType); Assert.True(sd.IsKeyedService); Assert.Equal("key", sd.ServiceKey); @@ -199,6 +223,28 @@ public void AddKeyedEmbeddingGenerator_RegistersExpectedLifetime(ServiceLifetime Assert.Equal(expectedLifetime, sd.Lifetime); } + [Theory] + [InlineData(null)] + [InlineData(ServiceLifetime.Singleton)] + [InlineData(ServiceLifetime.Scoped)] + [InlineData(ServiceLifetime.Transient)] + public void AddKeyedEmbeddingGenerator_RegistersNonGeneric(ServiceLifetime? lifetime) + { + ServiceCollection sc = new(); + ServiceLifetime expectedLifetime = lifetime ?? ServiceLifetime.Singleton; + var builder = lifetime.HasValue + ? sc.AddKeyedEmbeddingGenerator("key", services => new TestEmbeddingGenerator(), lifetime.Value) + : sc.AddKeyedEmbeddingGenerator("key", services => new TestEmbeddingGenerator()); + IServiceProvider sp = sc.BuildServiceProvider(); + + IEmbeddingGenerator>? g = sp.GetKeyedService>>("key"); + IEmbeddingGenerator? ng = sp.GetKeyedService("key"); + + Assert.NotNull(g); + Assert.NotNull(ng); + Assert.Equal(lifetime != ServiceLifetime.Transient, ReferenceEquals(g, ng)); + } + public class SingletonMiddleware(IChatClient inner, IServiceProvider services) : DelegatingChatClient(inner) { public new IChatClient InnerClient => base.InnerClient; From f663faff909fd0a1455eddf735bdd8af571680d5 Mon Sep 17 00:00:00 2001 From: Steve Sanderson Date: Tue, 11 Mar 2025 14:11:44 +0000 Subject: [PATCH 7/7] Fix grouping of ChatResponseUpdate into ChatMessage (#6074) * For Ollama client, ensure ToChatResponseAsync coalesces text chunks into a single message * Fix OpenAI case by not treating empty-string response IDs as message boundaries --- .../ChatCompletion/ChatResponseExtensions.cs | 4 ++-- .../ChatCompletion/ChatResponseUpdate.cs | 6 ++++++ .../OllamaChatClient.cs | 5 ++++- .../ChatResponseUpdateExtensionsTests.cs | 2 +- .../ChatClientIntegrationTests.cs | 19 +++++++++++++++++++ .../OllamaChatClientTests.cs | 9 +++++++-- 6 files changed, 39 insertions(+), 6 deletions(-) diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseExtensions.cs index 16eed49db93..e6fb9d4dafb 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseExtensions.cs @@ -171,7 +171,7 @@ private static void ProcessUpdate(ChatResponseUpdate update, ChatResponse respon // response ID than the newest update, create a new message. ChatMessage message; if (response.Messages.Count == 0 || - (update.ResponseId is string updateId && response.ResponseId is string responseId && updateId != responseId)) + (update.ResponseId is { Length: > 0 } updateId && response.ResponseId is string responseId && updateId != responseId)) { message = new ChatMessage(ChatRole.Assistant, []); response.Messages.Add(message); @@ -213,7 +213,7 @@ private static void ProcessUpdate(ChatResponseUpdate update, ChatResponse respon // Other members on a ChatResponseUpdate map to members of the ChatResponse. // Update the response object with those, preferring the values from later updates. - if (update.ResponseId is not null) + if (update.ResponseId is { Length: > 0 }) { // Note that this must come after the message checks earlier, as they depend // on this value for change detection. diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdate.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdate.cs index 24610ac76fc..346a5ed0d65 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdate.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseUpdate.cs @@ -99,6 +99,12 @@ public IList Contents public AdditionalPropertiesDictionary? AdditionalProperties { get; set; } /// Gets or sets the ID of the response of which this update is a part. + /// + /// This value is used when + /// groups instances into instances. + /// The value must be unique to each call to the underlying provider, and must be shared by + /// all updates that are part of the same response. + /// public string? ResponseId { get; set; } /// Gets or sets the chat thread ID associated with the chat response of which this update is a part. diff --git a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs index 0af538b9802..4fede2b6ceb 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs @@ -132,6 +132,9 @@ public async IAsyncEnumerable GetStreamingResponseAsync( await OllamaUtilities.ThrowUnsuccessfulOllamaResponseAsync(httpResponse, cancellationToken).ConfigureAwait(false); } + // Ollama doesn't set a response ID on streamed chunks, so we need to generate one. + var responseId = Guid.NewGuid().ToString("N"); + using var httpResponseStream = await httpResponse.Content #if NET .ReadAsStreamAsync(cancellationToken) @@ -160,7 +163,7 @@ public async IAsyncEnumerable GetStreamingResponseAsync( CreatedAt = DateTimeOffset.TryParse(chunk.CreatedAt, CultureInfo.InvariantCulture, DateTimeStyles.None, out DateTimeOffset createdAt) ? createdAt : null, FinishReason = ToFinishReason(chunk), ModelId = modelId, - ResponseId = chunk.CreatedAt, + ResponseId = responseId, Role = chunk.Message?.Role is not null ? new ChatRole(chunk.Message.Role) : null, }; diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseUpdateExtensionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseUpdateExtensionsTests.cs index 00e074ab276..4c20074301c 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseUpdateExtensionsTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseUpdateExtensionsTests.cs @@ -51,7 +51,7 @@ public async Task ToChatResponse_SuccessfullyCreatesResponse(bool useAsync) Assert.Equal("123", response.ChatThreadId); - ChatMessage message = response.Messages.Last(); + ChatMessage message = response.Messages.Single(); Assert.Equal(new ChatRole("human"), message.Role); Assert.Equal("Someone", message.AuthorName); Assert.Null(message.AdditionalProperties); diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs index 55b840eea5f..81e4d1044e3 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs @@ -130,6 +130,25 @@ public virtual async Task GetStreamingResponseAsync_UsageDataAvailable() Assert.Equal(usage.Details.InputTokenCount + usage.Details.OutputTokenCount, usage.Details.TotalTokenCount); } + [ConditionalFact] + public virtual async Task GetStreamingResponseAsync_AppendToHistory() + { + SkipIfNotEnabled(); + + List history = [new(ChatRole.User, "Explain in 100 words how AI works")]; + + var streamingResponse = _chatClient.GetStreamingResponseAsync(history); + + Assert.Single(history); + await history.AddMessagesAsync(streamingResponse); + Assert.Equal(2, history.Count); + Assert.Equal(ChatRole.Assistant, history[1].Role); + + var singleTextContent = (TextContent)history[1].Contents.Single(); + Assert.NotEmpty(singleTextContent.Text); + Assert.Equal(history[1].Text, singleTextContent.Text); + } + protected virtual string? GetModel_MultiModal_DescribeImage() => null; [ConditionalFact] diff --git a/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientTests.cs index 8f7499aa272..16df3bc52ea 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientTests.cs @@ -171,11 +171,12 @@ public async Task BasicRequestResponse_Streaming() using IChatClient client = new OllamaChatClient("http://localhost:11434", "llama3.1", httpClient); List updates = []; - await foreach (var update in client.GetStreamingResponseAsync("hello", new() + var streamingResponse = client.GetStreamingResponseAsync("hello", new() { MaxOutputTokens = 20, Temperature = 0.5f, - })) + }); + await foreach (var update in streamingResponse) { updates.Add(update); } @@ -201,6 +202,10 @@ public async Task BasicRequestResponse_Streaming() Assert.Equal(11, usage.Details.InputTokenCount); Assert.Equal(20, usage.Details.OutputTokenCount); Assert.Equal(31, usage.Details.TotalTokenCount); + + var chatResponse = await streamingResponse.ToChatResponseAsync(); + Assert.Single(Assert.Single(chatResponse.Messages).Contents); + Assert.Equal("Hello! How are you today? Is there something I can help you with or would you like to", chatResponse.Text); } [Fact]