Skip to content

Commit b098b48

Browse files
nsycahvanhovell
authored andcommitted
[SPARK-18582][SQL] Whitelist LogicalPlan operators allowed in correlated subqueries
## What changes were proposed in this pull request? This fix puts an explicit list of operators that Spark supports for correlated subqueries. ## How was this patch tested? Run sql/test, catalyst/test and add a new test case on Generate. Author: Nattavut Sutyanyong <[email protected]> Closes #16046 from nsyca/spark18455.0. (cherry picked from commit 4a3c096) Signed-off-by: Herman van Hovell <[email protected]>
1 parent 28ea432 commit b098b48

File tree

4 files changed

+129
-53
lines changed

4 files changed

+129
-53
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 108 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -952,24 +952,24 @@ class Analyzer(
952952
private def pullOutCorrelatedPredicates(sub: LogicalPlan): (LogicalPlan, Seq[Expression]) = {
953953
val predicateMap = scala.collection.mutable.Map.empty[LogicalPlan, Seq[Expression]]
954954

955-
/** Make sure a plans' subtree does not contain a tagged predicate. */
956-
def failOnOuterReferenceInSubTree(p: LogicalPlan, msg: String): Unit = {
957-
if (p.collect(predicateMap).nonEmpty) {
958-
failAnalysis(s"Accessing outer query column is not allowed in $msg: $p")
955+
// Make sure a plan's subtree does not contain outer references
956+
def failOnOuterReferenceInSubTree(p: LogicalPlan): Unit = {
957+
if (p.collectFirst(predicateMap).nonEmpty) {
958+
failAnalysis(s"Accessing outer query column is not allowed in:\n$p")
959959
}
960960
}
961961

962-
/** Helper function for locating outer references. */
962+
// Helper function for locating outer references.
963963
def containsOuter(e: Expression): Boolean = {
964964
e.find(_.isInstanceOf[OuterReference]).isDefined
965965
}
966966

967-
/** Make sure a plans' expressions do not contain a tagged predicate. */
967+
// Make sure a plan's expressions do not contain outer references
968968
def failOnOuterReference(p: LogicalPlan): Unit = {
969969
if (p.expressions.exists(containsOuter)) {
970970
failAnalysis(
971971
"Expressions referencing the outer query are not supported outside of WHERE/HAVING " +
972-
s"clauses: $p")
972+
s"clauses:\n$p")
973973
}
974974
}
975975

@@ -1018,10 +1018,51 @@ class Analyzer(
10181018

10191019
// Simplify the predicates before pulling them out.
10201020
val transformed = BooleanSimplification(sub) transformUp {
1021-
// WARNING:
1022-
// Only Filter can host correlated expressions at this time
1023-
// Anyone adding a new "case" below needs to add the call to
1024-
// "failOnOuterReference" to disallow correlated expressions in it.
1021+
1022+
// Whitelist operators allowed in a correlated subquery
1023+
// There are 4 categories:
1024+
// 1. Operators that are allowed anywhere in a correlated subquery, and,
1025+
// by definition of the operators, they either do not contain
1026+
// any columns or cannot host outer references.
1027+
// 2. Operators that are allowed anywhere in a correlated subquery
1028+
// so long as they do not host outer references.
1029+
// 3. Operators that need special handlings. These operators are
1030+
// Project, Filter, Join, Aggregate, and Generate.
1031+
//
1032+
// Any operators that are not in the above list are allowed
1033+
// in a correlated subquery only if they are not on a correlation path.
1034+
// In other word, these operators are allowed only under a correlation point.
1035+
//
1036+
// A correlation path is defined as the sub-tree of all the operators that
1037+
// are on the path from the operator hosting the correlated expressions
1038+
// up to the operator producing the correlated values.
1039+
1040+
// Category 1:
1041+
// BroadcastHint, Distinct, LeafNode, Repartition, and SubqueryAlias
1042+
case p: BroadcastHint =>
1043+
p
1044+
case p: Distinct =>
1045+
p
1046+
case p: LeafNode =>
1047+
p
1048+
case p: Repartition =>
1049+
p
1050+
case p: SubqueryAlias =>
1051+
p
1052+
1053+
// Category 2:
1054+
// These operators can be anywhere in a correlated subquery.
1055+
// so long as they do not host outer references in the operators.
1056+
case p: Sort =>
1057+
failOnOuterReference(p)
1058+
p
1059+
case p: RedistributeData =>
1060+
failOnOuterReference(p)
1061+
p
1062+
1063+
// Category 3:
1064+
// Filter is one of the two operators allowed to host correlated expressions.
1065+
// The other operator is Join. Filter can be anywhere in a correlated subquery.
10251066
case f @ Filter(cond, child) =>
10261067
// Find all predicates with an outer reference.
10271068
val (correlated, local) = splitConjunctivePredicates(cond).partition(containsOuter)
@@ -1043,14 +1084,24 @@ class Analyzer(
10431084
predicateMap += child -> xs
10441085
child
10451086
}
1087+
1088+
// Project cannot host any correlated expressions
1089+
// but can be anywhere in a correlated subquery.
10461090
case p @ Project(expressions, child) =>
10471091
failOnOuterReference(p)
1092+
10481093
val referencesToAdd = missingReferences(p)
10491094
if (referencesToAdd.nonEmpty) {
10501095
Project(expressions ++ referencesToAdd, child)
10511096
} else {
10521097
p
10531098
}
1099+
1100+
// Aggregate cannot host any correlated expressions
1101+
// It can be on a correlation path if the correlation contains
1102+
// only equality correlated predicates.
1103+
// It cannot be on a correlation path if the correlation has
1104+
// non-equality correlated predicates.
10541105
case a @ Aggregate(grouping, expressions, child) =>
10551106
failOnOuterReference(a)
10561107
failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, a)
@@ -1061,48 +1112,55 @@ class Analyzer(
10611112
} else {
10621113
a
10631114
}
1064-
case w : Window =>
1065-
failOnOuterReference(w)
1066-
failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, w)
1067-
w
1068-
case j @ Join(left, _, RightOuter, _) =>
1069-
failOnOuterReference(j)
1070-
failOnOuterReferenceInSubTree(left, "a RIGHT OUTER JOIN")
1071-
j
1072-
// SPARK-18578: Do not allow any correlated predicate
1073-
// in a Full (Outer) Join operator and its descendants
1074-
case j @ Join(_, _, FullOuter, _) =>
1075-
failOnOuterReferenceInSubTree(j, "a FULL OUTER JOIN")
1076-
j
1077-
case j @ Join(_, right, jt, _) if !jt.isInstanceOf[InnerLike] =>
1078-
failOnOuterReference(j)
1079-
failOnOuterReferenceInSubTree(right, "a LEFT (OUTER) JOIN")
1115+
1116+
// Join can host correlated expressions.
1117+
case j @ Join(left, right, joinType, _) =>
1118+
joinType match {
1119+
// Inner join, like Filter, can be anywhere.
1120+
case _: InnerLike =>
1121+
failOnOuterReference(j)
1122+
1123+
// Left outer join's right operand cannot be on a correlation path.
1124+
// LeftAnti and ExistenceJoin are special cases of LeftOuter.
1125+
// Note that ExistenceJoin cannot be expressed externally in both SQL and DataFrame
1126+
// so it should not show up here in Analysis phase. This is just a safety net.
1127+
//
1128+
// LeftSemi does not allow output from the right operand.
1129+
// Any correlated references in the subplan
1130+
// of the right operand cannot be pulled up.
1131+
case LeftOuter | LeftSemi | LeftAnti | ExistenceJoin(_) =>
1132+
failOnOuterReference(j)
1133+
failOnOuterReferenceInSubTree(right)
1134+
1135+
// Likewise, Right outer join's left operand cannot be on a correlation path.
1136+
case RightOuter =>
1137+
failOnOuterReference(j)
1138+
failOnOuterReferenceInSubTree(left)
1139+
1140+
// Any other join types not explicitly listed above,
1141+
// including Full outer join, are treated as Category 4.
1142+
case _ =>
1143+
failOnOuterReferenceInSubTree(j)
1144+
}
10801145
j
1081-
case u: Union =>
1082-
failOnOuterReferenceInSubTree(u, "a UNION")
1083-
u
1084-
case s: SetOperation =>
1085-
failOnOuterReferenceInSubTree(s.right, "an INTERSECT/EXCEPT")
1086-
s
1087-
case e: Expand =>
1088-
failOnOuterReferenceInSubTree(e, "an EXPAND")
1089-
e
1090-
case l : LocalLimit =>
1091-
failOnOuterReferenceInSubTree(l, "a LIMIT")
1092-
l
1093-
// Since LIMIT <n> is represented as GlobalLimit(<n>, (LocalLimit (<n>, child))
1094-
// and we are walking bottom up, we will fail on LocalLimit before
1095-
// reaching GlobalLimit.
1096-
// The code below is just a safety net.
1097-
case g : GlobalLimit =>
1098-
failOnOuterReferenceInSubTree(g, "a LIMIT")
1099-
g
1100-
case s : Sample =>
1101-
failOnOuterReferenceInSubTree(s, "a TABLESAMPLE")
1102-
s
1103-
case p =>
1146+
1147+
// Generator with join=true, i.e., expressed with
1148+
// LATERAL VIEW [OUTER], similar to inner join,
1149+
// allows to have correlation under it
1150+
// but must not host any outer references.
1151+
// Note:
1152+
// Generator with join=false is treated as Category 4.
1153+
case p @ Generate(generator, true, _, _, _, _) =>
11041154
failOnOuterReference(p)
11051155
p
1156+
1157+
// Category 4: Any other operators not in the above 3 categories
1158+
// cannot be on a correlation path, that is they are allowed only
1159+
// under a correlation point but they and their descendant operators
1160+
// are not allowed to have any correlated expressions.
1161+
case p =>
1162+
failOnOuterReferenceInSubTree(p)
1163+
p
11061164
}
11071165
(transformed, predicateMap.values.flatten.toSeq)
11081166
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -932,7 +932,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
932932
split(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil), left, right)
933933

934934
joinType match {
935-
case _: InnerLike | LeftSemi =>
935+
case _: InnerLike | LeftSemi =>
936936
// push down the single side only join filter for both sides sub queries
937937
val newLeft = leftJoinConditions.
938938
reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,7 @@ class AnalysisErrorSuite extends AnalysisTest {
542542
Filter(EqualTo(OuterReference(a), b), LocalRelation(b)))
543543
),
544544
LocalRelation(a))
545-
assertAnalysisError(plan4, "Accessing outer query column is not allowed in a LIMIT" :: Nil)
545+
assertAnalysisError(plan4, "Accessing outer query column is not allowed in" :: Nil)
546546

