Skip to content

Commit 14900ae

Browse files
committed
Simplifies binary node pattern matching
1 parent 74dc2a9 commit 14900ae

File tree

5 files changed

+71
-77
lines changed

5 files changed

+71
-77
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 51 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -130,28 +130,28 @@ trait HiveTypeCoercion {
130130
* the appropriate numeric equivalent.
131131
*/
132132
object ConvertNaNs extends Rule[LogicalPlan] {
133-
private val stringNaN = Literal("NaN")
133+
private val StringNaN = Literal("NaN")
134134

135135
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
136136
case q: LogicalPlan => q transformExpressions {
137137
// Skip nodes who's children have not been resolved yet.
138138
case e if !e.childrenResolved => e
139139

140140
/* Double Conversions */
141-
case b: BinaryExpression if b.left == stringNaN && b.right.dataType == DoubleType =>
142-
b.makeCopy(Array(b.right, Literal(Double.NaN)))
143-
case b: BinaryExpression if b.left.dataType == DoubleType && b.right == stringNaN =>
144-
b.makeCopy(Array(Literal(Double.NaN), b.left))
145-
case b: BinaryExpression if b.left == stringNaN && b.right == stringNaN =>
146-
b.makeCopy(Array(Literal(Double.NaN), b.left))
141+
case b @ BinaryExpression(StringNaN, r @ DoubleType()) =>
142+
b.makeCopy(Array(r, Literal(Double.NaN)))
143+
case b @ BinaryExpression(l @ DoubleType(), StringNaN) =>
144+
b.makeCopy(Array(Literal(Double.NaN), l))
147145

148146
/* Float Conversions */
149-
case b: BinaryExpression if b.left == stringNaN && b.right.dataType == FloatType =>
147+
case b @ BinaryExpression(StringNaN, r @ FloatType()) =>
150148
b.makeCopy(Array(b.right, Literal(Float.NaN)))
151-
case b: BinaryExpression if b.left.dataType == FloatType && b.right == stringNaN =>
152-
b.makeCopy(Array(Literal(Float.NaN), b.left))
153-
case b: BinaryExpression if b.left == stringNaN && b.right == stringNaN =>
154-
b.makeCopy(Array(Literal(Float.NaN), b.left))
149+
case b @ BinaryExpression(l @ FloatType(), StringNaN) =>
150+
b.makeCopy(Array(Literal(Float.NaN), l))
151+
152+
/* Use float NaN by default to avoid unnecessary type widening */
153+
case b @ BinaryExpression(l @ StringNaN, StringNaN) =>
154+
b.makeCopy(Array(Literal(Float.NaN), l))
155155
}
156156
}
157157
}
@@ -227,12 +227,12 @@ trait HiveTypeCoercion {
227227
// Skip nodes who's children have not been resolved yet.
228228
case e if !e.childrenResolved => e
229229

230-
case b: BinaryExpression if b.left.dataType != b.right.dataType =>
231-
findTightestCommonTypeOfTwo(b.left.dataType, b.right.dataType).map { widestType =>
230+
case b @ BinaryExpression(l, r) if l.dataType != r.dataType =>
231+
findTightestCommonType(l.dataType, r.dataType).map { widestType =>
232232
val newLeft =
233-
if (b.left.dataType == widestType) b.left else Cast(b.left, widestType)
233+
if (l.dataType == widestType) l else Cast(l, widestType)
234234
val newRight =
235-
if (b.right.dataType == widestType) b.right else Cast(b.right, widestType)
235+
if (r.dataType == widestType) r else Cast(r, widestType)
236236
b.makeCopy(Array(newLeft, newRight))
237237
}.getOrElse(b) // If there is no applicable conversion, leave expression unchanged.
238238
}
@@ -247,57 +247,42 @@ trait HiveTypeCoercion {
247247
// Skip nodes who's children have not been resolved yet.
248248
case e if !e.childrenResolved => e
249249

250-
case a: BinaryArithmetic if a.left.dataType == StringType =>
251-
a.makeCopy(Array(Cast(a.left, DoubleType), a.right))
252-
case a: BinaryArithmetic if a.right.dataType == StringType =>
253-
a.makeCopy(Array(a.left, Cast(a.right, DoubleType)))
250+
case a @ BinaryArithmetic(l @ StringType(), r) =>
251+
a.makeCopy(Array(Cast(l, DoubleType), r))
252+
case a @ BinaryArithmetic(l, r @ StringType()) =>
253+
a.makeCopy(Array(l, Cast(r, DoubleType)))
254254

255255
// we should cast all timestamp/date/string compare into string compare
256-
case p: BinaryComparison if p.left.dataType == StringType &&
257-
p.right.dataType == DateType =>
258-
p.makeCopy(Array(p.left, Cast(p.right, StringType)))
259-
case p: BinaryComparison if p.left.dataType == DateType &&
260-
p.right.dataType == StringType =>
261-
p.makeCopy(Array(Cast(p.left, StringType), p.right))
262-
case p: BinaryComparison if p.left.dataType == StringType &&
263-
p.right.dataType == TimestampType =>
264-
p.makeCopy(Array(Cast(p.left, TimestampType), p.right))
265-
case p: BinaryComparison if p.left.dataType == TimestampType &&
266-
p.right.dataType == StringType =>
267-
p.makeCopy(Array(p.left, Cast(p.right, TimestampType)))
268-
case p: BinaryComparison if p.left.dataType == TimestampType &&
269-
p.right.dataType == DateType =>
270-
p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType)))
271-
case p: BinaryComparison if p.left.dataType == DateType &&
272-
p.right.dataType == TimestampType =>
273-
p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType)))
274-
275-
case p: BinaryComparison if p.left.dataType == StringType &&
276-
p.right.dataType != StringType =>
277-
p.makeCopy(Array(Cast(p.left, DoubleType), p.right))
278-
case p: BinaryComparison if p.left.dataType != StringType &&
279-
p.right.dataType == StringType =>
280-
p.makeCopy(Array(p.left, Cast(p.right, DoubleType)))
281-
282-
case i @ In(a, b) if a.dataType == DateType &&
283-
b.forall(_.dataType == StringType) =>
256+
case p @ BinaryComparison(l @ StringType(), r @ DateType()) =>
257+
p.makeCopy(Array(l, Cast(r, StringType)))
258+
case p @ BinaryComparison(l @ DateType(), r @ StringType()) =>
259+
p.makeCopy(Array(Cast(l, StringType), r))
260+
case p @ BinaryComparison(l @ StringType(), r @ TimestampType()) =>
261+
p.makeCopy(Array(Cast(l, TimestampType), r))
262+
case p @ BinaryComparison(l @ TimestampType(), r @ StringType()) =>
263+
p.makeCopy(Array(l, Cast(r, TimestampType)))
264+
case p @ BinaryComparison(l @ TimestampType(), r @ DateType()) =>
265+
p.makeCopy(Array(Cast(l, StringType), Cast(r, StringType)))
266+
case p @ BinaryComparison(l @ DateType(), r @ TimestampType()) =>
267+
p.makeCopy(Array(Cast(l, StringType), Cast(r, StringType)))
268+
269+
case p @ BinaryComparison(l @ StringType(), r) if r.dataType != StringType =>
270+
p.makeCopy(Array(Cast(l, DoubleType), r))
271+
case p @ BinaryComparison(l, r @ StringType()) if l.dataType != StringType =>
272+
p.makeCopy(Array(l, Cast(r, DoubleType)))
273+
274+
case i @ In(a @ DateType(), b) if b.forall(_.dataType == StringType) =>
284275
i.makeCopy(Array(Cast(a, StringType), b))
285-
case i @ In(a, b) if a.dataType == TimestampType &&
286-
b.forall(_.dataType == StringType) =>
276+
case i @ In(a @ TimestampType(), b) if b.forall(_.dataType == StringType) =>
287277
i.makeCopy(Array(a, b.map(Cast(_, TimestampType))))
288-
case i @ In(a, b) if a.dataType == DateType &&
289-
b.forall(_.dataType == TimestampType) =>
278+
case i @ In(a @ DateType(), b) if b.forall(_.dataType == TimestampType) =>
290279
i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType))))
291-
case i @ In(a, b) if a.dataType == TimestampType &&
292-
b.forall(_.dataType == DateType) =>
280+
case i @ In(a @ TimestampType(), b) if b.forall(_.dataType == DateType) =>
293281
i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType))))
294282

