From 574d7f723ddf00170f308653c55295a60d3eae68 Mon Sep 17 00:00:00 2001 From: ajcvickers Date: Tue, 26 Nov 2024 17:04:15 +0000 Subject: [PATCH 1/3] Return null when the type is nullable for Cosmos Max/Min/Average (#35173) * Return null when the type is nullable for Cosmos Max/Min/Average Fixes #35094 This was a regression resulting from the major Cosmos query refactoring that happened in EF9. In EF8, the functions Min, Max, and Average would return null if the return type was nullable or was cast to a nullable when the collection is empty. In EF9, this started throwing, which is correct for non-nullable types, but a regression for nullable types. * Added notes --- ...yableMethodTranslatingExpressionVisitor.cs | 118 +++++--------- ...yableMethodTranslatingExpressionVisitor.cs | 56 ++----- .../AdHocMiscellaneousQueryCosmosTest.cs | 111 ++++++++++++++ ...thwindAggregateOperatorsQueryCosmosTest.cs | 144 ++++++------------ 4 files changed, 203 insertions(+), 226 deletions(-) diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosQueryableMethodTranslatingExpressionVisitor.cs index 9b446a78152..ba6ee9d5027 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosQueryableMethodTranslatingExpressionVisitor.cs @@ -444,25 +444,7 @@ private ShapedQueryExpression CreateShapedQueryExpression(SelectExpression selec /// doing so can result in application failures when updating to a new Entity Framework Core release. /// protected override ShapedQueryExpression? TranslateAverage(ShapedQueryExpression source, LambdaExpression? selector, Type resultType) - { - var selectExpression = (SelectExpression)source.QueryExpression; - if (selectExpression.IsDistinct - || selectExpression.Limit != null - || selectExpression.Offset != null) - { - return null; - } - - if (selector != null) - { - source = TranslateSelect(source, selector); - } - - var projection = (SqlExpression)selectExpression.GetMappedProjection(new ProjectionMember()); - projection = _sqlExpressionFactory.Function("AVG", new[] { projection }, projection.Type, projection.TypeMapping); - - return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType); - } + => TranslateAggregate(source, selector, resultType, "AVG"); /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to @@ -842,26 +824,7 @@ protected override ShapedQueryExpression TranslateCast(ShapedQueryExpression sou /// doing so can result in application failures when updating to a new Entity Framework Core release. /// protected override ShapedQueryExpression? TranslateMax(ShapedQueryExpression source, LambdaExpression? selector, Type resultType) - { - var selectExpression = (SelectExpression)source.QueryExpression; - if (selectExpression.IsDistinct - || selectExpression.Limit != null - || selectExpression.Offset != null) - { - return null; - } - - if (selector != null) - { - source = TranslateSelect(source, selector); - } - - var projection = (SqlExpression)selectExpression.GetMappedProjection(new ProjectionMember()); - - projection = _sqlExpressionFactory.Function("MAX", new[] { projection }, resultType, projection.TypeMapping); - - return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType); - } + => TranslateAggregate(source, selector, resultType, "MAX"); /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to @@ -870,26 +833,7 @@ protected override ShapedQueryExpression TranslateCast(ShapedQueryExpression sou /// doing so can result in application failures when updating to a new Entity Framework Core release. /// protected override ShapedQueryExpression? TranslateMin(ShapedQueryExpression source, LambdaExpression? selector, Type resultType) - { - var selectExpression = (SelectExpression)source.QueryExpression; - if (selectExpression.IsDistinct - || selectExpression.Limit != null - || selectExpression.Offset != null) - { - return null; - } - - if (selector != null) - { - source = TranslateSelect(source, selector); - } - - var projection = (SqlExpression)selectExpression.GetMappedProjection(new ProjectionMember()); - - projection = _sqlExpressionFactory.Function("MIN", new[] { projection }, resultType, projection.TypeMapping); - - return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType); - } + => TranslateAggregate(source, selector, resultType, "MIN"); /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to @@ -1242,7 +1186,7 @@ protected override ShapedQueryExpression TranslateSelect(ShapedQueryExpression s projection = _sqlExpressionFactory.Function("SUM", new[] { projection }, serverOutputType, projection.TypeMapping); - return AggregateResultShaper(source, projection, throwOnNullResult: false, resultType); + return AggregateResultShaper(source, projection, resultType); } /// @@ -1520,6 +1464,35 @@ protected override ShapedQueryExpression TranslateSelect(ShapedQueryExpression s #endregion Queryable collection support + private ShapedQueryExpression? TranslateAggregate(ShapedQueryExpression source, LambdaExpression? selector, Type resultType, string functionName) + { + var selectExpression = (SelectExpression)source.QueryExpression; + if (selectExpression.IsDistinct + || selectExpression.Limit != null + || selectExpression.Offset != null) + { + return null; + } + + if (selector != null) + { + source = TranslateSelect(source, selector); + } + + if (!_subquery && resultType.IsNullableType()) + { + // For nullable types, we want to return null from Max, Min, and Average, rather than throwing. See Issue #35094. + // Note that relational databases typically return null, which propagates. Cosmos will instead return no elements, + // and hence for Cosmos only we need to change no elements into null. + source = source.UpdateResultCardinality(ResultCardinality.SingleOrDefault); + } + + var projection = (SqlExpression)selectExpression.GetMappedProjection(new ProjectionMember()); + projection = _sqlExpressionFactory.Function(functionName, [projection], resultType, _typeMappingSource.FindMapping(resultType)); + + return AggregateResultShaper(source, projection, resultType); + } + private bool TryApplyPredicate(ShapedQueryExpression source, LambdaExpression predicate) { var select = (SelectExpression)source.QueryExpression; @@ -1700,7 +1673,6 @@ private Expression RemapLambdaBody(ShapedQueryExpression shapedQueryExpression, private static ShapedQueryExpression AggregateResultShaper( ShapedQueryExpression source, Expression projection, - bool throwOnNullResult, Type resultType) { var selectExpression = (SelectExpression)source.QueryExpression; @@ -1711,29 +1683,7 @@ private static ShapedQueryExpression AggregateResultShaper( var nullableResultType = resultType.MakeNullable(); Expression shaper = new ProjectionBindingExpression(source.QueryExpression, new ProjectionMember(), nullableResultType); - if (throwOnNullResult) - { - var resultVariable = Expression.Variable(nullableResultType, "result"); - var returnValueForNull = resultType.IsNullableType() - ? (Expression)Expression.Constant(null, resultType) - : Expression.Throw( - Expression.New( - typeof(InvalidOperationException).GetConstructors() - .Single(ci => ci.GetParameters().Length == 1), - Expression.Constant(CoreStrings.SequenceContainsNoElements)), - resultType); - - shaper = Expression.Block( - new[] { resultVariable }, - Expression.Assign(resultVariable, shaper), - Expression.Condition( - Expression.Equal(resultVariable, Expression.Default(nullableResultType)), - returnValueForNull, - resultType != resultVariable.Type - ? Expression.Convert(resultVariable, resultType) - : resultVariable)); - } - else if (resultType != shaper.Type) + if (resultType != shaper.Type) { shaper = Expression.Convert(shaper, resultType); } diff --git a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs index d98d0edaa1e..ffc0b1cde5e 100644 --- a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs @@ -518,7 +518,7 @@ private static ShapedQueryExpression CreateShapedQueryExpression(IEntityType ent ShapedQueryExpression source, LambdaExpression? selector, Type resultType) - => TranslateAggregateWithSelector(source, selector, QueryableMethods.GetAverageWithoutSelector, throwWhenEmpty: true, resultType); + => TranslateAggregateWithSelector(source, selector, QueryableMethods.GetAverageWithoutSelector, resultType); /// protected override ShapedQueryExpression TranslateCast(ShapedQueryExpression source, Type resultType) @@ -968,7 +968,7 @@ private SqlExpression CreateJoinPredicate(Expression outerKey, Expression innerK } return TranslateAggregateWithSelector( - source, selector, t => QueryableMethods.MaxWithoutSelector.MakeGenericMethod(t), throwWhenEmpty: true, resultType); + source, selector, t => QueryableMethods.MaxWithoutSelector.MakeGenericMethod(t), resultType); } /// @@ -984,7 +984,7 @@ private SqlExpression CreateJoinPredicate(Expression outerKey, Expression innerK } return TranslateAggregateWithSelector( - source, selector, t => QueryableMethods.MinWithoutSelector.MakeGenericMethod(t), throwWhenEmpty: true, resultType); + source, selector, t => QueryableMethods.MinWithoutSelector.MakeGenericMethod(t), resultType); } /// @@ -1235,7 +1235,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp /// protected override ShapedQueryExpression? TranslateSum(ShapedQueryExpression source, LambdaExpression? selector, Type resultType) - => TranslateAggregateWithSelector(source, selector, QueryableMethods.GetSumWithoutSelector, throwWhenEmpty: false, resultType); + => TranslateAggregateWithSelector(source, selector, QueryableMethods.GetSumWithoutSelector, resultType); /// protected override ShapedQueryExpression? TranslateTake(ShapedQueryExpression source, Expression count) @@ -1958,7 +1958,6 @@ private static Expression MatchShaperNullabilityForSetOperation(Expression shape ShapedQueryExpression source, LambdaExpression? selectorLambda, Func methodGenerator, - bool throwWhenEmpty, Type resultType) { var selectExpression = (SelectExpression)source.QueryExpression; @@ -2004,48 +2003,13 @@ private static Expression MatchShaperNullabilityForSetOperation(Expression shape new Dictionary { { new ProjectionMember(), translation } }); selectExpression.ClearOrdering(); - Expression shaper; - - if (throwWhenEmpty) - { - // Avg/Max/Min case. - // We always read nullable value - // If resultType is nullable then we always return null. Only non-null result shows throwing behavior. - // otherwise, if projection.Type is nullable then server result is passed through DefaultIfEmpty, hence we return default - // otherwise, server would return null only if it is empty, and we throw - var nullableResultType = resultType.MakeNullable(); - shaper = new ProjectionBindingExpression(source.QueryExpression, new ProjectionMember(), nullableResultType); - var resultVariable = Expression.Variable(nullableResultType, "result"); - var returnValueForNull = resultType.IsNullableType() - ? (Expression)Expression.Default(resultType) - : translation.Type.IsNullableType() - ? Expression.Default(resultType) - : Expression.Throw( - Expression.New( - typeof(InvalidOperationException).GetConstructors() - .Single(ci => ci.GetParameters().Length == 1), - Expression.Constant(CoreStrings.SequenceContainsNoElements)), - resultType); - - shaper = Expression.Block( - new[] { resultVariable }, - Expression.Assign(resultVariable, shaper), - Expression.Condition( - Expression.Equal(resultVariable, Expression.Default(nullableResultType)), - returnValueForNull, - resultType != resultVariable.Type - ? Expression.Convert(resultVariable, resultType) - : resultVariable)); - } - else - { - // Sum case. Projection is always non-null. We read nullable value. - shaper = new ProjectionBindingExpression(source.QueryExpression, new ProjectionMember(), translation.Type.MakeNullable()); - if (resultType != shaper.Type) - { - shaper = Expression.Convert(shaper, resultType); - } + // Sum case. Projection is always non-null. We read nullable value. + Expression shaper = new ProjectionBindingExpression(source.QueryExpression, new ProjectionMember(), translation.Type.MakeNullable()); + + if (resultType != shaper.Type) + { + shaper = Expression.Convert(shaper, resultType); } return source.UpdateShaperExpression(shaper); diff --git a/test/EFCore.Cosmos.FunctionalTests/Query/AdHocMiscellaneousQueryCosmosTest.cs b/test/EFCore.Cosmos.FunctionalTests/Query/AdHocMiscellaneousQueryCosmosTest.cs index 7fc328c16fd..e2dd2b90904 100644 --- a/test/EFCore.Cosmos.FunctionalTests/Query/AdHocMiscellaneousQueryCosmosTest.cs +++ b/test/EFCore.Cosmos.FunctionalTests/Query/AdHocMiscellaneousQueryCosmosTest.cs @@ -1,6 +1,8 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.ComponentModel.DataAnnotations.Schema; + namespace Microsoft.EntityFrameworkCore.Query; #nullable disable @@ -50,6 +52,115 @@ public enum MemberType #endregion 34911 + #region 35094 + + // TODO: Move these tests to a better location. They require nullable properties with nulls in the database. + + [ConditionalFact] + public virtual async Task Min_over_value_type_containing_nulls() + { + await using var context = (await InitializeAsync()).CreateContext(); + Assert.Null(await context.Set().MinAsync(p => p.NullableVal)); + } + + [ConditionalFact] + public virtual async Task Min_over_value_type_containing_all_nulls() + { + await using var context = (await InitializeAsync()).CreateContext(); + Assert.Null(await context.Set().Where(e => e.NullableVal == null).MinAsync(p => p.NullableVal)); + } + + [ConditionalFact] + public virtual async Task Min_over_reference_type_containing_nulls() + { + await using var context = (await InitializeAsync()).CreateContext(); + Assert.Null(await context.Set().MinAsync(p => p.NullableRef)); + } + + [ConditionalFact] + public virtual async Task Min_over_reference_type_containing_all_nulls() + { + await using var context = (await InitializeAsync()).CreateContext(); + Assert.Null(await context.Set().Where(e => e.NullableRef == null).MinAsync(p => p.NullableRef)); + } + + [ConditionalFact] + public virtual async Task Min_over_reference_type_containing_no_data() + { + await using var context = (await InitializeAsync()).CreateContext(); + Assert.Null(await context.Set().Where(e => e.Id < 0).MinAsync(p => p.NullableRef)); + } + + [ConditionalFact] + public virtual async Task Max_over_value_type_containing_nulls() + { + await using var context = (await InitializeAsync()).CreateContext(); + Assert.Equal(3.14, await context.Set().MaxAsync(p => p.NullableVal)); + } + + [ConditionalFact] + public virtual async Task Max_over_value_type_containing_all_nulls() + { + await using var context = (await InitializeAsync()).CreateContext(); + Assert.Null(await context.Set().Where(e => e.NullableVal == null).MaxAsync(p => p.NullableVal)); + } + + [ConditionalFact] + public virtual async Task Max_over_reference_type_containing_nulls() + { + await using var context = (await InitializeAsync()).CreateContext(); + Assert.Equal("Value", await context.Set().MaxAsync(p => p.NullableRef)); + } + + [ConditionalFact] + public virtual async Task Max_over_reference_type_containing_all_nulls() + { + await using var context = (await InitializeAsync()).CreateContext(); + Assert.Null(await context.Set().Where(e => e.NullableRef == null).MaxAsync(p => p.NullableRef)); + } + + [ConditionalFact] + public virtual async Task Max_over_reference_type_containing_no_data() + { + await using var context = (await InitializeAsync()).CreateContext(); + Assert.Null(await context.Set().Where(e => e.Id < 0).MaxAsync(p => p.NullableRef)); + } + + [ConditionalFact] + public virtual async Task Average_over_value_type_containing_nulls() + { + await using var context = (await InitializeAsync()).CreateContext(); + Assert.Null(await context.Set().AverageAsync(p => p.NullableVal)); + } + + [ConditionalFact] + public virtual async Task Average_over_value_type_containing_all_nulls() + { + await using var context = (await InitializeAsync()).CreateContext(); + Assert.Null(await context.Set().Where(e => e.NullableVal == null).AverageAsync(p => p.NullableVal)); + } + + protected class Context35094(DbContextOptions options) : DbContext(options) + { + public DbSet Products { get; set; } + + protected override void OnModelCreating(ModelBuilder modelBuilder) + => modelBuilder.Entity().HasData( + new Product { Id = 1, NullableRef = "Value", NullableVal = 3.14 }, + new Product { Id = 2, NullableVal = 3.14 }, + new Product { Id = 3, NullableRef = "Value" }); + + public class Product + { + [DatabaseGenerated(DatabaseGeneratedOption.None)] + public int Id { get; set; } + public double? NullableVal { get; set; } + public string NullableRef { get; set; } + } + } + + #endregion 35094 + protected override string StoreName => "AdHocMiscellaneousQueryTests"; diff --git a/test/EFCore.Cosmos.FunctionalTests/Query/NorthwindAggregateOperatorsQueryCosmosTest.cs b/test/EFCore.Cosmos.FunctionalTests/Query/NorthwindAggregateOperatorsQueryCosmosTest.cs index 2d30fc87c90..4def5734725 100644 --- a/test/EFCore.Cosmos.FunctionalTests/Query/NorthwindAggregateOperatorsQueryCosmosTest.cs +++ b/test/EFCore.Cosmos.FunctionalTests/Query/NorthwindAggregateOperatorsQueryCosmosTest.cs @@ -555,49 +555,33 @@ FROM root c } } - public override async Task Average_no_data_nullable(bool async) - { - // Sync always throws before getting to exception being tested. - if (async) - { - await Fixture.NoSyncTest( - async, async a => - { - Assert.Equal( - CoreStrings.SequenceContainsNoElements, - (await Assert.ThrowsAsync(() => base.Average_no_data_nullable(a))).Message); + public override Task Average_no_data_nullable(bool async) + => Fixture.NoSyncTest( + async, async a => + { + await base.Average_no_data_nullable(a); - AssertSql( - """ + AssertSql( + """ SELECT VALUE AVG(c["SupplierID"]) FROM root c WHERE ((c["$type"] = "Product") AND (c["SupplierID"] = -1)) """); - }); - } - } + }); - public override async Task Average_no_data_cast_to_nullable(bool async) - { - // Sync always throws before getting to exception being tested. - if (async) - { - await Fixture.NoSyncTest( - async, async a => - { - Assert.Equal( - CoreStrings.SequenceContainsNoElements, - (await Assert.ThrowsAsync(() => base.Average_no_data_cast_to_nullable(a))).Message); + public override Task Average_no_data_cast_to_nullable(bool async) + => Fixture.NoSyncTest( + async, async a => + { + await base.Average_no_data_cast_to_nullable(a); - AssertSql( - """ + AssertSql( + """ SELECT VALUE AVG(c["OrderID"]) FROM root c WHERE ((c["$type"] = "Order") AND (c["OrderID"] = -1)) """); - }); - } - } + }); public override async Task Min_no_data(bool async) { @@ -647,49 +631,33 @@ public override async Task Max_no_data_subquery(bool async) AssertSql(); } - public override async Task Max_no_data_nullable(bool async) - { - // Sync always throws before getting to exception being tested. - if (async) - { - await Fixture.NoSyncTest( - async, async a => - { - Assert.Equal( - CoreStrings.SequenceContainsNoElements, - (await Assert.ThrowsAsync(() => base.Max_no_data_nullable(a))).Message); + public override Task Max_no_data_nullable(bool async) + => Fixture.NoSyncTest( + async, async a => + { + await base.Max_no_data_nullable(a); - AssertSql( - """ + AssertSql( + """ SELECT VALUE MAX(c["SupplierID"]) FROM root c WHERE ((c["$type"] = "Product") AND (c["SupplierID"] = -1)) """); - }); - } - } + }); - public override async Task Max_no_data_cast_to_nullable(bool async) - { - // Sync always throws before getting to exception being tested. - if (async) - { - await Fixture.NoSyncTest( - async, async a => - { - Assert.Equal( - CoreStrings.SequenceContainsNoElements, - (await Assert.ThrowsAsync(() => base.Max_no_data_cast_to_nullable(a))).Message); + public override Task Max_no_data_cast_to_nullable(bool async) + => Fixture.NoSyncTest( + async, async a => + { + await base.Max_no_data_cast_to_nullable(a); - AssertSql( - """ + AssertSql( + """ SELECT VALUE MAX(c["OrderID"]) FROM root c WHERE ((c["$type"] = "Order") AND (c["OrderID"] = -1)) """); - }); - } - } + }); public override async Task Min_no_data_subquery(bool async) { @@ -874,49 +842,33 @@ FROM root c """); }); - public override async Task Min_no_data_nullable(bool async) - { - // Sync always throws before getting to exception being tested. - if (async) - { - await Fixture.NoSyncTest( - async, async a => - { - Assert.Equal( - CoreStrings.SequenceContainsNoElements, - (await Assert.ThrowsAsync(() => base.Min_no_data_nullable(a))).Message); + public override Task Min_no_data_nullable(bool async) + => Fixture.NoSyncTest( + async, async a => + { + await base.Min_no_data_nullable(a); - AssertSql( - """ + AssertSql( + """ SELECT VALUE MIN(c["SupplierID"]) FROM root c WHERE ((c["$type"] = "Product") AND (c["SupplierID"] = -1)) """); - }); - } - } + }); - public override async Task Min_no_data_cast_to_nullable(bool async) - { - // Sync always throws before getting to exception being tested. - if (async) - { - await Fixture.NoSyncTest( - async, async a => - { - Assert.Equal( - CoreStrings.SequenceContainsNoElements, - (await Assert.ThrowsAsync(() => base.Min_no_data_cast_to_nullable(a))).Message); + public override Task Min_no_data_cast_to_nullable(bool async) + => Fixture.NoSyncTest( + async, async a => + { + await base.Min_no_data_cast_to_nullable(a); - AssertSql( - """ + AssertSql( + """ SELECT VALUE MIN(c["OrderID"]) FROM root c WHERE ((c["$type"] = "Order") AND (c["OrderID"] = -1)) """); - }); - } - } + }); public override Task Min_with_coalesce(bool async) => Fixture.NoSyncTest( From a035706dfeb81866046b0dbd8579cce9bd3b8db7 Mon Sep 17 00:00:00 2001 From: ajcvickers Date: Tue, 26 Nov 2024 17:22:17 +0000 Subject: [PATCH 2/3] Added quirks --- ...yableMethodTranslatingExpressionVisitor.cs | 110 +++++++++++++++++- ...yableMethodTranslatingExpressionVisitor.cs | 56 +++++++-- 2 files changed, 150 insertions(+), 16 deletions(-) diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosQueryableMethodTranslatingExpressionVisitor.cs index ba6ee9d5027..19453ab8e34 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosQueryableMethodTranslatingExpressionVisitor.cs @@ -16,6 +16,9 @@ namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal; /// public class CosmosQueryableMethodTranslatingExpressionVisitor : QueryableMethodTranslatingExpressionVisitor { + private static readonly bool UseOldBehavior35094 = + AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue35094", out var enabled) && enabled; + private readonly CosmosQueryCompilationContext _queryCompilationContext; private readonly ISqlExpressionFactory _sqlExpressionFactory; private readonly ITypeMappingSource _typeMappingSource; @@ -444,7 +447,31 @@ private ShapedQueryExpression CreateShapedQueryExpression(SelectExpression selec /// doing so can result in application failures when updating to a new Entity Framework Core release. /// protected override ShapedQueryExpression? TranslateAverage(ShapedQueryExpression source, LambdaExpression? selector, Type resultType) - => TranslateAggregate(source, selector, resultType, "AVG"); + { + if (UseOldBehavior35094) + { + var selectExpression = (SelectExpression)source.QueryExpression; + if (selectExpression.IsDistinct + || selectExpression.Limit != null + || selectExpression.Offset != null) + { + return null; + } + + if (selector != null) + { + source = TranslateSelect(source, selector); + } + + var projection = (SqlExpression)selectExpression.GetMappedProjection(new ProjectionMember()); + projection = _sqlExpressionFactory.Function("AVG", new[] { projection }, projection.Type, projection.TypeMapping); + + return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType); + + } + + return TranslateAggregate(source, selector, resultType, "AVG"); + } /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to @@ -824,7 +851,31 @@ protected override ShapedQueryExpression TranslateCast(ShapedQueryExpression sou /// doing so can result in application failures when updating to a new Entity Framework Core release. /// protected override ShapedQueryExpression? TranslateMax(ShapedQueryExpression source, LambdaExpression? selector, Type resultType) - => TranslateAggregate(source, selector, resultType, "MAX"); + { + if (UseOldBehavior35094) + { + var selectExpression = (SelectExpression)source.QueryExpression; + if (selectExpression.IsDistinct + || selectExpression.Limit != null + || selectExpression.Offset != null) + { + return null; + } + + if (selector != null) + { + source = TranslateSelect(source, selector); + } + + var projection = (SqlExpression)selectExpression.GetMappedProjection(new ProjectionMember()); + + projection = _sqlExpressionFactory.Function("MAX", new[] { projection }, resultType, projection.TypeMapping); + + return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType); + } + + return TranslateAggregate(source, selector, resultType, "MAX"); + } /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to @@ -833,7 +884,31 @@ protected override ShapedQueryExpression TranslateCast(ShapedQueryExpression sou /// doing so can result in application failures when updating to a new Entity Framework Core release. /// protected override ShapedQueryExpression? TranslateMin(ShapedQueryExpression source, LambdaExpression? selector, Type resultType) - => TranslateAggregate(source, selector, resultType, "MIN"); + { + if (UseOldBehavior35094) + { + var selectExpression = (SelectExpression)source.QueryExpression; + if (selectExpression.IsDistinct + || selectExpression.Limit != null + || selectExpression.Offset != null) + { + return null; + } + + if (selector != null) + { + source = TranslateSelect(source, selector); + } + + var projection = (SqlExpression)selectExpression.GetMappedProjection(new ProjectionMember()); + + projection = _sqlExpressionFactory.Function("MIN", new[] { projection }, resultType, projection.TypeMapping); + + return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType); + } + + return TranslateAggregate(source, selector, resultType, "MIN"); + } /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to @@ -1186,7 +1261,7 @@ protected override ShapedQueryExpression TranslateSelect(ShapedQueryExpression s projection = _sqlExpressionFactory.Function("SUM", new[] { projection }, serverOutputType, projection.TypeMapping); - return AggregateResultShaper(source, projection, resultType); + return AggregateResultShaper(source, projection, throwOnNullResult: false, resultType); } /// @@ -1490,7 +1565,7 @@ protected override ShapedQueryExpression TranslateSelect(ShapedQueryExpression s var projection = (SqlExpression)selectExpression.GetMappedProjection(new ProjectionMember()); projection = _sqlExpressionFactory.Function(functionName, [projection], resultType, _typeMappingSource.FindMapping(resultType)); - return AggregateResultShaper(source, projection, resultType); + return AggregateResultShaper(source, projection, throwOnNullResult: false, resultType); } private bool TryApplyPredicate(ShapedQueryExpression source, LambdaExpression predicate) @@ -1673,6 +1748,7 @@ private Expression RemapLambdaBody(ShapedQueryExpression shapedQueryExpression, private static ShapedQueryExpression AggregateResultShaper( ShapedQueryExpression source, Expression projection, + bool throwOnNullResult, Type resultType) { var selectExpression = (SelectExpression)source.QueryExpression; @@ -1683,7 +1759,29 @@ private static ShapedQueryExpression AggregateResultShaper( var nullableResultType = resultType.MakeNullable(); Expression shaper = new ProjectionBindingExpression(source.QueryExpression, new ProjectionMember(), nullableResultType); - if (resultType != shaper.Type) + if (throwOnNullResult) + { + var resultVariable = Expression.Variable(nullableResultType, "result"); + var returnValueForNull = resultType.IsNullableType() + ? (Expression)Expression.Constant(null, resultType) + : Expression.Throw( + Expression.New( + typeof(InvalidOperationException).GetConstructors() + .Single(ci => ci.GetParameters().Length == 1), + Expression.Constant(CoreStrings.SequenceContainsNoElements)), + resultType); + + shaper = Expression.Block( + new[] { resultVariable }, + Expression.Assign(resultVariable, shaper), + Expression.Condition( + Expression.Equal(resultVariable, Expression.Default(nullableResultType)), + returnValueForNull, + resultType != resultVariable.Type + ? Expression.Convert(resultVariable, resultType) + : resultVariable)); + } + else if (resultType != shaper.Type) { shaper = Expression.Convert(shaper, resultType); } diff --git a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs index ffc0b1cde5e..d98d0edaa1e 100644 --- a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs @@ -518,7 +518,7 @@ private static ShapedQueryExpression CreateShapedQueryExpression(IEntityType ent ShapedQueryExpression source, LambdaExpression? selector, Type resultType) - => TranslateAggregateWithSelector(source, selector, QueryableMethods.GetAverageWithoutSelector, resultType); + => TranslateAggregateWithSelector(source, selector, QueryableMethods.GetAverageWithoutSelector, throwWhenEmpty: true, resultType); /// protected override ShapedQueryExpression TranslateCast(ShapedQueryExpression source, Type resultType) @@ -968,7 +968,7 @@ private SqlExpression CreateJoinPredicate(Expression outerKey, Expression innerK } return TranslateAggregateWithSelector( - source, selector, t => QueryableMethods.MaxWithoutSelector.MakeGenericMethod(t), resultType); + source, selector, t => QueryableMethods.MaxWithoutSelector.MakeGenericMethod(t), throwWhenEmpty: true, resultType); } /// @@ -984,7 +984,7 @@ private SqlExpression CreateJoinPredicate(Expression outerKey, Expression innerK } return TranslateAggregateWithSelector( - source, selector, t => QueryableMethods.MinWithoutSelector.MakeGenericMethod(t), resultType); + source, selector, t => QueryableMethods.MinWithoutSelector.MakeGenericMethod(t), throwWhenEmpty: true, resultType); } /// @@ -1235,7 +1235,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp /// protected override ShapedQueryExpression? TranslateSum(ShapedQueryExpression source, LambdaExpression? selector, Type resultType) - => TranslateAggregateWithSelector(source, selector, QueryableMethods.GetSumWithoutSelector, resultType); + => TranslateAggregateWithSelector(source, selector, QueryableMethods.GetSumWithoutSelector, throwWhenEmpty: false, resultType); /// protected override ShapedQueryExpression? TranslateTake(ShapedQueryExpression source, Expression count) @@ -1958,6 +1958,7 @@ private static Expression MatchShaperNullabilityForSetOperation(Expression shape ShapedQueryExpression source, LambdaExpression? selectorLambda, Func methodGenerator, + bool throwWhenEmpty, Type resultType) { var selectExpression = (SelectExpression)source.QueryExpression; @@ -2003,13 +2004,48 @@ private static Expression MatchShaperNullabilityForSetOperation(Expression shape new Dictionary { { new ProjectionMember(), translation } }); selectExpression.ClearOrdering(); - - // Sum case. Projection is always non-null. We read nullable value. - Expression shaper = new ProjectionBindingExpression(source.QueryExpression, new ProjectionMember(), translation.Type.MakeNullable()); - - if (resultType != shaper.Type) + Expression shaper; + + if (throwWhenEmpty) + { + // Avg/Max/Min case. + // We always read nullable value + // If resultType is nullable then we always return null. Only non-null result shows throwing behavior. + // otherwise, if projection.Type is nullable then server result is passed through DefaultIfEmpty, hence we return default + // otherwise, server would return null only if it is empty, and we throw + var nullableResultType = resultType.MakeNullable(); + shaper = new ProjectionBindingExpression(source.QueryExpression, new ProjectionMember(), nullableResultType); + var resultVariable = Expression.Variable(nullableResultType, "result"); + var returnValueForNull = resultType.IsNullableType() + ? (Expression)Expression.Default(resultType) + : translation.Type.IsNullableType() + ? Expression.Default(resultType) + : Expression.Throw( + Expression.New( + typeof(InvalidOperationException).GetConstructors() + .Single(ci => ci.GetParameters().Length == 1), + Expression.Constant(CoreStrings.SequenceContainsNoElements)), + resultType); + + shaper = Expression.Block( + new[] { resultVariable }, + Expression.Assign(resultVariable, shaper), + Expression.Condition( + Expression.Equal(resultVariable, Expression.Default(nullableResultType)), + returnValueForNull, + resultType != resultVariable.Type + ? Expression.Convert(resultVariable, resultType) + : resultVariable)); + } + else { - shaper = Expression.Convert(shaper, resultType); + // Sum case. Projection is always non-null. We read nullable value. + shaper = new ProjectionBindingExpression(source.QueryExpression, new ProjectionMember(), translation.Type.MakeNullable()); + + if (resultType != shaper.Type) + { + shaper = Expression.Convert(shaper, resultType); + } } return source.UpdateShaperExpression(shaper); From 3084d6b354ecbad776dfa827de050552566e04ae Mon Sep 17 00:00:00 2001 From: ajcvickers Date: Wed, 27 Nov 2024 13:15:14 +0000 Subject: [PATCH 3/3] Fix tests. --- ...yableMethodTranslatingExpressionVisitor.cs | 2 +- ...thwindAggregateOperatorsQueryCosmosTest.cs | 38 ++++++++----------- 2 files changed, 17 insertions(+), 23 deletions(-) diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosQueryableMethodTranslatingExpressionVisitor.cs index 19453ab8e34..7ad49bd00c4 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosQueryableMethodTranslatingExpressionVisitor.cs @@ -1565,7 +1565,7 @@ protected override ShapedQueryExpression TranslateSelect(ShapedQueryExpression s var projection = (SqlExpression)selectExpression.GetMappedProjection(new ProjectionMember()); projection = _sqlExpressionFactory.Function(functionName, [projection], resultType, _typeMappingSource.FindMapping(resultType)); - return AggregateResultShaper(source, projection, throwOnNullResult: false, resultType); + return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType); } private bool TryApplyPredicate(ShapedQueryExpression source, LambdaExpression predicate) diff --git a/test/EFCore.Cosmos.FunctionalTests/Query/NorthwindAggregateOperatorsQueryCosmosTest.cs b/test/EFCore.Cosmos.FunctionalTests/Query/NorthwindAggregateOperatorsQueryCosmosTest.cs index 4def5734725..6f1291dbecd 100644 --- a/test/EFCore.Cosmos.FunctionalTests/Query/NorthwindAggregateOperatorsQueryCosmosTest.cs +++ b/test/EFCore.Cosmos.FunctionalTests/Query/NorthwindAggregateOperatorsQueryCosmosTest.cs @@ -667,22 +667,19 @@ public override async Task Min_no_data_subquery(bool async) AssertSql(); } - public override async Task Average_with_no_arg(bool async) - { - // Always throws for sync. - if (async) - { - // Average truncates. Issue #26378. - await Assert.ThrowsAsync(async () => await base.Average_with_no_arg(async)); + public override Task Average_with_no_arg(bool async) + => Fixture.NoSyncTest( + async, async a => + { + await base.Average_with_no_arg(a); - AssertSql( - """ + AssertSql( + """ SELECT VALUE AVG(c["OrderID"]) FROM root c WHERE (c["$type"] = "Order") """); - } - } + }); public override Task Average_with_binary_expression(bool async) => Fixture.NoSyncTest( @@ -698,22 +695,19 @@ FROM root c """); }); - public override async Task Average_with_arg(bool async) - { - // Always throws for sync. - if (async) - { - // Average truncates. Issue #26378. - await Assert.ThrowsAsync(async () => await base.Average_with_arg(async)); + public override Task Average_with_arg(bool async) + => Fixture.NoSyncTest( + async, async a => + { + await base.Average_with_arg(a); - AssertSql( - """ + AssertSql( + """ SELECT VALUE AVG(c["OrderID"]) FROM root c WHERE (c["$type"] = "Order") """); - } - } + }); public override Task Average_with_arg_expression(bool async) => Fixture.NoSyncTest(