@@ -68,36 +68,34 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
6868 case _ => Seq .empty[Attribute ]
6969 }
7070
71+ // Collect aliases from expressions, so we may avoid producing recursive constraints.
72+ private lazy val aliasMap = AttributeMap (
73+ (expressions ++ children.flatMap(_.expressions)).collect {
74+ case a : Alias => (a.toAttribute, a.child)
75+ })
76+
7177 /**
7278 * Infers an additional set of constraints from a given set of equality constraints.
7379 * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an
74- * additional constraint of the form `b = 5`
80+ * additional constraint of the form `b = 5`.
81+ *
82+ * [SPARK-17733] We explicitly prevent producing recursive constraints of the form `a = f(a, b)`
83+ * as they are often useless and can lead to a non-converging set of constraints.
7584 */
7685 private def inferAdditionalConstraints (constraints : Set [Expression ]): Set [Expression ] = {
77- // Collect alias from expressions to avoid producing non-converging set of constraints
78- // for recursive functions.
79- //
80- // Don't apply transform on constraints if the replacement will cause an recursive deduction,
81- // when that happens a non-converging set of constraints will be created and finally throw
82- // OOM Exception.
83- // For more details, refer to https://issues.apache.org/jira/browse/SPARK-17733
84- val aliasMap = AttributeMap ((expressions ++ children.flatMap(_.expressions)).collect {
85- case a : Alias => (a.toAttribute, a.child)
86- })
87-
88- val equalExprSets = generateEqualExpressionSets(constraints, aliasMap)
86+ val constraintClasses = generateEquivalentConstraintClasses(constraints)
8987
9088 var inferredConstraints = Set .empty[Expression ]
9189 constraints.foreach {
9290 case eq @ EqualTo (l : Attribute , r : Attribute ) =>
9391 val candidateConstraints = constraints - eq
9492 inferredConstraints ++= candidateConstraints.map(_ transform {
9593 case a : Attribute if a.semanticEquals(l) &&
96- ! isRecursiveDeduction(r, aliasMap, equalExprSets ) => r
94+ ! isRecursiveDeduction(r, constraintClasses ) => r
9795 })
9896 inferredConstraints ++= candidateConstraints.map(_ transform {
9997 case a : Attribute if a.semanticEquals(r) &&
100- ! isRecursiveDeduction(l, aliasMap, equalExprSets ) => l
98+ ! isRecursiveDeduction(l, constraintClasses ) => l
10199 })
102100 case _ => // No inference
103101 }
@@ -110,58 +108,60 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
110108 * expression sets: (Set(a, b, c), Set(e, f)). This will be used to search all expressions equal
111109 * to an selected attribute.
112110 */
113- private def generateEqualExpressionSets (
114- constraints : Set [Expression ],
115- aliasMap : AttributeMap [Expression ]): Seq [Set [Expression ]] = {
116- var equalExprSets = Seq .empty[Set [Expression ]]
111+ private def generateEquivalentConstraintClasses (
112+ constraints : Set [Expression ]): Seq [Set [Expression ]] = {
113+ var constraintClasses = Seq .empty[Set [Expression ]]
117114 constraints.foreach {
118115 case eq @ EqualTo (l : Attribute , r : Attribute ) =>
119116 // Transform [[Alias]] to its child.
120117 val left = aliasMap.getOrElse(l, l)
121118 val right = aliasMap.getOrElse(r, r)
122- // Get the expression set for equivalence class of expressions .
123- val leftEqualSet = getEqualExprSet (left, equalExprSets).getOrElse( Set .empty[ Expression ] )
124- val rightEqualSet = getEqualExprSet (right, equalExprSets).getOrElse( Set .empty[ Expression ] )
125- if (! leftEqualSet.isEmpty && ! rightEqualSet.isEmpty ) {
119+ // Get the expression set for an equivalence constraint class .
120+ val leftConstraintClass = getConstraintClass (left, constraintClasses )
121+ val rightConstraintClass = getConstraintClass (right, constraintClasses )
122+ if (leftConstraintClass.nonEmpty && rightConstraintClass.nonEmpty ) {
126123 // Combine the two sets.
127- equalExprSets = equalExprSets.diff(leftEqualSet :: rightEqualSet :: Nil ) :+
128- (leftEqualSet ++ rightEqualSet)
129- } else if (! leftEqualSet.isEmpty) { // && rightEqualSet.isEmpty
124+ constraintClasses = constraintClasses
125+ .diff(leftConstraintClass :: rightConstraintClass :: Nil ) :+
126+ (leftConstraintClass ++ rightConstraintClass)
127+ } else if (leftConstraintClass.nonEmpty) { // && rightConstraintClass.isEmpty
130128 // Update equivalence class of `left` expression.
131- equalExprSets = equalExprSets.diff(leftEqualSet :: Nil ) :+ (leftEqualSet + right)
132- } else if (! rightEqualSet.isEmpty) { // && leftEqualSet.isEmpty
129+ constraintClasses = constraintClasses
130+ .diff(leftConstraintClass :: Nil ) :+ (leftConstraintClass + right)
131+ } else if (rightConstraintClass.nonEmpty) { // && leftConstraintClass.isEmpty
133132 // Update equivalence class of `right` expression.
134- equalExprSets = equalExprSets.diff(rightEqualSet :: Nil ) :+ (rightEqualSet + left)
135- } else { // leftEqualSet.isEmpty && rightEqualSet.isEmpty
136- // Create new equivalence class since both expression don't present in any classes.
137- equalExprSets = equalExprSets :+ Set (left, right)
133+ constraintClasses = constraintClasses
134+ .diff(rightConstraintClass :: Nil ) :+ (rightConstraintClass + left)
135+ } else { // leftConstraintClass.isEmpty && rightConstraintClass.isEmpty
136+ // Create new equivalence constraint class since neither expression presents
137+ // in any classes.
138+ constraintClasses = constraintClasses :+ Set (left, right)
138139 }
139140 case _ => // Skip
140141 }
141142
142- equalExprSets
143+ constraintClasses
143144 }
144145
145146 /*
146147 * Get all expressions equivalent to the selected expression.
147148 */
148- private def getEqualExprSet (
149+ private def getConstraintClass (
149150 expr : Expression ,
150- equalExprSets : Seq [Set [Expression ]]): Option [ Set [Expression ] ] =
151- equalExprSets.filter (_.contains(expr)).headOption
151+ constraintClasses : Seq [Set [Expression ]]): Set [Expression ] =
152+ constraintClasses.find (_.contains(expr)).getOrElse( Set .empty[ Expression ])
152153
153154 /*
154- * Check whether replace by an [[Attribute]] will cause an recursive deduction. Generally it
155- * has an form like: `a -> f(a, b)`, where `a` and `b` are expressions and `f` is an function.
155+ * Check whether replace by an [[Attribute]] will cause a recursive deduction. Generally it
156+ * has the form like: `a -> f(a, b)`, where `a` and `b` are expressions and `f` is a function.
156157 * Here we first get all expressions equal to `attr` and then check whether at least one of them
157- * is child of the referenced expression.
158+ * is a child of the referenced expression.
158159 */
159160 private def isRecursiveDeduction (
160161 attr : Attribute ,
161- aliasMap : AttributeMap [Expression ],
162- equalExprSets : Seq [Set [Expression ]]): Boolean = {
162+ constraintClasses : Seq [Set [Expression ]]): Boolean = {
163163 val expr = aliasMap.getOrElse(attr, attr)
164- getEqualExprSet (expr, equalExprSets).getOrElse( Set .empty[ Expression ] ).exists { e =>
164+ getConstraintClass (expr, constraintClasses ).exists { e =>
165165 expr.children.exists(_.semanticEquals(e))
166166 }
167167 }
0 commit comments