Skip to content

Commit f1181d9

Browse files
committed
Nullability-related fixes to LEAST/GREATEST
Fixup to #32338
1 parent d41ba67 commit f1181d9

16 files changed

+432
-392
lines changed

src/EFCore.Relational/Query/ISqlExpressionFactory.cs

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33

44
using System.Diagnostics.CodeAnalysis;
5-
using System.Runtime.CompilerServices;
6-
using Microsoft.EntityFrameworkCore.Query.Internal;
75
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;
86

97
namespace Microsoft.EntityFrameworkCore.Query;
@@ -439,30 +437,4 @@ SqlExpression NiladicFunction(
439437
/// <param name="sql">A string token to print in SQL tree.</param>
440438
/// <returns>An expression representing a SQL token.</returns>
441439
SqlExpression Fragment(string sql);
442-
443-
/// <summary>
444-
/// Attempts to creates a new expression that returns the smallest value from a list of expressions, e.g. an invocation of the
445-
/// <c>LEAST</c> SQL function.
446-
/// </summary>
447-
/// <param name="expressions">An entity type to project.</param>
448-
/// <param name="resultType">The result CLR type for the returned expression.</param>
449-
/// <param name="leastExpression">The expression which computes the smallest value.</param>
450-
/// <returns><see langword="true" /> if the expression could be created, <see langword="false" /> otherwise.</returns>
451-
bool TryCreateLeast(
452-
IReadOnlyList<SqlExpression> expressions,
453-
Type resultType,
454-
[NotNullWhen(true)] out SqlExpression? leastExpression);
455-
456-
/// <summary>
457-
/// Attempts to creates a new expression that returns the greatest value from a list of expressions, e.g. an invocation of the
458-
/// <c>GREATEST</c> SQL function.
459-
/// </summary>
460-
/// <param name="expressions">An entity type to project.</param>
461-
/// <param name="resultType">The result CLR type for the returned expression.</param>
462-
/// <param name="greatestExpression">The expression which computes the greatest value.</param>
463-
/// <returns><see langword="true" /> if the expression could be created, <see langword="false" /> otherwise.</returns>
464-
bool TryCreateGreatest(
465-
IReadOnlyList<SqlExpression> expressions,
466-
Type resultType,
467-
[NotNullWhen(true)] out SqlExpression? greatestExpression);
468440
}

