diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index d949b8f1d6696..c6143e5e1f45b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -801,12 +801,12 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private[this] def castToByteCode(from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => val wrapper = ctx.freshName("wrapper") - ctx.addMutableState("UTF8String.IntWrapper", wrapper, + val wrapperAccessor = ctx.addMutableState("UTF8String.IntWrapper", wrapper, s"$wrapper = new UTF8String.IntWrapper();") (c, evPrim, evNull) => s""" - if ($c.toByte($wrapper)) { - $evPrim = (byte) $wrapper.value; + if ($c.toByte($wrapperAccessor)) { + $evPrim = (byte) $wrapperAccessor.value; } else { $evNull = true; } @@ -828,12 +828,12 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String ctx: CodegenContext): CastFunction = from match { case StringType => val wrapper = ctx.freshName("wrapper") - ctx.addMutableState("UTF8String.IntWrapper", wrapper, + val wrapperAccessor = ctx.addMutableState("UTF8String.IntWrapper", wrapper, s"$wrapper = new UTF8String.IntWrapper();") (c, evPrim, evNull) => s""" - if ($c.toShort($wrapper)) { - $evPrim = (short) $wrapper.value; + if ($c.toShort($wrapperAccessor)) { + $evPrim = (short) $wrapperAccessor.value; } else { $evNull = true; } @@ -853,12 +853,12 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private[this] def castToIntCode(from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => val wrapper = ctx.freshName("wrapper") - ctx.addMutableState("UTF8String.IntWrapper", wrapper, + val wrapperAccessor = ctx.addMutableState("UTF8String.IntWrapper", wrapper, s"$wrapper = new UTF8String.IntWrapper();") (c, evPrim, evNull) => s""" - if ($c.toInt($wrapper)) { - $evPrim = $wrapper.value; + if ($c.toInt($wrapperAccessor)) { + $evPrim = $wrapperAccessor.value; } else { $evNull = true; } @@ -878,13 +878,13 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private[this] def castToLongCode(from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => val wrapper = ctx.freshName("wrapper") - ctx.addMutableState("UTF8String.LongWrapper", wrapper, + val wrapperAccessor = ctx.addMutableState("UTF8String.LongWrapper", wrapper, s"$wrapper = new UTF8String.LongWrapper();") (c, evPrim, evNull) => s""" - if ($c.toLong($wrapper)) { - $evPrim = $wrapper.value; + if ($c.toLong($wrapperAccessor)) { + $evPrim = $wrapperAccessor.value; } else { $evNull = true; } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index 84027b53dca27..079c271912180 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -67,14 +67,15 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val countTerm = ctx.freshName("count") val partitionMaskTerm = ctx.freshName("partitionMask") - ctx.addMutableState(ctx.JAVA_LONG, countTerm, "") - ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm, "") - ctx.addPartitionInitializationStatement(s"$countTerm = 0L;") - ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;") + val countTermAccessor = ctx.addMutableState(ctx.JAVA_LONG, countTerm, "") + val partitionMaskTermAccessor = ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm, "") + ctx.addPartitionInitializationStatement(s"$countTermAccessor = 0L;") + ctx.addPartitionInitializationStatement( + s"$partitionMaskTermAccessor = ((long) partitionIndex) << 33;") ev.copy(code = s""" - final ${ctx.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm; - $countTerm++;""", isNull = "false") + final ${ctx.javaType(dataType)} ${ev.value} = $partitionMaskTermAccessor + $countTermAccessor; + $countTermAccessor++;""", isNull = "false") } override def prettyName: String = "monotonically_increasing_id" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 527f1670c25e1..1729bd95988b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -991,11 +991,11 @@ case class ScalaUDF( val converterTerm = ctx.freshName("converter") val expressionIdx = ctx.references.size - 1 - ctx.addMutableState(converterClassName, converterTerm, + val converterTermAccessor = ctx.addMutableState(converterClassName, converterTerm, s"$converterTerm = ($converterClassName)$typeConvertersClassName" + s".createToScalaConverter(((${expressionClassName})((($scalaUDFClassName)" + s"references[$expressionIdx]).getChildren().apply($index))).dataType());") - converterTerm + converterTermAccessor } override def doGenCode( @@ -1008,8 +1008,9 @@ case class ScalaUDF( // Generate codes used to convert the returned value of user-defined functions to Catalyst type val catalystConverterTerm = ctx.freshName("catalystConverter") - ctx.addMutableState(converterClassName, catalystConverterTerm, - s"$catalystConverterTerm = ($converterClassName)$typeConvertersClassName" + + val catalystConverterTermAccessor = + ctx.addMutableState(converterClassName, catalystConverterTerm, + s"$catalystConverterTerm = ($converterClassName)$typeConvertersClassName" + s".createToCatalystConverter($scalaUDF.dataType());") val resultTerm = ctx.freshName("result") @@ -1022,7 +1023,7 @@ case class ScalaUDF( val funcClassName = s"scala.Function${children.size}" val funcTerm = ctx.freshName("udf") - ctx.addMutableState(funcClassName, funcTerm, + val funcTermAccessor = ctx.addMutableState(funcClassName, funcTerm, s"$funcTerm = ($funcClassName)$scalaUDF.userDefinedFunc();") // codegen for children expressions @@ -1040,12 +1041,13 @@ case class ScalaUDF( (convert, argTerm) }.unzip - val getFuncResult = s"$funcTerm.apply(${funcArguments.mkString(", ")})" + val getFuncResult = s"$funcTermAccessor.apply(${funcArguments.mkString(", ")})" val callFunc = s""" ${ctx.boxedType(dataType)} $resultTerm = null; try { - $resultTerm = (${ctx.boxedType(dataType)})$catalystConverterTerm.apply($getFuncResult); + $resultTerm = (${ctx.boxedType(dataType)}) $catalystConverterTermAccessor + .apply($getFuncResult); } catch (Exception e) { throw new org.apache.spark.SparkException($scalaUDF.udfErrorMessage(), e); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index 8db7efdbb5dd4..ad429273efcc9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -44,8 +44,9 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val idTerm = ctx.freshName("partitionId") - ctx.addMutableState(ctx.JAVA_INT, idTerm, "") + val idTermAccessor = ctx.addMutableState(ctx.JAVA_INT, idTerm, "") ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;") - ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = $idTerm;", isNull = "false") + ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = $idTermAccessor;", + isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index f9c5ef8439085..36701803c3b0d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -113,8 +113,9 @@ class CodegenContext { val idx = references.length references += obj val clsName = Option(className).getOrElse(obj.getClass.getName) - addMutableState(clsName, term, s"$term = ($clsName) references[$idx];") - term + val termAccessor = addMutableState(clsName, term, s"$term = ($clsName) references[$idx];") + + termAccessor } /** @@ -148,44 +149,150 @@ class CodegenContext { * * They will be kept as member variables in generated classes like `SpecificProjection`. */ - val mutableStates: mutable.ArrayBuffer[(String, String, String)] = - mutable.ArrayBuffer.empty[(String, String, String)] + val mutableStates: mutable.ListBuffer[(String, String, String)] = + mutable.ListBuffer.empty[(String, String, String)] + + // An array keyed by the tuple of mutable states' types and initialization code, holds the + // current max index of the array + var mutableStateArrayIdx: mutable.Map[(String, String), Int] = + mutable.Map.empty[(String, String), Int] + + // An array keyed by the tuple of mutable states' types and initialization code, holds the name + // of the mutableStateArray into which state of the given key will be compacted + var mutableStateArrayNames: mutable.Map[(String, String), String] = + mutable.Map.empty[(String, String), String] + + // An array keyed by the tuple of mutable states' types and initialization code, holds the code + // that will initialize the mutableStateArray when initialized in loops + var mutableStateArrayInitCodes: mutable.Map[(String, String), String] = + mutable.Map.empty[(String, String), String] + + /** + * Adds an instance of globally-accessible mutable state. Mutable state may either be inlined + * as a private member variable to the class, or it may be compacted into arrays of the same + * type and initialization in order to avoid Constant Pool limit errors for both state declaration + * and initialization. + * + * We compact state into arrays when we can anticipate variables of the same type and `initCode` + * may appear numerous times. Variable names with integer suffixes (as given by the `freshName` + * function), that are either simply assigned (null, to the empty/base constructor of the type, or + * having no initialization) or are primitive are workable candidates for array compaction, as + * these variable types are likely to appear numerous times, and can be easily initialized in + * loops. + * + * @param javaType the javaType + * @param variableName the variable name + * @param initCode the initialization code for the variable + * @param inline whether the declaration and initialization code should be inlined rather than + * compacted + * @return the name of the mutable state variable, which is either the original name if the + * variable is inlined to the class, or an array access if the variable is to be stored + * in an array of variables of the same type and initialization. + */ + def addMutableState( + javaType: String, + variableName: String, + initCode: String, + inline: Boolean = false): String = { + if (!inline && + // identifies a 'freshname' style variable with a numerical suffix, and possible + // underscore-delimited prefix. + variableName.matches("[\\w_]+\\d+") && + // identifies a simply-assigned object, or a primitive type + (initCode.matches("(^[\\w_]+\\d+\\s*=\\s*null;|" + + "^[\\w_]+\\d+\\s*=\\s*new\\s*[\\w\\.]+\\(\\);$|" + + "^$)") + || isPrimitiveType(javaType))) { + + // Create an initialization code agnostic to the actual variable name which we can key by + val initCodeKey = initCode.replaceAll(variableName, "*VALUE*") + + if (mutableStateArrayIdx.contains((javaType, initCodeKey))) { + // a mutableStateArray for the given type and initialization has already been declared, + // update the max index of the array and return the array-based alias for the variable + val arrayName = mutableStateArrayNames((javaType, initCodeKey)) + val idx = mutableStateArrayIdx((javaType, initCodeKey)) + 1 + + mutableStateArrayIdx.update((javaType, initCodeKey), idx) + + s"$arrayName[$idx]" + } else { + // no mutableStateArray has been declared yet for the given type and initialization code. + // Create a new name for the array, and add entries keeping track of the new array name, + // its current index, and initialization code + val arrayName = freshName("mutableStateArray") + val qualifiedInitCode = initCode.replaceAll(variableName, s"$arrayName[i]") + mutableStateArrayNames += (javaType, initCodeKey) -> arrayName + mutableStateArrayIdx += (javaType, initCodeKey) -> 0 + mutableStateArrayInitCodes += (javaType, initCodeKey) -> qualifiedInitCode + + s"$arrayName[0]" + } + } else { + // non-primitive and non-simply-assigned state is declared inline to the outer class + mutableStates += Tuple3(javaType, variableName, initCode) - def addMutableState(javaType: String, variableName: String, initCode: String): Unit = { - mutableStates += ((javaType, variableName, initCode)) + variableName + } } + /** * Add buffer variable which stores data coming from an [[InternalRow]]. This methods guarantees * that the variable is safely stored, which is important for (potentially) byte array backed * data types like: UTF8String, ArrayData, MapData & InternalRow. */ def addBufferedState(dataType: DataType, variableName: String, initCode: String): ExprCode = { - val value = freshName(variableName) - addMutableState(javaType(dataType), value, "") + val valueAccessor = addMutableState(javaType(dataType), freshName(variableName), "") val code = dataType match { - case StringType => s"$value = $initCode.clone();" - case _: StructType | _: ArrayType | _: MapType => s"$value = $initCode.copy();" - case _ => s"$value = $initCode;" + case StringType => s"$valueAccessor = $initCode.clone();" + case _: StructType | _: ArrayType | _: MapType => s"$valueAccessor = $initCode.copy();" + case _ => s"$valueAccessor = $initCode;" } - ExprCode(code, "false", value) + ExprCode(code, "false", valueAccessor) } def declareMutableStates(): String = { // It's possible that we add same mutable state twice, e.g. the `mergeExpressions` in // `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones. - mutableStates.distinct.map { case (javaType, variableName, _) => + val inlinedStates = mutableStates.distinct.map { case (javaType, variableName, _) => s"private $javaType $variableName;" - }.mkString("\n") + } + + val arrayStates = mutableStateArrayNames.map { case ((javaType, initCode), arrayName) => + val length = mutableStateArrayIdx((javaType, initCode)) + 1 + if (javaType.matches("^.*\\[\\]$")) { + val baseType = javaType.substring(0, javaType.length - 2) + s"private $javaType[] $arrayName = new $baseType[$length][];" + } else { + s"private $javaType[] $arrayName = new $javaType[$length];" + } + } + + (inlinedStates ++ arrayStates).mkString("\n") } def initMutableStates(): String = { // It's possible that we add same mutable state twice, e.g. the `mergeExpressions` in // `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones. val initCodes = mutableStates.distinct.map(_._3 + "\n") + // array state is initialized in loops + val arrayInitCodes = mutableStateArrayNames.map { case ((javaType, initCode), arrayName) => + val qualifiedInitCode = mutableStateArrayInitCodes((javaType, initCode)) + if (qualifiedInitCode.equals("")) { + "" + } else { + s""" + for (int i = 0; i < $arrayName.length; i++) { + $qualifiedInitCode + } + """ + } + } + // The generated initialization code may exceed 64kb function size limit in JVM if there are too // many mutable states, so split it into multiple functions. - splitExpressions(initCodes, "init", Nil) + splitExpressions(initCodes ++ arrayInitCodes, "init", Nil) } /** @@ -761,7 +868,7 @@ class CodegenContext { * @param arguments the list of (type, name) of the arguments of the split function. * @param returnType the return type of the split function. * @param makeSplitFunction makes split function body, e.g. add preparation or cleanup. - * @param foldFunctions folds the split function calls. + * @param transformFunctions processes the function calls with an additional transformation. */ def splitExpressions( expressions: Seq[String], @@ -769,7 +876,7 @@ class CodegenContext { arguments: Seq[(String, String)], returnType: String = "void", makeSplitFunction: String => String = identity, - foldFunctions: Seq[String] => String = _.mkString("", ";\n", ";")): String = { + transformFunctions: Seq[String] => Seq[String] = _.map(s => s + ";\n")): String = { val blocks = new ArrayBuffer[String]() val blockBuilder = new StringBuilder() for (code <- expressions) { @@ -801,7 +908,10 @@ class CodegenContext { addNewFunction(name, code) } - foldFunctions(functions.map(name => s"$name(${arguments.map(_._2).mkString(", ")})")) + val exprs = transformFunctions(functions.map(name => + s"$name(${arguments.map(_._2).mkString(", ")})")) + + splitExpressions(exprs, funcName, arguments) } } @@ -895,12 +1005,12 @@ class CodegenContext { // 2. Less code. // Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with // at least two nodes) as the cost of doing it is expected to be low. - addMutableState("boolean", isNull, s"$isNull = false;") - addMutableState(javaType(expr.dataType), value, + val isNullAccessor = addMutableState("boolean", isNull, s"$isNull = false;") + val valueAccessor = addMutableState(javaType(expr.dataType), value, s"$value = ${defaultValue(expr.dataType)};") subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);" - val state = SubExprEliminationState(isNull, value) + val state = SubExprEliminationState(isNullAccessor, valueAccessor) e.foreach(subExprEliminationExprs.put(_, state)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index b5429fade53cf..bfb363a00731b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -63,35 +63,35 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP if (e.nullable) { val isNull = s"isNull_$i" val value = s"value_$i" - ctx.addMutableState("boolean", isNull, s"$isNull = true;") - ctx.addMutableState(ctx.javaType(e.dataType), value, + val isNullAccessor = ctx.addMutableState("boolean", isNull, s"$isNull = true;") + val valueAccessor = ctx.addMutableState(ctx.javaType(e.dataType), value, s"$value = ${ctx.defaultValue(e.dataType)};") - s""" + (s""" ${ev.code} - $isNull = ${ev.isNull}; - $value = ${ev.value}; - """ + $isNullAccessor = ${ev.isNull}; + $valueAccessor = ${ev.value}; + """, isNullAccessor, valueAccessor, i) } else { val value = s"value_$i" - ctx.addMutableState(ctx.javaType(e.dataType), value, + val valueAccessor = ctx.addMutableState(ctx.javaType(e.dataType), value, s"$value = ${ctx.defaultValue(e.dataType)};") - s""" + (s""" ${ev.code} - $value = ${ev.value}; - """ + $valueAccessor = ${ev.value}; + """, ev.isNull, valueAccessor, i) } } // Evaluate all the subexpressions. val evalSubexpr = ctx.subexprFunctions.mkString("\n") - val updates = validExpr.zip(index).map { - case (e, i) => - val ev = ExprCode("", s"isNull_$i", s"value_$i") + val updates = validExpr.zip(projectionCodes).map { + case (e, (_, isNullAccessor, valueAccessor, i)) => + val ev = ExprCode("", s"$isNullAccessor", s"$valueAccessor") ctx.updateColumn("mutableRow", e.dataType, i, ev, e.nullable) } - val allProjections = ctx.splitExpressions(ctx.INPUT_ROW, projectionCodes) + val allProjections = ctx.splitExpressions(ctx.INPUT_ROW, projectionCodes.map(_._1)) val allUpdates = ctx.splitExpressions(ctx.INPUT_ROW, updates) val codeBody = s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 1639d1b9dda1f..28ffe15bf2579 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -136,7 +136,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR return 0; """ }, - foldFunctions = { funCalls => + transformFunctions = { funCalls => funCalls.zipWithIndex.map { case (funCall, i) => val comp = ctx.freshName("comp") s""" @@ -145,7 +145,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR return $comp; } """ - }.mkString + } }) // make sure INPUT_ROW is declared even if splitExpressions // returns an inlined block diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 1e4ac3f2afd52..0399ce990141e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -49,7 +49,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val output = ctx.freshName("safeRow") val values = ctx.freshName("values") // These expressions could be split into multiple functions - ctx.addMutableState("Object[]", values, s"$values = null;") + val valuesAccessor = ctx.addMutableState("Object[]", values, s"$values = null;") val rowClass = classOf[GenericInternalRow].getName @@ -58,17 +58,17 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] s""" if (!$tmp.isNullAt($i)) { ${converter.code} - $values[$i] = ${converter.value}; + $valuesAccessor[$i] = ${converter.value}; } """ } val allFields = ctx.splitExpressions(tmp, fieldWriters) val code = s""" final InternalRow $tmp = $input; - $values = new Object[${schema.length}]; + $valuesAccessor = new Object[${schema.length}]; $allFields - final InternalRow $output = new $rowClass($values); - $values = null; + final InternalRow $output = new $rowClass($valuesAccessor); + $valuesAccessor = null; """ ExprCode(code, "false", output) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 4bd50aee05514..3145e918e505b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -81,7 +81,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro isTopLevel: Boolean = false): String = { val rowWriterClass = classOf[UnsafeRowWriter].getName val rowWriter = ctx.freshName("rowWriter") - ctx.addMutableState(rowWriterClass, rowWriter, + val rowWriterAccessor = ctx.addMutableState(rowWriterClass, rowWriter, s"$rowWriter = new $rowWriterClass($bufferHolder, ${inputs.length});") val resetWriter = if (isTopLevel) { @@ -93,10 +93,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // need to clear it out every time. "" } else { - s"$rowWriter.zeroOutNullBytes();" + s"$rowWriterAccessor.zeroOutNullBytes();" } } else { - s"$rowWriter.reset();" + s"$rowWriterAccessor.reset();" } val writeFields = inputs.zip(inputTypes).zipWithIndex.map { @@ -181,7 +181,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro bufferHolder: String): String = { val arrayWriterClass = classOf[UnsafeArrayWriter].getName val arrayWriter = ctx.freshName("arrayWriter") - ctx.addMutableState(arrayWriterClass, arrayWriter, + val arrayWriterAccessor = ctx.addMutableState(arrayWriterClass, arrayWriter, s"$arrayWriter = new $arrayWriterClass();") val numElements = ctx.freshName("numElements") val index = ctx.freshName("index") @@ -206,29 +206,32 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s""" final int $tmpCursor = $bufferHolder.cursor; ${writeStructToBuffer(ctx, element, t.map(_.dataType), bufferHolder)} - $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + $arrayWriterAccessor.setOffsetAndSize($index, + $tmpCursor, $bufferHolder.cursor - $tmpCursor); """ case a @ ArrayType(et, _) => s""" final int $tmpCursor = $bufferHolder.cursor; ${writeArrayToBuffer(ctx, element, et, bufferHolder)} - $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + $arrayWriterAccessor.setOffsetAndSize($index, + $tmpCursor, $bufferHolder.cursor - $tmpCursor); """ case m @ MapType(kt, vt, _) => s""" final int $tmpCursor = $bufferHolder.cursor; ${writeMapToBuffer(ctx, element, kt, vt, bufferHolder)} - $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + $arrayWriterAccessor.setOffsetAndSize($index, + $tmpCursor, $bufferHolder.cursor - $tmpCursor); """ case t: DecimalType => - s"$arrayWriter.write($index, $element, ${t.precision}, ${t.scale});" + s"$arrayWriterAccessor.write($index, $element, ${t.precision}, ${t.scale});" case NullType => "" - case _ => s"$arrayWriter.write($index, $element);" + case _ => s"$arrayWriterAccessor.write($index, $element);" } val primitiveTypeName = if (ctx.isPrimitiveType(jt)) ctx.primitiveTypeName(et) else "" @@ -237,11 +240,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ${writeUnsafeData(ctx, s"((UnsafeArrayData) $input)", bufferHolder)} } else { final int $numElements = $input.numElements(); - $arrayWriter.initialize($bufferHolder, $numElements, $elementOrOffsetSize); + $arrayWriterAccessor.initialize($bufferHolder, $numElements, $elementOrOffsetSize); for (int $index = 0; $index < $numElements; $index++) { if ($input.isNullAt($index)) { - $arrayWriter.setNull$primitiveTypeName($index); + $arrayWriterAccessor.setNull$primitiveTypeName($index); } else { final $jt $element = ${ctx.getValue(input, et, index)}; $writeElement @@ -316,29 +319,31 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } val result = ctx.freshName("result") - ctx.addMutableState("UnsafeRow", result, s"$result = new UnsafeRow(${expressions.length});") + val resultAccessor = ctx.addMutableState("UnsafeRow", + result, s"$result = new UnsafeRow(${expressions.length});") val holder = ctx.freshName("holder") val holderClass = classOf[BufferHolder].getName - ctx.addMutableState(holderClass, holder, - s"$holder = new $holderClass($result, ${numVarLenFields * 32});") + val holderAccessor = ctx.addMutableState(holderClass, holder, + s"$holder = new $holderClass($resultAccessor, ${numVarLenFields * 32});") val resetBufferHolder = if (numVarLenFields == 0) { "" } else { - s"$holder.reset();" + s"$holderAccessor.reset();" } val updateRowSize = if (numVarLenFields == 0) { "" } else { - s"$result.setTotalSize($holder.totalSize());" + s"$resultAccessor.setTotalSize($holderAccessor.totalSize());" } // Evaluate all the subexpression. val evalSubexpr = ctx.subexprFunctions.mkString("\n") val writeExpressions = - writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, holder, isTopLevel = true) + writeExpressionsToBuffer(ctx, + ctx.INPUT_ROW, exprEvals, exprTypes, holderAccessor, isTopLevel = true) val code = s""" @@ -347,7 +352,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro $writeExpressions $updateRowSize """ - ExprCode(code, "false", result) + ExprCode(code, "false", resultAccessor) } protected def canonicalize(in: Seq[Expression]): Seq[Expression] = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 4b6574a31424e..127cac2aaec4d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -94,7 +94,7 @@ private [sql] object GenArrayData { if (!ctx.isPrimitiveType(elementType)) { val genericArrayClass = classOf[GenericArrayData].getName ctx.addMutableState("Object[]", arrayName, - s"$arrayName = new Object[$numElements];") + s"$arrayName = new Object[$numElements];", inline = true) val assignments = elementsCode.zipWithIndex.map { case (eval, i) => val isNullAssignment = if (!isMapKey) { @@ -120,7 +120,7 @@ private [sql] object GenArrayData { UnsafeArrayData.calculateHeaderPortionInBytes(numElements) + ByteArrayMethods.roundNumberOfBytesToNearestWord(elementType.defaultSize * numElements) val baseOffset = Platform.BYTE_ARRAY_OFFSET - ctx.addMutableState("UnsafeArrayData", arrayDataName, "") + ctx.addMutableState("UnsafeArrayData", arrayDataName, "", inline = true) val primitiveValueTypeName = ctx.primitiveTypeName(elementType) val assignments = elementsCode.zipWithIndex.map { case (eval, i) => @@ -350,24 +350,24 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rowClass = classOf[GenericInternalRow].getName val values = ctx.freshName("values") - ctx.addMutableState("Object[]", values, s"$values = null;") + val valuesAccessor = ctx.addMutableState("Object[]", values, s"$values = null;") ev.copy(code = s""" - $values = new Object[${valExprs.size}];""" + + $valuesAccessor = new Object[${valExprs.size}];""" + ctx.splitExpressions( ctx.INPUT_ROW, valExprs.zipWithIndex.map { case (e, i) => val eval = e.genCode(ctx) eval.code + s""" if (${eval.isNull}) { - $values[$i] = null; + $valuesAccessor[$i] = null; } else { - $values[$i] = ${eval.value}; + $valuesAccessor[$i] = ${eval.value}; }""" }) + s""" - final InternalRow ${ev.value} = new $rowClass($values); - $values = null; + final InternalRow ${ev.value} = new $rowClass($valuesAccessor); + $valuesAccessor = null; """, isNull = "false") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index d95b59d5ec423..ea373098b3747 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -118,21 +118,23 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi dataType: DataType, baseFuncName: String): (String, String, String) = { val globalIsNull = ctx.freshName("isNull") - ctx.addMutableState("boolean", globalIsNull, s"$globalIsNull = false;") + val globalIsNullAccessor = + ctx.addMutableState("boolean", globalIsNull, s"$globalIsNull = false;") val globalValue = ctx.freshName("value") - ctx.addMutableState(ctx.javaType(dataType), globalValue, + val globalValueAccessor = + ctx.addMutableState(ctx.javaType(dataType), globalValue, s"$globalValue = ${ctx.defaultValue(dataType)};") val funcName = ctx.freshName(baseFuncName) val funcBody = s""" |private void $funcName(InternalRow ${ctx.INPUT_ROW}) { | ${ev.code.trim} - | $globalIsNull = ${ev.isNull}; - | $globalValue = ${ev.value}; + | $globalIsNullAccessor = ${ev.isNull}; + | $globalValueAccessor = ${ev.value}; |} """.stripMargin val fullFuncName = ctx.addNewFunction(funcName, funcBody) - (fullFuncName, globalIsNull, globalValue) + (fullFuncName, globalIsNullAccessor, globalValueAccessor) } override def toString: String = s"if ($predicate) $trueValue else $falseValue" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index eaf8788888211..4533846cbbc44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -444,10 +444,11 @@ case class DayOfWeek(child: Expression) extends UnaryExpression with ImplicitCas val cal = classOf[Calendar].getName val c = ctx.freshName("cal") val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - ctx.addMutableState(cal, c, s"""$c = $cal.getInstance($dtu.getTimeZone("UTC"));""") + val cAccessor = ctx.addMutableState(cal, c, + s"""$c = $cal.getInstance($dtu.getTimeZone("UTC"));""") s""" - $c.setTimeInMillis($time * 1000L * 3600L * 24L); - ${ev.value} = $c.get($cal.DAY_OF_WEEK); + $cAccessor.setTimeInMillis($time * 1000L * 3600L * 24L); + ${ev.value} = $cAccessor.get($cal.DAY_OF_WEEK); """ }) } @@ -486,15 +487,15 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa val cal = classOf[Calendar].getName val c = ctx.freshName("cal") val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - ctx.addMutableState(cal, c, + val cAccessor = ctx.addMutableState(cal, c, s""" $c = $cal.getInstance($dtu.getTimeZone("UTC")); $c.setFirstDayOfWeek($cal.MONDAY); $c.setMinimalDaysInFirstWeek(4); """) s""" - $c.setTimeInMillis($time * 1000L * 3600L * 24L); - ${ev.value} = $c.get($cal.WEEK_OF_YEAR); + $cAccessor.setTimeInMillis($time * 1000L * 3600L * 24L); + ${ev.value} = $cAccessor.get($cal.WEEK_OF_YEAR); """ }) } @@ -1018,15 +1019,17 @@ case class FromUTCTimestamp(left: Expression, right: Expression) val utcTerm = ctx.freshName("utc") val tzClass = classOf[TimeZone].getName val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - ctx.addMutableState(tzClass, tzTerm, s"""$tzTerm = $dtu.getTimeZone("$tz");""") - ctx.addMutableState(tzClass, utcTerm, s"""$utcTerm = $dtu.getTimeZone("UTC");""") + val tzTermAccessor = + ctx.addMutableState(tzClass, tzTerm, s"""$tzTerm = $dtu.getTimeZone("$tz");""") + val utcTermAccessor = + ctx.addMutableState(tzClass, utcTerm, s"""$utcTerm = $dtu.getTimeZone("UTC");""") val eval = left.genCode(ctx) ev.copy(code = s""" |${eval.code} |boolean ${ev.isNull} = ${eval.isNull}; |long ${ev.value} = 0; |if (!${ev.isNull}) { - | ${ev.value} = $dtu.convertTz(${eval.value}, $utcTerm, $tzTerm); + | ${ev.value} = $dtu.convertTz(${eval.value}, $utcTermAccessor, $tzTermAccessor); |} """.stripMargin) } @@ -1194,15 +1197,17 @@ case class ToUTCTimestamp(left: Expression, right: Expression) val utcTerm = ctx.freshName("utc") val tzClass = classOf[TimeZone].getName val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - ctx.addMutableState(tzClass, tzTerm, s"""$tzTerm = $dtu.getTimeZone("$tz");""") - ctx.addMutableState(tzClass, utcTerm, s"""$utcTerm = $dtu.getTimeZone("UTC");""") + val tzTermAccessor = + ctx.addMutableState(tzClass, tzTerm, s"""$tzTerm = $dtu.getTimeZone("$tz");""") + val utcTermAccessor = + ctx.addMutableState(tzClass, utcTerm, s"""$utcTerm = $dtu.getTimeZone("UTC");""") val eval = left.genCode(ctx) ev.copy(code = s""" |${eval.code} |boolean ${ev.isNull} = ${eval.isNull}; |long ${ev.value} = 0; |if (!${ev.isNull}) { - | ${ev.value} = $dtu.convertTz(${eval.value}, $tzTerm, $utcTerm); + | ${ev.value} = $dtu.convertTz(${eval.value}, $tzTermAccessor, $utcTermAccessor); |} """.stripMargin) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 8618f49086077..744f01d345d5d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -200,7 +200,8 @@ case class Stack(children: Seq[Expression]) extends Generator { override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // Rows - we write these into an array. val rowData = ctx.freshName("rows") - ctx.addMutableState("InternalRow[]", rowData, s"$rowData = new InternalRow[$numRows];") + val rowDataAccessor = + ctx.addMutableState("InternalRow[]", rowData, s"$rowData = new InternalRow[$numRows];") val values = children.tail val dataTypes = values.take(numFields).map(_.dataType) val code = ctx.splitExpressions(ctx.INPUT_ROW, Seq.tabulate(numRows) { row => @@ -209,15 +210,15 @@ case class Stack(children: Seq[Expression]) extends Generator { if (index < values.length) values(index) else Literal(null, dataTypes(col)) } val eval = CreateStruct(fields).genCode(ctx) - s"${eval.code}\n$rowData[$row] = ${eval.value};" + s"${eval.code}\n$rowDataAccessor[$row] = ${eval.value};" }) - // Create the collection. + // Create the collection. Inline to outer class. val wrapperClass = classOf[mutable.WrappedArray[_]].getName ctx.addMutableState( s"$wrapperClass", ev.value, - s"${ev.value} = $wrapperClass$$.MODULE$$.make($rowData);") + s"${ev.value} = $wrapperClass$$.MODULE$$.make($rowDataAccessor);", inline = true) ev.copy(code = code, isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 85a5f7fb2c6c3..c29d328f8c769 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -270,17 +270,19 @@ abstract class HashExpression[E] extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { ev.isNull = "false" + val evValueAccessor = ctx.addMutableState(ctx.javaType(dataType), ev.value, "") + val childrenHash = ctx.splitExpressions(ctx.INPUT_ROW, children.map { child => val childGen = child.genCode(ctx) childGen.code + ctx.nullSafeExec(child.nullable, childGen.isNull) { - computeHash(childGen.value, child.dataType, ev.value, ctx) + computeHash(childGen.value, child.dataType, evValueAccessor, ctx) } }) - ctx.addMutableState(ctx.javaType(dataType), ev.value, "") ev.copy(code = s""" - ${ev.value} = $seed; - $childrenHash""") + $evValueAccessor = $seed; + $childrenHash""", + value = evValueAccessor) } protected def nullSafeElementHash( @@ -607,19 +609,23 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { ev.isNull = "false" val childHash = ctx.freshName("childHash") + + val evValueAccessor = ctx.addMutableState(ctx.javaType(dataType), ev.value, "") + val childHashAccessor = ctx.addMutableState("int", childHash, s"$childHash = 0;") + val childrenHash = ctx.splitExpressions(ctx.INPUT_ROW, children.map { child => val childGen = child.genCode(ctx) childGen.code + ctx.nullSafeExec(child.nullable, childGen.isNull) { - computeHash(childGen.value, child.dataType, childHash, ctx) - } + s"${ev.value} = (31 * ${ev.value}) + $childHash;" + - s"\n$childHash = 0;" + computeHash(childGen.value, child.dataType, childHashAccessor, ctx) + } + s"$evValueAccessor = (31 * $evValueAccessor) + $childHashAccessor;" + + s"\n$childHashAccessor = 0;" }) - ctx.addMutableState(ctx.javaType(dataType), ev.value, "") - ctx.addMutableState("int", childHash, s"$childHash = 0;") + ev.copy(code = s""" - ${ev.value} = $seed; - $childrenHash""") + $evValueAccessor = $seed; + $childrenHash""", + value = evValueAccessor) } override def eval(input: InternalRow = null): Int = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 9b28a18035b1c..f8c342a055633 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -62,15 +62,15 @@ trait InvokeLike extends Expression with NonSQLExpression { val resultIsNull = if (needNullCheck) { val resultIsNull = ctx.freshName("resultIsNull") - ctx.addMutableState("boolean", resultIsNull, "") - resultIsNull + val resultIsNullAccessor = ctx.addMutableState("boolean", resultIsNull, "") + resultIsNullAccessor } else { "false" } val argValues = arguments.map { e => val argValue = ctx.freshName("argValue") - ctx.addMutableState(ctx.javaType(e.dataType), argValue, "") - argValue + val argValueAccessor = ctx.addMutableState(ctx.javaType(e.dataType), argValue, "") + argValueAccessor } val argCodes = if (needNullCheck) { @@ -347,6 +347,9 @@ case class NewInstance( ev.isNull = resultIsNull + val valueAccessor = ctx.addMutableState(javaType, ev.value, + s"${ev.value} = ${ctx.defaultValue(javaType)};") + val constructorCall = outer.map { gen => s"${gen.value}.new ${cls.getSimpleName}($argString)" }.getOrElse { @@ -356,9 +359,9 @@ case class NewInstance( val code = s""" $argCode ${outer.map(_.code).getOrElse("")} - final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(javaType)} : $constructorCall; + $valueAccessor = ${ev.isNull} ? ${ctx.defaultValue(javaType)} : $constructorCall; """ - ev.copy(code = code) + ev.copy(code = code, value = valueAccessor) } override def toString: String = s"newInstance($cls)" @@ -545,7 +548,8 @@ case class MapObjects private( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val elementJavaType = ctx.javaType(loopVarDataType) - ctx.addMutableState(elementJavaType, loopValue, "") + val loopIsNullAccessor = ctx.addMutableState("boolean", loopIsNull, "", inline = true) + val loopValueAccessor = ctx.addMutableState(elementJavaType, loopValue, "", inline = true) val genInputData = inputData.genCode(ctx) val genFunction = lambdaFunction.genCode(ctx) val dataLength = ctx.freshName("dataLength") @@ -616,10 +620,9 @@ case class MapObjects private( } val loopNullCheck = if (loopIsNull != "false") { - ctx.addMutableState("boolean", loopIsNull, "") inputDataType match { - case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);" - case _ => s"$loopIsNull = $loopValue == null;" + case _: ArrayType => s"$loopIsNullAccessor = ${genInputData.value}.isNullAt($loopIndex);" + case _ => s"$loopIsNullAccessor = $loopValueAccessor == null;" } } else { "" @@ -677,7 +680,7 @@ case class MapObjects private( int $loopIndex = 0; while ($loopIndex < $dataLength) { - $loopValue = ($elementJavaType) ($getLoopVar); + $loopValueAccessor = ($elementJavaType) ($getLoopVar); $loopNullCheck ${genFunction.code} @@ -779,10 +782,13 @@ case class CatalystToExternalMap private( val mapType = inputDataType(inputData.dataType).asInstanceOf[MapType] val keyElementJavaType = ctx.javaType(mapType.keyType) - ctx.addMutableState(keyElementJavaType, keyLoopValue, "") + val keyLoopValueAccessor = ctx.addMutableState(keyElementJavaType, keyLoopValue, "", + inline = true) val genKeyFunction = keyLambdaFunction.genCode(ctx) val valueElementJavaType = ctx.javaType(mapType.valueType) - ctx.addMutableState(valueElementJavaType, valueLoopValue, "") + val valueLoopValueAccessor = ctx.addMutableState(valueElementJavaType, valueLoopValue, "", + inline = true) + val valueLoopIsNullAccessor = ctx.addMutableState("boolean", valueLoopIsNull, "", inline = true) val genValueFunction = valueLambdaFunction.genCode(ctx) val genInputData = inputData.genCode(ctx) val dataLength = ctx.freshName("dataLength") @@ -815,8 +821,7 @@ case class CatalystToExternalMap private( val genValueFunctionValue = genFunctionValue(valueLambdaFunction, genValueFunction) val valueLoopNullCheck = if (valueLoopIsNull != "false") { - ctx.addMutableState("boolean", valueLoopIsNull, "") - s"$valueLoopIsNull = $valueArray.isNullAt($loopIndex);" + s"$valueLoopIsNullAccessor = $valueArray.isNullAt($loopIndex);" } else { "" } @@ -853,8 +858,8 @@ case class CatalystToExternalMap private( int $loopIndex = 0; while ($loopIndex < $dataLength) { - $keyLoopValue = ($keyElementJavaType) ($getKeyLoopVar); - $valueLoopValue = ($valueElementJavaType) ($getValueLoopVar); + $keyLoopValueAccessor = ($keyElementJavaType) ($getKeyLoopVar); + $valueLoopValueAccessor = ($valueElementJavaType) ($getValueLoopVar); $valueLoopNullCheck ${genKeyFunction.code} @@ -965,8 +970,8 @@ case class ExternalMapToCatalyst private( val keyElementJavaType = ctx.javaType(keyType) val valueElementJavaType = ctx.javaType(valueType) - ctx.addMutableState(keyElementJavaType, key, "") - ctx.addMutableState(valueElementJavaType, value, "") + val keyAccessor = ctx.addMutableState(keyElementJavaType, key, "") + val valueAccessor = ctx.addMutableState(valueElementJavaType, value, "") val (defineEntries, defineKeyValue) = child.dataType match { case ObjectType(cls) if classOf[java.util.Map[_, _]].isAssignableFrom(cls) => @@ -979,8 +984,8 @@ case class ExternalMapToCatalyst private( val defineKeyValue = s""" final $javaMapEntryCls $entry = ($javaMapEntryCls) $entries.next(); - $key = (${ctx.boxedType(keyType)}) $entry.getKey(); - $value = (${ctx.boxedType(valueType)}) $entry.getValue(); + $keyAccessor = (${ctx.boxedType(keyType)}) $entry.getKey(); + $valueAccessor = (${ctx.boxedType(valueType)}) $entry.getValue(); """ defineEntries -> defineKeyValue @@ -994,23 +999,23 @@ case class ExternalMapToCatalyst private( val defineKeyValue = s""" final $scalaMapEntryCls $entry = ($scalaMapEntryCls) $entries.next(); - $key = (${ctx.boxedType(keyType)}) $entry._1(); - $value = (${ctx.boxedType(valueType)}) $entry._2(); + $keyAccessor = (${ctx.boxedType(keyType)}) $entry._1(); + $valueAccessor = (${ctx.boxedType(valueType)}) $entry._2(); """ defineEntries -> defineKeyValue } val keyNullCheck = if (keyIsNull != "false") { - ctx.addMutableState("boolean", keyIsNull, "") - s"$keyIsNull = $key == null;" + val keyIsNullAccessor = ctx.addMutableState("boolean", keyIsNull, "") + s"$keyIsNullAccessor = $keyAccessor == null;" } else { "" } val valueNullCheck = if (valueIsNull != "false") { - ctx.addMutableState("boolean", valueIsNull, "") - s"$valueIsNull = $value == null;" + val valueIsNullAccessor = ctx.addMutableState("boolean", valueIsNull, "") + s"$valueIsNullAccessor = $valueAccessor == null;" } else { "" } @@ -1077,15 +1082,15 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rowClass = classOf[GenericRowWithSchema].getName val values = ctx.freshName("values") - ctx.addMutableState("Object[]", values, "") + val valuesAccessor = ctx.addMutableState("Object[]", values, "") val childrenCodes = children.zipWithIndex.map { case (e, i) => val eval = e.genCode(ctx) eval.code + s""" if (${eval.isNull}) { - $values[$i] = null; + $valuesAccessor[$i] = null; } else { - $values[$i] = ${eval.value}; + $valuesAccessor[$i] = ${eval.value}; } """ } @@ -1094,9 +1099,9 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) val schemaField = ctx.addReferenceObj("schema", schema) val code = s""" - $values = new Object[${children.size}]; + $valuesAccessor = new Object[${children.size}]; $childrenCode - final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, $schemaField); + final ${classOf[Row].getName} ${ev.value} = new $rowClass($valuesAccessor, $schemaField); """ ev.copy(code = code, isNull = "false") } @@ -1133,12 +1138,13 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) $serializer = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance(); } """ - ctx.addMutableState(serializerInstanceClass, serializer, serializerInit) + val serializerAccessor = ctx.addMutableState(serializerInstanceClass, + serializer, serializerInit) // Code to serialize. val input = child.genCode(ctx) val javaType = ctx.javaType(dataType) - val serialize = s"$serializer.serialize(${input.value}, null).array()" + val serialize = s"$serializerAccessor.serialize(${input.value}, null).array()" val code = s""" ${input.code} @@ -1179,13 +1185,14 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B $serializer = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance(); } """ - ctx.addMutableState(serializerInstanceClass, serializer, serializerInit) + val serializerAccessor = ctx.addMutableState(serializerInstanceClass, + serializer, serializerInit) // Code to deserialize. val input = child.genCode(ctx) val javaType = ctx.javaType(dataType) val deserialize = - s"($javaType) $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null)" + s"($javaType) $serializerAccessor.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null)" val code = s""" ${input.code} @@ -1215,26 +1222,26 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp val javaBeanInstance = ctx.freshName("javaBean") val beanInstanceJavaType = ctx.javaType(beanInstance.dataType) - ctx.addMutableState(beanInstanceJavaType, javaBeanInstance, "") + val javaBeanInstanceAccessor = ctx.addMutableState(beanInstanceJavaType, javaBeanInstance, "") val initialize = setters.map { case (setterMethod, fieldValue) => val fieldGen = fieldValue.genCode(ctx) s""" ${fieldGen.code} - ${javaBeanInstance}.$setterMethod(${fieldGen.value}); + $javaBeanInstanceAccessor.$setterMethod(${fieldGen.value}); """ } val initializeCode = ctx.splitExpressions(ctx.INPUT_ROW, initialize.toSeq) val code = s""" ${instanceGen.code} - ${javaBeanInstance} = ${instanceGen.value}; + $javaBeanInstanceAccessor = ${instanceGen.value}; if (!${instanceGen.isNull}) { $initializeCode } """ - ev.copy(code = code, isNull = instanceGen.isNull, value = instanceGen.value) + ev.copy(code = code, isNull = instanceGen.isNull, value = javaBeanInstanceAccessor) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index efcd45fad779c..9835df3042cc5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -315,14 +315,14 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with } else { "" } - ctx.addMutableState(setName, setTerm, + val setTermAccessor = ctx.addMutableState(setName, setTerm, s"$setTerm = (($InSetName)references[${ctx.references.size - 1}]).getSet();") ev.copy(code = s""" ${childGen.code} boolean ${ev.isNull} = ${childGen.isNull}; boolean ${ev.value} = false; if (!${ev.isNull}) { - ${ev.value} = $setTerm.contains(${childGen.value}); + ${ev.value} = $setTermAccessor.contains(${childGen.value}); $setNull } """) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 97051769cbf72..0ef4792f9a9e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -79,11 +79,12 @@ case class Rand(child: Expression) extends RDG { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rngTerm = ctx.freshName("rng") val className = classOf[XORShiftRandom].getName - ctx.addMutableState(className, rngTerm, "") + val rngTermAccessor = ctx.addMutableState(className, rngTerm, "") ctx.addPartitionInitializationStatement( - s"$rngTerm = new $className(${seed}L + partitionIndex);") + s"$rngTermAccessor = new $className(${seed}L + partitionIndex);") ev.copy(code = s""" - final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""", isNull = "false") + final ${ctx.javaType(dataType)} ${ev.value} = $rngTermAccessor.nextDouble();""", + isNull = "false") } } @@ -114,11 +115,12 @@ case class Randn(child: Expression) extends RDG { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rngTerm = ctx.freshName("rng") val className = classOf[XORShiftRandom].getName - ctx.addMutableState(className, rngTerm, "") + val rngTermAccessor = ctx.addMutableState(className, rngTerm, "") ctx.addPartitionInitializationStatement( - s"$rngTerm = new $className(${seed}L + partitionIndex);") + s"$rngTermAccessor = new $className(${seed}L + partitionIndex);") ev.copy(code = s""" - final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", isNull = "false") + final ${ctx.javaType(dataType)} ${ev.value} = $rngTermAccessor.nextGaussian();""", + isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index d0d663f63f5db..ebb7e37c75758 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -119,7 +119,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi if (rVal != null) { val regexStr = StringEscapeUtils.escapeJava(escape(rVal.asInstanceOf[UTF8String].toString())) - ctx.addMutableState(patternClass, pattern, + val patternAccessor = ctx.addMutableState(patternClass, pattern, s"""$pattern = ${patternClass}.compile("$regexStr");""") // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. @@ -129,7 +129,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi boolean ${ev.isNull} = ${eval.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { - ${ev.value} = $pattern.matcher(${eval.value}.toString()).matches(); + ${ev.value} = $patternAccessor.matcher(${eval.value}.toString()).matches(); } """) } else { @@ -194,7 +194,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress if (rVal != null) { val regexStr = StringEscapeUtils.escapeJava(rVal.asInstanceOf[UTF8String].toString()) - ctx.addMutableState(patternClass, pattern, + val patternAccessor = ctx.addMutableState(patternClass, pattern, s"""$pattern = ${patternClass}.compile("$regexStr");""") // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. @@ -204,7 +204,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress boolean ${ev.isNull} = ${eval.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { - ${ev.value} = $pattern.matcher(${eval.value}.toString()).find(0); + ${ev.value} = $patternAccessor.matcher(${eval.value}.toString()).find(0); } """) } else { @@ -329,13 +329,19 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio val matcher = ctx.freshName("matcher") - ctx.addMutableState("UTF8String", termLastRegex, s"${termLastRegex} = null;") - ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;") - ctx.addMutableState("String", termLastReplacement, s"${termLastReplacement} = null;") - ctx.addMutableState("UTF8String", + val termLastRegexAccessor = + ctx.addMutableState("UTF8String", termLastRegex, s"${termLastRegex} = null;") + val termPatternAccessor = + ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;") + val termLastReplacementAccessor = + ctx.addMutableState("String", termLastReplacement, s"${termLastReplacement} = null;") + val termLastReplacementInUTF8Accessor = + ctx.addMutableState("UTF8String", termLastReplacementInUTF8, s"${termLastReplacementInUTF8} = null;") - ctx.addMutableState(classNameStringBuffer, - termResult, s"${termResult} = new $classNameStringBuffer();") + val termResultAccessor = + ctx.addMutableState(classNameStringBuffer, + termResult, + s"${termResult} = new $classNameStringBuffer();") val setEvNotNull = if (nullable) { s"${ev.isNull} = false;" @@ -345,24 +351,24 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio nullSafeCodeGen(ctx, ev, (subject, regexp, rep) => { s""" - if (!$regexp.equals(${termLastRegex})) { + if (!$regexp.equals(${termLastRegexAccessor})) { // regex value changed - ${termLastRegex} = $regexp.clone(); - ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); + ${termLastRegexAccessor} = $regexp.clone(); + ${termPatternAccessor} = ${classNamePattern}.compile(${termLastRegexAccessor}.toString()); } - if (!$rep.equals(${termLastReplacementInUTF8})) { + if (!$rep.equals(${termLastReplacementInUTF8Accessor})) { // replacement string changed - ${termLastReplacementInUTF8} = $rep.clone(); - ${termLastReplacement} = ${termLastReplacementInUTF8}.toString(); + ${termLastReplacementInUTF8Accessor} = $rep.clone(); + ${termLastReplacementAccessor} = ${termLastReplacementInUTF8Accessor}.toString(); } - ${termResult}.delete(0, ${termResult}.length()); - java.util.regex.Matcher ${matcher} = ${termPattern}.matcher($subject.toString()); + ${termResultAccessor}.delete(0, ${termResultAccessor}.length()); + java.util.regex.Matcher ${matcher} = ${termPatternAccessor}.matcher($subject.toString()); while (${matcher}.find()) { - ${matcher}.appendReplacement(${termResult}, ${termLastReplacement}); + ${matcher}.appendReplacement(${termResultAccessor}, ${termLastReplacementAccessor}); } - ${matcher}.appendTail(${termResult}); - ${ev.value} = UTF8String.fromString(${termResult}.toString()); + ${matcher}.appendTail(${termResultAccessor}); + ${ev.value} = UTF8String.fromString(${termResultAccessor}.toString()); $setEvNotNull """ }) @@ -422,8 +428,10 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio val matcher = ctx.freshName("matcher") val matchResult = ctx.freshName("matchResult") - ctx.addMutableState("UTF8String", termLastRegex, s"${termLastRegex} = null;") - ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;") + val termLastRegexAccessor = + ctx.addMutableState("UTF8String", termLastRegex, s"${termLastRegex} = null;") + val termPatternAccessor = + ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;") val setEvNotNull = if (nullable) { s"${ev.isNull} = false;" @@ -433,13 +441,13 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => { s""" - if (!$regexp.equals(${termLastRegex})) { + if (!$regexp.equals(${termLastRegexAccessor})) { // regex value changed - ${termLastRegex} = $regexp.clone(); - ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); + ${termLastRegexAccessor} = $regexp.clone(); + ${termPatternAccessor} = ${classNamePattern}.compile(${termLastRegexAccessor}.toString()); } java.util.regex.Matcher ${matcher} = - ${termPattern}.matcher($subject.toString()); + ${termPatternAccessor}.matcher($subject.toString()); if (${matcher}.find()) { java.util.regex.MatchResult ${matchResult} = ${matcher}.toMatchResult(); if (${matchResult}.group($idx) == null) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index c341943187820..6d8c4aa05eca6 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -440,24 +440,27 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac val termDict = ctx.freshName("dict") val classNameDict = classOf[JMap[Character, Character]].getCanonicalName - ctx.addMutableState("UTF8String", termLastMatching, s"$termLastMatching = null;") - ctx.addMutableState("UTF8String", termLastReplace, s"$termLastReplace = null;") - ctx.addMutableState(classNameDict, termDict, s"$termDict = null;") + val termLastMatchingAccessor = + ctx.addMutableState("UTF8String", termLastMatching, s"$termLastMatching = null;") + val termLastReplaceAccessor = + ctx.addMutableState("UTF8String", termLastReplace, s"$termLastReplace = null;") + val termDictAccessor = ctx.addMutableState(classNameDict, termDict, s"$termDict = null;") nullSafeCodeGen(ctx, ev, (src, matching, replace) => { val check = if (matchingExpr.foldable && replaceExpr.foldable) { - s"$termDict == null" + s"$termDictAccessor == null" } else { - s"!$matching.equals($termLastMatching) || !$replace.equals($termLastReplace)" + s"!$matching.equals($termLastMatchingAccessor) || " + + s"!$replace.equals($termLastReplaceAccessor)" } s"""if ($check) { // Not all of them is literal or matching or replace value changed - $termLastMatching = $matching.clone(); - $termLastReplace = $replace.clone(); - $termDict = org.apache.spark.sql.catalyst.expressions.StringTranslate - .buildDict($termLastMatching, $termLastReplace); + $termLastMatchingAccessor = $matching.clone(); + $termLastReplaceAccessor = $replace.clone(); + $termDictAccessor = org.apache.spark.sql.catalyst.expressions.StringTranslate + .buildDict($termLastMatchingAccessor, $termLastReplaceAccessor); } - ${ev.value} = $src.translate($termDict); + ${ev.value} = $src.translate($termDictAccessor); """ }) } @@ -1965,27 +1968,27 @@ case class FormatNumber(x: Expression, d: Expression) val numberFormat = ctx.freshName("numberFormat") val i = ctx.freshName("i") val dFormat = ctx.freshName("dFormat") - ctx.addMutableState("int", lastDValue, s"$lastDValue = -100;") - ctx.addMutableState(sb, pattern, s"$pattern = new $sb();") - ctx.addMutableState(df, numberFormat, + val lastDValueAccessor = ctx.addMutableState("int", lastDValue, s"$lastDValue = -100;") + val patternAccessor = ctx.addMutableState(sb, pattern, s"$pattern = new $sb();") + val numberFormatAccessor = ctx.addMutableState(df, numberFormat, s"""$numberFormat = new $df("", new $dfs($l.$usLocale));""") s""" if ($d >= 0) { - $pattern.delete(0, $pattern.length()); - if ($d != $lastDValue) { - $pattern.append("#,###,###,###,###,###,##0"); + $patternAccessor.delete(0, $patternAccessor.length()); + if ($d != $lastDValueAccessor) { + $patternAccessor.append("#,###,###,###,###,###,##0"); if ($d > 0) { - $pattern.append("."); + $patternAccessor.append("."); for (int $i = 0; $i < $d; $i++) { - $pattern.append("0"); + $patternAccessor.append("0"); } } - $lastDValue = $d; - $numberFormat.applyLocalizedPattern($pattern.toString()); + $lastDValueAccessor = $d; + $numberFormatAccessor.applyLocalizedPattern($patternAccessor.toString()); } - ${ev.value} = UTF8String.fromString($numberFormat.format(${typeHelper(num)})); + ${ev.value} = UTF8String.fromString($numberFormatAccessor.format(${typeHelper(num)})); } else { ${ev.value} = null; ${ev.isNull} = true; diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala index 0cd0d8859145f..2c79ac6e5714f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala @@ -133,6 +133,57 @@ class GeneratedProjectionSuite extends SparkFunSuite { assert(result === row2) } + test("SPARK-18016: generated projections on wider table requiring state compaction") { + val N = 40000 + val wideRow1 = new GenericInternalRow((0 until N).toArray[Any]) + val schema1 = StructType((1 to N).map(i => StructField("", IntegerType))) + val wideRow2 = new GenericInternalRow( + (0 until N).map(i => UTF8String.fromString(i.toString)).toArray[Any]) + val schema2 = StructType((1 to N).map(i => StructField("", StringType))) + val joined = new JoinedRow(wideRow1, wideRow2) + val joinedSchema = StructType(schema1 ++ schema2) + val nested = new JoinedRow(InternalRow(joined, joined), joined) + val nestedSchema = StructType( + Seq(StructField("", joinedSchema), StructField("", joinedSchema)) ++ joinedSchema) + + // test generated UnsafeProjection + val unsafeProj = UnsafeProjection.create(nestedSchema) + val unsafe: UnsafeRow = unsafeProj(nested) + (0 until N).foreach { i => + val s = UTF8String.fromString(i.toString) + assert(i === unsafe.getInt(i + 2)) + assert(s === unsafe.getUTF8String(i + 2 + N)) + assert(i === unsafe.getStruct(0, N * 2).getInt(i)) + assert(s === unsafe.getStruct(0, N * 2).getUTF8String(i + N)) + assert(i === unsafe.getStruct(1, N * 2).getInt(i)) + assert(s === unsafe.getStruct(1, N * 2).getUTF8String(i + N)) + } + + // test generated SafeProjection + val safeProj = FromUnsafeProjection(nestedSchema) + val result = safeProj(unsafe) + // Can't compare GenericInternalRow with JoinedRow directly + (0 until N).foreach { i => + val s = UTF8String.fromString(i.toString) + assert(i === result.getInt(i + 2)) + assert(s === result.getUTF8String(i + 2 + N)) + assert(i === result.getStruct(0, N * 2).getInt(i)) + assert(s === result.getStruct(0, N * 2).getUTF8String(i + N)) + assert(i === result.getStruct(1, N * 2).getInt(i)) + assert(s === result.getStruct(1, N * 2).getUTF8String(i + N)) + } + + // test generated MutableProjection + val exprs = nestedSchema.fields.zipWithIndex.map { case (f, i) => + BoundReference(i, f.dataType, true) + } + val mutableProj = GenerateMutableProjection.generate(exprs) + val row1 = mutableProj(result) + assert(result === row1) + val row2 = mutableProj(result) + assert(result === row2) + } + test("generated unsafe projection with array of binary") { val row = InternalRow( Array[Byte](1, 2), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 1afe83ea3539e..01b63a3d3a739 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -73,27 +73,29 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { override protected def doProduce(ctx: CodegenContext): String = { val input = ctx.freshName("input") // PhysicalRDD always just has one input - ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") + val inputAccessor = ctx.addMutableState("scala.collection.Iterator", input, + s"$input = inputs[0];") // metrics val numOutputRows = metricTerm(ctx, "numOutputRows") val scanTimeMetric = metricTerm(ctx, "scanTime") val scanTimeTotalNs = ctx.freshName("scanTime") - ctx.addMutableState("long", scanTimeTotalNs, s"$scanTimeTotalNs = 0;") + val scanTimeTotalNsAccessor = ctx.addMutableState("long", scanTimeTotalNs, + s"$scanTimeTotalNs = 0;") val columnarBatchClz = classOf[ColumnarBatch].getName val batch = ctx.freshName("batch") - ctx.addMutableState(columnarBatchClz, batch, s"$batch = null;") + val batchAccessor = ctx.addMutableState(columnarBatchClz, batch, s"$batch = null;") val idx = ctx.freshName("batchIdx") - ctx.addMutableState("int", idx, s"$idx = 0;") + val idxAccessor = ctx.addMutableState("int", idx, s"$idx = 0;") val colVars = output.indices.map(i => ctx.freshName("colInstance" + i)) val columnVectorClzs = vectorTypes.getOrElse( Seq.fill(colVars.size)(classOf[ColumnVector].getName)) - val columnAssigns = colVars.zip(columnVectorClzs).zipWithIndex.map { + val columnNameAssigns = colVars.zip(columnVectorClzs).zipWithIndex.map { case ((name, columnVectorClz), i) => - ctx.addMutableState(columnVectorClz, name, s"$name = null;") - s"$name = ($columnVectorClz) $batch.column($i);" + val nameAccessor = ctx.addMutableState(columnVectorClz, name, s"$name = null;") + (nameAccessor, s"$nameAccessor = ($columnVectorClz) $batchAccessor.column($i);") } val nextBatch = ctx.freshName("nextBatch") @@ -101,46 +103,46 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { s""" |private void $nextBatch() throws java.io.IOException { | long getBatchStart = System.nanoTime(); - | if ($input.hasNext()) { - | $batch = ($columnarBatchClz)$input.next(); - | $numOutputRows.add($batch.numRows()); - | $idx = 0; - | ${columnAssigns.mkString("", "\n", "\n")} + | if ($inputAccessor.hasNext()) { + | $batchAccessor = ($columnarBatchClz)$inputAccessor.next(); + | $numOutputRows.add($batchAccessor.numRows()); + | $idxAccessor = 0; + | ${columnNameAssigns.map(_._2).mkString("", "\n", "\n")} | } - | $scanTimeTotalNs += System.nanoTime() - getBatchStart; + | $scanTimeTotalNsAccessor += System.nanoTime() - getBatchStart; |}""".stripMargin) ctx.currentVars = null val rowidx = ctx.freshName("rowIdx") - val columnsBatchInput = (output zip colVars).map { case (attr, colVar) => + val columnsBatchInput = (output zip columnNameAssigns.map(_._1)).map { case (attr, colVar) => genCodeColumnVector(ctx, colVar, rowidx, attr.dataType, attr.nullable) } val localIdx = ctx.freshName("localIdx") val localEnd = ctx.freshName("localEnd") val numRows = ctx.freshName("numRows") val shouldStop = if (isShouldStopRequired) { - s"if (shouldStop()) { $idx = $rowidx + 1; return; }" + s"if (shouldStop()) { $idxAccessor = $rowidx + 1; return; }" } else { "// shouldStop check is eliminated" } s""" - |if ($batch == null) { + |if ($batchAccessor == null) { | $nextBatchFuncName(); |} - |while ($batch != null) { - | int $numRows = $batch.numRows(); - | int $localEnd = $numRows - $idx; + |while ($batchAccessor != null) { + | int $numRows = $batchAccessor.numRows(); + | int $localEnd = $numRows - $idxAccessor; | for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) { - | int $rowidx = $idx + $localIdx; + | int $rowidx = $idxAccessor + $localIdx; | ${consume(ctx, columnsBatchInput).trim} | $shouldStop | } - | $idx = $numRows; - | $batch = null; + | $idxAccessor = $numRows; + | $batchAccessor = null; | $nextBatchFuncName(); |} - |$scanTimeMetric.add($scanTimeTotalNs / (1000 * 1000)); - |$scanTimeTotalNs = 0; + |$scanTimeMetric.add($scanTimeTotalNsAccessor / (1000 * 1000)); + |$scanTimeTotalNsAccessor = 0; """.stripMargin } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 8d0fc32feac99..a337342ee7cdc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -111,7 +111,8 @@ case class RowDataSourceScanExec( val numOutputRows = metricTerm(ctx, "numOutputRows") // PhysicalRDD always just has one input val input = ctx.freshName("input") - ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") + val inputAccessor = ctx.addMutableState("scala.collection.Iterator", + input, s"$input = inputs[0];") val exprRows = output.zipWithIndex.map{ case (a, i) => BoundReference(i, a.dataType, a.nullable) } @@ -120,8 +121,8 @@ case class RowDataSourceScanExec( ctx.currentVars = null val columnsRowInput = exprRows.map(_.genCode(ctx)) s""" - |while ($input.hasNext()) { - | InternalRow $row = (InternalRow) $input.next(); + |while ($inputAccessor.hasNext()) { + | InternalRow $row = (InternalRow) $inputAccessor.next(); | $numOutputRows.add(1); | ${consume(ctx, columnsRowInput, null).trim} | if (shouldStop()) return; @@ -353,7 +354,8 @@ case class FileSourceScanExec( val numOutputRows = metricTerm(ctx, "numOutputRows") // PhysicalRDD always just has one input val input = ctx.freshName("input") - ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") + val inputAccessor = ctx.addMutableState("scala.collection.Iterator", + input, s"$input = inputs[0];") val exprRows = output.zipWithIndex.map{ case (a, i) => BoundReference(i, a.dataType, a.nullable) } @@ -363,8 +365,8 @@ case class FileSourceScanExec( val columnsRowInput = exprRows.map(_.genCode(ctx)) val inputRow = if (needsUnsafeRowConversion) null else row s""" - |while ($input.hasNext()) { - | InternalRow $row = (InternalRow) $input.next(); + |while ($inputAccessor.hasNext()) { + | InternalRow $row = (InternalRow) $inputAccessor.next(); | $numOutputRows.add(1); | ${consume(ctx, columnsRowInput, inputRow).trim} | if (shouldStop()) return; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index ff71fd4dc7bb7..432b33ee995e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -126,19 +126,20 @@ case class SortExec( override protected def doProduce(ctx: CodegenContext): String = { val needToSort = ctx.freshName("needToSort") - ctx.addMutableState("boolean", needToSort, s"$needToSort = true;") + val needToSortAccessor = ctx.addMutableState("boolean", needToSort, s"$needToSort = true;") // Initialize the class member variables. This includes the instance of the Sorter and // the iterator to return sorted rows. val thisPlan = ctx.addReferenceObj("plan", this) sorterVariable = ctx.freshName("sorter") ctx.addMutableState(classOf[UnsafeExternalRowSorter].getName, sorterVariable, - s"$sorterVariable = $thisPlan.createSorter();") + s"$sorterVariable = $thisPlan.createSorter();", inline = true) val metrics = ctx.freshName("metrics") - ctx.addMutableState(classOf[TaskMetrics].getName, metrics, + val metricsAccessor = ctx.addMutableState(classOf[TaskMetrics].getName, metrics, s"$metrics = org.apache.spark.TaskContext.get().taskMetrics();") val sortedIterator = ctx.freshName("sortedIter") - ctx.addMutableState("scala.collection.Iterator", sortedIterator, "") + val sortedIteratorAccessor = ctx.addMutableState("scala.collection.Iterator", + sortedIterator, "") val addToSorter = ctx.freshName("addToSorter") val addToSorterFuncName = ctx.addNewFunction(addToSorter, @@ -158,19 +159,19 @@ case class SortExec( val spillSizeBefore = ctx.freshName("spillSizeBefore") val sortTime = metricTerm(ctx, "sortTime") s""" - | if ($needToSort) { - | long $spillSizeBefore = $metrics.memoryBytesSpilled(); + | if ($needToSortAccessor) { + | long $spillSizeBefore = $metricsAccessor.memoryBytesSpilled(); | $addToSorterFuncName(); - | $sortedIterator = $sorterVariable.sort(); + | $sortedIteratorAccessor = $sorterVariable.sort(); | $sortTime.add($sorterVariable.getSortTimeNanos() / 1000000); | $peakMemory.add($sorterVariable.getPeakMemoryUsage()); - | $spillSize.add($metrics.memoryBytesSpilled() - $spillSizeBefore); - | $metrics.incPeakExecutionMemory($sorterVariable.getPeakMemoryUsage()); - | $needToSort = false; + | $spillSize.add($metricsAccessor.memoryBytesSpilled() - $spillSizeBefore); + | $metricsAccessor.incPeakExecutionMemory($sorterVariable.getPeakMemoryUsage()); + | $needToSortAccessor = false; | } | - | while ($sortedIterator.hasNext()) { - | UnsafeRow $outputRow = (UnsafeRow)$sortedIterator.next(); + | while ($sortedIteratorAccessor.hasNext()) { + | UnsafeRow $outputRow = (UnsafeRow)$sortedIteratorAccessor.next(); | ${consume(ctx, null, outputRow)} | if (shouldStop()) return; | } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 1aaaf896692d1..108a074283808 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -258,11 +258,12 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp override def doProduce(ctx: CodegenContext): String = { val input = ctx.freshName("input") // Right now, InputAdapter is only used when there is one input RDD. - ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") + val inputAccessor = ctx.addMutableState("scala.collection.Iterator", + input, s"$input = inputs[0];") val row = ctx.freshName("row") s""" - | while ($input.hasNext() && !stopEarly()) { - | InternalRow $row = (InternalRow) $input.next(); + | while ($inputAccessor.hasNext() && !stopEarly()) { + | InternalRow $row = (InternalRow) $inputAccessor.next(); | ${consume(ctx, null, row).trim} | if (shouldStop()) return; | } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index f424096b330e3..ce7fdc1198fe4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -166,7 +166,7 @@ case class HashAggregateExec( private def doProduceWithoutKeys(ctx: CodegenContext): String = { val initAgg = ctx.freshName("initAgg") - ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") + val initAggAccessor = ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") // generate variables for aggregation buffer val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) @@ -174,15 +174,15 @@ case class HashAggregateExec( bufVars = initExpr.map { e => val isNull = ctx.freshName("bufIsNull") val value = ctx.freshName("bufValue") - ctx.addMutableState("boolean", isNull, "") - ctx.addMutableState(ctx.javaType(e.dataType), value, "") + val isNullAccessor = ctx.addMutableState("boolean", isNull, "") + val valueAccessor = ctx.addMutableState(ctx.javaType(e.dataType), value, "") // The initial expression should not access any column val ev = e.genCode(ctx) val initVars = s""" - | $isNull = ${ev.isNull}; - | $value = ${ev.value}; + | $isNullAccessor = ${ev.isNull}; + | $valueAccessor = ${ev.value}; """.stripMargin - ExprCode(ev.code + initVars, isNull, value) + ExprCode(ev.code + initVars, isNullAccessor, valueAccessor) } val initBufVar = evaluateVariables(bufVars) @@ -227,8 +227,8 @@ case class HashAggregateExec( val aggTime = metricTerm(ctx, "aggTime") val beforeAgg = ctx.freshName("beforeAgg") s""" - | while (!$initAgg) { - | $initAgg = true; + | while (!$initAggAccessor) { + | $initAggAccessor = true; | long $beforeAgg = System.nanoTime(); | $doAggFuncName(); | $aggTime.add((System.nanoTime() - $beforeAgg) / 1000000); @@ -555,7 +555,7 @@ case class HashAggregateExec( private def doProduceWithKeys(ctx: CodegenContext): String = { val initAgg = ctx.freshName("initAgg") - ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") + val initAggAccessor = ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") if (sqlContext.conf.enableTwoLevelAggMap) { enableTwoLevelHashMap(ctx) } else { @@ -580,33 +580,36 @@ case class HashAggregateExec( // Create a name for iterator from vectorized HashMap val iterTermForFastHashMap = ctx.freshName("fastHashMapIter") - if (isFastHashMapEnabled) { + val iterTermForFastHashMapAccessor = if (isFastHashMapEnabled) { if (isVectorizedHashMapEnabled) { - ctx.addMutableState(fastHashMapClassName, fastHashMapTerm, + fastHashMapTerm = ctx.addMutableState(fastHashMapClassName, fastHashMapTerm, s"$fastHashMapTerm = new $fastHashMapClassName();") ctx.addMutableState( "java.util.Iterator", - iterTermForFastHashMap, "") + iterTermForFastHashMap, "", inline = true) } else { ctx.addMutableState(fastHashMapClassName, fastHashMapTerm, s"$fastHashMapTerm = new $fastHashMapClassName(" + - s"$thisPlan.getTaskMemoryManager(), $thisPlan.getEmptyAggregationBuffer());") + s"$thisPlan.getTaskMemoryManager(), $thisPlan.getEmptyAggregationBuffer());", + inline = true) ctx.addMutableState( "org.apache.spark.unsafe.KVIterator", - iterTermForFastHashMap, "") + iterTermForFastHashMap, "", inline = true) } } // create hashMap hashMapTerm = ctx.freshName("hashMap") val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName - ctx.addMutableState(hashMapClassName, hashMapTerm, "") + hashMapTerm = ctx.addMutableState(hashMapClassName, hashMapTerm, "") sorterTerm = ctx.freshName("sorter") - ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, sorterTerm, "") + sorterTerm = ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, sorterTerm, "", + inline = true) // Create a name for iterator from HashMap val iterTerm = ctx.freshName("mapIter") - ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm, "") + val iterTermAccessor = ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, + iterTerm, "") def generateGenerateCode(): String = { if (isFastHashMapEnabled) { @@ -634,10 +637,10 @@ case class HashAggregateExec( ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} ${if (isFastHashMapEnabled) { - s"$iterTermForFastHashMap = $fastHashMapTerm.rowIterator();"} else ""} + s"$iterTermForFastHashMapAccessor = $fastHashMapTerm.rowIterator();"} else ""} - $iterTerm = $thisPlan.finishAggregate($hashMapTerm, $sorterTerm, $peakMemory, $spillSize, - $avgHashProbe); + $iterTermAccessor = $thisPlan.finishAggregate( + $hashMapTerm, $sorterTerm, $peakMemory, $spillSize, $avgHashProbe); } """) @@ -663,10 +666,10 @@ case class HashAggregateExec( def outputFromRowBasedMap: String = { s""" - while ($iterTermForFastHashMap.next()) { + while ($iterTermForFastHashMapAccessor.next()) { $numOutput.add(1); - UnsafeRow $keyTerm = (UnsafeRow) $iterTermForFastHashMap.getKey(); - UnsafeRow $bufferTerm = (UnsafeRow) $iterTermForFastHashMap.getValue(); + UnsafeRow $keyTerm = (UnsafeRow) $iterTermForFastHashMapAccessor.getKey(); + UnsafeRow $bufferTerm = (UnsafeRow) $iterTermForFastHashMapAccessor.getValue(); $outputFunc($keyTerm, $bufferTerm); if (shouldStop()) return; @@ -689,11 +692,11 @@ case class HashAggregateExec( .map { case (attr, i) => BoundReference(groupingKeySchema.length + i, attr.dataType, attr.nullable) }) s""" - | while ($iterTermForFastHashMap.hasNext()) { + | while ($iterTermForFastHashMapAccessor.hasNext()) { | $numOutput.add(1); | org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row $row = | (org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row) - | $iterTermForFastHashMap.next(); + | $iterTermForFastHashMapAccessor.next(); | ${generateKeyRow.code} | ${generateBufferRow.code} | $outputFunc(${generateKeyRow.value}, ${generateBufferRow.value}); @@ -709,8 +712,8 @@ case class HashAggregateExec( val aggTime = metricTerm(ctx, "aggTime") val beforeAgg = ctx.freshName("beforeAgg") s""" - if (!$initAgg) { - $initAgg = true; + if (!$initAggAccessor) { + $initAggAccessor = true; long $beforeAgg = System.nanoTime(); $doAggFuncName(); $aggTime.add((System.nanoTime() - $beforeAgg) / 1000000); @@ -721,14 +724,14 @@ case class HashAggregateExec( while ($iterTerm.next()) { $numOutput.add(1); - UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey(); - UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue(); + UnsafeRow $keyTerm = (UnsafeRow) $iterTermAccessor.getKey(); + UnsafeRow $bufferTerm = (UnsafeRow) $iterTermAccessor.getValue(); $outputFunc($keyTerm, $bufferTerm); if (shouldStop()) return; } - $iterTerm.close(); + $iterTermAccessor.close(); if ($sorterTerm == null) { $hashMapTerm.free(); } @@ -768,9 +771,10 @@ case class HashAggregateExec( val (checkFallbackForGeneratedHashMap, checkFallbackForBytesToBytesMap, resetCounter, incCounter) = if (testFallbackStartsAt.isDefined) { val countTerm = ctx.freshName("fallbackCounter") - ctx.addMutableState("int", countTerm, s"$countTerm = 0;") - (s"$countTerm < ${testFallbackStartsAt.get._1}", - s"$countTerm < ${testFallbackStartsAt.get._2}", s"$countTerm = 0;", s"$countTerm += 1;") + val countTermAccessor = ctx.addMutableState("int", countTerm, s"$countTerm = 0;") + (s"$countTermAccessor < ${testFallbackStartsAt.get._1}", + s"$countTermAccessor < ${testFallbackStartsAt.get._2}", + s"$countTermAccessor = 0;", s"$countTermAccessor += 1;") } else { ("true", "true", "", "") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala index 90deb20e97244..12c835fc445d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -48,15 +48,15 @@ abstract class HashMapGenerator( initExpr.map { e => val isNull = ctx.freshName("bufIsNull") val value = ctx.freshName("bufValue") - ctx.addMutableState("boolean", isNull, "") - ctx.addMutableState(ctx.javaType(e.dataType), value, "") + val isNullAccessor = ctx.addMutableState("boolean", isNull, "") + val valueAccessor = ctx.addMutableState(ctx.javaType(e.dataType), value, "") val ev = e.genCode(ctx) val initVars = s""" - | $isNull = ${ev.isNull}; - | $value = ${ev.value}; + | $isNullAccessor = ${ev.isNull}; + | $valueAccessor = ${ev.value}; """.stripMargin - ExprCode(ev.code + initVars, isNull, value) + ExprCode(ev.code + initVars, isNullAccessor, valueAccessor) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 63cd1691f4cd7..4cfc8bc6e6c3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -303,12 +303,12 @@ case class SampleExec( | } """.stripMargin.trim) - ctx.addMutableState(s"$samplerClass", sampler, + val samplerAccessor = ctx.addMutableState(s"$samplerClass", sampler, s"$initSamplerFuncName();") val samplingCount = ctx.freshName("samplingCount") s""" - | int $samplingCount = $sampler.sample(); + | int $samplingCount = $samplerAccessor.sample(); | while ($samplingCount-- > 0) { | $numOutput.add(1); | ${consume(ctx, input)} @@ -316,14 +316,14 @@ case class SampleExec( """.stripMargin.trim } else { val samplerClass = classOf[BernoulliCellSampler[UnsafeRow]].getName - ctx.addMutableState(s"$samplerClass", sampler, + val samplerAccessor = ctx.addMutableState(s"$samplerClass", sampler, s""" | $sampler = new $samplerClass($lowerBound, $upperBound, false); | $sampler.setSeed(${seed}L + partitionIndex); """.stripMargin.trim) s""" - | if ($sampler.sample() != 0) { + | if ($samplerAccessor.sample() != 0) { | $numOutput.add(1); | ${consume(ctx, input)} | } @@ -367,19 +367,20 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) val numOutput = metricTerm(ctx, "numOutputRows") val initTerm = ctx.freshName("initRange") - ctx.addMutableState("boolean", initTerm, s"$initTerm = false;") + val initTermAccessor = ctx.addMutableState("boolean", initTerm, s"$initTerm = false;") val number = ctx.freshName("number") - ctx.addMutableState("long", number, s"$number = 0L;") + val numberAccessor = ctx.addMutableState("long", number, s"$number = 0L;") val value = ctx.freshName("value") val ev = ExprCode("", "false", value) val BigInt = classOf[java.math.BigInteger].getName val taskContext = ctx.freshName("taskContext") - ctx.addMutableState("TaskContext", taskContext, s"$taskContext = TaskContext.get();") + val taskContextAccessor = ctx.addMutableState("TaskContext", + taskContext, s"$taskContext = TaskContext.get();") val inputMetrics = ctx.freshName("inputMetrics") - ctx.addMutableState("InputMetrics", inputMetrics, - s"$inputMetrics = $taskContext.taskMetrics().inputMetrics();") + val inputMetricsAccessor = ctx.addMutableState("InputMetrics", inputMetrics, + s"$inputMetrics = $taskContextAccessor.taskMetrics().inputMetrics();") // In order to periodically update the metrics without inflicting performance penalty, this // operator produces elements in batches. After a batch is complete, the metrics are updated @@ -390,11 +391,12 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) // Once number == batchEnd, it's time to progress to the next batch. val batchEnd = ctx.freshName("batchEnd") - ctx.addMutableState("long", batchEnd, s"$batchEnd = 0;") + val batchEndAccessor = ctx.addMutableState("long", batchEnd, s"$batchEnd = 0;") // How many values should still be generated by this range operator. val numElementsTodo = ctx.freshName("numElementsTodo") - ctx.addMutableState("long", numElementsTodo, s"$numElementsTodo = 0L;") + val numElementsTodoAccessor = ctx.addMutableState("long", + numElementsTodo, s"$numElementsTodo = 0L;") // How many values should be generated in the next batch. val nextBatchTodo = ctx.freshName("nextBatchTodo") @@ -414,13 +416,13 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) | | $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start); | if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { - | $number = Long.MAX_VALUE; + | $numberAccessor = Long.MAX_VALUE; | } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { - | $number = Long.MIN_VALUE; + | $numberAccessor = Long.MIN_VALUE; | } else { - | $number = st.longValue(); + | $numberAccessor = st.longValue(); | } - | $batchEnd = $number; + | $batchEndAccessor = $numberAccessor; | | $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice) | .multiply(step).add(start); @@ -433,62 +435,62 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) | } | | $BigInt startToEnd = $BigInt.valueOf(partitionEnd).subtract( - | $BigInt.valueOf($number)); - | $numElementsTodo = startToEnd.divide(step).longValue(); - | if ($numElementsTodo < 0) { - | $numElementsTodo = 0; + | $BigInt.valueOf($numberAccessor)); + | $numElementsTodoAccessor = startToEnd.divide(step).longValue(); + | if ($numElementsTodoAccessor < 0) { + | $numElementsTodoAccessor = 0; | } else if (startToEnd.remainder(step).compareTo($BigInt.valueOf(0L)) != 0) { - | $numElementsTodo++; + | $numElementsTodoAccessor++; | } | } """.stripMargin) val input = ctx.freshName("input") // Right now, Range is only used when there is one upstream. - ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") + ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];", inline = true) val localIdx = ctx.freshName("localIdx") val localEnd = ctx.freshName("localEnd") val range = ctx.freshName("range") val shouldStop = if (isShouldStopRequired) { - s"if (shouldStop()) { $number = $value + ${step}L; return; }" + s"if (shouldStop()) { $numberAccessor = $value + ${step}L; return; }" } else { "// shouldStop check is eliminated" } s""" | // initialize Range - | if (!$initTerm) { - | $initTerm = true; + | if (!$initTermAccessor) { + | $initTermAccessor = true; | $initRangeFuncName(partitionIndex); | } | | while (true) { - | long $range = $batchEnd - $number; + | long $range = $batchEndAccessor - $numberAccessor; | if ($range != 0L) { | int $localEnd = (int)($range / ${step}L); | for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) { - | long $value = ((long)$localIdx * ${step}L) + $number; + | long $value = ((long)$localIdx * ${step}L) + $numberAccessor; | ${consume(ctx, Seq(ev))} | $shouldStop | } - | $number = $batchEnd; + | $numberAccessor = $batchEndAccessor; | } | - | $taskContext.killTaskIfInterrupted(); + | $taskContextAccessor.killTaskIfInterrupted(); | | long $nextBatchTodo; - | if ($numElementsTodo > ${batchSize}L) { + | if ($numElementsTodoAccessor > ${batchSize}L) { | $nextBatchTodo = ${batchSize}L; - | $numElementsTodo -= ${batchSize}L; + | $numElementsTodoAccessor -= ${batchSize}L; | } else { - | $nextBatchTodo = $numElementsTodo; + | $nextBatchTodo = $numElementsTodoAccessor; | $numElementsTodo = 0; | if ($nextBatchTodo == 0) break; | } | $numOutput.add($nextBatchTodo); - | $inputMetrics.incRecordsRead($nextBatchTodo); + | $inputMetricsAccessor.incRecordsRead($nextBatchTodo); | - | $batchEnd += $nextBatchTodo * ${step}L; + | $batchEndAccessor += $nextBatchTodo * ${step}L; | } """.stripMargin } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index ae600c1ffae8e..a501e0f3b1b20 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -89,19 +89,22 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera case array: ArrayType => classOf[ArrayColumnAccessor].getName case t: MapType => classOf[MapColumnAccessor].getName } - ctx.addMutableState(accessorCls, accessorName, "") + val accessorNameAccessor = ctx.addMutableState(accessorCls, accessorName, "") val createCode = dt match { case t if ctx.isPrimitiveType(dt) => - s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));" + s"$accessorNameAccessor = " + + s"new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));" case NullType | StringType | BinaryType => - s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));" + s"$accessorNameAccessor = " + + s"new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));" case other => - s"""$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder), - (${dt.getClass.getName}) columnTypes[$index]);""" + s"""$accessorNameAccessor = new $accessorCls( + ByteBuffer.wrap(buffers[$index]).order(nativeOrder), + (${dt.getClass.getName}) columnTypes[$index]);""" } - val extract = s"$accessorName.extractTo(mutableRow, $index);" + val extract = s"$accessorNameAccessor.extractTo(mutableRow, $index);" val patch = dt match { case DecimalType.Fixed(p, s) if p > Decimal.MAX_LONG_DIGITS => // For large Decimal, it should have 16 bytes for future update even it's null now. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index b09da9bdacb99..8ad32730546ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -124,13 +124,13 @@ case class BroadcastHashJoinExec( val avgHashProbe = metricTerm(ctx, "avgHashProbe") val addTaskListener = genTaskListener(avgHashProbe, relationTerm) - ctx.addMutableState(clsName, relationTerm, + val relationTermAccessor = ctx.addMutableState(clsName, relationTerm, s""" | $relationTerm = (($clsName) $broadcast.value()).asReadOnlyCopy(); | incPeakExecutionMemory($relationTerm.estimatedSize()); | $addTaskListener """.stripMargin) - (broadcastRelation, relationTerm) + (broadcastRelation, relationTermAccessor) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 4e02803552e82..38439f48de432 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -423,14 +423,14 @@ case class SortMergeJoinExec( private def genScanner(ctx: CodegenContext): (String, String) = { // Create class member for next row from both sides. val leftRow = ctx.freshName("leftRow") - ctx.addMutableState("InternalRow", leftRow, "") + val leftRowAccessor = ctx.addMutableState("InternalRow", leftRow, "") val rightRow = ctx.freshName("rightRow") - ctx.addMutableState("InternalRow", rightRow, s"$rightRow = null;") + val rightRowAccessor = ctx.addMutableState("InternalRow", rightRow, s"$rightRow = null;") // Create variables for join keys from both sides. - val leftKeyVars = createJoinKey(ctx, leftRow, leftKeys, left.output) + val leftKeyVars = createJoinKey(ctx, leftRowAccessor, leftKeys, left.output) val leftAnyNull = leftKeyVars.map(_.isNull).mkString(" || ") - val rightKeyTmpVars = createJoinKey(ctx, rightRow, rightKeys, right.output) + val rightKeyTmpVars = createJoinKey(ctx, rightRowAccessor, rightKeys, right.output) val rightAnyNull = rightKeyTmpVars.map(_.isNull).mkString(" || ") // Copy the right key as class members so they could be used in next function call. val rightKeyVars = copyKeys(ctx, rightKeyTmpVars) @@ -442,7 +442,7 @@ case class SortMergeJoinExec( val spillThreshold = getSpillThreshold val inMemoryThreshold = getInMemoryThreshold - ctx.addMutableState(clsName, matches, + val matchesAccessor = ctx.addMutableState(clsName, matches, s"$matches = new $clsName($inMemoryThreshold, $spillThreshold);") // Copy the left keys as class members so they could be used in next function call. val matchedKeyVars = copyKeys(ctx, leftKeyVars) @@ -452,58 +452,58 @@ case class SortMergeJoinExec( |private boolean findNextInnerJoinRows( | scala.collection.Iterator leftIter, | scala.collection.Iterator rightIter) { - | $leftRow = null; + | $leftRowAccessor = null; | int comp = 0; - | while ($leftRow == null) { + | while ($leftRowAccessor == null) { | if (!leftIter.hasNext()) return false; - | $leftRow = (InternalRow) leftIter.next(); + | $leftRowAccessor = (InternalRow) leftIter.next(); | ${leftKeyVars.map(_.code).mkString("\n")} | if ($leftAnyNull) { - | $leftRow = null; + | $leftRowAccessor = null; | continue; | } - | if (!$matches.isEmpty()) { + | if (!$matchesAccessor.isEmpty()) { | ${genComparison(ctx, leftKeyVars, matchedKeyVars)} | if (comp == 0) { | return true; | } - | $matches.clear(); + | $matchesAccessor.clear(); | } | | do { - | if ($rightRow == null) { + | if ($rightRowAccessor == null) { | if (!rightIter.hasNext()) { | ${matchedKeyVars.map(_.code).mkString("\n")} - | return !$matches.isEmpty(); + | return !$matchesAccessor.isEmpty(); | } - | $rightRow = (InternalRow) rightIter.next(); + | $rightRowAccessor = (InternalRow) rightIter.next(); | ${rightKeyTmpVars.map(_.code).mkString("\n")} | if ($rightAnyNull) { - | $rightRow = null; + | $rightRowAccessor = null; | continue; | } | ${rightKeyVars.map(_.code).mkString("\n")} | } | ${genComparison(ctx, leftKeyVars, rightKeyVars)} | if (comp > 0) { - | $rightRow = null; + | $rightRowAccessor = null; | } else if (comp < 0) { - | if (!$matches.isEmpty()) { + | if (!$matchesAccessor.isEmpty()) { | ${matchedKeyVars.map(_.code).mkString("\n")} | return true; | } - | $leftRow = null; + | $leftRowAccessor = null; | } else { - | $matches.add((UnsafeRow) $rightRow); - | $rightRow = null;; + | $matchesAccessor.add((UnsafeRow) $rightRowAccessor); + | $rightRowAccessor = null;; | } - | } while ($leftRow != null); + | } while ($leftRowAccessor != null); | } | return false; // unreachable |} """.stripMargin, inlineToOuterClass = true) - (leftRow, matches) + (leftRowAccessor, matchesAccessor) } /** @@ -519,18 +519,18 @@ case class SortMergeJoinExec( val value = ctx.freshName("value") val valueCode = ctx.getValue(leftRow, a.dataType, i.toString) // declare it as class member, so we can access the column before or in the loop. - ctx.addMutableState(ctx.javaType(a.dataType), value, "") + val valueAccessor = ctx.addMutableState(ctx.javaType(a.dataType), value, "") if (a.nullable) { val isNull = ctx.freshName("isNull") - ctx.addMutableState("boolean", isNull, "") + val isNullAccessor = ctx.addMutableState("boolean", isNull, "") val code = s""" - |$isNull = $leftRow.isNullAt($i); - |$value = $isNull ? ${ctx.defaultValue(a.dataType)} : ($valueCode); + |$isNullAccessor = $leftRow.isNullAt($i); + |$valueAccessor = $isNullAccessor ? ${ctx.defaultValue(a.dataType)} : ($valueCode); """.stripMargin - ExprCode(code, isNull, value) + ExprCode(code, isNullAccessor, valueAccessor) } else { - ExprCode(s"$value = $valueCode;", "false", value) + ExprCode(s"$valueAccessor = $valueCode;", "false", valueAccessor) } } } @@ -572,9 +572,11 @@ case class SortMergeJoinExec( override def doProduce(ctx: CodegenContext): String = { ctx.copyResult = true val leftInput = ctx.freshName("leftInput") - ctx.addMutableState("scala.collection.Iterator", leftInput, s"$leftInput = inputs[0];") + val leftInputAccessor = ctx.addMutableState("scala.collection.Iterator", + leftInput, s"$leftInput = inputs[0];") val rightInput = ctx.freshName("rightInput") - ctx.addMutableState("scala.collection.Iterator", rightInput, s"$rightInput = inputs[1];") + val rightInputAccessor = ctx.addMutableState("scala.collection.Iterator", + rightInput, s"$rightInput = inputs[1];") val (leftRow, matches) = genScanner(ctx) @@ -615,7 +617,7 @@ case class SortMergeJoinExec( } s""" - |while (findNextInnerJoinRows($leftInput, $rightInput)) { + |while (findNextInnerJoinRows($leftInputAccessor, $rightInputAccessor)) { | ${beforeLoop.trim} | scala.collection.Iterator $iterator = $matches.generateIterator(); | while ($iterator.hasNext()) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 13da4b26a5dcb..f5ffcb9402fb5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -72,22 +72,22 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { val stopEarly = ctx.freshName("stopEarly") - ctx.addMutableState("boolean", stopEarly, s"$stopEarly = false;") + val stopEarlyAccessor = ctx.addMutableState("boolean", stopEarly, s"$stopEarly = false;") ctx.addNewFunction("stopEarly", s""" @Override protected boolean stopEarly() { - return $stopEarly; + return $stopEarlyAccessor; } """, inlineToOuterClass = true) val countTerm = ctx.freshName("count") - ctx.addMutableState("int", countTerm, s"$countTerm = 0;") + val countTermAccessor = ctx.addMutableState("int", countTerm, s"$countTerm = 0;") s""" - | if ($countTerm < $limit) { - | $countTerm += 1; + | if ($countTermAccessor < $limit) { + | $countTermAccessor += 1; | ${consume(ctx, input)} | } else { - | $stopEarly = true; + | $stopEarlyAccessor = true; | } """.stripMargin }