295-
case Sum(e) if e.dataType == StringType =>
296-
Sum(Cast(e, DoubleType))
297-
case Average(e) if e.dataType == StringType =>
298-
Average(Cast(e, DoubleType))
299-
case Sqrt(e) if e.dataType == StringType =>
300-
Sqrt(Cast(e, DoubleType))
283+
case Sum(e @ StringType()) => Sum(Cast(e, DoubleType))
284+
case Average(e @ StringType()) => Average(Cast(e, DoubleType))
285+
case Sqrt(e @ StringType()) => Sqrt(Cast(e, DoubleType))
301286
}
302287
}
303288

@@ -467,16 +452,16 @@ trait HiveTypeCoercion {
467452

468453
// Promote integers inside a binary expression with fixed-precision decimals to decimals,
469454
// and fixed-precision decimals in an expression with floats / doubles to doubles
470-
case b: BinaryExpression if b.left.dataType != b.right.dataType =>
471-
(b.left.dataType, b.right.dataType) match {
455+
case b @ BinaryExpression(l, r) if l.dataType != r.dataType =>
456+
(l.dataType, r.dataType) match {
472457
case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) =>
473-
b.makeCopy(Array(Cast(b.left, intTypeToFixed(t)), b.right))
458+
b.makeCopy(Array(Cast(l, intTypeToFixed(t)), r))
474459
case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) =>
475-
b.makeCopy(Array(b.left, Cast(b.right, intTypeToFixed(t))))
460+
b.makeCopy(Array(l, Cast(r, intTypeToFixed(t))))
476461
case (t, DecimalType.Fixed(p, s)) if isFloat(t) =>
477-
b.makeCopy(Array(b.left, Cast(b.right, DoubleType)))
462+
b.makeCopy(Array(l, Cast(r, DoubleType)))
478463
case (DecimalType.Fixed(p, s), t) if isFloat(t) =>
479-
b.makeCopy(Array(Cast(b.left, DoubleType), b.right))
464+
b.makeCopy(Array(Cast(l, DoubleType), r))
480465
case _ =>
481466
b
482467
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,10 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
118118
override def toString: String = s"($left $symbol $right)"
119119
}
120120

