Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -114,28 +114,28 @@ trait HiveTypeCoercion {
* the appropriate numeric equivalent.
*/
object ConvertNaNs extends Rule[LogicalPlan] {
val stringNaN = Literal("NaN", StringType)
val StringNaN = Literal("NaN", StringType)

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e

/* Double Conversions */
case b: BinaryExpression if b.left == stringNaN && b.right.dataType == DoubleType =>
b.makeCopy(Array(b.right, Literal(Double.NaN)))
case b: BinaryExpression if b.left.dataType == DoubleType && b.right == stringNaN =>
b.makeCopy(Array(Literal(Double.NaN), b.left))
case b: BinaryExpression if b.left == stringNaN && b.right == stringNaN =>
b.makeCopy(Array(Literal(Double.NaN), b.left))
case b @ BinaryExpression(StringNaN, DoubleType(r)) =>
b.makeCopy(Array(r, Literal(Double.NaN)))
case b @ BinaryExpression(DoubleType(l), StringNaN) =>
b.makeCopy(Array(Literal(Double.NaN), l))

/* Float Conversions */
case b: BinaryExpression if b.left == stringNaN && b.right.dataType == FloatType =>
b.makeCopy(Array(b.right, Literal(Float.NaN)))
case b: BinaryExpression if b.left.dataType == FloatType && b.right == stringNaN =>
b.makeCopy(Array(Literal(Float.NaN), b.left))
case b: BinaryExpression if b.left == stringNaN && b.right == stringNaN =>
b.makeCopy(Array(Literal(Float.NaN), b.left))
case b @ BinaryExpression(StringNaN, FloatType(r)) =>
b.makeCopy(Array(r, Literal(Float.NaN)))
case b @ BinaryExpression(FloatType(l), StringNaN) =>
b.makeCopy(Array(Literal(Float.NaN), l))

/* Use float NaN by default to avoid unnecessary type widening */
case b @ BinaryExpression(l @ StringNaN, StringNaN) =>
b.makeCopy(Array(Literal(Float.NaN), l))
}
}
}
Expand Down Expand Up @@ -168,9 +168,9 @@ trait HiveTypeCoercion {
case u @ Union(left, right) if u.childrenResolved && !u.resolved =>
val castedInput = left.output.zip(right.output).map {
// When a string is found on one side, make the other side a string too.
case (l, r) if l.dataType == StringType && r.dataType != StringType =>
case (StringType(l), r) if r.dataType != StringType =>
(l, Alias(Cast(r, StringType), r.name)())
case (l, r) if l.dataType != StringType && r.dataType == StringType =>
case (l, StringType(r)) if l.dataType != StringType =>
(Alias(Cast(l, StringType), l.name)(), r)

case (l, r) if l.dataType != r.dataType =>
Expand Down Expand Up @@ -211,12 +211,12 @@ trait HiveTypeCoercion {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e

case b: BinaryExpression if b.left.dataType != b.right.dataType =>
findTightestCommonType(b.left.dataType, b.right.dataType).map { widestType =>
case b @ BinaryExpression(l, r) if l.dataType != r.dataType =>
findTightestCommonType(l.dataType, r.dataType).map { widestType =>
val newLeft =
if (b.left.dataType == widestType) b.left else Cast(b.left, widestType)
if (l.dataType == widestType) l else Cast(l, widestType)
val newRight =
if (b.right.dataType == widestType) b.right else Cast(b.right, widestType)
if (r.dataType == widestType) r else Cast(r, widestType)
b.makeCopy(Array(newLeft, newRight))
}.getOrElse(b) // If there is no applicable conversion, leave expression unchanged.
}
Expand All @@ -231,51 +231,50 @@ trait HiveTypeCoercion {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e

case a: BinaryArithmetic if a.left.dataType == StringType =>
a.makeCopy(Array(Cast(a.left, DoubleType), a.right))
case a: BinaryArithmetic if a.right.dataType == StringType =>
a.makeCopy(Array(a.left, Cast(a.right, DoubleType)))
case a @ BinaryArithmetic(StringType(l), r) =>
a.makeCopy(Array(Cast(l, DoubleType), r))
case a @ BinaryArithmetic(l, StringType(r)) =>
a.makeCopy(Array(l, Cast(r, DoubleType)))

// we should cast all timestamp/date/string compare into string compare
case p: BinaryPredicate if p.left.dataType == StringType
&& p.right.dataType == DateType =>
p.makeCopy(Array(p.left, Cast(p.right, StringType)))
case p: BinaryPredicate if p.left.dataType == DateType
&& p.right.dataType == StringType =>
p.makeCopy(Array(Cast(p.left, StringType), p.right))
case p: BinaryPredicate if p.left.dataType == StringType
&& p.right.dataType == TimestampType =>
p.makeCopy(Array(p.left, Cast(p.right, StringType)))
case p: BinaryPredicate if p.left.dataType == TimestampType
&& p.right.dataType == StringType =>
p.makeCopy(Array(Cast(p.left, StringType), p.right))
case p: BinaryPredicate if p.left.dataType == TimestampType
&& p.right.dataType == DateType =>
p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType)))
case p: BinaryPredicate if p.left.dataType == DateType
&& p.right.dataType == TimestampType =>
p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType)))

