Skip to content

Commit 4ce970d

Browse files
nsycahvanhovell
authored andcommitted
[SPARK-18874][SQL] First phase: Deferring the correlated predicate pull up to Optimizer phase
## What changes were proposed in this pull request? Currently Analyzer as part of ResolveSubquery, pulls up the correlated predicates to its originating SubqueryExpression. The subquery plan is then transformed to remove the correlated predicates after they are moved up to the outer plan. In this PR, the task of pulling up correlated predicates is deferred to Optimizer. This is the initial work that will allow us to support the form of correlated subqueries that we don't support today. The design document from nsyca can be found in the following link : [DesignDoc](https://docs.google.com/document/d/1QDZ8JwU63RwGFS6KVF54Rjj9ZJyK33d49ZWbjFBaIgU/edit#) The brief description of code changes (hopefully to aid with code review) can be be found in the following link: [CodeChanges](https://docs.google.com/document/d/18mqjhL9V1An-tNta7aVE13HkALRZ5GZ24AATA-Vqqf0/edit#) ## How was this patch tested? The test case PRs were submitted earlier using. [16337](#16337) [16759](#16759) [16841](#16841) [16915](#16915) [16798](#16798) [16712](#16712) [16710](#16710) [16760](#16760) [16802](#16802) Author: Dilip Biswal <[email protected]> Closes #16954 from dilipbiswal/SPARK-18874.
1 parent f6314ea commit 4ce970d

File tree

13 files changed

+675
-300
lines changed

13 files changed

+675
-300
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 181 additions & 133 deletions
Large diffs are not rendered by default.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 15 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.analysis
2020
import org.apache.spark.sql.AnalysisException
2121
import org.apache.spark.sql.catalyst.expressions._
2222
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
23+
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
2324
import org.apache.spark.sql.catalyst.plans.logical._
2425
import org.apache.spark.sql.types._
2526

