From 18ec598ee5477e4e77454c244acc7f5fff50a8e1 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 19 Dec 2017 23:18:13 +0800 Subject: [PATCH 1/4] Ensure no global variables in arguments of method split by CodegenContext.splitExpressions() --- .../sql/catalyst/expressions/arithmetic.scala | 18 ++++++++---------- .../expressions/codegen/CodeGenerator.scala | 12 ++++++++++++ .../expressions/conditionalExpressions.scala | 8 +++----- .../catalyst/expressions/nullExpressions.scala | 9 ++++----- .../sql/catalyst/expressions/predicates.scala | 2 +- 5 files changed, 28 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index d3a8cb5804717..8bb14598a6d7b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -602,13 +602,13 @@ case class Least(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - val tmpIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "leastTmpIsNull") + ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) val evals = evalChildren.map(eval => s""" |${eval.code} - |if (!${eval.isNull} && ($tmpIsNull || + |if (!${eval.isNull} && (${ev.isNull} || | ${ctx.genGreater(dataType, ev.value, eval.value)})) { - | $tmpIsNull = false; + | ${ev.isNull} = false; | ${ev.value} = ${eval.value}; |} """.stripMargin @@ -628,10 +628,9 @@ case class Least(children: Seq[Expression]) extends Expression { foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) ev.copy(code = s""" - |$tmpIsNull = true; + |${ev.isNull} = true; |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; |$codes - |final boolean ${ev.isNull} = $tmpIsNull; """.stripMargin) } } @@ -682,13 +681,13 @@ case class Greatest(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - val tmpIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "greatestTmpIsNull") + ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) val evals = evalChildren.map(eval => s""" |${eval.code} - |if (!${eval.isNull} && ($tmpIsNull || + |if (!${eval.isNull} && (${ev.isNull} || | ${ctx.genGreater(dataType, eval.value, ev.value)})) { - | $tmpIsNull = false; + | ${ev.isNull} = false; | ${ev.value} = ${eval.value}; |} """.stripMargin @@ -708,10 +707,9 @@ case class Greatest(children: Seq[Expression]) extends Expression { foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) ev.copy(code = s""" - |$tmpIsNull = true; + |${ev.isNull} = true; |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; |$codes - |final boolean ${ev.isNull} = $tmpIsNull; """.stripMargin) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 41a920ba3d677..473df7d10cd85 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -930,6 +930,18 @@ class CodegenContext { // inline execution if only one block blocks.head } else { + if (Utils.isTesting) { + // Passing global variables to the split method is dangerous, as any mutating to it is + // ignored and may lead to unexpected behavior. + // We don't need to check `arrayCompactedMutableStates` here, as it results to array access + // code and will raise compile error if we use it in parameter list. + val mutableStateNames = inlinedMutableStates.map(_._2).toSet + arguments.foreach { case (_, name) => + assert(!mutableStateNames.contains(name), + s"split function argument $name cannot be a global variable.") + } + } + val func = freshName(funcName) val argString = arguments.map { case (t, name) => s"$t $name" }.mkString(", ") val functions = blocks.zipWithIndex.map { case (body, i) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 1a9b68222a7f4..142dfb02be0a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -190,7 +190,7 @@ case class CaseWhen( // It is initialized to `NOT_MATCHED`, and if it's set to `HAS_NULL` or `HAS_NONNULL`, // We won't go on anymore on the computation. val resultState = ctx.freshName("caseWhenResultState") - val tmpResult = ctx.addMutableState(ctx.javaType(dataType), "caseWhenTmpResult") + ev.value = ctx.addMutableState(ctx.javaType(dataType), ev.value) // these blocks are meant to be inside a // do { @@ -205,7 +205,7 @@ case class CaseWhen( |if (!${cond.isNull} && ${cond.value}) { | ${res.code} | $resultState = (byte)(${res.isNull} ? $HAS_NULL : $HAS_NONNULL); - | $tmpResult = ${res.value}; + | ${ev.value} = ${res.value}; | continue; |} """.stripMargin @@ -216,7 +216,7 @@ case class CaseWhen( s""" |${res.code} |$resultState = (byte)(${res.isNull} ? $HAS_NULL : $HAS_NONNULL); - |$tmpResult = ${res.value}; + |${ev.value} = ${res.value}; """.stripMargin } @@ -264,13 +264,11 @@ case class CaseWhen( ev.copy(code = s""" |${ctx.JAVA_BYTE} $resultState = $NOT_MATCHED; - |$tmpResult = ${ctx.defaultValue(dataType)}; |do { | $codes |} while (false); |// TRUE if any condition is met and the result is null, or no any condition is met. |final boolean ${ev.isNull} = ($resultState != $HAS_NONNULL); - |final ${ctx.javaType(dataType)} ${ev.value} = $tmpResult; """.stripMargin) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index b4f895fffda38..470d5da041ea5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -72,7 +72,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val tmpIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "coalesceTmpIsNull") + ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) // all the evals are meant to be in a do { ... } while (false); loop val evals = children.map { e => @@ -80,7 +80,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { s""" |${eval.code} |if (!${eval.isNull}) { - | $tmpIsNull = false; + | ${ev.isNull} = false; | ${ev.value} = ${eval.value}; | continue; |} @@ -103,7 +103,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { foldFunctions = _.map { funcCall => s""" |${ev.value} = $funcCall; - |if (!$tmpIsNull) { + |if (!${ev.isNull}) { | continue; |} """.stripMargin @@ -112,12 +112,11 @@ case class Coalesce(children: Seq[Expression]) extends Expression { ev.copy(code = s""" - |$tmpIsNull = true; + |${ev.isNull} = true; |$resultType ${ev.value} = ${ctx.defaultValue(dataType)}; |do { | $codes |} while (false); - |final boolean ${ev.isNull} = $tmpIsNull; """.stripMargin) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index ac9f56f78eb2e..f4ee3d10f3f43 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -285,7 +285,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { |${valueGen.code} |byte $tmpResult = $HAS_NULL; |if (!${valueGen.isNull}) { - | $tmpResult = 0; + | $tmpResult = $NOT_MATCHED; | $javaDataType $valueArg = ${valueGen.value}; | do { | $codes From f1afb9245438e5fbeb1b36c6e4cd5e065d7d2b82 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 20 Dec 2017 01:21:49 +0800 Subject: [PATCH 2/4] fix stack --- .../spark/sql/catalyst/expressions/generators.scala | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 1cd73a92a8635..3f00d204a03ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -201,6 +201,11 @@ case class Stack(children: Seq[Expression]) extends Generator { // Rows - we write these into an array. val rowData = ctx.addMutableState("InternalRow[]", "rows", v => s"$v = new InternalRow[$numRows];") + // Create the collection. + val wrapperClass = classOf[mutable.WrappedArray[_]].getName + ev.value = ctx.addMutableState(s"$wrapperClass", ev.value, + v => s"$v = $wrapperClass$$.MODULE$$.make($rowData);") + val values = children.tail val dataTypes = values.take(numFields).map(_.dataType) val code = ctx.splitExpressionsWithCurrentInputs(Seq.tabulate(numRows) { row => @@ -212,12 +217,6 @@ case class Stack(children: Seq[Expression]) extends Generator { s"${eval.code}\n$rowData[$row] = ${eval.value};" }) - // Create the collection. - val wrapperClass = classOf[mutable.WrappedArray[_]].getName - ctx.addMutableState( - s"$wrapperClass", - ev.value, - v => s"$v = $wrapperClass$$.MODULE$$.make($rowData);", useFreshName = false) ev.copy(code = code, isNull = "false") } } From 3d44195f48c1688d7dc5b87fd0c9f07c1535000b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 20 Dec 2017 12:29:39 +0800 Subject: [PATCH 3/4] address comments --- .../expressions/codegen/CodeGenerator.scala | 38 +++++++++++-------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 473df7d10cd85..d594f03405dfc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -128,7 +128,7 @@ class CodegenContext { * `currentVars` to null, or set `currentVars(i)` to null for certain columns, before calling * `Expression.genCode`. */ - final var INPUT_ROW = "i" + var INPUT_ROW = "i" /** * Holding a list of generated columns as input of current operator, will be used by @@ -146,22 +146,30 @@ class CodegenContext { * as a member variable * * They will be kept as member variables in generated classes like `SpecificProjection`. + * + * Exposed for tests only. */ - val inlinedMutableStates: mutable.ArrayBuffer[(String, String)] = + private[catalyst] val inlinedMutableStates: mutable.ArrayBuffer[(String, String)] = mutable.ArrayBuffer.empty[(String, String)] /** * The mapping between mutable state types and corrseponding compacted arrays. * The keys are java type string. The values are [[MutableStateArrays]] which encapsulates * the compacted arrays for the mutable states with the same java type. + * + * Exposed for tests only. */ - val arrayCompactedMutableStates: mutable.Map[String, MutableStateArrays] = + private[catalyst] val arrayCompactedMutableStates: mutable.Map[String, MutableStateArrays] = mutable.Map.empty[String, MutableStateArrays] // An array holds the code that will initialize each state - val mutableStateInitCode: mutable.ArrayBuffer[String] = + // Exposed for tests only. + private[catalyst] val mutableStateInitCode: mutable.ArrayBuffer[String] = mutable.ArrayBuffer.empty[String] + // Tracks the names of all the mutable states. + private val mutableStateNames: mutable.HashSet[String] = mutable.HashSet.empty + /** * This class holds a set of names of mutableStateArrays that is used for compacting mutable * states for a certain type, and holds the next available slot of the current compacted array. @@ -172,7 +180,11 @@ class CodegenContext { private[this] var currentIndex = 0 - private def createNewArray() = arrayNames.append(freshName("mutableStateArray")) + private def createNewArray() = { + val newArrayName = freshName("mutableStateArray") + mutableStateNames += newArrayName + arrayNames.append(newArrayName) + } def getCurrentIndex: Int = currentIndex @@ -241,6 +253,7 @@ class CodegenContext { val initCode = initFunc(varName) inlinedMutableStates += ((javaType, varName)) mutableStateInitCode += initCode + mutableStateNames += varName varName } else { val arrays = arrayCompactedMutableStates.getOrElseUpdate(javaType, new MutableStateArrays) @@ -930,16 +943,11 @@ class CodegenContext { // inline execution if only one block blocks.head } else { - if (Utils.isTesting) { - // Passing global variables to the split method is dangerous, as any mutating to it is - // ignored and may lead to unexpected behavior. - // We don't need to check `arrayCompactedMutableStates` here, as it results to array access - // code and will raise compile error if we use it in parameter list. - val mutableStateNames = inlinedMutableStates.map(_._2).toSet - arguments.foreach { case (_, name) => - assert(!mutableStateNames.contains(name), - s"split function argument $name cannot be a global variable.") - } + // Passing global variables to the split method is dangerous, as any mutating to it is + // ignored and may lead to unexpected behavior. + arguments.foreach { case (_, name) => + assert(!mutableStateNames.contains(name), + s"[BUG] split function argument $name cannot be a global variable.") } val func = freshName(funcName) From 900f246579f86f187b61454df6b0f76c8d5052c8 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 21 Dec 2017 15:11:53 +0800 Subject: [PATCH 4/4] only check in test --- .../catalyst/expressions/codegen/CodeGenerator.scala | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index d594f03405dfc..9adf632ddcde8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -943,11 +943,13 @@ class CodegenContext { // inline execution if only one block blocks.head } else { - // Passing global variables to the split method is dangerous, as any mutating to it is - // ignored and may lead to unexpected behavior. - arguments.foreach { case (_, name) => - assert(!mutableStateNames.contains(name), - s"[BUG] split function argument $name cannot be a global variable.") + if (Utils.isTesting) { + // Passing global variables to the split method is dangerous, as any mutating to it is + // ignored and may lead to unexpected behavior. + arguments.foreach { case (_, name) => + assert(!mutableStateNames.contains(name), + s"split function argument $name cannot be a global variable.") + } } val func = freshName(funcName)