Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/EFCore.Relational/Query/ISqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,14 @@ public interface ISqlExpressionFactory
/// <param name="operand">A <see cref="SqlExpression" /> to apply unary operator on.</param>
/// <param name="type">The type of the created expression.</param>
/// <param name="typeMapping">A type mapping to be assigned to the created expression.</param>
/// <param name="existingExpr">An optional expression that can be re-used if it matches the new expression.</param>
/// <returns>A <see cref="SqlExpression" /> with the given arguments.</returns>
SqlExpression? MakeUnary(
ExpressionType operatorType,
SqlExpression operand,
Type type,
RelationalTypeMapping? typeMapping = null);
RelationalTypeMapping? typeMapping = null,
SqlExpression? existingExpr = null);

/// <summary>
/// Creates a new <see cref="SqlExpression" /> with the given arguments.
Expand Down Expand Up @@ -275,11 +277,11 @@ SqlExpression Convert(
/// Creates a new <see cref="CaseExpression" /> which represent a CASE statement in a SQL tree.
/// </summary>
/// <param name="operand">An expression to compare with <see cref="CaseWhenClause.Test" /> in <paramref name="whenClauses" />.</param>
/// <param name="whenClauses">A list of <see cref="CaseWhenClause" /> to compare and get result from.</param>
/// <param name="whenClauses">A list of <see cref="CaseWhenClause" /> to compare or evaluate and get result from.</param>
/// <param name="elseResult">A value to return if no <paramref name="whenClauses" /> matches, if any.</param>
/// <returns>An expression representing a CASE statement in a SQL tree.</returns>
SqlExpression Case(
SqlExpression operand,
SqlExpression? operand,
IReadOnlyList<CaseWhenClause> whenClauses,
SqlExpression? elseResult);

Expand Down
118 changes: 88 additions & 30 deletions src/EFCore.Relational/Query/SqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -599,10 +599,15 @@ public virtual SqlExpression Coalesce(SqlExpression left, SqlExpression right, R
ExpressionType operatorType,
SqlExpression operand,
Type type,
RelationalTypeMapping? typeMapping = null)
=> SqlUnaryExpression.IsValidOperator(operatorType)
? ApplyTypeMapping(new SqlUnaryExpression(operatorType, operand, type, null), typeMapping)
: null;
RelationalTypeMapping? typeMapping = null,
SqlExpression? existingExpr = null)
=> operatorType switch
{
ExpressionType.Not => ApplyTypeMapping(Not(operand, existingExpr), typeMapping),
_ when SqlUnaryExpression.IsValidOperator(operatorType)
=> ApplyTypeMapping(new SqlUnaryExpression(operatorType, operand, type, null), typeMapping),
_ => null,
};