src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -976,19 +976,39 @@ private SqlExpression CreateJoinPredicate(Expression outerKey, Expression innerK
976976

977977
/// <inheritdoc />
978978
protected override ShapedQueryExpression? TranslateMax(ShapedQueryExpression source, LambdaExpression? selector, Type resultType)
979-
=> TryExtractBareInlineCollectionValues(source, out var values)
980-
&& _sqlExpressionFactory.TryCreateGreatest(values, resultType, out var greatestExpression)
981-
? source.Update(new SelectExpression(greatestExpression, _sqlAliasManager), source.ShaperExpression)
982-
: TranslateAggregateWithSelector(
983-
source, selector, t => QueryableMethods.MaxWithoutSelector.MakeGenericMethod(t), throwWhenEmpty: true, resultType);
979+
{
980+
// For Max() over an inline array, translate to GREATEST() if possible; otherwise use the default translation of aggregate SQL
981+
// MAX().
982+
// Note that some providers propagate NULL arguments (SQLite, MySQL), while others only return NULL if all arguments evaluate to
983+
// NULL (SQL Server, PostgreSQL). If the argument is a nullable value type, don't translate to GREATEST() if it propagates NULLs,
984+
// to match the .NET behavior.
985+
if (TryExtractBareInlineCollectionValues(source, out var values)
986+
&& _sqlTranslator.GenerateGreatest(values, resultType.UnwrapNullableType()) is SqlFunctionExpression greatestExpression
987+
&& (Nullable.GetUnderlyingType(resultType) is null
988+
|| greatestExpression.ArgumentsPropagateNullability?.All(a => a == false) == true))
989+
{
990+
return source.Update(new SelectExpression(greatestExpression, _sqlAliasManager), source.ShaperExpression);
991+
}
992+
993+
return TranslateAggregateWithSelector(
994+
source, selector, t => QueryableMethods.MaxWithoutSelector.MakeGenericMethod(t), throwWhenEmpty: true, resultType);
995+
}
984996

985997
/// <inheritdoc />
986998
protected override ShapedQueryExpression? TranslateMin(ShapedQueryExpression source, LambdaExpression? selector, Type resultType)
987-
=> TryExtractBareInlineCollectionValues(source, out var values)
988-
&& _sqlExpressionFactory.TryCreateLeast(values, resultType, out var leastExpression)
989-
? source.Update(new SelectExpression(leastExpression, _sqlAliasManager), source.ShaperExpression)
990-
: TranslateAggregateWithSelector(
991-
source, selector, t => QueryableMethods.MinWithoutSelector.MakeGenericMethod(t), throwWhenEmpty: true, resultType);
999+
{
1000+
// See comments above in TranslateMax()
1001+
if (TryExtractBareInlineCollectionValues(source, out var values)
1002+
&& _sqlTranslator.GenerateLeast(values, resultType.UnwrapNullableType()) is SqlFunctionExpression leastExpression
1003+
&& (Nullable.GetUnderlyingType(resultType) is null
1004+
|| leastExpression.ArgumentsPropagateNullability?.All(a => a == false) == true))
1005+
{
1006+
return source.Update(new SelectExpression(leastExpression, _sqlAliasManager), source.ShaperExpression);
1007+
}
1008+
1009+
return TranslateAggregateWithSelector(
1010+
source, selector, t => QueryableMethods.MinWithoutSelector.MakeGenericMethod(t), throwWhenEmpty: true, resultType);
1011+
}
9921012

9931013
/// <inheritdoc />
9941014
protected override ShapedQueryExpression? TranslateOfType(ShapedQueryExpression source, Type resultType)

src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs

Lines changed: 58 additions & 156 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,6 @@ private static readonly MethodInfo StringEqualsWithStringComparison
5252
private static readonly MethodInfo StringEqualsWithStringComparisonStatic
5353
= typeof(string).GetRuntimeMethod(nameof(string.Equals), [typeof(string), typeof(string), typeof(StringComparison)])!;
5454

55-
private static readonly MethodInfo LeastMethodInfo
56-
= typeof(RelationalDbFunctionsExtensions).GetMethod(nameof(RelationalDbFunctionsExtensions.Least))!;
57-
58-
private static readonly MethodInfo GreatestMethodInfo
59-
= typeof(RelationalDbFunctionsExtensions).GetMethod(nameof(RelationalDbFunctionsExtensions.Greatest))!;
60-
6155
private static readonly MethodInfo GetTypeMethodInfo = typeof(object).GetTypeInfo().GetDeclaredMethod(nameof(GetType))!;
6256

6357
private readonly QueryCompilationContext _queryCompilationContext;
@@ -183,138 +177,6 @@ protected virtual void AddTranslationErrorDetails(string details)
183177
return result;
184178
}
185179

