1717
1818package org .apache .spark .sql .execution
1919
20+ import scala .collection .mutable
21+ import scala .collection .mutable .ArrayBuffer
22+
2023import org .apache .spark .sql .SparkSession
21- import org .apache .spark .sql .catalyst .expressions
22- import org .apache .spark .sql .catalyst .InternalRow
23- import org .apache .spark .sql .catalyst .expressions .{Expression , ExprId , Literal , SubqueryExpression }
24+ import org .apache .spark .sql .catalyst .{expressions , InternalRow }
25+ import org .apache .spark .sql .catalyst .expressions ._
2426import org .apache .spark .sql .catalyst .expressions .codegen .{CodegenContext , ExprCode }
2527import org .apache .spark .sql .catalyst .plans .logical .LogicalPlan
2628import org .apache .spark .sql .catalyst .rules .Rule
27- import org .apache .spark .sql .types .DataType
29+ import org .apache .spark .sql .internal .SQLConf
30+ import org .apache .spark .sql .types .{BooleanType , DataType , StructType }
31+
32+ /**
33+ * The base class for subquery that is used in SparkPlan.
34+ */
35+ trait ExecSubqueryExpression extends SubqueryExpression {
36+
37+ val executedPlan : SubqueryExec
38+ def withExecutedPlan (plan : SubqueryExec ): ExecSubqueryExpression
39+
40+ // does not have logical plan
41+ override def query : LogicalPlan = throw new UnsupportedOperationException
42+ override def withNewPlan (plan : LogicalPlan ): SubqueryExpression =
43+ throw new UnsupportedOperationException
44+
45+ override def plan : SparkPlan = executedPlan
46+
47+ /**
48+ * Fill the expression with collected result from executed plan.
49+ */
50+ def updateResult (): Unit
51+ }
2852
2953/**
3054 * A subquery that will return only one row and one column.
3155 *
3256 * This is the physical copy of ScalarSubquery to be used inside SparkPlan.
3357 */
3458case class ScalarSubquery (
35- executedPlan : SparkPlan ,
59+ executedPlan : SubqueryExec ,
3660 exprId : ExprId )
37- extends SubqueryExpression {
38-
39- override def query : LogicalPlan = throw new UnsupportedOperationException
40- override def withNewPlan (plan : LogicalPlan ): SubqueryExpression = {
41- throw new UnsupportedOperationException
42- }
43- override def plan : SparkPlan = SubqueryExec (simpleString, executedPlan)
61+ extends ExecSubqueryExpression {
4462
4563 override def dataType : DataType = executedPlan.schema.fields.head.dataType
4664 override def children : Seq [Expression ] = Nil
4765 override def nullable : Boolean = true
48- override def toString : String = s " subquery# ${exprId.id}"
66+ override def toString : String = executedPlan.simpleString
67+
68+ def withExecutedPlan (plan : SubqueryExec ): ExecSubqueryExpression = copy(executedPlan = plan)
69+
70+ override def semanticEquals (other : Expression ): Boolean = other match {
71+ case s : ScalarSubquery => executedPlan.sameResult(executedPlan)
72+ case _ => false
73+ }
4974
5075 // the first column in first row from `query`.
5176 @ volatile private var result : Any = null
5277 @ volatile private var updated : Boolean = false
5378
54- def updateResult (v : Any ): Unit = {
55- result = v
79+ def updateResult (): Unit = {
80+ val rows = plan.executeCollect()
81+ if (rows.length > 1 ) {
82+ sys.error(s " more than one row returned by a subquery used as an expression: \n ${plan}" )
83+ }
84+ if (rows.length == 1 ) {
85+ assert(rows(0 ).numFields == 1 ,
86+ s " Expects 1 field, but got ${rows(0 ).numFields}; something went wrong in analysis " )
87+ result = rows(0 ).get(0 , dataType)
88+ } else {
89+ // If there is no rows returned, the result should be null.
90+ result = null
91+ }
5692 updated = true
5793 }
5894
@@ -67,6 +103,51 @@ case class ScalarSubquery(
67103 }
68104}
69105
106+ /**
107+ * A subquery that will check the value of `child` whether is in the result of a query or not.
108+ */
109+ case class InSubquery (
110+ child : Expression ,
111+ executedPlan : SubqueryExec ,
112+ exprId : ExprId ,
113+ private var result : Array [Any ] = null ,
114+ private var updated : Boolean = false ) extends ExecSubqueryExpression {
115+
116+ override def dataType : DataType = BooleanType
117+ override def children : Seq [Expression ] = child :: Nil
118+ override def nullable : Boolean = child.nullable
119+ override def toString : String = s " $child IN ${executedPlan.name}"
120+
121+ def withExecutedPlan (plan : SubqueryExec ): ExecSubqueryExpression = copy(executedPlan = plan)
122+
123+ override def semanticEquals (other : Expression ): Boolean = other match {
124+ case in : InSubquery => child.semanticEquals(in.child) &&
125+ executedPlan.sameResult(in.executedPlan)
126+ case _ => false
127+ }
128+
129+ def updateResult (): Unit = {
130+ val rows = plan.executeCollect()
131+ result = rows.map(_.get(0 , child.dataType)).asInstanceOf [Array [Any ]]
132+ updated = true
133+ }
134+
135+ override def eval (input : InternalRow ): Any = {
136+ require(updated, s " $this has not finished " )
137+ val v = child.eval(input)
138+ if (v == null ) {
139+ null
140+ } else {
141+ result.contains(v)
142+ }
143+ }
144+
145+ override def doGenCode (ctx : CodegenContext , ev : ExprCode ): ExprCode = {
146+ require(updated, s " $this has not finished " )
147+ InSet (child, result.toSet).doGenCode(ctx, ev)
148+ }
149+ }
150+
70151/**
71152 * Plans scalar subqueries from that are present in the given [[SparkPlan ]].
72153 */
@@ -75,7 +156,39 @@ case class PlanSubqueries(sparkSession: SparkSession) extends Rule[SparkPlan] {
75156 plan.transformAllExpressions {
76157 case subquery : expressions.ScalarSubquery =>
77158 val executedPlan = new QueryExecution (sparkSession, subquery.plan).executedPlan
78- ScalarSubquery (executedPlan, subquery.exprId)
159+ ScalarSubquery (
160+ SubqueryExec (s " subquery ${subquery.exprId.id}" , executedPlan),
161+ subquery.exprId)
162+ case expressions.PredicateSubquery (plan, Seq (e : Expression ), _, exprId) =>
163+ val executedPlan = new QueryExecution (sparkSession, plan).executedPlan
164+ InSubquery (e, SubqueryExec (s " subquery ${exprId.id}" , executedPlan), exprId)
165+ }
166+ }
167+ }
168+
169+
170+ /**
171+ * Find out duplicated exchanges in the spark plan, then use the same exchange for all the
172+ * references.
173+ */
174+ case class ReuseSubquery (conf : SQLConf ) extends Rule [SparkPlan ] {
175+
176+ def apply (plan : SparkPlan ): SparkPlan = {
177+ if (! conf.exchangeReuseEnabled) {
178+ return plan
179+ }
180+ // Build a hash map using schema of exchanges to avoid O(N*N) sameResult calls.
181+ val subqueries = mutable.HashMap [StructType , ArrayBuffer [SubqueryExec ]]()
182+ plan transformAllExpressions {
183+ case sub : ExecSubqueryExpression =>
184+ val sameSchema = subqueries.getOrElseUpdate(sub.plan.schema, ArrayBuffer [SubqueryExec ]())
185+ val sameResult = sameSchema.find(_.sameResult(sub.plan))
186+ if (sameResult.isDefined) {
187+ sub.withExecutedPlan(sameResult.get)
188+ } else {
189+ sameSchema += sub.executedPlan
190+ sub
191+ }
79192 }
80193 }
81194}
0 commit comments