Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -133,10 +134,8 @@ trait CheckAnalysis extends PredicateHelper {
if (conditions.isEmpty && query.output.size != 1) {
failAnalysis(
s"Scalar subquery must return only one column, but got ${query.output.size}")
} else if (conditions.nonEmpty) {
// Collect the columns from the subquery for further checking.
var subqueryColumns = conditions.flatMap(_.references).filter(query.output.contains)

}
else if (conditions.nonEmpty) {
def checkAggregate(agg: Aggregate): Unit = {
// Make sure correlated scalar subqueries contain one row for every outer row by
// enforcing that they are aggregates containing exactly one aggregate expression.
Expand All @@ -152,6 +151,9 @@ trait CheckAnalysis extends PredicateHelper {
// SPARK-18504/SPARK-18814: Block cases where GROUP BY columns
// are not part of the correlated columns.
val groupByCols = AttributeSet(agg.groupingExpressions.flatMap(_.references))
// Collect the local references from the correlated predicate in the subquery.
val subqueryColumns = getCorrelatedPredicates(query).flatMap(_.references)
.filterNot(conditions.flatMap(_.references).contains)
val correlatedCols = AttributeSet(subqueryColumns)
val invalidCols = groupByCols -- correlatedCols
// GROUP BY columns must be a subset of columns in the predicates
Expand All @@ -167,17 +169,7 @@ trait CheckAnalysis extends PredicateHelper {
// For projects, do the necessary mapping and skip to its child.
def cleanQuery(p: LogicalPlan): LogicalPlan = p match {
case s: SubqueryAlias => cleanQuery(s.child)
case p: Project =>
// SPARK-18814: Map any aliases to their AttributeReference children
// for the checking in the Aggregate operators below this Project.
subqueryColumns = subqueryColumns.map {
xs => p.projectList.collectFirst {
case e @ Alias(child : AttributeReference, _) if e.exprId == xs.exprId =>
child
}.getOrElse(xs)
}

cleanQuery(p.child)
case p: Project => cleanQuery(p.child)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now the de-duplication of sub plan happens in optimizer when we actually pull up the correlated predicates. Thus the project case is simplified.

case child => child
}

Expand Down Expand Up @@ -211,14 +203,9 @@ trait CheckAnalysis extends PredicateHelper {
s"filter expression '${f.condition.sql}' " +
s"of type ${f.condition.dataType.simpleString} is not a boolean.")

case Filter(condition, _) =>
splitConjunctivePredicates(condition).foreach {
case _: PredicateSubquery | Not(_: PredicateSubquery) =>
case e if PredicateSubquery.hasNullAwarePredicateWithinNot(e) =>
failAnalysis(s"Null-aware predicate sub-queries cannot be used in nested" +
s" conditions: $e")
case e =>
}
case Filter(condition, _) if hasNullAwarePredicateWithinNot(condition) =>
failAnalysis("Null-aware predicate sub-queries cannot be used in nested " +
s"conditions: $condition")

case j @ Join(_, _, _, Some(condition)) if condition.dataType != BooleanType =>
failAnalysis(
Expand Down Expand Up @@ -306,8 +293,11 @@ trait CheckAnalysis extends PredicateHelper {
s"Correlated scalar sub-queries can only be used in a Filter/Aggregate/Project: $p")
}

case p if p.expressions.exists(PredicateSubquery.hasPredicateSubquery) =>
failAnalysis(s"Predicate sub-queries can only be used in a Filter: $p")
case p if p.expressions.exists(SubqueryExpression.hasInOrExistsSubquery) =>
p match {
case _: Filter => // Ok
case _ => failAnalysis(s"Predicate sub-queries can only be used in a Filter: $p")
}

case _: Union | _: SetOperation if operator.children.length > 1 =>
def dataTypes(plan: LogicalPlan): Seq[DataType] = plan.output.map(_.dataType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,28 @@ object TypeCoercion {
case _ => None
}

/**
* This function determines the target type of a comparison operator when one operand
* is a String and the other is not. It also handles when one op is a Date and the
* other is a Timestamp by making the target type to be String.
*/
val findCommonTypeForBinaryComparison: (DataType, DataType) => Option[DataType] = {
// We should cast all relative timestamp/date/string comparison into string comparisons
// This behaves as a user would expect because timestamp strings sort lexicographically.
// i.e. TimeStamp(2013-01-01 00:00 ...) < "2014" = true
case (StringType, DateType) => Some(StringType)
case (DateType, StringType) => Some(StringType)
case (StringType, TimestampType) => Some(StringType)
case (TimestampType, StringType) => Some(StringType)
case (TimestampType, DateType) => Some(StringType)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems weird. Is this also the current behavior?

Copy link
Contributor Author

@dilipbiswal dilipbiswal Feb 28, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hvanhovell Here is where i saw the code that handles the promotion between date and timestamp types. code

Please let me know if i missed something here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we factor that code out then? Now we have the same logic in two places.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hvanhovell Thanks!!. I had tried to do this before as well as this came up during the internal review. I have made another try. Please let me know what you think.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok please do this in a follow-up

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

O wait, you have changed it.

case (DateType, TimestampType) => Some(StringType)
case (StringType, NullType) => Some(StringType)
case (NullType, StringType) => Some(StringType)
case (l: StringType, r: AtomicType) if r != StringType => Some(r)
case (l: AtomicType, r: StringType) if (l != StringType) => Some(l)
case (l, r) => None
}

/**
* Case 2 type widening (see the classdoc comment above for TypeCoercion).
*
Expand Down Expand Up @@ -305,6 +327,14 @@ object TypeCoercion {
* Promotes strings that appear in arithmetic expressions.
*/
object PromoteStrings extends Rule[LogicalPlan] {
private def castExpr(expr: Expression, targetType: DataType): Expression = {
(expr.dataType, targetType) match {
case (NullType, dt) => Literal.create(null, targetType)
case (l, dt) if (l != dt) => Cast(expr, targetType)
case _ => expr
}
}

def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
Expand All @@ -321,37 +351,10 @@ object TypeCoercion {
case p @ Equality(left @ TimestampType(), right @ StringType()) =>
p.makeCopy(Array(left, Cast(right, TimestampType)))

// We should cast all relative timestamp/date/string comparison into string comparisons
// This behaves as a user would expect because timestamp strings sort lexicographically.
// i.e. TimeStamp(2013-01-01 00:00 ...) < "2014" = true
case p @ BinaryComparison(left @ StringType(), right @ DateType()) =>
p.makeCopy(Array(left, Cast(right, StringType)))
case p @ BinaryComparison(left @ DateType(), right @ StringType()) =>
p.makeCopy(Array(Cast(left, StringType), right))
case p @ BinaryComparison(left @ StringType(), right @ TimestampType()) =>
p.makeCopy(Array(left, Cast(right, StringType)))
case p @ BinaryComparison(left @ TimestampType(), right @ StringType()) =>
p.makeCopy(Array(Cast(left, StringType), right))

// Comparisons between dates and timestamps.
case p @ BinaryComparison(left @ TimestampType(), right @ DateType()) =>
p.makeCopy(Array(Cast(left, StringType), Cast(right, StringType)))
case p @ BinaryComparison(left @ DateType(), right @ TimestampType()) =>
p.makeCopy(Array(Cast(left, StringType), Cast(right, StringType)))

// Checking NullType
case p @ BinaryComparison(left @ StringType(), right @ NullType()) =>
p.makeCopy(Array(left, Literal.create(null, StringType)))
case p @ BinaryComparison(left @ NullType(), right @ StringType()) =>
p.makeCopy(Array(Literal.create(null, StringType), right))

// When compare string with atomic type, case string to that type.
case p @ BinaryComparison(left @ StringType(), right @ AtomicType())
if right.dataType != StringType =>
p.makeCopy(Array(Cast(left, right.dataType), right))
case p @ BinaryComparison(left @ AtomicType(), right @ StringType())
if left.dataType != StringType =>
p.makeCopy(Array(left, Cast(right, left.dataType)))
case p @ BinaryComparison(left, right)
if findCommonTypeForBinaryComparison(left.dataType, right.dataType).isDefined =>
val commonType = findCommonTypeForBinaryComparison(left.dataType, right.dataType).get
p.makeCopy(Array(castExpr(left, commonType), castExpr(right, commonType)))

case Sum(e @ StringType()) => Sum(Cast(e, DoubleType))
case Average(e @ StringType()) => Average(Cast(e, DoubleType))
Expand All @@ -365,17 +368,72 @@ object TypeCoercion {
}

/**
* Convert the value and in list expressions to the common operator type
* by looking at all the argument types and finding the closest one that
* all the arguments can be cast to. When no common operator type is found
* the original expression will be returned and an Analysis Exception will
* be raised at type checking phase.
* Handles type coercion for both IN expression with subquery and IN
* expressions without subquery.
* 1. In the first case, find the common type by comparing the left hand side (LHS)
* expression types against corresponding right hand side (RHS) expression derived
* from the subquery expression's plan output. Inject appropriate casts in the
* LHS and RHS side of IN expression.
*
* 2. In the second case, convert the value and in list expressions to the
* common operator type by looking at all the argument types and finding
* the closest one that all the arguments can be cast to. When no common
* operator type is found the original expression will be returned and an
* Analysis Exception will be raised at the type checking phase.
*/
object InConversion extends Rule[LogicalPlan] {
private def flattenExpr(expr: Expression): Seq[Expression] = {
expr match {
// Multi columns in IN clause is represented as a CreateNamedStruct.
// flatten the named struct to get the list of expressions.
case cns: CreateNamedStruct => cns.valExprs
case expr => Seq(expr)
}
}

def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e

// Handle type casting required between value expression and subquery output
// in IN subquery.
case i @ In(a, Seq(ListQuery(sub, children, exprId)))
if !i.resolved && flattenExpr(a).length == sub.output.length =>
// LHS is the value expression of IN subquery.
val lhs = flattenExpr(a)

// RHS is the subquery output.
val rhs = sub.output

val commonTypes = lhs.zip(rhs).flatMap { case (l, r) =>
findCommonTypeForBinaryComparison(l.dataType, r.dataType)
.orElse(findTightestCommonType(l.dataType, r.dataType))
}

// The number of columns/expressions must match between LHS and RHS of an
// IN subquery expression.
if (commonTypes.length == lhs.length) {
val castedRhs = rhs.zip(commonTypes).map {
case (e, dt) if e.dataType != dt => Alias(Cast(e, dt), e.name)()
case (e, _) => e
}
val castedLhs = lhs.zip(commonTypes).map {
case (e, dt) if e.dataType != dt => Cast(e, dt)
case (e, _) => e
}

// Before constructing the In expression, wrap the multi values in LHS
// in a CreatedNamedStruct.
val newLhs = castedLhs match {
case Seq(lhs) => lhs
case _ => CreateStruct(castedLhs)
}

In(newLhs, Seq(ListQuery(Project(castedRhs, sub), children, exprId)))
} else {
i
}

case i @ In(a, b) if b.exists(_.dataType != a.dataType) =>
findWiderCommonType(i.children.map(_.dataType)) match {
case Some(finalDataType) => i.withNewChildren(i.children.map(Cast(_, finalDataType)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,19 +123,44 @@ case class Not(child: Expression)
*/
@ExpressionDescription(
usage = "expr1 _FUNC_(expr2, expr3, ...) - Returns true if `expr` equals to any valN.")
case class In(value: Expression, list: Seq[Expression]) extends Predicate
with ImplicitCastInputTypes {
case class In(value: Expression, list: Seq[Expression]) extends Predicate {

require(list != null, "list should not be null")
override def checkInputDataTypes(): TypeCheckResult = {
list match {
case ListQuery(sub, _, _) :: Nil =>
val valExprs = value match {
case cns: CreateNamedStruct => cns.valExprs
case expr => Seq(expr)
}

override def inputTypes: Seq[AbstractDataType] = value.dataType +: list.map(_.dataType)
val mismatchedColumns = valExprs.zip(sub.output).flatMap {
case (l, r) if l.dataType != r.dataType =>
s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})"
case _ => None
}

override def checkInputDataTypes(): TypeCheckResult = {
if (list.exists(l => l.dataType != value.dataType)) {
TypeCheckResult.TypeCheckFailure(
"Arguments must be same type")
} else {
TypeCheckResult.TypeCheckSuccess
if (mismatchedColumns.nonEmpty) {
TypeCheckResult.TypeCheckFailure(
s"""
|The data type of one or more elements in the left hand side of an IN subquery
|is not compatible with the data type of the output of the subquery
|Mismatched columns:
|[${mismatchedColumns.mkString(", ")}]
|Left side:
|[${valExprs.map(_.dataType.catalogString).mkString(", ")}].
|Right side:
|[${sub.output.map(_.dataType.catalogString).mkString(", ")}].
""".stripMargin)
} else {
TypeCheckResult.TypeCheckSuccess
}
case _ =>
if (list.exists(l => l.dataType != value.dataType)) {
TypeCheckResult.TypeCheckFailure("Arguments must be same type")
} else {
TypeCheckResult.TypeCheckSuccess
}
}
}

Expand Down
Loading