case p: BinaryPredicate if p.left.dataType == StringType && p.right.dataType != StringType =>
p.makeCopy(Array(Cast(p.left, DoubleType), p.right))
case p: BinaryPredicate if p.left.dataType != StringType && p.right.dataType == StringType =>
p.makeCopy(Array(p.left, Cast(p.right, DoubleType)))

case i @ In(a, b) if a.dataType == DateType && b.forall(_.dataType == StringType) =>
case p @ BinaryPredicate(StringType(l), DateType(r)) =>
p.makeCopy(Array(l, Cast(r, StringType)))
case p @ BinaryPredicate(DateType(l), StringType(r)) =>
p.makeCopy(Array(Cast(l, StringType), r))
case p @ BinaryPredicate(TimestampType(l), DateType(r)) =>
p.makeCopy(Array(Cast(l, StringType), Cast(r, StringType)))
case p @ BinaryPredicate(DateType(l), TimestampType(r)) =>
p.makeCopy(Array(Cast(l, StringType), Cast(r, StringType)))
case p @ BinaryPredicate(StringType(l), TimestampType(r)) =>
p.makeCopy(Array(Cast(l, TimestampType), r))
case p @ BinaryPredicate(TimestampType(l), StringType(r)) =>
p.makeCopy(Array(l, Cast(r, TimestampType)))

case p @ BinaryPredicate(StringType(l), r) if r.dataType != StringType =>
p.makeCopy(Array(Cast(l, DoubleType), r))
case p @ BinaryPredicate(l, StringType(r)) if l.dataType != StringType =>
p.makeCopy(Array(l, Cast(r, DoubleType)))

case i @ In(DateType(a), b) if b.forall(_.dataType == StringType) =>
i.makeCopy(Array(Cast(a, StringType), b))
case i @ In(a, b) if a.dataType == TimestampType && b.forall(_.dataType == StringType) =>
case i @ In(TimestampType(a), b) if b.forall(_.dataType == StringType) =>
i.makeCopy(Array(Cast(a, StringType), b))
case i @ In(a, b) if a.dataType == DateType && b.forall(_.dataType == TimestampType) =>
case i @ In(DateType(a), b) if b.forall(_.dataType == TimestampType) =>
i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType))))
case i @ In(a, b) if a.dataType == TimestampType && b.forall(_.dataType == DateType) =>
case i @ In(TimestampType(a), b) if b.forall(_.dataType == DateType) =>
i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType))))
case i @ In(DateType(a), b) if b.forall(_.dataType == StringType) =>
i.makeCopy(Array(Cast(a, StringType), b))
case i @ In(TimestampType(a), b) if b.forall(_.dataType == StringType) =>
i.makeCopy(Array(a, b.map(Cast(_,TimestampType))))
case i @ In(DateType(a), b) if b.forall(_.dataType == TimestampType) =>
i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType))))
case i @ In(TimestampType(a), b) if b.forall(_.dataType == DateType) =>
i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType))))

case Sum(e) if e.dataType == StringType =>
Sum(Cast(e, DoubleType))
case Average(e) if e.dataType == StringType =>
Average(Cast(e, DoubleType))
case Sqrt(e) if e.dataType == StringType =>
Sqrt(Cast(e, DoubleType))
case Sum(StringType(e)) => Sum(Cast(e, DoubleType))
case Average(StringType(e)) => Average(Cast(e, DoubleType))
case Sqrt(StringType(e)) => Sqrt(Cast(e, DoubleType))
}
}

Expand Down Expand Up @@ -395,19 +394,18 @@ trait HiveTypeCoercion {
case e if !e.childrenResolved => e

// Hive treats (true = 1) as true and (false = 0) as true.
case EqualTo(l @ BooleanType(), r) if trueValues.contains(r) => l
case EqualTo(l, r @ BooleanType()) if trueValues.contains(l) => r
case EqualTo(l @ BooleanType(), r) if falseValues.contains(r) => Not(l)
case EqualTo(l, r @ BooleanType()) if falseValues.contains(l) => Not(r)
case EqualTo(BooleanType(l), r) if trueValues.contains(r) => l
case EqualTo(l, BooleanType(r)) if trueValues.contains(l) => r
case EqualTo(BooleanType(l), r) if falseValues.contains(r) => Not(l)
case EqualTo(l, BooleanType(r)) if falseValues.contains(l) => Not(r)

// No need to change other EqualTo operators as that actually makes sense for boolean types.
case e: EqualTo => e
// No need to change the EqualNullSafe operators, too
case e: EqualNullSafe => e
// Otherwise turn them to Byte types so that there exists and ordering.
case p: BinaryComparison
if p.left.dataType == BooleanType && p.right.dataType == BooleanType =>
p.makeCopy(Array(Cast(p.left, ByteType), Cast(p.right, ByteType)))
case p @ BinaryComparison(BooleanType(l), BooleanType(r)) =>
p.makeCopy(Array(Cast(l, ByteType), Cast(r, ByteType)))
}
}