121+
private[sql] object BinaryExpression {
122+
def unapply(e: BinaryExpression): Option[(Expression, Expression)] = Some((e.left, e.right))
123+
}
124+
121125
abstract class LeafExpression extends Expression with trees.LeafNode[Expression] {
122126
self: Product =>
123127
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,10 @@ abstract class BinaryArithmetic extends BinaryExpression {
118118
sys.error(s"BinaryArithmetics must override either eval or evalInternal")
119119
}
120120

121+
private[sql] object BinaryArithmetic {
122+
def unapply(e: BinaryArithmetic): Option[(Expression, Expression)] = Some((e.left, e.right))
123+
}
124+
121125
case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
122126
override def symbol: String = "+"
123127

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,10 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison
218218
}
219219
}
220220

221+
private[sql] object BinaryComparison {
222+
def unapply(e: BinaryComparison): Option[(Expression, Expression)] = Some((e.left, e.right))
223+
}
224+
221225
case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComparison {
222226
override def symbol: String = "<=>"
223227

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ object NullPropagation extends Rule[LogicalPlan] {
266266
if (newChildren.length == 0) {
267267
Literal.create(null, e.dataType)
268268
} else if (newChildren.length == 1) {
269-
newChildren(0)
269+
newChildren.head
270270
} else {
271271
Coalesce(newChildren)
272272
}
@@ -280,21 +280,18 @@ object NullPropagation extends Rule[LogicalPlan] {
280280
case e: MinOf => e
281281

282282
// Put exceptional cases above if any
283-
case e: BinaryArithmetic => e.children match {
284-
case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType)
285-
case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType)
286-
case _ => e
287-
}
288-
case e: BinaryComparison => e.children match {
289-
case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType)
290-
case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType)
291-
case _ => e
292-
}
283+
case e @ BinaryArithmetic(Literal(null, _), _) => Literal.create(null, e.dataType)
284+
case e @ BinaryArithmetic(_, Literal(null, _)) => Literal.create(null, e.dataType)
285+
286+
case e @ BinaryComparison(Literal(null, _), _) => Literal.create(null, e.dataType)
287+
case e @ BinaryComparison(_, Literal(null, _)) => Literal.create(null, e.dataType)
288+
293289
case e: StringRegexExpression => e.children match {
294290
case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType)
295291
case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType)
296292
case _ => e
297293
}
294+
298295
case e: StringComparison => e.children match {
299296
case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType)
300297
case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType)

0 commit comments

Comments
 (0)