/// <inheritdoc />
public virtual SqlExpression IsNull(SqlExpression operand)
Expand All @@ -620,33 +625,102 @@ public virtual SqlExpression Convert(SqlExpression operand, Type type, Relationa
public virtual SqlExpression Not(SqlExpression operand)
=> MakeUnary(ExpressionType.Not, operand, operand.Type, operand.TypeMapping)!;

private SqlExpression Not(SqlExpression operand, SqlExpression? existingExpr)
=> operand switch
{
// !(null) -> null
// ~(null) -> null (bitwise negation)
SqlConstantExpression { Value: null } => operand,

// !(true) -> false
// !(false) -> true
SqlConstantExpression { Value: bool boolValue } => Constant(!boolValue, operand.Type, operand.TypeMapping),

// !(!a) -> a
// ~(~a) -> a (bitwise negation)
SqlUnaryExpression { OperatorType: ExpressionType.Not } unary => unary.Operand,

// !(a IS NULL) -> a IS NOT NULL
SqlUnaryExpression { OperatorType: ExpressionType.Equal } unary => IsNotNull(unary.Operand),

// !(a IS NOT NULL) -> a IS NULL
SqlUnaryExpression { OperatorType: ExpressionType.NotEqual } unary => IsNull(unary.Operand),

// !(a AND b) -> !a OR !b (De Morgan)
SqlBinaryExpression { OperatorType: ExpressionType.AndAlso } binary
=> OrElse(Not(binary.Left), Not(binary.Right)),

// !(a OR b) -> !a AND !b (De Morgan)
SqlBinaryExpression { OperatorType: ExpressionType.OrElse } binary
=> AndAlso(Not(binary.Left), Not(binary.Right)),

// use equality where possible
// !(a == true) -> a == false
// !(a == false) -> a == true
SqlBinaryExpression { OperatorType: ExpressionType.Equal, Right: SqlConstantExpression { Value: bool } } binary
=> Equal(binary.Left, Not(binary.Right)),

// !(true == a) -> false == a
// !(false == a) -> true == a
SqlBinaryExpression { OperatorType: ExpressionType.Equal, Left: SqlConstantExpression { Value: bool } } binary
=> Equal(Not(binary.Left), binary.Right),

// !(a == b) -> a != b
SqlBinaryExpression { OperatorType: ExpressionType.Equal } sqlBinaryOperand => NotEqual(sqlBinaryOperand.Left, sqlBinaryOperand.Right),
// !(a != b) -> a == b
SqlBinaryExpression { OperatorType: ExpressionType.NotEqual } sqlBinaryOperand => Equal(sqlBinaryOperand.Left, sqlBinaryOperand.Right),

// !(CASE x WHEN t1 THEN r1 ... ELSE rN) -> CASE x WHEN t1 THEN !r1 ... ELSE !rN
CaseExpression caseExpression
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment to explain the cases this simplifies?

when caseExpression.Type == typeof(bool)
&& caseExpression.ElseResult is null or SqlConstantExpression
&& caseExpression.WhenClauses.All(clause => clause.Result is SqlConstantExpression)
=> Case(
caseExpression.Operand,
[.. caseExpression.WhenClauses.Select(clause => new CaseWhenClause(clause.Test, Not(clause.Result)))],
caseExpression.ElseResult is null ? null : Not(caseExpression.ElseResult)),

_ => existingExpr is SqlUnaryExpression { OperatorType: ExpressionType.Not } unaryExpr && unaryExpr.Operand == operand
? existingExpr
: new SqlUnaryExpression(ExpressionType.Not, operand, operand.Type, null),
};

/// <inheritdoc />
public virtual SqlExpression Negate(SqlExpression operand)
=> MakeUnary(ExpressionType.Negate, operand, operand.Type, operand.TypeMapping)!;

/// <inheritdoc />
public virtual SqlExpression Case(SqlExpression? operand, IReadOnlyList<CaseWhenClause> whenClauses, SqlExpression? elseResult)
{
var operandTypeMapping = operand!.TypeMapping
?? whenClauses.Select(wc => wc.Test.TypeMapping).FirstOrDefault(t => t != null)
// Since we never look at type of Operand/Test after this place,
// we need to find actual typeMapping based on non-object type.
?? new[] { operand.Type }.Concat(whenClauses.Select(wc => wc.Test.Type))
.Where(t => t != typeof(object)).Select(t => _typeMappingSource.FindMapping(t, Dependencies.Model))
.FirstOrDefault();
RelationalTypeMapping? testTypeMapping;
if (operand == null)
{
testTypeMapping = _boolTypeMapping;
}
else
{
testTypeMapping = operand.TypeMapping
?? whenClauses.Select(wc => wc.Test.TypeMapping).FirstOrDefault(t => t != null)
// Since we never look at type of Operand/Test after this place,
// we need to find actual typeMapping based on non-object type.
?? new[] { operand.Type }.Concat(whenClauses.Select(wc => wc.Test.Type))
.Where(t => t != typeof(object)).Select(t => _typeMappingSource.FindMapping(t, Dependencies.Model))
.FirstOrDefault();

operand = ApplyTypeMapping(operand, testTypeMapping);
}

var resultTypeMapping = elseResult?.TypeMapping
?? whenClauses.Select(wc => wc.Result.TypeMapping).FirstOrDefault(t => t != null);

operand = ApplyTypeMapping(operand, operandTypeMapping);
elseResult = ApplyTypeMapping(elseResult, resultTypeMapping);

var typeMappedWhenClauses = new List<CaseWhenClause>();
foreach (var caseWhenClause in whenClauses)
{
typeMappedWhenClauses.Add(
new CaseWhenClause(
ApplyTypeMapping(caseWhenClause.Test, operandTypeMapping),
ApplyTypeMapping(caseWhenClause.Test, testTypeMapping),
ApplyTypeMapping(caseWhenClause.Result, resultTypeMapping)));
}

Expand All @@ -655,23 +729,7 @@ public virtual SqlExpression Case(SqlExpression? operand, IReadOnlyList<CaseWhen

/// <inheritdoc />
public virtual SqlExpression Case(IReadOnlyList<CaseWhenClause> whenClauses, SqlExpression? elseResult)
{
var resultTypeMapping = elseResult?.TypeMapping
?? whenClauses.Select(wc => wc.Result.TypeMapping).FirstOrDefault(t => t != null);

var typeMappedWhenClauses = new List<CaseWhenClause>();
foreach (var caseWhenClause in whenClauses)
{
typeMappedWhenClauses.Add(
new CaseWhenClause(
ApplyTypeMapping(caseWhenClause.Test, _boolTypeMapping),
ApplyTypeMapping(caseWhenClause.Result, resultTypeMapping)));
}

elseResult = ApplyTypeMapping(elseResult, resultTypeMapping);

return new CaseExpression(typeMappedWhenClauses, elseResult);
}
=> Case(operand: null, whenClauses, elseResult);

/// <inheritdoc />
public virtual SqlExpression Function(
Expand Down
16 changes: 5 additions & 11 deletions src/EFCore.Relational/Query/SqlExpressions/CaseExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ public class CaseExpression : SqlExpression
/// Creates a new instance of the <see cref="CaseExpression" /> class which represents a simple CASE expression.
/// </summary>
/// <param name="operand">An expression to compare with <see cref="CaseWhenClause.Test" /> in <see cref="WhenClauses" />.</param>
/// <param name="whenClauses">A list of <see cref="CaseWhenClause" /> to compare and get result from.</param>
/// <param name="whenClauses">A list of <see cref="CaseWhenClause" /> to compare or evaluate and get result from.</param>
/// <param name="elseResult">A value to return if no <see cref="WhenClauses" /> matches, if any.</param>
public CaseExpression(
SqlExpression operand,
SqlExpression? operand,
IReadOnlyList<CaseWhenClause> whenClauses,
SqlExpression? elseResult = null)
: base(whenClauses[0].Result.Type, whenClauses[0].Result.TypeMapping)
Expand All @@ -45,10 +45,8 @@ public CaseExpression(
public CaseExpression(
IReadOnlyList<CaseWhenClause> whenClauses,
SqlExpression? elseResult = null)
: base(whenClauses[0].Result.Type, whenClauses[0].Result.TypeMapping)
: this(null, whenClauses, elseResult)
{
_whenClauses.AddRange(whenClauses);
ElseResult = elseResult;
}

/// <summary>
Expand Down Expand Up @@ -94,9 +92,7 @@ protected override Expression VisitChildren(ExpressionVisitor visitor)
changed |= elseResult != ElseResult;

return changed
? operand == null
? new CaseExpression(whenClauses, elseResult)
: new CaseExpression(operand, whenClauses, elseResult)
? new CaseExpression(operand, whenClauses, elseResult)
: this;
}

Expand All @@ -113,9 +109,7 @@ public virtual CaseExpression Update(
IReadOnlyList<CaseWhenClause> whenClauses,
SqlExpression? elseResult)
=> operand != Operand || !whenClauses.SequenceEqual(WhenClauses) || elseResult != ElseResult
? (operand == null
? new CaseExpression(whenClauses, elseResult)
: new CaseExpression(operand, whenClauses, elseResult))
? new CaseExpression(operand, whenClauses, elseResult)
: this;

/// <inheritdoc />
Expand Down
Loading