@@ -108,6 +108,28 @@ object TypeCoercion {
108108 case _ => None
109109 }
110110
111+ /**
112+ * This function determines the target type of a comparison operator when one operand
113+ * is a String and the other is not. It also handles when one op is a Date and the
114+ * other is a Timestamp by making the target type to be String.
115+ */
116+ val findCommonTypeForBinaryComparison : (DataType , DataType ) => Option [DataType ] = {
117+ // We should cast all relative timestamp/date/string comparison into string comparisons
118+ // This behaves as a user would expect because timestamp strings sort lexicographically.
119+ // i.e. TimeStamp(2013-01-01 00:00 ...) < "2014" = true
120+ case (StringType , DateType ) => Some (StringType )
121+ case (DateType , StringType ) => Some (StringType )
122+ case (StringType , TimestampType ) => Some (StringType )
123+ case (TimestampType , StringType ) => Some (StringType )
124+ case (TimestampType , DateType ) => Some (StringType )
125+ case (DateType , TimestampType ) => Some (StringType )
126+ case (StringType , NullType ) => Some (StringType )
127+ case (NullType , StringType ) => Some (StringType )
128+ case (l : StringType , r : AtomicType ) if r != StringType => Some (r)
129+ case (l : AtomicType , r : StringType ) if (l != StringType ) => Some (l)
130+ case (l, r) => None
131+ }
132+
111133 /**
112134 * Case 2 type widening (see the classdoc comment above for TypeCoercion).
113135 *
@@ -305,6 +327,14 @@ object TypeCoercion {
305327 * Promotes strings that appear in arithmetic expressions.
306328 */
307329 object PromoteStrings extends Rule [LogicalPlan ] {
330+ private def castExpr (expr : Expression , targetType : DataType ): Expression = {
331+ (expr.dataType, targetType) match {
332+ case (NullType , dt) => Literal .create(null , targetType)
333+ case (l, dt) if (l != dt) => Cast (expr, targetType)
334+ case _ => expr
335+ }
336+ }
337+
308338 def apply (plan : LogicalPlan ): LogicalPlan = plan resolveExpressions {
309339 // Skip nodes who's children have not been resolved yet.
310340 case e if ! e.childrenResolved => e
@@ -321,37 +351,10 @@ object TypeCoercion {
321351 case p @ Equality (left @ TimestampType (), right @ StringType ()) =>
322352 p.makeCopy(Array (left, Cast (right, TimestampType )))
323353
324- // We should cast all relative timestamp/date/string comparison into string comparisons
325- // This behaves as a user would expect because timestamp strings sort lexicographically.
326- // i.e. TimeStamp(2013-01-01 00:00 ...) < "2014" = true
327- case p @ BinaryComparison (left @ StringType (), right @ DateType ()) =>
328- p.makeCopy(Array (left, Cast (right, StringType )))
329- case p @ BinaryComparison (left @ DateType (), right @ StringType ()) =>
330- p.makeCopy(Array (Cast (left, StringType ), right))
331- case p @ BinaryComparison (left @ StringType (), right @ TimestampType ()) =>
332- p.makeCopy(Array (left, Cast (right, StringType )))
333- case p @ BinaryComparison (left @ TimestampType (), right @ StringType ()) =>
334- p.makeCopy(Array (Cast (left, StringType ), right))
335-
336- // Comparisons between dates and timestamps.
337- case p @ BinaryComparison (left @ TimestampType (), right @ DateType ()) =>
338- p.makeCopy(Array (Cast (left, StringType ), Cast (right, StringType )))
339- case p @ BinaryComparison (left @ DateType (), right @ TimestampType ()) =>
340- p.makeCopy(Array (Cast (left, StringType ), Cast (right, StringType )))
341-
342- // Checking NullType
343- case p @ BinaryComparison (left @ StringType (), right @ NullType ()) =>
344- p.makeCopy(Array (left, Literal .create(null , StringType )))
345- case p @ BinaryComparison (left @ NullType (), right @ StringType ()) =>
346- p.makeCopy(Array (Literal .create(null , StringType ), right))
347-
348- // When compare string with atomic type, case string to that type.
349- case p @ BinaryComparison (left @ StringType (), right @ AtomicType ())
350- if right.dataType != StringType =>
351- p.makeCopy(Array (Cast (left, right.dataType), right))
352- case p @ BinaryComparison (left @ AtomicType (), right @ StringType ())
353- if left.dataType != StringType =>
354- p.makeCopy(Array (left, Cast (right, left.dataType)))
354+ case p @ BinaryComparison (left, right)
355+ if findCommonTypeForBinaryComparison(left.dataType, right.dataType).isDefined =>
356+ val commonType = findCommonTypeForBinaryComparison(left.dataType, right.dataType).get
357+ p.makeCopy(Array (castExpr(left, commonType), castExpr(right, commonType)))
355358
356359 case Sum (e @ StringType ()) => Sum (Cast (e, DoubleType ))
357360 case Average (e @ StringType ()) => Average (Cast (e, DoubleType ))
@@ -365,17 +368,72 @@ object TypeCoercion {
365368 }
366369
367370 /**
368- * Convert the value and in list expressions to the common operator type
369- * by looking at all the argument types and finding the closest one that
370- * all the arguments can be cast to. When no common operator type is found
371- * the original expression will be returned and an Analysis Exception will
372- * be raised at type checking phase.
371+ * Handles type coercion for both IN expression with subquery and IN
372+ * expressions without subquery.
373+ * 1. In the first case, find the common type by comparing the left hand side (LHS)
374+ * expression types against corresponding right hand side (RHS) expression derived
375+ * from the subquery expression's plan output. Inject appropriate casts in the
376+ * LHS and RHS side of IN expression.
377+ *
378+ * 2. In the second case, convert the value and in list expressions to the
379+ * common operator type by looking at all the argument types and finding
380+ * the closest one that all the arguments can be cast to. When no common
381+ * operator type is found the original expression will be returned and an
382+ * Analysis Exception will be raised at the type checking phase.
373383 */
374384 object InConversion extends Rule [LogicalPlan ] {
385+ private def flattenExpr (expr : Expression ): Seq [Expression ] = {
386+ expr match {
387+ // Multi columns in IN clause is represented as a CreateNamedStruct.
388+ // flatten the named struct to get the list of expressions.
389+ case cns : CreateNamedStruct => cns.valExprs
390+ case expr => Seq (expr)
391+ }
392+ }
393+
375394 def apply (plan : LogicalPlan ): LogicalPlan = plan resolveExpressions {
376395 // Skip nodes who's children have not been resolved yet.
377396 case e if ! e.childrenResolved => e
378397
398+ // Handle type casting required between value expression and subquery output
399+ // in IN subquery.
400+ case i @ In (a, Seq (ListQuery (sub, children, exprId)))
401+ if ! i.resolved && flattenExpr(a).length == sub.output.length =>
402+ // LHS is the value expression of IN subquery.
403+ val lhs = flattenExpr(a)
404+
405+ // RHS is the subquery output.
406+ val rhs = sub.output
407+
408+ val commonTypes = lhs.zip(rhs).flatMap { case (l, r) =>
409+ findCommonTypeForBinaryComparison(l.dataType, r.dataType)
410+ .orElse(findTightestCommonType(l.dataType, r.dataType))
411+ }
412+
413+ // The number of columns/expressions must match between LHS and RHS of an
414+ // IN subquery expression.
415+ if (commonTypes.length == lhs.length) {
416+ val castedRhs = rhs.zip(commonTypes).map {
417+ case (e, dt) if e.dataType != dt => Alias (Cast (e, dt), e.name)()
418+ case (e, _) => e
419+ }
420+ val castedLhs = lhs.zip(commonTypes).map {
421+ case (e, dt) if e.dataType != dt => Cast (e, dt)
422+ case (e, _) => e
423+ }
424+
425+ // Before constructing the In expression, wrap the multi values in LHS
426+ // in a CreatedNamedStruct.
427+ val newLhs = castedLhs match {
428+ case Seq (lhs) => lhs
429+ case _ => CreateStruct (castedLhs)
430+ }
431+
432+ In (newLhs, Seq (ListQuery (Project (castedRhs, sub), children, exprId)))
433+ } else {
434+ i
435+ }
436+
379437 case i @ In (a, b) if b.exists(_.dataType != a.dataType) =>
380438 findWiderCommonType(i.children.map(_.dataType)) match {
381439 case Some (finalDataType) => i.withNewChildren(i.children.map(Cast (_, finalDataType)))
0 commit comments