547547
val plan5 = Filter(
548548
Exists(
@@ -551,6 +551,6 @@ class AnalysisErrorSuite extends AnalysisTest {
551551
),
552552
LocalRelation(a))
553553
assertAnalysisError(plan5,
554-
"Accessing outer query column is not allowed in a TABLESAMPLE" :: Nil)
554+
"Accessing outer query column is not allowed in" :: Nil)
555555
}
556556
}

sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -789,4 +789,22 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
789789
}
790790
}
791791
}
792+
793+
// Generate operator
794+
test("Correlated subqueries in LATERAL VIEW") {
795+
withTempView("t1", "t2") {
796+
Seq((1, 1), (2, 0)).toDF("c1", "c2").createOrReplaceTempView("t1")
797+
Seq[(Int, Array[Int])]((1, Array(1, 2)), (2, Array(-1, -3)))
798+
.toDF("c1", "arr_c2").createTempView("t2")
799+
checkAnswer(
800+
sql(
801+
"""
802+
| select c2
803+
| from t1
804+
| where exists (select *
805+
| from t2 lateral view explode(arr_c2) q as c2
806+
where t1.c1 = t2.c1)""".stripMargin),
807+
Row(1) :: Row(0) :: Nil)
808+
}
809+
}
792810
}

0 commit comments

Comments
 (0)