Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 0 additions & 28 deletions src/EFCore.Relational/Query/ISqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics.CodeAnalysis;
using System.Runtime.CompilerServices;
using Microsoft.EntityFrameworkCore.Query.Internal;
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;

namespace Microsoft.EntityFrameworkCore.Query;
Expand Down Expand Up @@ -439,30 +437,4 @@ SqlExpression NiladicFunction(
/// <param name="sql">A string token to print in SQL tree.</param>
/// <returns>An expression representing a SQL token.</returns>
SqlExpression Fragment(string sql);

/// <summary>
/// Attempts to creates a new expression that returns the smallest value from a list of expressions, e.g. an invocation of the
/// <c>LEAST</c> SQL function.
/// </summary>
/// <param name="expressions">An entity type to project.</param>
/// <param name="resultType">The result CLR type for the returned expression.</param>
/// <param name="leastExpression">The expression which computes the smallest value.</param>
/// <returns><see langword="true" /> if the expression could be created, <see langword="false" /> otherwise.</returns>
bool TryCreateLeast(
IReadOnlyList<SqlExpression> expressions,
Type resultType,
[NotNullWhen(true)] out SqlExpression? leastExpression);

/// <summary>
/// Attempts to creates a new expression that returns the greatest value from a list of expressions, e.g. an invocation of the
/// <c>GREATEST</c> SQL function.
/// </summary>
/// <param name="expressions">An entity type to project.</param>
/// <param name="resultType">The result CLR type for the returned expression.</param>
/// <param name="greatestExpression">The expression which computes the greatest value.</param>
/// <returns><see langword="true" /> if the expression could be created, <see langword="false" /> otherwise.</returns>
bool TryCreateGreatest(
IReadOnlyList<SqlExpression> expressions,
Type resultType,
[NotNullWhen(true)] out SqlExpression? greatestExpression);
}
Original file line number Diff line number Diff line change
Expand Up @@ -976,19 +976,39 @@ private SqlExpression CreateJoinPredicate(Expression outerKey, Expression innerK

/// <inheritdoc />
protected override ShapedQueryExpression? TranslateMax(ShapedQueryExpression source, LambdaExpression? selector, Type resultType)
=> TryExtractBareInlineCollectionValues(source, out var values)
&& _sqlExpressionFactory.TryCreateGreatest(values, resultType, out var greatestExpression)
? source.Update(new SelectExpression(greatestExpression, _sqlAliasManager), source.ShaperExpression)
: TranslateAggregateWithSelector(
source, selector, t => QueryableMethods.MaxWithoutSelector.MakeGenericMethod(t), throwWhenEmpty: true, resultType);
{
// For Max() over an inline array, translate to GREATEST() if possible; otherwise use the default translation of aggregate SQL
// MAX().
// Note that some providers propagate NULL arguments (SQLite, MySQL), while others only return NULL if all arguments evaluate to
// NULL (SQL Server, PostgreSQL). If the argument is a nullable value type, don't translate to GREATEST() if it propagates NULLs,
// to match the .NET behavior.
if (TryExtractBareInlineCollectionValues(source, out var values)
&& _sqlTranslator.GenerateGreatest(values, resultType.UnwrapNullableType()) is SqlFunctionExpression greatestExpression
&& (Nullable.GetUnderlyingType(resultType) is null
|| greatestExpression.ArgumentsPropagateNullability?.All(a => a == false) == true))
{
return source.Update(new SelectExpression(greatestExpression, _sqlAliasManager), source.ShaperExpression);
}

return TranslateAggregateWithSelector(
source, selector, t => QueryableMethods.MaxWithoutSelector.MakeGenericMethod(t), throwWhenEmpty: true, resultType);
}

/// <inheritdoc />
protected override ShapedQueryExpression? TranslateMin(ShapedQueryExpression source, LambdaExpression? selector, Type resultType)
=> TryExtractBareInlineCollectionValues(source, out var values)
&& _sqlExpressionFactory.TryCreateLeast(values, resultType, out var leastExpression)
? source.Update(new SelectExpression(leastExpression, _sqlAliasManager), source.ShaperExpression)
: TranslateAggregateWithSelector(
source, selector, t => QueryableMethods.MinWithoutSelector.MakeGenericMethod(t), throwWhenEmpty: true, resultType);
{
// See comments above in TranslateMax()
if (TryExtractBareInlineCollectionValues(source, out var values)
&& _sqlTranslator.GenerateLeast(values, resultType.UnwrapNullableType()) is SqlFunctionExpression leastExpression
&& (Nullable.GetUnderlyingType(resultType) is null
|| leastExpression.ArgumentsPropagateNullability?.All(a => a == false) == true))
{
return source.Update(new SelectExpression(leastExpression, _sqlAliasManager), source.ShaperExpression);
}

return TranslateAggregateWithSelector(
source, selector, t => QueryableMethods.MinWithoutSelector.MakeGenericMethod(t), throwWhenEmpty: true, resultType);
}

/// <inheritdoc />
protected override ShapedQueryExpression? TranslateOfType(ShapedQueryExpression source, Type resultType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,6 @@ private static readonly MethodInfo StringEqualsWithStringComparison
private static readonly MethodInfo StringEqualsWithStringComparisonStatic
= typeof(string).GetRuntimeMethod(nameof(string.Equals), [typeof(string), typeof(string), typeof(StringComparison)])!;

private static readonly MethodInfo LeastMethodInfo
= typeof(RelationalDbFunctionsExtensions).GetMethod(nameof(RelationalDbFunctionsExtensions.Least))!;

private static readonly MethodInfo GreatestMethodInfo
= typeof(RelationalDbFunctionsExtensions).GetMethod(nameof(RelationalDbFunctionsExtensions.Greatest))!;

private static readonly MethodInfo GetTypeMethodInfo = typeof(object).GetTypeInfo().GetDeclaredMethod(nameof(GetType))!;

private readonly QueryCompilationContext _queryCompilationContext;
Expand Down Expand Up @@ -183,138 +177,6 @@ protected virtual void AddTranslationErrorDetails(string details)
return result;
}

/// <summary>
/// Translates Average over an expression to an equivalent SQL representation.
/// </summary>
/// <param name="sqlExpression">An expression to translate Average over.</param>
/// <returns>A SQL translation of Average over the given expression.</returns>
[Obsolete("Use IAggregateMethodCallTranslatorProvider to add translation for aggregate methods")]
public virtual SqlExpression? TranslateAverage(SqlExpression sqlExpression)
{
var inputType = sqlExpression.Type;
if (inputType == typeof(int)
|| inputType == typeof(long))
{
sqlExpression = sqlExpression is DistinctExpression distinctExpression
? new DistinctExpression(
_sqlExpressionFactory.ApplyDefaultTypeMapping(
_sqlExpressionFactory.Convert(distinctExpression.Operand, typeof(double))))
: _sqlExpressionFactory.ApplyDefaultTypeMapping(
_sqlExpressionFactory.Convert(sqlExpression, typeof(double)));
}

return inputType == typeof(float)
? _sqlExpressionFactory.Convert(
_sqlExpressionFactory.Function(
"AVG",
new[] { sqlExpression },
nullable: true,
argumentsPropagateNullability: new[] { false },
typeof(double)),
sqlExpression.Type,
sqlExpression.TypeMapping)
: _sqlExpressionFactory.Function(
"AVG",
new[] { sqlExpression },
nullable: true,
argumentsPropagateNullability: new[] { false },
sqlExpression.Type,
sqlExpression.TypeMapping);
}

/// <summary>
/// Translates Count over an expression to an equivalent SQL representation.
/// </summary>
/// <param name="sqlExpression">An expression to translate Count over.</param>
/// <returns>A SQL translation of Count over the given expression.</returns>
[Obsolete("Use IAggregateMethodCallTranslatorProvider to add translation for aggregate methods")]
public virtual SqlExpression? TranslateCount(SqlExpression sqlExpression)
=> _sqlExpressionFactory.ApplyDefaultTypeMapping(
_sqlExpressionFactory.Function(
"COUNT",
new[] { sqlExpression },
nullable: false,
argumentsPropagateNullability: new[] { false },
typeof(int)));

/// <summary>
/// Translates LongCount over an expression to an equivalent SQL representation.
/// </summary>
/// <param name="sqlExpression">An expression to translate LongCount over.</param>
/// <returns>A SQL translation of LongCount over the given expression.</returns>
[Obsolete("Use IAggregateMethodCallTranslatorProvider to add translation for aggregate methods")]
public virtual SqlExpression? TranslateLongCount(SqlExpression sqlExpression)
=> _sqlExpressionFactory.ApplyDefaultTypeMapping(
_sqlExpressionFactory.Function(
"COUNT",
new[] { sqlExpression },
nullable: false,
argumentsPropagateNullability: new[] { false },
typeof(long)));

/// <summary>
/// Translates Max over an expression to an equivalent SQL representation.
/// </summary>
/// <param name="sqlExpression">An expression to translate Max over.</param>
/// <returns>A SQL translation of Max over the given expression.</returns>
[Obsolete("Use IAggregateMethodCallTranslatorProvider to add translation for aggregate methods")]
public virtual SqlExpression? TranslateMax(SqlExpression sqlExpression)
=> sqlExpression != null
? _sqlExpressionFactory.Function(
"MAX",
new[] { sqlExpression },
nullable: true,
argumentsPropagateNullability: new[] { false },
sqlExpression.Type,
sqlExpression.TypeMapping)
: null;

/// <summary>
/// Translates Min over an expression to an equivalent SQL representation.
/// </summary>
/// <param name="sqlExpression">An expression to translate Min over.</param>
/// <returns>A SQL translation of Min over the given expression.</returns>
[Obsolete("Use IAggregateMethodCallTranslatorProvider to add translation for aggregate methods")]
public virtual SqlExpression? TranslateMin(SqlExpression sqlExpression)
=> sqlExpression != null
? _sqlExpressionFactory.Function(
"MIN",
new[] { sqlExpression },
nullable: true,
argumentsPropagateNullability: new[] { false },
sqlExpression.Type,
sqlExpression.TypeMapping)
: null;

/// <summary>
/// Translates Sum over an expression to an equivalent SQL representation.
/// </summary>
/// <param name="sqlExpression">An expression to translate Sum over.</param>
/// <returns>A SQL translation of Sum over the given expression.</returns>
[Obsolete("Use IAggregateMethodCallTranslatorProvider to add translation for aggregate methods")]
public virtual SqlExpression? TranslateSum(SqlExpression sqlExpression)
{
var inputType = sqlExpression.Type;

return inputType == typeof(float)
? _sqlExpressionFactory.Convert(
_sqlExpressionFactory.Function(
"SUM",
new[] { sqlExpression },
nullable: true,
argumentsPropagateNullability: new[] { false },
typeof(double)),
inputType,
sqlExpression.TypeMapping)
: _sqlExpressionFactory.Function(
"SUM",
new[] { sqlExpression },
nullable: true,
argumentsPropagateNullability: new[] { false },
inputType,
sqlExpression.TypeMapping);
}

/// <inheritdoc />
protected override Expression VisitBinary(BinaryExpression binaryExpression)
{
Expand Down Expand Up @@ -937,14 +799,9 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
// translation.
case
{
Method:
{
Name: nameof(RelationalDbFunctionsExtensions.Least) or nameof(RelationalDbFunctionsExtensions.Greatest),
IsGenericMethod: true
},
Method.Name: nameof(RelationalDbFunctionsExtensions.Least) or nameof(RelationalDbFunctionsExtensions.Greatest),
Arguments: [_, NewArrayExpression newArray]
} when method.GetGenericMethodDefinition() is var genericMethodDefinition
&& (genericMethodDefinition == LeastMethodInfo || genericMethodDefinition == GreatestMethodInfo):
} when method.DeclaringType == typeof(RelationalDbFunctionsExtensions):
{
var values = newArray.Expressions;
var translatedValues = new SqlExpression[values.Count];
Expand All @@ -962,21 +819,55 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
translatedValues[i] = translatedValue!;
}

var elementClrType = newArray.Type.GetElementType()!;
var elementClrType = newArray.Type.GetElementType()!.UnwrapNullableType();

if (genericMethodDefinition == LeastMethodInfo
&& _sqlExpressionFactory.TryCreateLeast(translatedValues, elementClrType, out var leastExpression))
{
return leastExpression;
}
return method.Name switch
{
nameof(RelationalDbFunctionsExtensions.Greatest) => GenerateGreatest(translatedValues, elementClrType),
nameof(RelationalDbFunctionsExtensions.Least) => GenerateLeast(translatedValues, elementClrType),
_ => throw new UnreachableException()
}
?? QueryCompilationContext.NotTranslatedExpression;
}

