@@ -130,28 +130,28 @@ trait HiveTypeCoercion {
130130 * the appropriate numeric equivalent.
131131 */
132132 object ConvertNaNs extends Rule [LogicalPlan ] {
133- private val stringNaN = Literal (" NaN" )
133+ private val StringNaN = Literal (" NaN" )
134134
135135 def apply (plan : LogicalPlan ): LogicalPlan = plan transform {
136136 case q : LogicalPlan => q transformExpressions {
137137 // Skip nodes who's children have not been resolved yet.
138138 case e if ! e.childrenResolved => e
139139
140140 /* Double Conversions */
141- case b : BinaryExpression if b.left == stringNaN && b.right.dataType == DoubleType =>
142- b.makeCopy(Array (b.right, Literal (Double .NaN )))
143- case b : BinaryExpression if b.left.dataType == DoubleType && b.right == stringNaN =>
144- b.makeCopy(Array (Literal (Double .NaN ), b.left))
145- case b : BinaryExpression if b.left == stringNaN && b.right == stringNaN =>
146- b.makeCopy(Array (Literal (Double .NaN ), b.left))
141+ case b @ BinaryExpression (StringNaN , r @ DoubleType ()) =>
142+ b.makeCopy(Array (r, Literal (Double .NaN )))
143+ case b @ BinaryExpression (l @ DoubleType (), StringNaN ) =>
144+ b.makeCopy(Array (Literal (Double .NaN ), l))
147145
148146 /* Float Conversions */
149- case b : BinaryExpression if b.left == stringNaN && b.right.dataType == FloatType =>
147+ case b @ BinaryExpression ( StringNaN , r @ FloatType ()) =>
150148 b.makeCopy(Array (b.right, Literal (Float .NaN )))
151- case b : BinaryExpression if b.left.dataType == FloatType && b.right == stringNaN =>
152- b.makeCopy(Array (Literal (Float .NaN ), b.left))
153- case b : BinaryExpression if b.left == stringNaN && b.right == stringNaN =>
154- b.makeCopy(Array (Literal (Float .NaN ), b.left))
149+ case b @ BinaryExpression (l @ FloatType (), StringNaN ) =>
150+ b.makeCopy(Array (Literal (Float .NaN ), l))
151+
152+ /* Use float NaN by default to avoid unnecessary type widening */
153+ case b @ BinaryExpression (l @ StringNaN , StringNaN ) =>
154+ b.makeCopy(Array (Literal (Float .NaN ), l))
155155 }
156156 }
157157 }
@@ -227,12 +227,12 @@ trait HiveTypeCoercion {
227227 // Skip nodes who's children have not been resolved yet.
228228 case e if ! e.childrenResolved => e
229229
230- case b : BinaryExpression if b.left. dataType != b.right .dataType =>
231- findTightestCommonTypeOfTwo(b.left. dataType, b.right .dataType).map { widestType =>
230+ case b @ BinaryExpression (l, r) if l. dataType != r .dataType =>
231+ findTightestCommonType(l. dataType, r .dataType).map { widestType =>
232232 val newLeft =
233- if (b.left. dataType == widestType) b.left else Cast (b.left , widestType)
233+ if (l. dataType == widestType) l else Cast (l , widestType)
234234 val newRight =
235- if (b.right. dataType == widestType) b.right else Cast (b.right , widestType)
235+ if (r. dataType == widestType) r else Cast (r , widestType)
236236 b.makeCopy(Array (newLeft, newRight))
237237 }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged.
238238 }
@@ -247,57 +247,42 @@ trait HiveTypeCoercion {
247247 // Skip nodes who's children have not been resolved yet.
248248 case e if ! e.childrenResolved => e
249249
250- case a : BinaryArithmetic if a.left.dataType == StringType =>
251- a.makeCopy(Array (Cast (a.left , DoubleType ), a.right ))
252- case a : BinaryArithmetic if a.right.dataType == StringType =>
253- a.makeCopy(Array (a.left , Cast (a.right , DoubleType )))
250+ case a @ BinaryArithmetic (l @ StringType (), r) =>
251+ a.makeCopy(Array (Cast (l , DoubleType ), r ))
252+ case a @ BinaryArithmetic (l, r @ StringType ()) =>
253+ a.makeCopy(Array (l , Cast (r , DoubleType )))
254254
255255 // we should cast all timestamp/date/string compare into string compare
256- case p : BinaryComparison if p.left.dataType == StringType &&
257- p.right.dataType == DateType =>
258- p.makeCopy(Array (p.left, Cast (p.right, StringType )))
259- case p : BinaryComparison if p.left.dataType == DateType &&
260- p.right.dataType == StringType =>
261- p.makeCopy(Array (Cast (p.left, StringType ), p.right))
262- case p : BinaryComparison if p.left.dataType == StringType &&
263- p.right.dataType == TimestampType =>
264- p.makeCopy(Array (Cast (p.left, TimestampType ), p.right))
265- case p : BinaryComparison if p.left.dataType == TimestampType &&
266- p.right.dataType == StringType =>
267- p.makeCopy(Array (p.left, Cast (p.right, TimestampType )))
268- case p : BinaryComparison if p.left.dataType == TimestampType &&
269- p.right.dataType == DateType =>
270- p.makeCopy(Array (Cast (p.left, StringType ), Cast (p.right, StringType )))
271- case p : BinaryComparison if p.left.dataType == DateType &&
272- p.right.dataType == TimestampType =>
273- p.makeCopy(Array (Cast (p.left, StringType ), Cast (p.right, StringType )))
274-
275- case p : BinaryComparison if p.left.dataType == StringType &&
276- p.right.dataType != StringType =>
277- p.makeCopy(Array (Cast (p.left, DoubleType ), p.right))
278- case p : BinaryComparison if p.left.dataType != StringType &&
279- p.right.dataType == StringType =>
280- p.makeCopy(Array (p.left, Cast (p.right, DoubleType )))
281-
282- case i @ In (a, b) if a.dataType == DateType &&
283- b.forall(_.dataType == StringType ) =>
256+ case p @ BinaryComparison (l @ StringType (), r @ DateType ()) =>
257+ p.makeCopy(Array (l, Cast (r, StringType )))
258+ case p @ BinaryComparison (l @ DateType (), r @ StringType ()) =>
259+ p.makeCopy(Array (Cast (l, StringType ), r))
260+ case p @ BinaryComparison (l @ StringType (), r @ TimestampType ()) =>
261+ p.makeCopy(Array (Cast (l, TimestampType ), r))
262+ case p @ BinaryComparison (l @ TimestampType (), r @ StringType ()) =>
263+ p.makeCopy(Array (l, Cast (r, TimestampType )))
264+ case p @ BinaryComparison (l @ TimestampType (), r @ DateType ()) =>
265+ p.makeCopy(Array (Cast (l, StringType ), Cast (r, StringType )))
266+ case p @ BinaryComparison (l @ DateType (), r @ TimestampType ()) =>
267+ p.makeCopy(Array (Cast (l, StringType ), Cast (r, StringType )))
268+
269+ case p @ BinaryComparison (l @ StringType (), r) if r.dataType != StringType =>
270+ p.makeCopy(Array (Cast (l, DoubleType ), r))
271+ case p @ BinaryComparison (l, r @ StringType ()) if l.dataType != StringType =>
272+ p.makeCopy(Array (l, Cast (r, DoubleType )))
273+
274+ case i @ In (a @ DateType (), b) if b.forall(_.dataType == StringType ) =>
284275 i.makeCopy(Array (Cast (a, StringType ), b))
285- case i @ In (a, b) if a.dataType == TimestampType &&
286- b.forall(_.dataType == StringType ) =>
276+ case i @ In (a @ TimestampType (), b) if b.forall(_.dataType == StringType ) =>
287277 i.makeCopy(Array (a, b.map(Cast (_, TimestampType ))))
288- case i @ In (a, b) if a.dataType == DateType &&
289- b.forall(_.dataType == TimestampType ) =>
278+ case i @ In (a @ DateType (), b) if b.forall(_.dataType == TimestampType ) =>
290279 i.makeCopy(Array (Cast (a, StringType ), b.map(Cast (_, StringType ))))
291- case i @ In (a, b) if a.dataType == TimestampType &&
292- b.forall(_.dataType == DateType ) =>
280+ case i @ In (a @ TimestampType (), b) if b.forall(_.dataType == DateType ) =>
293281 i.makeCopy(Array (Cast (a, StringType ), b.map(Cast (_, StringType ))))
294282
295- case Sum (e) if e.dataType == StringType =>
296- Sum (Cast (e, DoubleType ))
297- case Average (e) if e.dataType == StringType =>
298- Average (Cast (e, DoubleType ))
299- case Sqrt (e) if e.dataType == StringType =>
300- Sqrt (Cast (e, DoubleType ))
283+ case Sum (e @ StringType ()) => Sum (Cast (e, DoubleType ))
284+ case Average (e @ StringType ()) => Average (Cast (e, DoubleType ))
285+ case Sqrt (e @ StringType ()) => Sqrt (Cast (e, DoubleType ))
301286 }
302287 }
303288
@@ -467,16 +452,16 @@ trait HiveTypeCoercion {
467452
468453 // Promote integers inside a binary expression with fixed-precision decimals to decimals,
469454 // and fixed-precision decimals in an expression with floats / doubles to doubles
470- case b : BinaryExpression if b.left. dataType != b.right .dataType =>
471- (b.left. dataType, b.right .dataType) match {
455+ case b @ BinaryExpression (l, r) if l. dataType != r .dataType =>
456+ (l. dataType, r .dataType) match {
472457 case (t, DecimalType .Fixed (p, s)) if intTypeToFixed.contains(t) =>
473- b.makeCopy(Array (Cast (b.left , intTypeToFixed(t)), b.right ))
458+ b.makeCopy(Array (Cast (l , intTypeToFixed(t)), r ))
474459 case (DecimalType .Fixed (p, s), t) if intTypeToFixed.contains(t) =>
475- b.makeCopy(Array (b.left , Cast (b.right , intTypeToFixed(t))))
460+ b.makeCopy(Array (l , Cast (r , intTypeToFixed(t))))
476461 case (t, DecimalType .Fixed (p, s)) if isFloat(t) =>
477- b.makeCopy(Array (b.left , Cast (b.right , DoubleType )))
462+ b.makeCopy(Array (l , Cast (r , DoubleType )))
478463 case (DecimalType .Fixed (p, s), t) if isFloat(t) =>
479- b.makeCopy(Array (Cast (b.left , DoubleType ), b.right ))
464+ b.makeCopy(Array (Cast (l , DoubleType ), r ))
480465 case _ =>
481466 b
482467 }
0 commit comments