186-
/// <summary>
187-
/// Translates Average over an expression to an equivalent SQL representation.
188-
/// </summary>
189-
/// <param name="sqlExpression">An expression to translate Average over.</param>
190-
/// <returns>A SQL translation of Average over the given expression.</returns>
191-
[Obsolete("Use IAggregateMethodCallTranslatorProvider to add translation for aggregate methods")]
192-
public virtual SqlExpression? TranslateAverage(SqlExpression sqlExpression)
193-
{
194-
var inputType = sqlExpression.Type;
195-
if (inputType == typeof(int)
196-
|| inputType == typeof(long))
197-
{
198-
sqlExpression = sqlExpression is DistinctExpression distinctExpression
199-
? new DistinctExpression(
200-
_sqlExpressionFactory.ApplyDefaultTypeMapping(
201-
_sqlExpressionFactory.Convert(distinctExpression.Operand, typeof(double))))
202-
: _sqlExpressionFactory.ApplyDefaultTypeMapping(
203-
_sqlExpressionFactory.Convert(sqlExpression, typeof(double)));
204-
}
205-
206-
return inputType == typeof(float)
207-
? _sqlExpressionFactory.Convert(
208-
_sqlExpressionFactory.Function(
209-
"AVG",
210-
new[] { sqlExpression },
211-
nullable: true,
212-
argumentsPropagateNullability: new[] { false },
213-
typeof(double)),
214-
sqlExpression.Type,
215-
sqlExpression.TypeMapping)
216-
: _sqlExpressionFactory.Function(
217-
"AVG",
218-
new[] { sqlExpression },
219-
nullable: true,
220-
argumentsPropagateNullability: new[] { false },
221-
sqlExpression.Type,
222-
sqlExpression.TypeMapping);
223-
}
224-
225-
/// <summary>
226-
/// Translates Count over an expression to an equivalent SQL representation.
227-
/// </summary>
228-
/// <param name="sqlExpression">An expression to translate Count over.</param>
229-
/// <returns>A SQL translation of Count over the given expression.</returns>
230-
[Obsolete("Use IAggregateMethodCallTranslatorProvider to add translation for aggregate methods")]
231-
public virtual SqlExpression? TranslateCount(SqlExpression sqlExpression)
232-
=> _sqlExpressionFactory.ApplyDefaultTypeMapping(
233-
_sqlExpressionFactory.Function(
234-
"COUNT",
235-
new[] { sqlExpression },
236-
nullable: false,
237-
argumentsPropagateNullability: new[] { false },
238-
typeof(int)));
239-
240-
/// <summary>
241-
/// Translates LongCount over an expression to an equivalent SQL representation.
242-
/// </summary>
243-
/// <param name="sqlExpression">An expression to translate LongCount over.</param>
244-
/// <returns>A SQL translation of LongCount over the given expression.</returns>
245-
[Obsolete("Use IAggregateMethodCallTranslatorProvider to add translation for aggregate methods")]
246-
public virtual SqlExpression? TranslateLongCount(SqlExpression sqlExpression)
247-
=> _sqlExpressionFactory.ApplyDefaultTypeMapping(
248-
_sqlExpressionFactory.Function(
249-
"COUNT",
250-
new[] { sqlExpression },
251-
nullable: false,
252-
argumentsPropagateNullability: new[] { false },
253-
typeof(long)));
254-
255-
/// <summary>
256-
/// Translates Max over an expression to an equivalent SQL representation.
257-
/// </summary>
258-
/// <param name="sqlExpression">An expression to translate Max over.</param>
259-
/// <returns>A SQL translation of Max over the given expression.</returns>
260-
[Obsolete("Use IAggregateMethodCallTranslatorProvider to add translation for aggregate methods")]
261-
public virtual SqlExpression? TranslateMax(SqlExpression sqlExpression)
262-
=> sqlExpression != null
263-
? _sqlExpressionFactory.Function(
264-
"MAX",
265-
new[] { sqlExpression },
266-
nullable: true,
267-
argumentsPropagateNullability: new[] { false },
268-
sqlExpression.Type,
269-
sqlExpression.TypeMapping)
270-
: null;
271-
272-
/// <summary>
273-
/// Translates Min over an expression to an equivalent SQL representation.
274-
/// </summary>
275-
/// <param name="sqlExpression">An expression to translate Min over.</param>
276-
/// <returns>A SQL translation of Min over the given expression.</returns>
277-
[Obsolete("Use IAggregateMethodCallTranslatorProvider to add translation for aggregate methods")]
278-
public virtual SqlExpression? TranslateMin(SqlExpression sqlExpression)
279-
=> sqlExpression != null
280-
? _sqlExpressionFactory.Function(
281-
"MIN",
282-
new[] { sqlExpression },
283-
nullable: true,
284-
argumentsPropagateNullability: new[] { false },
285-
sqlExpression.Type,
286-
sqlExpression.TypeMapping)
287-
: null;
288-
289-
/// <summary>
290-
/// Translates Sum over an expression to an equivalent SQL representation.
291-
/// </summary>
292-
/// <param name="sqlExpression">An expression to translate Sum over.</param>
293-
/// <returns>A SQL translation of Sum over the given expression.</returns>
294-
[Obsolete("Use IAggregateMethodCallTranslatorProvider to add translation for aggregate methods")]
295-
public virtual SqlExpression? TranslateSum(SqlExpression sqlExpression)
296-
{
297-
var inputType = sqlExpression.Type;
298-
299-
return inputType == typeof(float)
300-
? _sqlExpressionFactory.Convert(
301-
_sqlExpressionFactory.Function(
302-
"SUM",
303-
new[] { sqlExpression },
304-
nullable: true,
305-
argumentsPropagateNullability: new[] { false },
306-
typeof(double)),
307-
inputType,
308-
sqlExpression.TypeMapping)
309-
: _sqlExpressionFactory.Function(
310-
"SUM",
311-
new[] { sqlExpression },
312-
nullable: true,
313-
argumentsPropagateNullability: new[] { false },
314-
inputType,
315-
sqlExpression.TypeMapping);
316-
}
317-
318180
/// <inheritdoc />
319181
protected override Expression VisitBinary(BinaryExpression binaryExpression)
320182
{
@@ -937,14 +799,9 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
937799
// translation.
938800
case
939801
{
940-
Method:
941-
{
942-
Name: nameof(RelationalDbFunctionsExtensions.Least) or nameof(RelationalDbFunctionsExtensions.Greatest),
943-
IsGenericMethod: true
944-
},
802+
Method.Name: nameof(RelationalDbFunctionsExtensions.Least) or nameof(RelationalDbFunctionsExtensions.Greatest),
945803
Arguments: [_, NewArrayExpression newArray]
946-
} when method.GetGenericMethodDefinition() is var genericMethodDefinition
947-
&& (genericMethodDefinition == LeastMethodInfo || genericMethodDefinition == GreatestMethodInfo):
804+
} when method.DeclaringType == typeof(RelationalDbFunctionsExtensions):
948805
{
949806
var values = newArray.Expressions;
950807
var translatedValues = new SqlExpression[values.Count];
@@ -962,21 +819,54 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
962819
translatedValues[i] = translatedValue!;
963820
}
964821

