Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions src/EFCore.SqlServer/Query/Internal/SearchConditionConverter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public class SearchConditionConverter(ISqlExpressionFactory sqlExpressionFactory
/// </summary>
[return: NotNullIfNotNull(nameof(expression))]
public override Expression? Visit(Expression? expression)
=> Visit(expression, inSearchConditionContext: false);
=> Visit(expression, inSearchConditionContext: false, allowNullFalseEquivalence: false);

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
Expand All @@ -41,12 +41,12 @@ public class SearchConditionConverter(ISqlExpressionFactory sqlExpressionFactory
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
[return: NotNullIfNotNull(nameof(expression))]
protected virtual Expression? Visit(Expression? expression, bool inSearchConditionContext)
protected virtual Expression? Visit(Expression? expression, bool inSearchConditionContext, bool allowNullFalseEquivalence)
=> expression switch
{
CaseExpression e => VisitCase(e, inSearchConditionContext),
CaseExpression e => VisitCase(e, inSearchConditionContext, allowNullFalseEquivalence),
SelectExpression e => VisitSelect(e),
SqlBinaryExpression e => VisitSqlBinary(e, inSearchConditionContext),
SqlBinaryExpression e => VisitSqlBinary(e, inSearchConditionContext, allowNullFalseEquivalence),
SqlUnaryExpression e => VisitSqlUnary(e, inSearchConditionContext),
PredicateJoinExpressionBase e => VisitPredicateJoin(e),

Expand Down Expand Up @@ -139,19 +139,19 @@ private SqlExpression SimplifyNegatedBinary(SqlExpression sqlExpression)
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected virtual Expression VisitCase(CaseExpression caseExpression, bool inSearchConditionContext)
protected virtual Expression VisitCase(CaseExpression caseExpression, bool inSearchConditionContext, bool allowNullFalseEquivalence)
{
var testIsCondition = caseExpression.Operand is null;
var operand = (SqlExpression?)Visit(caseExpression.Operand);
var whenClauses = new List<CaseWhenClause>();
foreach (var whenClause in caseExpression.WhenClauses)
{
var test = (SqlExpression)Visit(whenClause.Test, testIsCondition);
var result = (SqlExpression)Visit(whenClause.Result);
var test = (SqlExpression)Visit(whenClause.Test, testIsCondition, testIsCondition);
var result = (SqlExpression)Visit(whenClause.Result, inSearchConditionContext: false, allowNullFalseEquivalence);
whenClauses.Add(new CaseWhenClause(test, result));
}

var elseResult = (SqlExpression?)Visit(caseExpression.ElseResult);
var elseResult = (SqlExpression?)Visit(caseExpression.ElseResult, inSearchConditionContext: false, allowNullFalseEquivalence);

return ApplyConversion(
sqlExpressionFactory.Case(operand, whenClauses, elseResult, caseExpression),
Expand All @@ -168,7 +168,7 @@ protected virtual Expression VisitCase(CaseExpression caseExpression, bool inSea
protected virtual Expression VisitPredicateJoin(PredicateJoinExpressionBase join)
=> join.Update(
(TableExpressionBase)Visit(join.Table),
(SqlExpression)Visit(join.JoinPredicate, inSearchConditionContext: true));
(SqlExpression)Visit(join.JoinPredicate, inSearchConditionContext: true, allowNullFalseEquivalence: true));

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
Expand All @@ -179,9 +179,9 @@ protected virtual Expression VisitPredicateJoin(PredicateJoinExpressionBase join
protected virtual Expression VisitSelect(SelectExpression select)
{
var tables = this.VisitAndConvert(select.Tables);
var predicate = (SqlExpression?)Visit(select.Predicate, inSearchConditionContext: true);
var predicate = (SqlExpression?)Visit(select.Predicate, inSearchConditionContext: true, allowNullFalseEquivalence: true);
var groupBy = this.VisitAndConvert(select.GroupBy);
var havingExpression = (SqlExpression?)Visit(select.Having, inSearchConditionContext: true);
var havingExpression = (SqlExpression?)Visit(select.Having, inSearchConditionContext: true, allowNullFalseEquivalence: true);
var projections = this.VisitAndConvert(select.Projection);
var orderings = this.VisitAndConvert(select.Orderings);
var offset = (SqlExpression?)Visit(select.Offset);
Expand All @@ -196,19 +196,19 @@ protected virtual Expression VisitSelect(SelectExpression select)
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected virtual Expression VisitSqlBinary(SqlBinaryExpression binary, bool inSearchConditionContext)
protected virtual Expression VisitSqlBinary(SqlBinaryExpression binary, bool inSearchConditionContext, bool allowNullFalseEquivalence)
{
// Only logical operations need conditions on both sides
var areOperandsInSearchConditionContext = binary.OperatorType is ExpressionType.AndAlso or ExpressionType.OrElse;

var newLeft = (SqlExpression)Visit(binary.Left, areOperandsInSearchConditionContext);
var newRight = (SqlExpression)Visit(binary.Right, areOperandsInSearchConditionContext);
var newLeft = (SqlExpression)Visit(binary.Left, areOperandsInSearchConditionContext, allowNullFalseEquivalence: false);
var newRight = (SqlExpression)Visit(binary.Right, areOperandsInSearchConditionContext, allowNullFalseEquivalence: false);

if (binary.OperatorType is ExpressionType.NotEqual or ExpressionType.Equal)
{
var leftType = newLeft.TypeMapping?.Converter?.ProviderClrType ?? newLeft.Type;
var rightType = newRight.TypeMapping?.Converter?.ProviderClrType ?? newRight.Type;
if (!inSearchConditionContext
if (!inSearchConditionContext && !allowNullFalseEquivalence
&& (leftType == typeof(bool) || leftType.IsInteger())
&& (rightType == typeof(bool) || rightType.IsInteger()))
{
Expand Down Expand Up @@ -309,7 +309,7 @@ protected virtual Expression VisitSqlUnary(SqlUnaryExpression sqlUnaryExpression
sqlUnaryExpression.OperatorType, typeof(SqlUnaryExpression)));
}

var operand = (SqlExpression)Visit(sqlUnaryExpression.Operand, isOperandInSearchConditionContext);
var operand = (SqlExpression)Visit(sqlUnaryExpression.Operand, isOperandInSearchConditionContext, allowNullFalseEquivalence: false);

return SimplifyNegatedBinary(
ApplyConversion(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -911,6 +911,17 @@ public virtual void Where_not_equal_using_relational_null_semantics_complex_with
.Select(e => e.Id).ToList();
}

[ConditionalFact]
public virtual void Where_not_equal_using_relational_null_semantics_complex_in_equals()
{
using var context = CreateContext(useRelationalNulls: true);
var l = context.Entities1
.Where(e => (e.NullableBoolA != e.NullableBoolB) == e.NullableBoolC)
.Select(e => e.Id).ToList();

Assert.Equal(l.OrderBy(e => e), [1, 5, 11, 13]);
}

[ConditionalTheory, MemberData(nameof(IsAsyncData))]
public virtual async Task Where_comparison_null_constant_and_null_parameter(bool async)
{
Expand Down Expand Up @@ -1161,6 +1172,11 @@ await AssertQueryScalar(
await AssertQueryScalar(
async,
ss => ss.Set<NullSemanticsEntity1>().Select(e => (e.BoolA ? e.NullableIntA : e.IntB) > e.IntC));

await AssertQueryScalar(
async,
ss => ss.Set<NullSemanticsEntity1>().Where(e => (e.BoolA ? e.NullableBoolB : !e.NullableBoolC) == null).Select(e => e.Id)
);
}

[ConditionalTheory, MemberData(nameof(IsAsyncData))]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3252,6 +3252,18 @@ WHERE [e].[NullableBoolA] <> [e].[NullableBoolB]
""");
}

public override void Where_not_equal_using_relational_null_semantics_complex_in_equals()
{
base.Where_not_equal_using_relational_null_semantics_complex_in_equals();

AssertSql(
"""
SELECT [e].[Id]
FROM [Entities1] AS [e]
WHERE [e].[NullableBoolA] ^ [e].[NullableBoolB] = [e].[NullableBoolC]
""");
}

public override async Task Where_comparison_null_constant_and_null_parameter(bool async)
{
await base.Where_comparison_null_constant_and_null_parameter(async);
Expand Down Expand Up @@ -3583,6 +3595,15 @@ ELSE [e].[IntB]
ELSE CAST(0 AS bit)
END
FROM [Entities1] AS [e]
""",
//
"""
SELECT [e].[Id]
FROM [Entities1] AS [e]
WHERE CASE
WHEN [e].[BoolA] = CAST(1 AS bit) THEN [e].[NullableBoolB]
ELSE ~[e].[NullableBoolC]
END IS NULL
""");
}

Expand Down Expand Up @@ -5087,7 +5108,10 @@ public override async Task Is_null_on_column_followed_by_OrElse_optimizes_nullab
SELECT [e].[Id], [e].[BoolA], [e].[BoolB], [e].[BoolC], [e].[IntA], [e].[IntB], [e].[IntC], [e].[NullableBoolA], [e].[NullableBoolB], [e].[NullableBoolC], [e].[NullableIntA], [e].[NullableIntB], [e].[NullableIntC], [e].[NullableStringA], [e].[NullableStringB], [e].[NullableStringC], [e].[StringA], [e].[StringB], [e].[StringC]
FROM [Entities1] AS [e]
WHERE CASE
WHEN [e].[NullableBoolA] IS NULL THEN ~([e].[BoolA] ^ [e].[BoolB])
WHEN [e].[NullableBoolA] IS NULL THEN CASE
WHEN [e].[BoolA] = [e].[BoolB] THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END
WHEN [e].[NullableBoolC] IS NULL THEN CASE
WHEN ([e].[NullableBoolA] <> [e].[NullableBoolC] OR [e].[NullableBoolA] IS NULL OR [e].[NullableBoolC] IS NULL) AND ([e].[NullableBoolA] IS NOT NULL OR [e].[NullableBoolC] IS NOT NULL) THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1703,6 +1703,15 @@ WHEN CASE
ELSE 0
END
FROM "Entities1" AS "e"
""",
//
"""
SELECT "e"."Id"
FROM "Entities1" AS "e"
WHERE CASE
WHEN "e"."BoolA" THEN "e"."NullableBoolB"
ELSE NOT ("e"."NullableBoolC")
END IS NULL
""");
}

Expand Down
Loading