Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,8 @@ object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] {
val newBranches = cw.branches.map { case (cond, value) =>
replaceNullWithFalse(cond) -> replaceNullWithFalse(value)
}
if (newBranches.forall(_._2 == FalseLiteral) && cw.elseValue.isEmpty) {
FalseLiteral
} else {
val newElseValue = cw.elseValue.map(replaceNullWithFalse)
CaseWhen(newBranches, newElseValue)
}
val newElseValue = cw.elseValue.map(replaceNullWithFalse).getOrElse(FalseLiteral)
CaseWhen(newBranches, newElseValue)
case i @ If(pred, trueVal, falseVal) if i.dataType == BooleanType =>
If(replaceNullWithFalse(pred), replaceNullWithFalse(trueVal), replaceNullWithFalse(falseVal))
case e if e.dataType == BooleanType =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -515,8 +515,9 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
val (h, t) = branches.span(_._1 != TrueLiteral)
CaseWhen( h :+ t.head, None)

case e @ CaseWhen(branches, Some(elseValue))
if branches.forall(_._2.semanticEquals(elseValue)) =>
case e @ CaseWhen(branches, elseOpt)
if branches.forall(_._2.semanticEquals(elseOpt.getOrElse(Literal(null, e.dataType)))) =>
val elseValue = elseOpt.getOrElse(Literal(null, e.dataType))
// For non-deterministic conditions with side effect, we can not remove it, or change
// the ordering. As a result, we try to remove the deterministic conditions from the tail.
var hitNonDeterministicCond = false
Expand All @@ -532,10 +533,6 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
} else {
e.copy(branches = branches.take(i).map(branch => (branch._1, elseValue)))
}

case e @ CaseWhen(branches, None)
if branches.forall(_._2.semanticEquals(Literal(null, e.dataType))) =>
Literal(null, e.dataType)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,14 +260,13 @@ class PushFoldableIntoBranchesSuite
}

test("SPARK-33847: Remove the CaseWhen if elseValue is empty and other outputs are null") {
Seq(a, LessThan(Rand(1), Literal(0.5))).foreach { condition =>
assertEquivalent(
EqualTo(CaseWhen(Seq((condition, Literal.create(null, IntegerType)))), Literal(2)),
Literal.create(null, BooleanType))
assertEquivalent(
EqualTo(CaseWhen(Seq((condition, Literal("str")))).cast(IntegerType), Literal(2)),
Literal.create(null, BooleanType))
}
assertEquivalent(
EqualTo(CaseWhen(Seq((a, Literal.create(null, IntegerType)))), Literal(2)),
Literal.create(null, BooleanType))
assertEquivalent(
EqualTo(CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), Literal("str")))).cast(IntegerType),
Literal(2)),
CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), Literal.create(null, BooleanType)))))
}

test("SPARK-33884: simplify CaseWhen clauses with (true and false) and (false and true)") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
val expectedBranches = Seq(
(UnresolvedAttribute("i") < Literal(10)) -> FalseLiteral,
(UnresolvedAttribute("i") > Literal(40)) -> TrueLiteral)
val expectedCond = CaseWhen(expectedBranches)
val expectedCond = CaseWhen(expectedBranches, FalseLiteral)

testFilter(originalCond, expectedCond)
testJoin(originalCond, expectedCond)
Expand All @@ -135,7 +135,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
(UnresolvedAttribute("i") < Literal(10)) -> TrueLiteral,
(UnresolvedAttribute("i") > Literal(10)) -> FalseLiteral,
TrueLiteral -> TrueLiteral)
val expectedCond = CaseWhen(expectedBranches)
val expectedCond = CaseWhen(expectedBranches, FalseLiteral)

testFilter(originalCond, expectedCond)
testJoin(originalCond, expectedCond)
Expand Down Expand Up @@ -238,7 +238,8 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
FalseLiteral)
val condition = CaseWhen(Seq((UnresolvedAttribute("i") > Literal(10)) -> branchValue))
val expectedCond = CaseWhen(Seq(
(UnresolvedAttribute("i") > Literal(10), (Literal(2) === nestedCaseWhen) <=> TrueLiteral)))
(UnresolvedAttribute("i") > Literal(10), (Literal(2) === nestedCaseWhen) <=> TrueLiteral)),
FalseLiteral)
testFilter(originalCond = condition, expectedCond = expectedCond)
testJoin(originalCond = condition, expectedCond = expectedCond)
testDelete(originalCond = condition, expectedCond = expectedCond)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,11 +237,13 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P
}

test("SPARK-33847: Remove the CaseWhen if elseValue is empty and other outputs are null") {
Seq(GreaterThan('a, 1), GreaterThan(Rand(0), 1)).foreach { condition =>
assertEquivalent(
CaseWhen((condition, Literal.create(null, IntegerType)) :: Nil, None),
Literal.create(null, IntegerType))
}
assertEquivalent(
CaseWhen((GreaterThan('a, 1), Literal.create(null, IntegerType)) :: Nil, None),
Literal.create(null, IntegerType))

assertEquivalent(
CaseWhen((GreaterThan(Rand(0), 0.5), Literal.create(null, IntegerType)) :: Nil, None),
CaseWhen((GreaterThan(Rand(0), 0.5), Literal.create(null, IntegerType)) :: Nil, None))
}

test("SPARK-33884: simplify CaseWhen clauses with (true and false) and (false and true)") {
Expand Down