965-
var elementClrType = newArray.Type.GetElementType()!;
822+
var elementClrType = newArray.Type.GetElementType()!.UnwrapNullableType();
966823

967-
if (genericMethodDefinition == LeastMethodInfo
968-
&& _sqlExpressionFactory.TryCreateLeast(translatedValues, elementClrType, out var leastExpression))
969-
{
970-
return leastExpression;
971-
}
824+
return method.Name switch
825+
{
826+
nameof(RelationalDbFunctionsExtensions.Greatest) => GenerateGreatest(translatedValues, elementClrType),
827+
nameof(RelationalDbFunctionsExtensions.Least) => GenerateLeast(translatedValues, elementClrType),
828+
_ => throw new UnreachableException()
829+
}
830+
?? QueryCompilationContext.NotTranslatedExpression;
831+
}
972832

973-
if (genericMethodDefinition == GreatestMethodInfo
974-
&& _sqlExpressionFactory.TryCreateGreatest(translatedValues, elementClrType, out var greatestExpression))
833+
// Translate Math.Max/Min.
834+
// These are here rather than in a MethodTranslator since we use TranslateGreatest/Least, and are very similar to the
835+
// EF.Functions.Greatest/Least translation just above.
836+
case
837+
{
838+
Method.Name: nameof(Math.Max) or nameof(Math.Min),
839+
Arguments: [Expression argument1, Expression argument2]
840+
} when method.DeclaringType == typeof(Math):
841+
{
842+
var translatedArguments = new List<SqlExpression>();
843+
844+
return TryFlattenVisit(argument1)
845+
&& TryFlattenVisit(argument2)
846+
&& method.Name switch
847+
{
848+
nameof(Math.Max) => GenerateGreatest(translatedArguments, argument1.Type),
849+
nameof(Math.Min) => GenerateLeast(translatedArguments, argument1.Type),
850+
_ => throw new UnreachableException()
851+
} is SqlExpression translatedFunctionCall
852+
? translatedFunctionCall
853+
: QueryCompilationContext.NotTranslatedExpression;
854+
855+
bool TryFlattenVisit(Expression argument)
975856
{
976-
return greatestExpression;
977-
}
857+
if (argument is MethodCallExpression nestedCall && nestedCall.Method == method)
858+
{
859+
return TryFlattenVisit(nestedCall.Arguments[0]) && TryFlattenVisit(nestedCall.Arguments[1]);
860+
}
978861

979-
throw new UnreachableException();
862+
if (TranslationFailed(argument, Visit(argument), out var translatedArgument))
863+
{
864+
return false;
865+
}
866+
867+
translatedArguments.Add(translatedArgument!);
868+
return true;
869+
}
980870
}
981871

982872
// For queryable methods, either we translate the whole aggregate or we go to subquery mode
@@ -1531,6 +1421,18 @@ when QueryableMethods.IsSumWithSelector(genericMethod):
15311421
return false;
15321422
}
15331423

1424+
/// <summary>
1425+
/// Generates a SQL GREATEST expression over the given expressions.
1426+
/// </summary>
1427+
public virtual SqlExpression? GenerateGreatest(IReadOnlyList<SqlExpression> expressions, Type resultType)
1428+
=> null;
1429+
1430+
/// <summary>
1431+
/// Generates a SQL GREATEST expression over the given expressions.
1432+
/// </summary>
1433+
public virtual SqlExpression? GenerateLeast(IReadOnlyList<SqlExpression> expressions, Type resultType)
1434+
=> null;
1435+
15341436
private bool TryTranslateAsEnumerableExpression(
15351437
Expression? expression,
15361438
[NotNullWhen(true)] out EnumerableExpression? enumerableExpression)

0 commit comments

Comments
 (0)