@@ -133,10 +134,8 @@ trait CheckAnalysis extends PredicateHelper {
133134
if (conditions.isEmpty && query.output.size != 1) {
134135
failAnalysis(
135136
s"Scalar subquery must return only one column, but got ${query.output.size}")
136-
} else if (conditions.nonEmpty) {
137-
// Collect the columns from the subquery for further checking.
138-
var subqueryColumns = conditions.flatMap(_.references).filter(query.output.contains)
139-
137+
}
138+
else if (conditions.nonEmpty) {
140139
def checkAggregate(agg: Aggregate): Unit = {
141140
// Make sure correlated scalar subqueries contain one row for every outer row by
142141
// enforcing that they are aggregates containing exactly one aggregate expression.
@@ -152,6 +151,9 @@ trait CheckAnalysis extends PredicateHelper {
152151
// SPARK-18504/SPARK-18814: Block cases where GROUP BY columns
153152
// are not part of the correlated columns.
154153
val groupByCols = AttributeSet(agg.groupingExpressions.flatMap(_.references))
154+
// Collect the local references from the correlated predicate in the subquery.
155+
val subqueryColumns = getCorrelatedPredicates(query).flatMap(_.references)
156+
.filterNot(conditions.flatMap(_.references).contains)
155157
val correlatedCols = AttributeSet(subqueryColumns)
156158
val invalidCols = groupByCols -- correlatedCols
157159
// GROUP BY columns must be a subset of columns in the predicates
@@ -167,17 +169,7 @@ trait CheckAnalysis extends PredicateHelper {
167169
// For projects, do the necessary mapping and skip to its child.
168170
def cleanQuery(p: LogicalPlan): LogicalPlan = p match {
169171
case s: SubqueryAlias => cleanQuery(s.child)
170-
case p: Project =>
171-
// SPARK-18814: Map any aliases to their AttributeReference children
172-
// for the checking in the Aggregate operators below this Project.
173-
subqueryColumns = subqueryColumns.map {
174-
xs => p.projectList.collectFirst {
175-
case e @ Alias(child : AttributeReference, _) if e.exprId == xs.exprId =>
176-
child
177-
}.getOrElse(xs)
178-
}
179-
180-
cleanQuery(p.child)
172+
case p: Project => cleanQuery(p.child)
181173
case child => child
182174
}
183175

@@ -211,14 +203,9 @@ trait CheckAnalysis extends PredicateHelper {
211203
s"filter expression '${f.condition.sql}' " +
212204
s"of type ${f.condition.dataType.simpleString} is not a boolean.")
213205

214-
case Filter(condition, _) =>
215-
splitConjunctivePredicates(condition).foreach {
216-
case _: PredicateSubquery | Not(_: PredicateSubquery) =>
217-
case e if PredicateSubquery.hasNullAwarePredicateWithinNot(e) =>
218-
failAnalysis(s"Null-aware predicate sub-queries cannot be used in nested" +
219-
s" conditions: $e")
220-
case e =>
221-
}
206+
case Filter(condition, _) if hasNullAwarePredicateWithinNot(condition) =>
207+
failAnalysis("Null-aware predicate sub-queries cannot be used in nested " +
208+
s"conditions: $condition")
222209

223210
case j @ Join(_, _, _, Some(condition)) if condition.dataType != BooleanType =>
224211
failAnalysis(
@@ -306,8 +293,11 @@ trait CheckAnalysis extends PredicateHelper {
306293
s"Correlated scalar sub-queries can only be used in a Filter/Aggregate/Project: $p")
307294
}
308295

309-
case p if p.expressions.exists(PredicateSubquery.hasPredicateSubquery) =>
310-
failAnalysis(s"Predicate sub-queries can only be used in a Filter: $p")
296+
case p if p.expressions.exists(SubqueryExpression.hasInOrExistsSubquery) =>
297+
p match {
298+
case _: Filter => // Ok
299+
case _ => failAnalysis(s"Predicate sub-queries can only be used in a Filter: $p")
300+
}
311301

312302
case _: Union | _: SetOperation if operator.children.length > 1 =>
313303
def dataTypes(plan: LogicalPlan): Seq[DataType] = plan.output.map(_.dataType)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala

Lines changed: 94 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,28 @@ object TypeCoercion {
108108
case _ => None
109109
}
110110

111+
/**
112+
* This function determines the target type of a comparison operator when one operand
113+
* is a String and the other is not. It also handles when one op is a Date and the
114+
* other is a Timestamp by making the target type to be String.
115+
*/
116+
val findCommonTypeForBinaryComparison: (DataType, DataType) => Option[DataType] = {
117+
// We should cast all relative timestamp/date/string comparison into string comparisons
118+
// This behaves as a user would expect because timestamp strings sort lexicographically.
119+
// i.e. TimeStamp(2013-01-01 00:00 ...) < "2014" = true
120+
case (StringType, DateType) => Some(StringType)
121+
case (DateType, StringType) => Some(StringType)
122+
case (StringType, TimestampType) => Some(StringType)
123+
case (TimestampType, StringType) => Some(StringType)
124+
case (TimestampType, DateType) => Some(StringType)
125+
case (DateType, TimestampType) => Some(StringType)
126+
case (StringType, NullType) => Some(StringType)
127+
case (NullType, StringType) => Some(StringType)
128+
case (l: StringType, r: AtomicType) if r != StringType => Some(r)
129+
case (l: AtomicType, r: StringType) if (l != StringType) => Some(l)
130+
case (l, r) => None
131+
}
132+
111133
/**
112134
* Case 2 type widening (see the classdoc comment above for TypeCoercion).
113135
*
@@ -305,6 +327,14 @@ object TypeCoercion {
305327
* Promotes strings that appear in arithmetic expressions.
306328
*/
307329
object PromoteStrings extends Rule[LogicalPlan] {
330+
private def castExpr(expr: Expression, targetType: DataType): Expression = {
331+
(expr.dataType, targetType) match {
332+
case (NullType, dt) => Literal.create(null, targetType)
333+
case (l, dt) if (l != dt) => Cast(expr, targetType)
334+
case _ => expr
335+
}
336+
}
337+
308338
def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
309339
// Skip nodes who's children have not been resolved yet.
310340
case e if !e.childrenResolved => e
@@ -321,37 +351,10 @@ object TypeCoercion {
321351
case p @ Equality(left @ TimestampType(), right @ StringType()) =>
322352
p.makeCopy(Array(left, Cast(right, TimestampType)))
323353

324-
// We should cast all relative timestamp/date/string comparison into string comparisons
325-
// This behaves as a user would expect because timestamp strings sort lexicographically.
326-
// i.e. TimeStamp(2013-01-01 00:00 ...) < "2014" = true
327-
case p @ BinaryComparison(left @ StringType(), right @ DateType()) =>
328-
p.makeCopy(Array(left, Cast(right, StringType)))
329-
case p @ BinaryComparison(left @ DateType(), right @ StringType()) =>
330-
p.makeCopy(Array(Cast(left, StringType), right))
331-
case p @ BinaryComparison(left @ StringType(), right @ TimestampType()) =>
332-
p.makeCopy(Array(left, Cast(right, StringType)))
333-
case p @ BinaryComparison(left @ TimestampType(), right @ StringType()) =>
334-
p.makeCopy(Array(Cast(left, StringType), right))
335-
336-
// Comparisons between dates and timestamps.
337-
case p @ BinaryComparison(left @ TimestampType(), right @ DateType()) =>
338-
p.makeCopy(Array(Cast(left, StringType), Cast(right, StringType)))
339-
case p @ BinaryComparison(left @ DateType(), right @ TimestampType()) =>
340-
p.makeCopy(Array(Cast(left, StringType), Cast(right, StringType)))
341-
342-
// Checking NullType
343-
case p @ BinaryComparison(left @ StringType(), right @ NullType()) =>
344-
p.makeCopy(Array(left, Literal.create(null, StringType)))
345-
case p @ BinaryComparison(left @ NullType(), right @ StringType()) =>
346-
p.makeCopy(Array(Literal.create(null, StringType), right))
347-
348-
// When compare string with atomic type, case string to that type.
349-
case p @ BinaryComparison(left @ StringType(), right @ AtomicType())
350-
if right.dataType != StringType =>
351-
p.makeCopy(Array(Cast(left, right.dataType), right))
352-
case p @ BinaryComparison(left @ AtomicType(), right @ StringType())
353-
if left.dataType != StringType =>
354-
p.makeCopy(Array(left, Cast(right, left.dataType)))
354+
case p @ BinaryComparison(left, right)
355+
if findCommonTypeForBinaryComparison(left.dataType, right.dataType).isDefined =>
356+
val commonType = findCommonTypeForBinaryComparison(left.dataType, right.dataType).get
357+
p.makeCopy(Array(castExpr(left, commonType), castExpr(right, commonType)))
355358

356359
case Sum(e @ StringType()) => Sum(Cast(e, DoubleType))
357360
case Average(e @ StringType()) => Average(Cast(e, DoubleType))
@@ -365,17 +368,72 @@ object TypeCoercion {
365368
}
366369

367370
/**
368-
* Convert the value and in list expressions to the common operator type
369-
* by looking at all the argument types and finding the closest one that
370-
* all the arguments can be cast to. When no common operator type is found
371-
* the original expression will be returned and an Analysis Exception will
372-
* be raised at type checking phase.
371+
* Handles type coercion for both IN expression with subquery and IN
372+
* expressions without subquery.
373+
* 1. In the first case, find the common type by comparing the left hand side (LHS)
374+
* expression types against corresponding right hand side (RHS) expression derived
375+
* from the subquery expression's plan output. Inject appropriate casts in the
376+
* LHS and RHS side of IN expression.
377+
*
378+
* 2. In the second case, convert the value and in list expressions to the
379+
* common operator type by looking at all the argument types and finding
380+
* the closest one that all the arguments can be cast to. When no common
381+
* operator type is found the original expression will be returned and an
382+
* Analysis Exception will be raised at the type checking phase.
373383
*/
374384
object InConversion extends Rule[LogicalPlan] {
385+
private def flattenExpr(expr: Expression): Seq[Expression] = {
386+
expr match {
387+
// Multi columns in IN clause is represented as a CreateNamedStruct.
388+
// flatten the named struct to get the list of expressions.
389+
case cns: CreateNamedStruct => cns.valExprs
390+
case expr => Seq(expr)
391+
}
392+
}
393+
375394
def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
376395
// Skip nodes who's children have not been resolved yet.
377396
case e if !e.childrenResolved => e
378397

398+
// Handle type casting required between value expression and subquery output
399+
// in IN subquery.
400+
case i @ In(a, Seq(ListQuery(sub, children, exprId)))
401+
if !i.resolved && flattenExpr(a).length == sub.output.length =>
402+
// LHS is the value expression of IN subquery.
403+
val lhs = flattenExpr(a)
404+
405+
// RHS is the subquery output.
406+
val rhs = sub.output
407+
408+
val commonTypes = lhs.zip(rhs).flatMap { case (l, r) =>
409+
findCommonTypeForBinaryComparison(l.dataType, r.dataType)
410+
.orElse(findTightestCommonType(l.dataType, r.dataType))
411+
}
412+
413+
// The number of columns/expressions must match between LHS and RHS of an
414+
// IN subquery expression.
415+
if (commonTypes.length == lhs.length) {
416+
val castedRhs = rhs.zip(commonTypes).map {
417+
case (e, dt) if e.dataType != dt => Alias(Cast(e, dt), e.name)()
418+
case (e, _) => e
419+
}
420+
val castedLhs = lhs.zip(commonTypes).map {
421+
case (e, dt) if e.dataType != dt => Cast(e, dt)
422+
case (e, _) => e
423+
}
424+
425+
// Before constructing the In expression, wrap the multi values in LHS
426+
// in a CreatedNamedStruct.
427+
val newLhs = castedLhs match {
428+
case Seq(lhs) => lhs
429+
case _ => CreateStruct(castedLhs)
430+
}
431+
432+
In(newLhs, Seq(ListQuery(Project(castedRhs, sub), children, exprId)))
433+
} else {
434+
i
435+
}
436+
379437
case i @ In(a, b) if b.exists(_.dataType != a.dataType) =>
380438
findWiderCommonType(i.children.map(_.dataType)) match {
381439
case Some(finalDataType) => i.withNewChildren(i.children.map(Cast(_, finalDataType)))

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -123,19 +123,44 @@ case class Not(child: Expression)
123123
*/
124124
@ExpressionDescription(
125125
usage = "expr1 _FUNC_(expr2, expr3, ...) - Returns true if `expr` equals to any valN.")
126-
case class In(value: Expression, list: Seq[Expression]) extends Predicate
127-
with ImplicitCastInputTypes {
126+
case class In(value: Expression, list: Seq[Expression]) extends Predicate {
128127

129128
require(list != null, "list should not be null")
129+
override def checkInputDataTypes(): TypeCheckResult = {
130+
list match {
131+
case ListQuery(sub, _, _) :: Nil =>
132+
val valExprs = value match {
133+
case cns: CreateNamedStruct => cns.valExprs
134+
case expr => Seq(expr)
135+
}
130136

131-
override def inputTypes: Seq[AbstractDataType] = value.dataType +: list.map(_.dataType)
137+
val mismatchedColumns = valExprs.zip(sub.output).flatMap {
138+
case (l, r) if l.dataType != r.dataType =>
139+
s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})"
140+
case _ => None
141+
}
132142

133-
override def checkInputDataTypes(): TypeCheckResult = {
134-
if (list.exists(l => l.dataType != value.dataType)) {
135-
TypeCheckResult.TypeCheckFailure(
136-
"Arguments must be same type")
137-
} else {
138-
TypeCheckResult.TypeCheckSuccess
143+
if (mismatchedColumns.nonEmpty) {
144+
TypeCheckResult.TypeCheckFailure(
145+
s"""
146+
|The data type of one or more elements in the left hand side of an IN subquery
147+
|is not compatible with the data type of the output of the subquery
148+
|Mismatched columns:
149+
|[${mismatchedColumns.mkString(", ")}]
150+
|Left side:
151+
|[${valExprs.map(_.dataType.catalogString).mkString(", ")}].
152+
|Right side:
153+
|[${sub.output.map(_.dataType.catalogString).mkString(", ")}].
154+
""".stripMargin)
155+
} else {
156+
TypeCheckResult.TypeCheckSuccess
157+
}
158+
case _ =>
159+
if (list.exists(l => l.dataType != value.dataType)) {
160+
TypeCheckResult.TypeCheckFailure("Arguments must be same type")
161+
} else {
162+
TypeCheckResult.TypeCheckSuccess
163+
}
139164
}
140165
}
141166

0 commit comments

Comments
 (0)