if (genericMethodDefinition == GreatestMethodInfo
&& _sqlExpressionFactory.TryCreateGreatest(translatedValues, elementClrType, out var greatestExpression))
// Translate Math.Max/Min.
// These are here rather than in a MethodTranslator since we use TranslateGreatest/Least, and since these are very similar to
// the EF.Functions.Greatest/Least translation just above.
case
{
Method.Name: nameof(Math.Max) or nameof(Math.Min),
Arguments: [Expression argument1, Expression argument2]
} when method.DeclaringType == typeof(Math):
{
var translatedArguments = new List<SqlExpression>();
var returnType = method.ReturnType.UnwrapNullableType();

return TryFlattenVisit(argument1)
&& TryFlattenVisit(argument2)
&& method.Name switch
{
nameof(Math.Max) => GenerateGreatest(translatedArguments, returnType),
nameof(Math.Min) => GenerateLeast(translatedArguments, returnType),
_ => throw new UnreachableException()
} is SqlExpression translatedFunctionCall
? translatedFunctionCall
: QueryCompilationContext.NotTranslatedExpression;

bool TryFlattenVisit(Expression argument)
{
return greatestExpression;
}
if (argument is MethodCallExpression nestedCall && nestedCall.Method == method)
{
return TryFlattenVisit(nestedCall.Arguments[0]) && TryFlattenVisit(nestedCall.Arguments[1]);
}

throw new UnreachableException();
if (TranslationFailed(argument, Visit(argument), out var translatedArgument))
{
return false;
}

translatedArguments.Add(translatedArgument!);
return true;
}
}

// For queryable methods, either we translate the whole aggregate or we go to subquery mode
Expand Down Expand Up @@ -1531,6 +1422,18 @@ when QueryableMethods.IsSumWithSelector(genericMethod):
return false;
}

/// <summary>
/// Generates a SQL GREATEST expression over the given expressions.
/// </summary>
public virtual SqlExpression? GenerateGreatest(IReadOnlyList<SqlExpression> expressions, Type resultType)
=> null;

/// <summary>
/// Generates a SQL GREATEST expression over the given expressions.
/// </summary>
public virtual SqlExpression? GenerateLeast(IReadOnlyList<SqlExpression> expressions, Type resultType)
=> null;

private bool TryTranslateAsEnumerableExpression(
Expression? expression,
[NotNullWhen(true)] out EnumerableExpression? enumerableExpression)
Expand Down
Loading