Expand All @@ -421,18 +419,18 @@ trait HiveTypeCoercion {
case e if !e.childrenResolved => e
// Skip if the type is boolean type already. Note that this extra cast should be removed
// by optimizer.SimplifyCasts.
case Cast(e, BooleanType) if e.dataType == BooleanType => e
case Cast(BooleanType(e), BooleanType) => e
// DateType should be null if be cast to boolean.
case Cast(e, BooleanType) if e.dataType == DateType => Cast(e, BooleanType)
case Cast(DateType(e), BooleanType) => Cast(e, BooleanType)
// If the data type is not boolean and is being cast boolean, turn it into a comparison
// with the numeric value, i.e. x != 0. This will coerce the type into numeric type.
case Cast(e, BooleanType) if e.dataType != BooleanType => Not(EqualTo(e, Literal(0)))
// Stringify boolean if casting to StringType.
// TODO Ensure true/false string letter casing is consistent with Hive in all cases.
case Cast(e, StringType) if e.dataType == BooleanType =>
case Cast(BooleanType(e), StringType) =>
If(e, Literal("true"), Literal("false"))
// Turn true into 1, and false into 0 if casting boolean into other types.
case Cast(e, dataType) if e.dataType == BooleanType =>
case Cast(BooleanType(e), dataType) =>
Cast(If(e, Literal(1), Literal(0)), dataType)
}
}
Expand All @@ -447,7 +445,7 @@ trait HiveTypeCoercion {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e

case Cast(e @ StringType(), t: IntegralType) =>
case Cast(StringType(e), t: IntegralType) =>
Cast(Cast(e, DecimalType.Unlimited), t)
}
}
Expand All @@ -468,20 +466,20 @@ trait HiveTypeCoercion {
children.map(c => if (c.dataType == commonType) c else Cast(c, commonType)))

// Promote SUM, SUM DISTINCT and AVERAGE to largest types to prevent overflows.
case s @ Sum(e @ DecimalType()) => s // Decimal is already the biggest.
case Sum(e @ IntegralType()) if e.dataType != LongType => Sum(Cast(e, LongType))
case Sum(e @ FractionalType()) if e.dataType != DoubleType => Sum(Cast(e, DoubleType))
case s @ Sum(DecimalType(e)) => s // Decimal is already the biggest.
case Sum(IntegralType(e)) if e.dataType != LongType => Sum(Cast(e, LongType))
case Sum(FractionalType(e)) if e.dataType != DoubleType => Sum(Cast(e, DoubleType))

case s @ SumDistinct(e @ DecimalType()) => s // Decimal is already the biggest.
case SumDistinct(e @ IntegralType()) if e.dataType != LongType =>
case s @ SumDistinct(DecimalType(e)) => s // Decimal is already the biggest.
case SumDistinct(IntegralType(e)) if e.dataType != LongType =>
SumDistinct(Cast(e, LongType))
case SumDistinct(e @ FractionalType()) if e.dataType != DoubleType =>
case SumDistinct(FractionalType(e)) if e.dataType != DoubleType =>
SumDistinct(Cast(e, DoubleType))

case s @ Average(e @ DecimalType()) => s // Decimal is already the biggest.
case Average(e @ IntegralType()) if e.dataType != LongType =>
case s @ Average(DecimalType(e)) => s // Decimal is already the biggest.
case Average(IntegralType(e)) if e.dataType != LongType =>
Average(Cast(e, LongType))
case Average(e @ FractionalType()) if e.dataType != DoubleType =>
case Average(FractionalType(e)) if e.dataType != DoubleType =>
Average(Cast(e, DoubleType))

