Skip to content

Commit 99626a4

Browse files
committed
Rewrite the code block that compares the equivalency of
Seq[Expression] in semanticEquals.
1 parent 4af3622 commit 99626a4

File tree

4 files changed

+32
-11
lines changed

4 files changed

+32
-11
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ import org.apache.spark.sql.types._
4848
* the same output data type.
4949
*
5050
*/
51-
abstract class Expression extends TreeNode[Expression] with PredicateHelper{
51+
abstract class Expression extends TreeNode[Expression]{
5252

5353
/**
5454
* Returns true when an expression is a candidate for static evaluation before the query is
@@ -160,6 +160,18 @@ abstract class Expression extends TreeNode[Expression] with PredicateHelper{
160160
checkSemantic(elements1, elements2)
161161
}
162162

163+
/**
164+
* Returns a sequence of expressions by removing from q the first expression that is semantically
165+
* equivalent to e.
166+
*/
167+
def removeFirstSemanticEquivalent(seq: Seq[Expression], e: Expression): Seq[Expression] = {
168+
seq match {
169+
case Seq() => Seq()
170+
case x +: rest if x semanticEquals e => rest
171+
case x +: rest => x +: removeFirstSemanticEquivalent(rest, e)
172+
}
173+
}
174+
163175
/**
164176
* Returns the hash for this expression. Expressions that compute the same result, even if
165177
* they differ cosmetically should return the same hash.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,8 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
228228
}
229229
}
230230

231-
case class And(left: Expression, right: Expression) extends BinaryOperator with Predicate {
231+
case class And(left: Expression, right: Expression) extends BinaryOperator
232+
with Predicate with PredicateHelper{
232233

233234
override def inputType: AbstractDataType = BooleanType
234235

@@ -256,10 +257,12 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with
256257
// Non-deterministic expressions cannot be semantic equal
257258
if (!deterministic || !other.deterministic) return false
258259

259-
// we know both expressions are And, so we can tolerate ordering different
260-
val elements1 = splitConjunctivePredicates(this).toSet.toSeq
261-
val elements2 = splitConjunctivePredicates(other).toSet.toSeq
262-
checkSemantic(elements1, elements2)
260+
// We already know both expressions are And, so we can tolerate ordering different
261+
// Recursively call semanticEquals on subexpressions to check the equivalency of two seqs.
262+
var elements1 = splitConjunctivePredicates(this)
263+
val elements2 = splitConjunctivePredicates(other)
264+
for (e <- elements2) elements1 = removeFirstSemanticEquivalent(elements1, e)
265+
elements1.isEmpty
263266
}
264267

265268
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
@@ -287,7 +290,8 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with
287290
}
288291

289292

290-
case class Or(left: Expression, right: Expression) extends BinaryOperator with Predicate {
293+
case class Or(left: Expression, right: Expression) extends BinaryOperator
294+
with Predicate with PredicateHelper {
291295

292296
override def inputType: AbstractDataType = BooleanType
293297

@@ -316,9 +320,11 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P
316320
if (!deterministic || !other.deterministic) return false
317321

318322
// we know both expressions are Or, so we can tolerate ordering different
319-
val elements1 = splitDisjunctivePredicates(this).toSet.toSeq
320-
val elements2 = splitDisjunctivePredicates(other).toSet.toSeq
321-
checkSemantic(elements1, elements2)
323+
// Recursively call semanticEquals on subexpressions to check the equivalency of two seqs.
324+
var elements1 = splitDisjunctivePredicates(this)
325+
val elements2 = splitDisjunctivePredicates(other)
326+
for (e <- elements2) elements1 = removeFirstSemanticEquivalent(elements1, e)
327+
elements1.isEmpty
322328
}
323329

324330
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.plans.QueryPlan
2525
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreeNode}
2626

2727

28-
abstract class LogicalPlan extends QueryPlan[LogicalPlan] with PredicateHelper with Logging {
28+
abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
2929

3030
private var _analyzed: Boolean = false
3131

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ class SameResultSuite extends SparkFunSuite {
6363
assertSameResult(testRelation.where('a === 'b || 'c === 'd),
6464
testRelation2.where('c === 'd || 'a === 'b )
6565
)
66+
assertSameResult(testRelation.where(('a === 'b || 'c === 'd) && ('e === 'f || 'g === 'h)),
67+
testRelation2.where(('g === 'h || 'e === 'f) && ('c === 'd || 'a === 'b ))
68+
)
6669

6770
assertSameResult(testRelation.where('a === 'b && 'c === 'd),
6871
testRelation2.where('a === 'c && 'b === 'd),

0 commit comments

Comments
 (0)