// Hive lets you do aggregation of timestamps... for some reason
Expand All @@ -503,10 +501,8 @@ trait HiveTypeCoercion {
case d: Divide if d.resolved && d.dataType == DoubleType => d
case d: Divide if d.resolved && d.dataType.isInstanceOf[DecimalType] => d

case Divide(l, r) if l.dataType.isInstanceOf[DecimalType] =>
Divide(l, Cast(r, DecimalType.Unlimited))
case Divide(l, r) if r.dataType.isInstanceOf[DecimalType] =>
Divide(Cast(l, DecimalType.Unlimited), r)
case Divide(DecimalType(l), r) => Divide(l, Cast(r, DecimalType.Unlimited))
case Divide(l, DecimalType(r)) => Divide(Cast(l, DecimalType.Unlimited), r)

case Divide(l, r) => Divide(Cast(l, DoubleType), Cast(r, DoubleType))
}
Expand All @@ -519,7 +515,7 @@ trait HiveTypeCoercion {
import HiveTypeCoercion._

def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case cw @ CaseWhen(branches) if !cw.resolved && !branches.exists(!_.resolved) =>
case cw @ CaseWhen(branches) if !cw.resolved && branches.forall(_.resolved) =>
val valueTypes = branches.sliding(2, 2).map {
case Seq(_, value) => value.dataType
case Seq(elseVal) => elseVal.dataType
Expand Down Expand Up @@ -547,5 +543,4 @@ trait HiveTypeCoercion {
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,10 @@ abstract class Expression extends TreeNode[Expression] {
}
}

object BinaryExpression {
def unapply(a: BinaryExpression): Option[(Expression, Expression)] = Some((a.left, a.right))
}

abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] {
self: Product =>

Expand All @@ -243,6 +247,4 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression]

abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] {
self: Product =>


}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.catalyst.types._
import scala.math.pow

case class UnaryMinus(child: Expression) extends UnaryExpression {
type EvaluatedType = Any
Expand All @@ -43,10 +42,14 @@ case class Sqrt(child: Expression) extends UnaryExpression {
override def toString = s"SQRT($child)"

override def eval(input: Row): Any = {
n1(child, input, ((na,a) => math.sqrt(na.toDouble(a))))
n1(child, input, (na, a) => math.sqrt(na.toDouble(a)))
}
}

object BinaryArithmetic {
def unapply(a: BinaryArithmetic): Option[(Expression, Expression)] = Some((a.left, a.right))
}

abstract class BinaryArithmetic extends BinaryExpression {
self: Product =>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
val $primitiveTerm: ${termForType(dataType)} = $value
""".children

case Cast(e @ BinaryType(), StringType) =>
case Cast(BinaryType(e), StringType) =>
val eval = expressionEvaluator(e)
eval.code ++
q"""
Expand All @@ -247,16 +247,16 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
new String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]])
""".children

case Cast(child @ NumericType(), IntegerType) =>
case Cast(NumericType(child), IntegerType) =>
child.castOrNull(c => q"$c.toInt", IntegerType)

case Cast(child @ NumericType(), LongType) =>
case Cast(NumericType(child), LongType) =>
child.castOrNull(c => q"$c.toLong", LongType)

case Cast(child @ NumericType(), DoubleType) =>
case Cast(NumericType(child), DoubleType) =>
child.castOrNull(c => q"$c.toDouble", DoubleType)

case Cast(child @ NumericType(), FloatType) =>
case Cast(NumericType(child), FloatType) =>
child.castOrNull(c => q"$c.toFloat", IntegerType)

// Special handling required for timestamps in hive test cases since the toString function
Expand Down Expand Up @@ -301,13 +301,13 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
""".children
*/

case GreaterThan(e1 @ NumericType(), e2 @ NumericType()) =>
case GreaterThan(NumericType(e1), NumericType(e2)) =>
(e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 > $eval2" }
case GreaterThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) =>
case GreaterThanOrEqual(NumericType(e1), NumericType(e2)) =>
(e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 >= $eval2" }
case LessThan(e1 @ NumericType(), e2 @ NumericType()) =>
case LessThan(NumericType(e1), NumericType(e2)) =>
(e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 < $eval2" }
case LessThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) =>
case LessThanOrEqual(NumericType(e1), NumericType(e2)) =>
(e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 <= $eval2" }

case And(e1, e2) =>
Expand Down Expand Up @@ -546,7 +546,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin

protected def getColumn(inputRow: TermName, dataType: DataType, ordinal: Int) = {
dataType match {
case dt @ NativeType() => q"$inputRow.${accessorForType(dt)}($ordinal)"
case NativeType(dt) => q"$inputRow.${accessorForType(dt)}($ordinal)"
case _ => q"$inputRow.apply($ordinal).asInstanceOf[${termForType(dataType)}]"
}
}
Expand All @@ -557,7 +557,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
ordinal: Int,
value: TermName) = {
dataType match {
case dt @ NativeType() => q"$destinationRow.${mutatorForType(dt)}($ordinal, $value)"
case NativeType(dt) => q"$destinationRow.${mutatorForType(dt)}($ordinal, $value)"
case _ => q"$destinationRow.update($ordinal, $value)"
}
}
Expand Down
Loading