From 1df9943915b2c3df0945d4def8282a3159f584bf Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 30 Apr 2018 13:29:07 +0000 Subject: [PATCH 01/20] Add API for handling expression code generation. --- .../sql/catalyst/expressions/Expression.scala | 3 - .../expressions/codegen/javaCode.scala | 69 +++++++++++++++++- .../expressions/codegen/ExprValueSuite.scala | 72 ++++++++++++++++++- 3 files changed, 139 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 97dff6ae88299..78865a6f2d233 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -445,7 +445,6 @@ abstract class UnaryExpression extends Expression { """) } else { ev.copy(code = s""" - boolean ${ev.isNull} = false; ${childGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $resultCode""", isNull = FalseLiteral) @@ -544,7 +543,6 @@ abstract class BinaryExpression extends Expression { """) } else { ev.copy(code = s""" - boolean ${ev.isNull} = false; ${leftGen.code} ${rightGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -687,7 +685,6 @@ abstract class TernaryExpression extends Expression { $nullSafeEval""") } else { ev.copy(code = s""" - boolean ${ev.isNull} = false; ${leftGen.code} ${midGen.code} ${rightGen.code} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index 74ff018488863..5f76e7b48d58a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import java.lang.{Boolean => JBool} +import scala.collection.mutable import scala.language.{existentials, implicitConversions} import org.apache.spark.sql.types.{BooleanType, DataType} @@ -120,12 +121,78 @@ object JavaCode { trait ExprValue extends JavaCode { def javaType: Class[_] def isPrimitive: Boolean = javaType.isPrimitive + + // This will be called during string interpolation. + override def toString: String = ExprValue.exprValueToString(this) } object ExprValue { - implicit def exprValueToString(exprValue: ExprValue): String = exprValue.toString + + private var currentHandlerForExprValue = + new ThreadLocal[(ExprValue) => String]() { + override def initialValue(): (ExprValue) => String = + defaultExprValueToString + } + + def defaultExprValueToString(exprValue: ExprValue): String = exprValue.code + + implicit def exprValueToString(exprValue: ExprValue): String = { + currentListeners.get.foreach(_.genExprValue(exprValue)) + currentHandlerForExprValue.get()(exprValue) + } + + protected[sql] val currentListeners = + new ThreadLocal[mutable.ArrayBuffer[ExprValueCodegenListener]]() { + override def initialValue(): mutable.ArrayBuffer[ExprValueCodegenListener] = + mutable.ArrayBuffer.empty[ExprValueCodegenListener] + } + + /** + * By adding a listener, we can catch all expr values which are generated during codegen. + */ + def withExprValueCodegenListener[T](listener: ExprValueCodegenListener) + (f: () => T): T = { + val index = currentListeners.get.length + currentListeners.get += listener + val result = f() + assert(currentListeners.get.length == index + 1, + "There are codegen listeners which are not removed after used!") + currentListeners.get.remove(index) + result + } + + /** + * Sets up a handler function used to convert `ExprValue` to string during codegen. The given + * handler can overwrite the default behavior which outputs `ExprValue` to its actual code. + */ + def withExprValueHandler[T](handlerFunc: (ExprValue) => String)(f: () => T): T = { + val previousHandler = currentHandlerForExprValue.get + currentHandlerForExprValue.set(handlerFunc) + val result = f() + currentHandlerForExprValue.set(previousHandler) + result + } +} + +/** + * A listener which gets notified when `ExprValue`s are generated to Java code. + */ +trait ExprValueCodegenListener { + def genExprValue(exprValue: ExprValue): Unit } +/** + * A codegen listener specified for `SimpleExprValue` expression values. + */ +case class SimpleExprValueCodegenListener() extends ExprValueCodegenListener { + protected[sql] val exprValues: mutable.HashSet[SimpleExprValue] = + mutable.HashSet.empty[SimpleExprValue] + + override def genExprValue(exprValue: ExprValue): Unit = exprValue match { + case s: SimpleExprValue => exprValues += s + case _ => + } +} /** * A java expression fragment. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala index 378b8bc055e34..84a960bb483a6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala @@ -17,8 +17,11 @@ package org.apache.spark.sql.catalyst.expressions.codegen +import scala.collection.mutable + import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.types.BooleanType +import org.apache.spark.sql.catalyst.expressions.{Add, Literal} +import org.apache.spark.sql.types.{BooleanType, IntegerType} class ExprValueSuite extends SparkFunSuite { @@ -35,4 +38,71 @@ class ExprValueSuite extends SparkFunSuite { assert(trueLit === JavaCode.literal("true", BooleanType)) assert(falseLit === JavaCode.literal("false", BooleanType)) } + + test("add expr value codegen listener") { + val context = new CodegenContext() + val expr = Add(Literal(1), Literal(2)) + val listener = new TestExprValueListener() + + assert(ExprValue.currentListeners.get.size == 0) + + ExprValue.withExprValueCodegenListener(listener) { () => + assert(ExprValue.currentListeners.get.size == 1) + val genCode = expr.genCode(context) + // Materialize generated code + s"${genCode.value} ${genCode.isNull}" + } + // 1 VariableValue + 3 LiteralValue (1, 2, false) + assert(listener.exprValues.size == 4) + } + + test("Listen to statement expression generation") { + // This listener only listens statement expression. + val listener = SimpleExprValueCodegenListener() + + val context = new CodegenContext() + val expr = Add(Literal(1), Literal(2)) + + assert(ExprValue.currentListeners.get.size == 0) + + ExprValue.withExprValueCodegenListener(listener) { () => + assert(ExprValue.currentListeners.get.size == 1) + val genCode = expr.genCode(context) + // Materialize generated code + s"${genCode.value} ${genCode.isNull}" + + assert(listener.exprValues.size == 0) + + val statementValue = JavaCode.expression(s"${genCode.value} + 1", IntegerType) + s"$statementValue" + } + assert(listener.exprValues.size == 1) + assert(listener.exprValues.head.code == "(value_0 + 1)") + } + + test("Use expr value handler to control expression generated code") { + // Let all `SimpleExprValue` output special string instead of their code. + def statementHanlder(exprValue: ExprValue): String = exprValue match { + case _: SimpleExprValue => "%STATEMENT%" + case _ => ExprValue.defaultExprValueToString(exprValue) + } + + val context = new CodegenContext() + + val exprCode = ExprValue.withExprValueHandler(statementHanlder) { () => + val statementValue = JavaCode.expression(s"1 + 1", IntegerType) + ExprCode(code = s"int result = $statementValue;", + value = JavaCode.variable("result", IntegerType), + isNull = FalseLiteral) + } + assert(exprCode.code == "int result = %STATEMENT%;") + } +} + +private class TestExprValueListener extends ExprValueCodegenListener { + val exprValues: mutable.HashSet[ExprValue] = mutable.HashSet.empty[ExprValue] + + override def genExprValue(exprValue: ExprValue): Unit = { + exprValues += exprValue + } } From 5fe425c2d2837f00bdfe9ba5e6f446829fba32c1 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 1 May 2018 15:06:32 +0000 Subject: [PATCH 02/20] Add new abstraction for expression codegen. --- .../catalyst/expressions/BoundAttribute.scala | 5 +- .../spark/sql/catalyst/expressions/Cast.scala | 9 +- .../sql/catalyst/expressions/Expression.scala | 22 +-- .../MonotonicallyIncreasingID.scala | 3 +- .../sql/catalyst/expressions/ScalaUDF.scala | 3 +- .../sql/catalyst/expressions/SortOrder.scala | 3 +- .../expressions/SparkPartitionID.scala | 3 +- .../sql/catalyst/expressions/TimeWindow.scala | 3 +- .../sql/catalyst/expressions/arithmetic.scala | 17 ++- .../expressions/codegen/CodeGenerator.scala | 15 +- .../expressions/codegen/CodegenFallback.scala | 5 +- .../codegen/GenerateSafeProjection.scala | 7 +- .../codegen/GenerateUnsafeProjection.scala | 3 +- .../expressions/codegen/javaCode.scala | 143 +++++++++++------- .../expressions/collectionOperations.scala | 17 ++- .../expressions/complexTypeCreator.scala | 7 +- .../expressions/conditionalExpressions.scala | 5 +- .../expressions/datetimeExpressions.scala | 21 +-- .../expressions/decimalExpressions.scala | 5 +- .../sql/catalyst/expressions/generators.scala | 3 +- .../spark/sql/catalyst/expressions/hash.scala | 5 +- .../catalyst/expressions/inputFileBlock.scala | 13 +- .../expressions/mathExpressions.scala | 5 +- .../spark/sql/catalyst/expressions/misc.scala | 5 +- .../expressions/nullExpressions.scala | 9 +- .../expressions/objects/objects.scala | 48 +++--- .../sql/catalyst/expressions/predicates.scala | 15 +- .../expressions/randomExpressions.scala | 5 +- .../expressions/regexpExpressions.scala | 9 +- .../expressions/stringExpressions.scala | 25 +-- .../ExpressionEvalHelperSuite.scala | 3 +- .../expressions/codegen/ExprValueSuite.scala | 70 +-------- .../sql/execution/ColumnarBatchScan.scala | 9 +- .../spark/sql/execution/ExpandExec.scala | 3 +- .../spark/sql/execution/GenerateExec.scala | 5 +- .../sql/execution/WholeStageCodegenExec.scala | 9 +- .../aggregate/HashAggregateExec.scala | 3 +- .../aggregate/HashMapGenerator.scala | 3 +- .../joins/BroadcastHashJoinExec.scala | 3 +- .../execution/joins/SortMergeJoinExec.scala | 5 +- .../spark/sql/GeneratorFunctionSuite.scala | 4 +- 41 files changed, 276 insertions(+), 279 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 4cc84b27d9eb0..df3ab05e02c76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -21,6 +21,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ /** @@ -56,13 +57,13 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) if (nullable) { ev.copy(code = - s""" + code""" |boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); |$javaType ${ev.value} = ${ev.isNull} ? | ${CodeGenerator.defaultValue(dataType)} : ($value); """.stripMargin) } else { - ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = FalseLiteral) + ev.copy(code = code"$javaType ${ev.value} = $value;", isNull = FalseLiteral) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 12330bfa55ab9..d3e2b96bc3786 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -23,6 +23,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -623,8 +624,14 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx) + + // Below the code comment including `eval.value` and `eval.isNull` is a trick. It makes the two + // expr values are referred by this code block. ev.copy(code = eval.code + - castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast)) + code""" + // Cast from ${eval.value}, ${eval.isNull} + ${castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast)} + """) } // The function arguments are: `input`, `result` and `resultIsNull`. We don't need `inputIsNull` diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 78865a6f2d233..000f128f63c60 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -22,6 +22,7 @@ import java.util.Locale import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -100,7 +101,8 @@ abstract class Expression extends TreeNode[Expression] { ctx.subExprEliminationExprs.get(this).map { subExprState => // This expression is repeated which means that the code to evaluate it has already been added // as a function before. In that case, we just re-use it. - ExprCode(ctx.registerComment(this.toString), subExprState.isNull, subExprState.value) + ExprCode(JavaCode.block(ctx.registerComment(this.toString)), subExprState.isNull, + subExprState.value) }.getOrElse { val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") @@ -108,9 +110,9 @@ abstract class Expression extends TreeNode[Expression] { JavaCode.isNullVariable(isNull), JavaCode.variable(value, dataType))) reduceCodeSize(ctx, eval) - if (eval.code.nonEmpty) { + if (eval.code.toString.nonEmpty) { // Add `this` in the comment. - eval.copy(code = s"${ctx.registerComment(this.toString)}\n" + eval.code.trim) + eval.copy(code = JavaCode.block(s"${ctx.registerComment(this.toString)}\n") + eval.code) } else { eval } @@ -143,7 +145,7 @@ abstract class Expression extends TreeNode[Expression] { """.stripMargin) eval.value = JavaCode.variable(newValue, dataType) - eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});" + eval.code = code"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});" } } @@ -437,14 +439,14 @@ abstract class UnaryExpression extends Expression { if (nullable) { val nullSafeEval = ctx.nullSafeExec(child.nullable, childGen.isNull)(resultCode) - ev.copy(code = s""" + ev.copy(code = code""" ${childGen.code} boolean ${ev.isNull} = ${childGen.isNull}; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $nullSafeEval """) } else { - ev.copy(code = s""" + ev.copy(code = code""" ${childGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $resultCode""", isNull = FalseLiteral) @@ -536,13 +538,13 @@ abstract class BinaryExpression extends Expression { } } - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = true; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $nullSafeEval """) } else { - ev.copy(code = s""" + ev.copy(code = code""" ${leftGen.code} ${rightGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -679,12 +681,12 @@ abstract class TernaryExpression extends Expression { } } - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = true; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $nullSafeEval""") } else { - ev.copy(code = s""" + ev.copy(code = code""" ${leftGen.code} ${midGen.code} ${rightGen.code} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index ad1e7bdb31987..a3b20f9c5e468 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.{DataType, LongType} /** @@ -71,7 +72,7 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Stateful { ctx.addPartitionInitializationStatement(s"$countTerm = 0L;") ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;") - ev.copy(code = s""" + ev.copy(code = code""" final ${CodeGenerator.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm; $countTerm++;""", isNull = FalseLiteral) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index e869258469a97..3e7ca88249737 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.DataType /** @@ -1030,7 +1031,7 @@ case class ScalaUDF( """.stripMargin ev.copy(code = - s""" + code""" |$evalCode |${initArgs.mkString("\n")} |$callFunc diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index ff7c98f714905..2ce9d072c71c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.unsafe.sort.PrefixComparators._ @@ -181,7 +182,7 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression { } ev.copy(code = childCode.code + - s""" + code""" |long ${ev.value} = 0L; |boolean ${ev.isNull} = ${childCode.isNull}; |if (!${childCode.isNull}) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index 787bcaf5e81de..9856b37e53fbc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.{DataType, IntegerType} /** @@ -46,7 +47,7 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic { val idTerm = "partitionId" ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_INT, idTerm) ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;") - ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = $idTerm;", + ev.copy(code = code"final ${CodeGenerator.javaType(dataType)} ${ev.value} = $idTerm;", isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala index 6c4a3601c1730..84e38a8b2711e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -164,7 +165,7 @@ case class PreciseTimestampConversion( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) ev.copy(code = eval.code + - s"""boolean ${ev.isNull} = ${eval.isNull}; + code"""boolean ${ev.isNull} = ${eval.isNull}; |${CodeGenerator.javaType(dataType)} ${ev.value} = ${eval.value}; """.stripMargin) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index d4e322d23b95b..2554b01c453a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -275,7 +276,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic s"($javaType)(${eval1.value} $symbol ${eval2.value})" } if (!left.nullable && !right.nullable) { - ev.copy(code = s""" + ev.copy(code = code""" ${eval2.code} boolean ${ev.isNull} = false; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -286,7 +287,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic ${ev.value} = $divide; }""") } else { - ev.copy(code = s""" + ev.copy(code = code""" ${eval2.code} boolean ${ev.isNull} = false; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -362,7 +363,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet s"($javaType)(${eval1.value} $symbol ${eval2.value})" } if (!left.nullable && !right.nullable) { - ev.copy(code = s""" + ev.copy(code = code""" ${eval2.code} boolean ${ev.isNull} = false; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -373,7 +374,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet ${ev.value} = $remainder; }""") } else { - ev.copy(code = s""" + ev.copy(code = code""" ${eval2.code} boolean ${ev.isNull} = false; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -479,7 +480,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { } if (!left.nullable && !right.nullable) { - ev.copy(code = s""" + ev.copy(code = code""" ${eval2.code} boolean ${ev.isNull} = false; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -490,7 +491,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { $result }""") } else { - ev.copy(code = s""" + ev.copy(code = code""" ${eval2.code} boolean ${ev.isNull} = false; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -612,7 +613,7 @@ case class Least(children: Seq[Expression]) extends Expression { """.stripMargin, foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) ev.copy(code = - s""" + code""" |${ev.isNull} = true; |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |$codes @@ -687,7 +688,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { """.stripMargin, foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) ev.copy(code = - s""" + code""" |${ev.isNull} = true; |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |$codes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index cf0a91ff00626..d12d464b84d08 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -38,6 +38,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.metrics.source.CodegenMetrics import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -56,19 +57,19 @@ import org.apache.spark.util.{ParentClassLoader, Utils} * @param value A term for a (possibly primitive) value of the result of the evaluation. Not * valid if `isNull` is set to `true`. */ -case class ExprCode(var code: String, var isNull: ExprValue, var value: ExprValue) +case class ExprCode(var code: Block, var isNull: ExprValue, var value: ExprValue) object ExprCode { def apply(isNull: ExprValue, value: ExprValue): ExprCode = { - ExprCode(code = "", isNull, value) + ExprCode(code = code"", isNull, value) } def forNullValue(dataType: DataType): ExprCode = { - ExprCode(code = "", isNull = TrueLiteral, JavaCode.defaultLiteral(dataType)) + ExprCode(code = code"", isNull = TrueLiteral, JavaCode.defaultLiteral(dataType)) } def forNonNullValue(value: ExprValue): ExprCode = { - ExprCode(code = "", isNull = FalseLiteral, value = value) + ExprCode(code = code"", isNull = FalseLiteral, value = value) } } @@ -329,9 +330,9 @@ class CodegenContext { def addBufferedState(dataType: DataType, variableName: String, initCode: String): ExprCode = { val value = addMutableState(javaType(dataType), variableName) val code = dataType match { - case StringType => s"$value = $initCode.clone();" - case _: StructType | _: ArrayType | _: MapType => s"$value = $initCode.copy();" - case _ => s"$value = $initCode;" + case StringType => code"$value = $initCode.clone();" + case _: StructType | _: ArrayType | _: MapType => code"$value = $initCode.copy();" + case _ => code"$value = $initCode;" } ExprCode(code, FalseLiteral, JavaCode.global(value, dataType)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index a91989e129664..3f4704d287cbd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, Nondeterministic} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ /** * A trait that can be used to provide a fallback mode for expression code generation. @@ -46,7 +47,7 @@ trait CodegenFallback extends Expression { val placeHolder = ctx.registerComment(this.toString) val javaType = CodeGenerator.javaType(this.dataType) if (nullable) { - ev.copy(code = s""" + ev.copy(code = code""" $placeHolder Object $objectTerm = ((Expression) references[$idx]).eval($input); boolean ${ev.isNull} = $objectTerm == null; @@ -55,7 +56,7 @@ trait CodegenFallback extends Expression { ${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm; }""") } else { - ev.copy(code = s""" + ev.copy(code = code""" $placeHolder Object $objectTerm = ((Expression) references[$idx]).eval($input); $javaType ${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 01c350e9dbf69..39778661d1c48 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -22,6 +22,7 @@ import scala.annotation.tailrec import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ @@ -71,7 +72,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] arguments = Seq("InternalRow" -> tmpInput, "Object[]" -> values) ) val code = - s""" + code""" |final InternalRow $tmpInput = $input; |final Object[] $values = new Object[${schema.length}]; |$allFields @@ -97,7 +98,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] ctx, JavaCode.expression(CodeGenerator.getValue(tmpInput, elementType, index), elementType), elementType) - val code = s""" + val code = code""" final ArrayData $tmpInput = $input; final int $numElements = $tmpInput.numElements(); final Object[] $values = new Object[$numElements]; @@ -124,7 +125,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val keyConverter = createCodeForArray(ctx, s"$tmpInput.keyArray()", keyType) val valueConverter = createCodeForArray(ctx, s"$tmpInput.valueArray()", valueType) - val code = s""" + val code = code""" final MapData $tmpInput = $input; ${keyConverter.code} ${valueConverter.code} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 01b4d6c4529bd..00f7b04f9a585 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ /** @@ -286,7 +287,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx, ctx.INPUT_ROW, exprEvals, exprTypes, rowWriter, isTopLevel = true) val code = - s""" + code""" |$rowWriter.reset(); |$evalSubexpr |$writeExpressions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index 5f76e7b48d58a..70fb97899a27c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions.codegen import java.lang.{Boolean => JBool} -import scala.collection.mutable import scala.language.{existentials, implicitConversions} import org.apache.spark.sql.types.{BooleanType, DataType} @@ -113,85 +112,117 @@ object JavaCode { def isNullExpression(code: String): SimpleExprValue = { expression(code, BooleanType) } + + def block(code: String): Block = { + CodeBlock(codeParts = Seq(code), exprValues = Seq.empty) + } } /** - * A typed java fragment that must be a valid java expression. + * A block of java code which involves some expressions represented by `ExprValue`. */ -trait ExprValue extends JavaCode { - def javaType: Class[_] - def isPrimitive: Boolean = javaType.isPrimitive +trait Block extends JavaCode { + def exprValues: Seq[Any] // This will be called during string interpolation. - override def toString: String = ExprValue.exprValueToString(this) -} + override def toString: String = _marginChar match { + case Some(c) => code.stripMargin(c) + case _ => code + } -object ExprValue { + var _marginChar: Option[Char] = None - private var currentHandlerForExprValue = - new ThreadLocal[(ExprValue) => String]() { - override def initialValue(): (ExprValue) => String = - defaultExprValueToString - } + def stripMargin(c: Char): this.type = { + _marginChar = Some(c) + this + } - def defaultExprValueToString(exprValue: ExprValue): String = exprValue.code + def stripMargin: this.type = { + _marginChar = Some('|') + this + } + + def + (other: Block): Block +} - implicit def exprValueToString(exprValue: ExprValue): String = { - currentListeners.get.foreach(_.genExprValue(exprValue)) - currentHandlerForExprValue.get()(exprValue) +object Block { + implicit def blockToString(block: Block): String = block.toString + + implicit def blocksToBlock(blocks: Seq[Block]): Block = Blocks(blocks) + + implicit class BlockHelper(val sc: StringContext) extends AnyVal { + def code(args: Any*): Block = { + if (sc.parts.length == 0) { + EmptyBlock + } else { + args.foreach { + case _: ExprValue => true + case _: Int | _: Long | _: Float | _: Double | _: String => true + case _: Block => true + case other => throw new IllegalArgumentException( + s"Can not interpolate ${other.getClass} into code block.") + } + CodeBlock(sc.parts, args) + } + } } +} - protected[sql] val currentListeners = - new ThreadLocal[mutable.ArrayBuffer[ExprValueCodegenListener]]() { - override def initialValue(): mutable.ArrayBuffer[ExprValueCodegenListener] = - mutable.ArrayBuffer.empty[ExprValueCodegenListener] +/** + * A block of java code. + */ +case class CodeBlock(codeParts: Seq[String], exprValues: Seq[Any]) extends Block { + override def code: String = { + val strings = codeParts.iterator + val expressions = exprValues.iterator + var buf = new StringBuffer(strings.next) + while (strings.hasNext) { + if (expressions.hasNext) { + buf append expressions.next + buf append strings.next + } } + buf.toString + } - /** - * By adding a listener, we can catch all expr values which are generated during codegen. - */ - def withExprValueCodegenListener[T](listener: ExprValueCodegenListener) - (f: () => T): T = { - val index = currentListeners.get.length - currentListeners.get += listener - val result = f() - assert(currentListeners.get.length == index + 1, - "There are codegen listeners which are not removed after used!") - currentListeners.get.remove(index) - result + override def + (other: Block): Block = other match { + case c: CodeBlock => Blocks(Seq(this, c)) + case b: Blocks => Blocks(Seq(this) ++ b.blocks) + case EmptyBlock => this } +} - /** - * Sets up a handler function used to convert `ExprValue` to string during codegen. The given - * handler can overwrite the default behavior which outputs `ExprValue` to its actual code. - */ - def withExprValueHandler[T](handlerFunc: (ExprValue) => String)(f: () => T): T = { - val previousHandler = currentHandlerForExprValue.get - currentHandlerForExprValue.set(handlerFunc) - val result = f() - currentHandlerForExprValue.set(previousHandler) - result +case class Blocks(blocks: Seq[Block]) extends Block { + override def exprValues: Seq[Any] = blocks.flatMap(_.exprValues) + override def code: String = blocks.map(_.toString).mkString + + override def + (other: Block): Block = other match { + case c: CodeBlock => Blocks(blocks :+ c) + case b: Blocks => Blocks(blocks ++ b.blocks) + case EmptyBlock => this } } -/** - * A listener which gets notified when `ExprValue`s are generated to Java code. - */ -trait ExprValueCodegenListener { - def genExprValue(exprValue: ExprValue): Unit +object EmptyBlock extends Block with Serializable { + override def code: String = "" + override def exprValues: Seq[Any] = Seq.empty + + override def + (other: Block): Block = other } /** - * A codegen listener specified for `SimpleExprValue` expression values. + * A typed java fragment that must be a valid java expression. */ -case class SimpleExprValueCodegenListener() extends ExprValueCodegenListener { - protected[sql] val exprValues: mutable.HashSet[SimpleExprValue] = - mutable.HashSet.empty[SimpleExprValue] +trait ExprValue extends JavaCode { + def javaType: Class[_] + def isPrimitive: Boolean = javaType.isPrimitive - override def genExprValue(exprValue: ExprValue): Unit = exprValue match { - case s: SimpleExprValue => exprValues += s - case _ => - } + // This will be called during string interpolation. + override def toString: String = ExprValue.exprValueToString(this) +} + +object ExprValue { + implicit def exprValueToString(exprValue: ExprValue): String = exprValue.code } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 6d63a531e3b74..ee2bbe2fd6884 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -21,6 +21,7 @@ import java.util.Comparator import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform @@ -54,7 +55,7 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val childGen = child.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = false; ${childGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 : @@ -486,14 +487,14 @@ case class ArrayJoin( } if (nullable) { ev.copy( - s""" + code""" |boolean ${ev.isNull} = true; |UTF8String ${ev.value} = null; |$code """.stripMargin) } else { ev.copy( - s""" + code""" |UTF8String ${ev.value} = null; |$code """.stripMargin, FalseLiteral) @@ -578,11 +579,11 @@ case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCast val childGen = child.genCode(ctx) val javaType = CodeGenerator.javaType(dataType) val i = ctx.freshName("i") - val item = ExprCode("", + val item = ExprCode(EmptyBlock, isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"), value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType)) ev.copy(code = - s""" + code""" |${childGen.code} |boolean ${ev.isNull} = true; |$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -643,11 +644,11 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast val childGen = child.genCode(ctx) val javaType = CodeGenerator.javaType(dataType) val i = ctx.freshName("i") - val item = ExprCode("", + val item = ExprCode(EmptyBlock, isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"), value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType)) ev.copy(code = - s""" + code""" |${childGen.code} |boolean ${ev.isNull} = true; |$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -938,7 +939,7 @@ case class Concat(children: Seq[Expression]) extends Expression { expressions = inputs, funcName = "valueConcat", extraArguments = (s"$javaType[]", args) :: Nil) - ev.copy(s""" + ev.copy(code""" $initCode $codes $javaType ${ev.value} = $concatenator.concat($args); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 67876a8565488..791ba05d4c228 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform @@ -63,7 +64,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { val (preprocess, assigns, postprocess, arrayData) = GenArrayData.genCodeToCreateArrayData(ctx, et, evals, false) ev.copy( - code = preprocess + assigns + postprocess, + code = JavaCode.block(preprocess + assigns + postprocess), value = JavaCode.variable(arrayData, dataType), isNull = FalseLiteral) } @@ -219,7 +220,7 @@ case class CreateMap(children: Seq[Expression]) extends Expression { val (preprocessValueData, assignValues, postprocessValueData, valueArrayData) = GenArrayData.genCodeToCreateArrayData(ctx, valueDt, evalValues, false) val code = - s""" + code""" final boolean ${ev.isNull} = false; $preprocessKeyData $assignKeys @@ -373,7 +374,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc extraArguments = "Object[]" -> values :: Nil) ev.copy(code = - s""" + code""" |Object[] $values = new Object[${valExprs.size}]; |$valuesCode |final InternalRow ${ev.value} = new $rowClass($values); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 205d77f6a9acf..77ac6c088022e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ // scalastyle:off line.size.limit @@ -66,7 +67,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi val falseEval = falseValue.genCode(ctx) val code = - s""" + code""" |${condEval.code} |boolean ${ev.isNull} = false; |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -265,7 +266,7 @@ case class CaseWhen( }.mkString) ev.copy(code = - s""" + code""" |${CodeGenerator.JAVA_BYTE} $resultState = $NOT_MATCHED; |do { | $codes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index d882d06cfd625..e9d1aa8403bab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -27,6 +27,7 @@ import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -717,7 +718,7 @@ abstract class UnixTime } else { val formatterName = ctx.addReferenceObj("formatter", formatter, df) val eval1 = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -746,7 +747,7 @@ abstract class UnixTime }) case TimestampType => val eval1 = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -757,7 +758,7 @@ abstract class UnixTime val tz = ctx.addReferenceObj("timeZone", timeZone) val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") val eval1 = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -852,7 +853,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ } else { val formatterName = ctx.addReferenceObj("formatter", formatter, df) val t = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" ${t.code} boolean ${ev.isNull} = ${t.isNull}; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -1048,7 +1049,7 @@ case class FromUTCTimestamp(left: Expression, right: Expression) if (right.foldable) { val tz = right.eval().asInstanceOf[UTF8String] if (tz == null) { - ev.copy(code = s""" + ev.copy(code = code""" |boolean ${ev.isNull} = true; |long ${ev.value} = 0; """.stripMargin) @@ -1062,7 +1063,7 @@ case class FromUTCTimestamp(left: Expression, right: Expression) ctx.addImmutableStateIfNotExists(tzClass, utcTerm, v => s"""$v = $dtu.getTimeZone("UTC");""") val eval = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" |${eval.code} |boolean ${ev.isNull} = ${eval.isNull}; |long ${ev.value} = 0; @@ -1237,7 +1238,7 @@ case class ToUTCTimestamp(left: Expression, right: Expression) if (right.foldable) { val tz = right.eval().asInstanceOf[UTF8String] if (tz == null) { - ev.copy(code = s""" + ev.copy(code = code""" |boolean ${ev.isNull} = true; |long ${ev.value} = 0; """.stripMargin) @@ -1251,7 +1252,7 @@ case class ToUTCTimestamp(left: Expression, right: Expression) ctx.addImmutableStateIfNotExists(tzClass, utcTerm, v => s"""$v = $dtu.getTimeZone("UTC");""") val eval = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" |${eval.code} |boolean ${ev.isNull} = ${eval.isNull}; |long ${ev.value} = 0; @@ -1394,13 +1395,13 @@ trait TruncInstant extends BinaryExpression with ImplicitCastInputTypes { val javaType = CodeGenerator.javaType(dataType) if (format.foldable) { if (truncLevel == DateTimeUtils.TRUNC_INVALID || truncLevel > maxLevel) { - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = true; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""") } else { val t = instant.genCode(ctx) val truncFuncStr = truncFunc(t.value, truncLevel.toString) - ev.copy(code = s""" + ev.copy(code = code""" ${t.code} boolean ${ev.isNull} = ${t.isNull}; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index db1579ba28671..04de83343be71 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, EmptyBlock, ExprCode} import org.apache.spark.sql.types._ /** @@ -72,7 +72,8 @@ case class PromotePrecision(child: Expression) extends UnaryExpression { override def eval(input: InternalRow): Any = child.eval(input) /** Just a simple pass-through for code generation. */ override def genCode(ctx: CodegenContext): ExprCode = child.genCode(ctx) - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ev.copy("") + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + ev.copy(EmptyBlock) override def prettyName: String = "promote_precision" override def sql: String = child.sql override lazy val canonicalized: Expression = child.canonicalized diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 3af4bfebad45e..b7c52f1d7b40a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ @@ -215,7 +216,7 @@ case class Stack(children: Seq[Expression]) extends Generator { // Create the collection. val wrapperClass = classOf[mutable.WrappedArray[_]].getName ev.copy(code = - s""" + code""" |$code |$wrapperClass ${ev.value} = $wrapperClass$$.MODULE$$.make($rowData); """.stripMargin, isNull = FalseLiteral) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index ef790338bdd27..cec00b66f873c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -28,6 +28,7 @@ import org.apache.commons.codec.digest.DigestUtils import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform @@ -293,7 +294,7 @@ abstract class HashExpression[E] extends Expression { foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) ev.copy(code = - s""" + code""" |$hashResultType ${ev.value} = $seed; |$codes """.stripMargin) @@ -674,7 +675,7 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { ev.copy(code = - s""" + code""" |${CodeGenerator.JAVA_INT} ${ev.value} = $seed; |${CodeGenerator.JAVA_INT} $childHash = 0; |$codes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala index 2a3cc580273ee..6e57d9d1db6cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.rdd.InputFileBlockHolder import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.{DataType, LongType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -42,8 +43,8 @@ case class InputFileName() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") - ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + - s"$className.getInputFilePath();", isNull = FalseLiteral) + ev.copy(code = code"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + + code"$className.getInputFilePath();", isNull = FalseLiteral) } } @@ -65,8 +66,8 @@ case class InputFileBlockStart() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") - ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + - s"$className.getStartOffset();", isNull = FalseLiteral) + ev.copy(code = code"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + + code"$className.getStartOffset();", isNull = FalseLiteral) } } @@ -88,7 +89,7 @@ case class InputFileBlockLength() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") - ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + - s"$className.getLength();", isNull = FalseLiteral) + ev.copy(code = code"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + + code"$className.getLength();", isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index bc4cfcec47425..c2e1720259b53 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.NumberConverter import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -1191,11 +1192,11 @@ abstract class RoundBase(child: Expression, scale: Expression, val javaType = CodeGenerator.javaType(dataType) if (scaleV == null) { // if scale is null, no need to eval its child at all - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = true; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""") } else { - ev.copy(code = s""" + ev.copy(code = code""" ${ce.code} boolean ${ev.isNull} = ${ce.isNull}; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 7eda65a867028..677bcdb4abf08 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -21,6 +21,7 @@ import java.util.UUID import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.RandomUUIDGenerator import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -88,7 +89,7 @@ case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCa // Use unnamed reference that doesn't create a local field here to reduce the number of fields // because errMsgField is used only when the value is null or false. val errMsgField = ctx.addReferenceObj("errMsg", errMsg) - ExprCode(code = s"""${eval.code} + ExprCode(code = code"""${eval.code} |if (${eval.isNull} || !${eval.value}) { | throw new RuntimeException($errMsgField); |}""".stripMargin, isNull = TrueLiteral, @@ -150,7 +151,7 @@ case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Sta ctx.addPartitionInitializationStatement(s"$randomGen = " + "new org.apache.spark.sql.catalyst.util.RandomUUIDGenerator(" + s"${randomSeed.get}L + partitionIndex);") - ev.copy(code = s"final UTF8String ${ev.value} = $randomGen.getNextUUIDUTF8String();", + ev.copy(code = code"final UTF8String ${ev.value} = $randomGen.getNextUUIDUTF8String();", isNull = FalseLiteral) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 0787342bce6bc..2eeed3bbb2d91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -111,7 +112,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { ev.copy(code = - s""" + code""" |${ev.isNull} = true; |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |do { @@ -232,7 +233,7 @@ case class IsNaN(child: Expression) extends UnaryExpression val eval = child.genCode(ctx) child.dataType match { case DoubleType | FloatType => - ev.copy(code = s""" + ev.copy(code = code""" ${eval.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; ${ev.value} = !${eval.isNull} && Double.isNaN(${eval.value});""", isNull = FalseLiteral) @@ -278,7 +279,7 @@ case class NaNvl(left: Expression, right: Expression) val rightGen = right.genCode(ctx) left.dataType match { case DoubleType | FloatType => - ev.copy(code = s""" + ev.copy(code = code""" ${leftGen.code} boolean ${ev.isNull} = false; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -440,7 +441,7 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate }.mkString) ev.copy(code = - s""" + code""" |${CodeGenerator.JAVA_INT} $nonnull = 0; |do { | $codes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index f974fd81fc788..2bf4203d0fec3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -269,7 +270,7 @@ case class StaticInvoke( s"${ev.value} = $callFunc;" } - val code = s""" + val code = code""" $argCode $prepareIsNull $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -385,8 +386,7 @@ case class Invoke( """ } - val code = s""" - ${obj.code} + val code = obj.code + code""" boolean ${ev.isNull} = true; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${obj.isNull}) { @@ -492,7 +492,7 @@ case class NewInstance( s"new $className($argString)" } - val code = s""" + val code = code""" $argCode ${outer.map(_.code).getOrElse("")} final $javaType ${ev.value} = ${ev.isNull} ? @@ -532,9 +532,7 @@ case class UnwrapOption( val javaType = CodeGenerator.javaType(dataType) val inputObject = child.genCode(ctx) - val code = s""" - ${inputObject.code} - + val code = inputObject.code + code""" final boolean ${ev.isNull} = ${inputObject.isNull} || ${inputObject.value}.isEmpty(); $javaType ${ev.value} = ${ev.isNull} ? ${CodeGenerator.defaultValue(dataType)} : (${CodeGenerator.boxedType(javaType)}) ${inputObject.value}.get(); @@ -564,9 +562,7 @@ case class WrapOption(child: Expression, optType: DataType) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val inputObject = child.genCode(ctx) - val code = s""" - ${inputObject.code} - + val code = inputObject.code + code""" scala.Option ${ev.value} = ${inputObject.isNull} ? scala.Option$$.MODULE$$.apply(null) : new scala.Some(${inputObject.value}); @@ -935,8 +931,7 @@ case class MapObjects private( ) } - val code = s""" - ${genInputData.code} + val code = genInputData.code + code""" ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${genInputData.isNull}) { @@ -1147,8 +1142,7 @@ case class CatalystToExternalMap private( """ val getBuilderResult = s"${ev.value} = (${collClass.getName}) $builderValue.result();" - val code = s""" - ${genInputData.code} + val code = genInputData.code + code""" ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${genInputData.isNull}) { @@ -1391,9 +1385,8 @@ case class ExternalMapToCatalyst private( val mapCls = classOf[ArrayBasedMapData].getName val convertedKeyType = CodeGenerator.boxedType(keyConverter.dataType) val convertedValueType = CodeGenerator.boxedType(valueConverter.dataType) - val code = - s""" - ${inputMap.code} + val code = inputMap.code + + code""" ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${inputMap.isNull}) { final int $length = ${inputMap.value}.size(); @@ -1471,7 +1464,7 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) val schemaField = ctx.addReferenceObj("schema", schema) val code = - s""" + code""" |Object[] $values = new Object[${children.size}]; |$childrenCode |final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, $schemaField); @@ -1499,8 +1492,7 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) val javaType = CodeGenerator.javaType(dataType) val serialize = s"$serializer.serialize(${input.value}, null).array()" - val code = s""" - ${input.code} + val code = input.code + code""" final $javaType ${ev.value} = ${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $serialize; """ @@ -1532,8 +1524,7 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B val deserialize = s"($javaType) $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null)" - val code = s""" - ${input.code} + val code = input.code + code""" final $javaType ${ev.value} = ${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $deserialize; """ @@ -1614,9 +1605,8 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp funcName = "initializeJavaBean", extraArguments = beanInstanceJavaType -> javaBeanInstance :: Nil) - val code = - s""" - |${instanceGen.code} + val code = instanceGen.code + + code""" |$beanInstanceJavaType $javaBeanInstance = ${instanceGen.value}; |if (!${instanceGen.isNull}) { | $initializeCode @@ -1664,9 +1654,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil) // because errMsgField is used only when the value is null. val errMsgField = ctx.addReferenceObj("errMsg", errMsg) - val code = s""" - ${childGen.code} - + val code = childGen.code + code""" if (${childGen.isNull}) { throw new NullPointerException($errMsgField); } @@ -1709,7 +1697,7 @@ case class GetExternalRowField( // because errMsgField is used only when the field is null. val errMsgField = ctx.addReferenceObj("errMsg", errMsg) val row = child.genCode(ctx) - val code = s""" + val code = code""" ${row.code} if (${row.isNull}) { @@ -1784,7 +1772,7 @@ case class ValidateExternalType(child: Expression, expected: DataType) s"$obj instanceof ${CodeGenerator.boxedType(dataType)}" } - val code = s""" + val code = code""" ${input.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${input.isNull}) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index e195ec17f3bcf..0f3c1b3fec978 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -22,6 +22,7 @@ import scala.collection.immutable.TreeSet import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -282,7 +283,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { }.mkString("\n")) ev.copy(code = - s""" + code""" |${valueGen.code} |byte $tmpResult = $HAS_NULL; |if (!${valueGen.isNull}) { @@ -346,7 +347,7 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with "" } ev.copy(code = - s""" + code""" |${childGen.code} |${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = ${childGen.isNull}; |${CodeGenerator.JAVA_BOOLEAN} ${ev.value} = false; @@ -398,7 +399,7 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with // The result should be `false`, if any of them is `false` whenever the other is null or not. if (!left.nullable && !right.nullable) { - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.value} = false; @@ -407,7 +408,7 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with ${ev.value} = ${eval2.value}; }""", isNull = FalseLiteral) } else { - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.isNull} = false; boolean ${ev.value} = false; @@ -462,7 +463,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P // The result should be `true`, if any of them is `true` whenever the other is null or not. if (!left.nullable && !right.nullable) { ev.isNull = FalseLiteral - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.value} = true; @@ -471,7 +472,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P ${ev.value} = ${eval2.value}; }""", isNull = FalseLiteral) } else { - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.isNull} = false; boolean ${ev.value} = true; @@ -613,7 +614,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp val eval1 = left.genCode(ctx) val eval2 = right.genCode(ctx) val equalCode = ctx.genEqual(left.dataType, eval1.value, eval2.value) - ev.copy(code = eval1.code + eval2.code + s""" + ev.copy(code = eval1.code + eval2.code + code""" boolean ${ev.value} = (${eval1.isNull} && ${eval2.isNull}) || (!${eval1.isNull} && !${eval2.isNull} && $equalCode);""", isNull = FalseLiteral) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 70186053617f8..a3a96e0f995a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -81,7 +82,7 @@ case class Rand(child: Expression) extends RDG { val rngTerm = ctx.addMutableState(className, "rng") ctx.addPartitionInitializationStatement( s"$rngTerm = new $className(${seed}L + partitionIndex);") - ev.copy(code = s""" + ev.copy(code = code""" final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""", isNull = FalseLiteral) } @@ -118,7 +119,7 @@ case class Randn(child: Expression) extends RDG { val rngTerm = ctx.addMutableState(className, "rng") ctx.addPartitionInitializationStatement( s"$rngTerm = new $className(${seed}L + partitionIndex);") - ev.copy(code = s""" + ev.copy(code = code""" final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", isNull = FalseLiteral) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index ad0c0791d895f..7b68bb771faf3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -23,6 +23,7 @@ import java.util.regex.{MatchResult, Pattern} import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -123,7 +124,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. val eval = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" ${eval.code} boolean ${ev.isNull} = ${eval.isNull}; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -132,7 +133,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi } """) } else { - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = true; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; """) @@ -198,7 +199,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. val eval = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" ${eval.code} boolean ${ev.isNull} = ${eval.isNull}; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -207,7 +208,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress } """) } else { - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = true; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; """) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index ea005a26a4c8b..9823b2fc5ad97 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -27,6 +27,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{ByteArray, UTF8String} @@ -105,7 +106,7 @@ case class ConcatWs(children: Seq[Expression]) expressions = inputs, funcName = "valueConcatWs", extraArguments = ("UTF8String[]", args) :: Nil) - ev.copy(s""" + ev.copy(code""" UTF8String[] $args = new UTF8String[$numArgs]; ${separator.code} $codes @@ -149,7 +150,7 @@ case class ConcatWs(children: Seq[Expression]) } }.unzip - val codes = ctx.splitExpressionsWithCurrentInputs(evals.map(_.code)) + val codes = ctx.splitExpressionsWithCurrentInputs(evals.map(_.code.toString)) val varargCounts = ctx.splitExpressionsWithCurrentInputs( expressions = varargCount, @@ -176,7 +177,7 @@ case class ConcatWs(children: Seq[Expression]) foldFunctions = _.map(funcCall => s"$idxVararg = $funcCall;").mkString("\n")) ev.copy( - s""" + code""" $codes int $varargNum = ${children.count(_.dataType == StringType) - 1}; int $idxVararg = 0; @@ -288,7 +289,7 @@ case class Elt(children: Seq[Expression]) extends Expression { }.mkString) ev.copy( - s""" + code""" |${index.code} |final int $indexVal = ${index.value}; |${CodeGenerator.JAVA_BOOLEAN} $indexMatched = false; @@ -654,7 +655,7 @@ case class StringTrim( val srcString = evals(0) if (evals.length == 1) { - ev.copy(evals.map(_.code).mkString + s""" + ev.copy(evals.map(_.code) :+ code""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -671,7 +672,7 @@ case class StringTrim( } else { ${ev.value} = ${srcString.value}.trim(${trimString.value}); }""" - ev.copy(evals.map(_.code).mkString + s""" + ev.copy(evals.map(_.code) :+ code""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -754,7 +755,7 @@ case class StringTrimLeft( val srcString = evals(0) if (evals.length == 1) { - ev.copy(evals.map(_.code).mkString + s""" + ev.copy(evals.map(_.code) :+ code""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -771,7 +772,7 @@ case class StringTrimLeft( } else { ${ev.value} = ${srcString.value}.trimLeft(${trimString.value}); }""" - ev.copy(evals.map(_.code).mkString + s""" + ev.copy(evals.map(_.code) :+ code""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -856,7 +857,7 @@ case class StringTrimRight( val srcString = evals(0) if (evals.length == 1) { - ev.copy(evals.map(_.code).mkString + s""" + ev.copy(evals.map(_.code) :+ code""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -873,7 +874,7 @@ case class StringTrimRight( } else { ${ev.value} = ${srcString.value}.trimRight(${trimString.value}); }""" - ev.copy(evals.map(_.code).mkString + s""" + ev.copy(evals.map(_.code) :+ code""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -1024,7 +1025,7 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) val substrGen = substr.genCode(ctx) val strGen = str.genCode(ctx) val startGen = start.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" int ${ev.value} = 0; boolean ${ev.isNull} = false; ${startGen.code} @@ -1350,7 +1351,7 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC val formatter = classOf[java.util.Formatter].getName val sb = ctx.freshName("sb") val stringBuffer = classOf[StringBuffer].getName - ev.copy(code = s""" + ev.copy(code = code""" ${pattern.code} boolean ${ev.isNull} = ${pattern.isNull}; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala index 64b65e2070ed6..7c7c4cccee253 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.{DataType, IntegerType} /** @@ -45,7 +46,7 @@ case class BadCodegenExpression() extends LeafExpression { override def eval(input: InternalRow): Any = 10 override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { ev.copy(code = - s""" + code""" |int some_variable = 11; |int ${ev.value} = 10; """.stripMargin) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala index 84a960bb483a6..4045d24968002 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala @@ -20,8 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import scala.collection.mutable import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{Add, Literal} -import org.apache.spark.sql.types.{BooleanType, IntegerType} +import org.apache.spark.sql.types.BooleanType class ExprValueSuite extends SparkFunSuite { @@ -38,71 +37,4 @@ class ExprValueSuite extends SparkFunSuite { assert(trueLit === JavaCode.literal("true", BooleanType)) assert(falseLit === JavaCode.literal("false", BooleanType)) } - - test("add expr value codegen listener") { - val context = new CodegenContext() - val expr = Add(Literal(1), Literal(2)) - val listener = new TestExprValueListener() - - assert(ExprValue.currentListeners.get.size == 0) - - ExprValue.withExprValueCodegenListener(listener) { () => - assert(ExprValue.currentListeners.get.size == 1) - val genCode = expr.genCode(context) - // Materialize generated code - s"${genCode.value} ${genCode.isNull}" - } - // 1 VariableValue + 3 LiteralValue (1, 2, false) - assert(listener.exprValues.size == 4) - } - - test("Listen to statement expression generation") { - // This listener only listens statement expression. - val listener = SimpleExprValueCodegenListener() - - val context = new CodegenContext() - val expr = Add(Literal(1), Literal(2)) - - assert(ExprValue.currentListeners.get.size == 0) - - ExprValue.withExprValueCodegenListener(listener) { () => - assert(ExprValue.currentListeners.get.size == 1) - val genCode = expr.genCode(context) - // Materialize generated code - s"${genCode.value} ${genCode.isNull}" - - assert(listener.exprValues.size == 0) - - val statementValue = JavaCode.expression(s"${genCode.value} + 1", IntegerType) - s"$statementValue" - } - assert(listener.exprValues.size == 1) - assert(listener.exprValues.head.code == "(value_0 + 1)") - } - - test("Use expr value handler to control expression generated code") { - // Let all `SimpleExprValue` output special string instead of their code. - def statementHanlder(exprValue: ExprValue): String = exprValue match { - case _: SimpleExprValue => "%STATEMENT%" - case _ => ExprValue.defaultExprValueToString(exprValue) - } - - val context = new CodegenContext() - - val exprCode = ExprValue.withExprValueHandler(statementHanlder) { () => - val statementValue = JavaCode.expression(s"1 + 1", IntegerType) - ExprCode(code = s"int result = $statementValue;", - value = JavaCode.variable("result", IntegerType), - isNull = FalseLiteral) - } - assert(exprCode.code == "int result = %STATEMENT%;") - } -} - -private class TestExprValueListener extends ExprValueCodegenListener { - val exprValues: mutable.HashSet[ExprValue] = mutable.HashSet.empty[ExprValue] - - override def genExprValue(exprValue: ExprValue): Unit = { - exprValues += exprValue - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index fc3dbc1c5591b..067b2aa05fe62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.expressions.{BoundReference, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.DataType import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} @@ -58,14 +59,14 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { } val valueVar = ctx.freshName("value") val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]" - val code = s"${ctx.registerComment(str)}\n" + (if (nullable) { - s""" + val code = JavaCode.block(s"${ctx.registerComment(str)}\n") + (if (nullable) { + code""" boolean $isNullVar = $columnVar.isNullAt($ordinal); $javaType $valueVar = $isNullVar ? ${CodeGenerator.defaultValue(dataType)} : ($value); """ } else { - s"$javaType $valueVar = $value;" - }).trim + code"$javaType $valueVar = $value;" + }) ExprCode(code, isNullVar, JavaCode.variable(valueVar, dataType)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index e4812f3d338fb..5b4edf5136e3f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -152,7 +153,7 @@ case class ExpandExec( } else { val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") - val code = s""" + val code = code""" |boolean $isNull = true; |${CodeGenerator.javaType(firstExpr.dataType)} $value = | ${CodeGenerator.defaultValue(firstExpr.dataType)}; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index f40c50df74ccb..2549b9e1537a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -21,6 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types._ @@ -313,13 +314,13 @@ case class GenerateExec( if (checks.nonEmpty) { val isNull = ctx.freshName("isNull") val code = - s""" + code""" |boolean $isNull = ${checks.mkString(" || ")}; |$javaType $value = $isNull ? ${CodeGenerator.defaultValue(dt)} : $getter; """.stripMargin ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, dt)) } else { - ExprCode(s"$javaType $value = $getter;", FalseLiteral, JavaCode.variable(value, dt)) + ExprCode(code"$javaType $value = $getter;", FalseLiteral, JavaCode.variable(value, dt)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 828b51fa199de..fadbeddff2c4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -27,6 +27,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.aggregate.HashAggregateExec @@ -122,10 +123,10 @@ trait CodegenSupport extends SparkPlan { ctx.INPUT_ROW = row ctx.currentVars = colVars val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false) - val code = s""" + val code = code""" |$evaluateInputs |${ev.code.trim} - """.stripMargin.trim + """.stripMargin ExprCode(code, FalseLiteral, ev.value) } else { // There are no columns @@ -260,7 +261,7 @@ trait CodegenSupport extends SparkPlan { */ protected def evaluateVariables(variables: Seq[ExprCode]): String = { val evaluate = variables.filter(_.code != "").map(_.code.trim).mkString("\n") - variables.foreach(_.code = "") + variables.foreach(_.code = EmptyBlock) evaluate } @@ -276,7 +277,7 @@ trait CodegenSupport extends SparkPlan { variables.zipWithIndex.foreach { case (ev, i) => if (ev.code != "" && required.contains(attributes(i))) { evaluateVars.append(ev.code.trim + "\n") - ev.code = "" + ev.code = EmptyBlock } } evaluateVars.toString() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index a5dc6ebf2b0f2..90fba7b01466a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} @@ -190,7 +191,7 @@ case class HashAggregateExec( val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue") // The initial expression should not access any column val ev = e.genCode(ctx) - val initVars = s""" + val initVars = code""" | $isNull = ${ev.isNull}; | $value = ${ev.value}; """.stripMargin diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala index de2d630de3fdb..e1c85823259b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, DeclarativeAggregate} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ /** @@ -50,7 +51,7 @@ abstract class HashMapGenerator( val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue") val ev = e.genCode(ctx) val initVars = - s""" + code""" | $isNull = ${ev.isNull}; | $value = ${ev.value}; """.stripMargin diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 6fa716d9fadee..0da0e8610c392 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, SparkPlan} @@ -183,7 +184,7 @@ case class BroadcastHashJoinExec( val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") val javaType = CodeGenerator.javaType(a.dataType) - val code = s""" + val code = code""" |boolean $isNull = true; |$javaType $value = ${CodeGenerator.defaultValue(a.dataType)}; |if ($matched != null) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index d8261f0f33b61..f4b9d132122e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ @@ -521,7 +522,7 @@ case class SortMergeJoinExec( if (a.nullable) { val isNull = ctx.freshName("isNull") val code = - s""" + code""" |$isNull = $leftRow.isNullAt($i); |$value = $isNull ? $defaultValue : ($valueCode); """.stripMargin @@ -533,7 +534,7 @@ case class SortMergeJoinExec( (ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, a.dataType)), leftVarsDecl) } else { - val code = s"$value = $valueCode;" + val code = code"$value = $valueCode;" val leftVarsDecl = s"""$javaType $value = $defaultValue;""" (ExprCode(code, FalseLiteral, JavaCode.variable(value, a.dataType)), leftVarsDecl) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala index 109fcf90a3ec9..8280a3ce39845 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Expression, Generator} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StructType} @@ -315,6 +316,7 @@ case class EmptyGenerator() extends Generator { override def eval(input: InternalRow): TraversableOnce[InternalRow] = Seq.empty override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val iteratorClass = classOf[Iterator[_]].getName - ev.copy(code = s"$iteratorClass ${ev.value} = $iteratorClass$$.MODULE$$.empty();") + ev.copy(code = + code"$iteratorClass ${ev.value} = $iteratorClass$$.MODULE$$.empty();") } } From 00bef6bb2559d93e68266d2605ef4403f4d8d2af Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 3 May 2018 11:13:15 +0000 Subject: [PATCH 03/20] Add basic tests. --- .../expressions/codegen/javaCode.scala | 26 +++-- .../expressions/codegen/CodeBlockSuite.scala | 98 +++++++++++++++++++ 2 files changed, 114 insertions(+), 10 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index 70fb97899a27c..9a36816c016fc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -114,15 +114,17 @@ object JavaCode { } def block(code: String): Block = { - CodeBlock(codeParts = Seq(code), exprValues = Seq.empty) + CodeBlock(codeParts = Seq(code), blockInputs = Seq.empty) } } /** - * A block of java code which involves some expressions represented by `ExprValue`. + * A trait representing a block of java code. */ trait Block extends JavaCode { - def exprValues: Seq[Any] + + // The expressions to be evaluated inside this block. + def exprValues: Seq[ExprValue] // This will be called during string interpolation. override def toString: String = _marginChar match { @@ -169,16 +171,20 @@ object Block { } /** - * A block of java code. + * A block of java code. Including a sequence of code parts and some inputs to this block. + * The actual java code is generated by embedding the inputs into the code parts. */ -case class CodeBlock(codeParts: Seq[String], exprValues: Seq[Any]) extends Block { +case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[Any]) extends Block { + override def exprValues: Seq[ExprValue] = blockInputs.filter(_.isInstanceOf[ExprValue]) + .asInstanceOf[Seq[ExprValue]] + override def code: String = { val strings = codeParts.iterator - val expressions = exprValues.iterator + val inputs = blockInputs.iterator var buf = new StringBuffer(strings.next) while (strings.hasNext) { - if (expressions.hasNext) { - buf append expressions.next + if (inputs.hasNext) { + buf append inputs.next buf append strings.next } } @@ -193,7 +199,7 @@ case class CodeBlock(codeParts: Seq[String], exprValues: Seq[Any]) extends Block } case class Blocks(blocks: Seq[Block]) extends Block { - override def exprValues: Seq[Any] = blocks.flatMap(_.exprValues) + override def exprValues: Seq[ExprValue] = blocks.flatMap(_.exprValues) override def code: String = blocks.map(_.toString).mkString override def + (other: Block): Block = other match { @@ -205,7 +211,7 @@ case class Blocks(blocks: Seq[Block]) extends Block { object EmptyBlock extends Block with Serializable { override def code: String = "" - override def exprValues: Seq[Any] = Seq.empty + override def exprValues: Seq[ExprValue] = Seq.empty override def + (other: Block): Block = other } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala new file mode 100644 index 0000000000000..f9c06cfd3f4fc --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.types.{BooleanType, IntegerType} + +class CodeBlockSuite extends SparkFunSuite { + + test("Block can interpolate string and ExprValue inputs") { + val isNull = JavaCode.isNullVariable("expr1_isNull") + val code = code"boolean ${isNull} = ${JavaCode.defaultLiteral(BooleanType)};" + assert(code.toString == "boolean expr1_isNull = false;") + } + + test("Block.stripMargin") { + val isNull = JavaCode.isNullVariable("expr1_isNull") + val value = JavaCode.variable("expr1", IntegerType) + val code1 = + code""" + |boolean $isNull = false; + |int $value = ${JavaCode.defaultLiteral(IntegerType)};""".stripMargin + val expected = + s""" + |boolean expr1_isNull = false; + |int expr1 = ${JavaCode.defaultLiteral(IntegerType)};""".stripMargin + assert(code1.toString == expected) + + val code2 = + code""" + >boolean $isNull = false; + >int $value = ${JavaCode.defaultLiteral(IntegerType)};""".stripMargin('>') + assert(code2.toString == expected) + } + + test("Block can capture input expr values") { + val isNull = JavaCode.isNullVariable("expr1_isNull") + val value = JavaCode.variable("expr1", IntegerType) + val code = + code""" + |boolean $isNull = false; + |int $value = -1; + """.stripMargin + val exprValues = code.exprValues + assert(exprValues.length == 2) + assert(exprValues(0).isInstanceOf[VariableValue] && exprValues(0).code == "expr1_isNull") + assert(exprValues(1).isInstanceOf[VariableValue] && exprValues(1).code == "expr1") + } + + test("concatenate blocks") { + val isNull1 = JavaCode.isNullVariable("expr1_isNull") + val value1 = JavaCode.variable("expr1", IntegerType) + val isNull2 = JavaCode.isNullVariable("expr2_isNull") + val value2 = JavaCode.variable("expr2", IntegerType) + val literal = JavaCode.literal("100", IntegerType) + + val code = + code""" + |boolean $isNull1 = false; + |int $value1 = -1;""".stripMargin + + code""" + |boolean $isNull2 = true; + |int $value2 = $literal;""".stripMargin + + val expected = + """ + |boolean expr1_isNull = false; + |int expr1 = -1; + |boolean expr2_isNull = true; + |int expr2 = 100;""".stripMargin + + assert(code.toString == expected) + + val exprValues = code.exprValues + assert(exprValues.length == 5) + assert(exprValues(0).isInstanceOf[VariableValue] && exprValues(0).code == "expr1_isNull") + assert(exprValues(1).isInstanceOf[VariableValue] && exprValues(1).code == "expr1") + assert(exprValues(2).isInstanceOf[VariableValue] && exprValues(2).code == "expr2_isNull") + assert(exprValues(3).isInstanceOf[VariableValue] && exprValues(3).code == "expr2") + assert(exprValues(4).isInstanceOf[LiteralValue] && exprValues(4).code == "100") + } +} From 162deb2f4de0b4d6fee3641ce79fbab424dd8e9b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 3 May 2018 14:54:10 +0000 Subject: [PATCH 04/20] Deal merging conflict. --- .../spark/sql/catalyst/expressions/datetimeExpressions.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index f2969d30d4b9c..fce0f2ed6ed7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -1043,7 +1043,7 @@ case class StringToTimestampWithoutTimezone(child: Expression, timeZoneId: Optio val tz = ctx.addReferenceObj("timeZone", timeZone) val longOpt = ctx.freshName("longOpt") val eval = child.genCode(ctx) - val code = s""" + val code = code""" |${eval.code} |${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = true; |${CodeGenerator.JAVA_LONG} ${ev.value} = ${CodeGenerator.defaultValue(TimestampType)}; From d138ee05541bc1976442c5f1170ba7a7f3f5114c Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 4 May 2018 01:12:46 +0000 Subject: [PATCH 05/20] Address comments and add more tests. --- .../expressions/codegen/javaCode.scala | 24 +++++++----- .../expressions/codegen/CodeBlockSuite.scala | 37 +++++++++++++++++++ 2 files changed, 51 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index 9a36816c016fc..158aceb8c4dc5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -154,15 +154,16 @@ object Block { implicit class BlockHelper(val sc: StringContext) extends AnyVal { def code(args: Any*): Block = { + sc.checkLengths(args) if (sc.parts.length == 0) { EmptyBlock } else { args.foreach { - case _: ExprValue => true - case _: Int | _: Long | _: Float | _: Double | _: String => true - case _: Block => true + case _: ExprValue => + case _: Int | _: Long | _: Float | _: Double | _: String => + case _: Block => case other => throw new IllegalArgumentException( - s"Can not interpolate ${other.getClass} into code block.") + s"Can not interpolate ${other.getClass.getName} into code block.") } CodeBlock(sc.parts, args) } @@ -175,18 +176,21 @@ object Block { * The actual java code is generated by embedding the inputs into the code parts. */ case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[Any]) extends Block { - override def exprValues: Seq[ExprValue] = blockInputs.filter(_.isInstanceOf[ExprValue]) - .asInstanceOf[Seq[ExprValue]] + override def exprValues: Seq[ExprValue] = { + blockInputs.flatMap { + case b: Block => b.exprValues + case e: ExprValue => Seq(e) + case _ => Seq.empty + } + } override def code: String = { val strings = codeParts.iterator val inputs = blockInputs.iterator var buf = new StringBuffer(strings.next) while (strings.hasNext) { - if (inputs.hasNext) { - buf append inputs.next - buf append strings.next - } + buf append inputs.next + buf append strings.next } buf.toString } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala index f9c06cfd3f4fc..c266c029693cd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala @@ -95,4 +95,41 @@ class CodeBlockSuite extends SparkFunSuite { assert(exprValues(3).isInstanceOf[VariableValue] && exprValues(3).code == "expr2") assert(exprValues(4).isInstanceOf[LiteralValue] && exprValues(4).code == "100") } + + test("Throws exception when interpolating unexcepted object in code block") { + val obj = TestClass(100) + val e = intercept[IllegalArgumentException] { + code"$obj" + } + assert(e.getMessage().contains(s"Can not interpolate ${obj.getClass.getName}")) + } + + test("replace expr values in code block") { + val statement = JavaCode.expression("1 + 1", IntegerType) + val isNull = JavaCode.isNullVariable("expr1_isNull") + val exprInFunc = JavaCode.variable("expr1", IntegerType) + + val code = + code""" + |callFunc(int $statement) { + | boolean $isNull = false; + | int $exprInFunc = $statement + 1; + |}""".stripMargin + + val aliasedParam = JavaCode.variable("aliased", statement.javaType) + val aliasedInputs = code.asInstanceOf[CodeBlock].blockInputs.map { + case _: SimpleExprValue => aliasedParam + case other => other + } + val aliasedCode = CodeBlock(code.asInstanceOf[CodeBlock].codeParts, aliasedInputs).stripMargin + val expected = + code""" + |callFunc(int $aliasedParam) { + | boolean $isNull = false; + | int $exprInFunc = $aliasedParam + 1; + |}""".stripMargin + assert(aliasedCode.toString == expected.toString) + } } + +private case class TestClass(a: Int) From ee9a4c03a2da8a01f588c1283500f3412e356875 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 5 May 2018 04:33:21 +0000 Subject: [PATCH 06/20] Address comments. --- .../spark/sql/catalyst/expressions/Expression.scala | 2 +- .../sql/catalyst/expressions/codegen/javaCode.scala | 12 +++++++++--- .../spark/sql/execution/ColumnarBatchScan.scala | 2 +- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 000f128f63c60..e838d7b19a863 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -112,7 +112,7 @@ abstract class Expression extends TreeNode[Expression] { reduceCodeSize(ctx, eval) if (eval.code.toString.nonEmpty) { // Add `this` in the comment. - eval.copy(code = JavaCode.block(s"${ctx.registerComment(this.toString)}\n") + eval.code) + eval.copy(code = code"${ctx.registerComment(this.toString)}\n" + eval.code) } else { eval } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index 158aceb8c4dc5..32682f00fef04 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -132,7 +132,9 @@ trait Block extends JavaCode { case _ => code } - var _marginChar: Option[Char] = None + // The leading prefix that should be stripped from each line. + // By default we strip blanks or control characters followed by '|' from the line. + var _marginChar: Option[Char] = Some('|') def stripMargin(c: Char): this.type = { _marginChar = Some(c) @@ -148,6 +150,9 @@ trait Block extends JavaCode { } object Block { + + val CODE_BLOCK_BUFFER_LENGTH: Int = 512 + implicit def blockToString(block: Block): String = block.toString implicit def blocksToBlock(blocks: Seq[Block]): Block = Blocks(blocks) @@ -187,10 +192,11 @@ case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[Any]) extends Bloc override def code: String = { val strings = codeParts.iterator val inputs = blockInputs.iterator - var buf = new StringBuffer(strings.next) + val buf = new StringBuilder(Block.CODE_BLOCK_BUFFER_LENGTH) + buf append StringContext.treatEscapes(strings.next) while (strings.hasNext) { buf append inputs.next - buf append strings.next + buf append StringContext.treatEscapes(strings.next) } buf.toString } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 067b2aa05fe62..d0edd77609024 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -59,7 +59,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { } val valueVar = ctx.freshName("value") val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]" - val code = JavaCode.block(s"${ctx.registerComment(str)}\n") + (if (nullable) { + val code = code"${ctx.registerComment(str)}\n" + (if (nullable) { code""" boolean $isNullVar = $columnVar.isNullAt($ordinal); $javaType $valueVar = $isNullVar ? ${CodeGenerator.defaultValue(dataType)} : ($value); From e7cfa286b6c4f812bcbae3d59ec24ae956d9e330 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 5 May 2018 13:44:44 +0000 Subject: [PATCH 07/20] Remove JavaCode.block. We should always use code string interpolator for constructing code block. --- .../apache/spark/sql/catalyst/expressions/Expression.scala | 2 +- .../spark/sql/catalyst/expressions/codegen/javaCode.scala | 4 ---- .../spark/sql/catalyst/expressions/complexTypeCreator.scala | 2 +- 3 files changed, 2 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index e838d7b19a863..b348b0a2579b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -101,7 +101,7 @@ abstract class Expression extends TreeNode[Expression] { ctx.subExprEliminationExprs.get(this).map { subExprState => // This expression is repeated which means that the code to evaluate it has already been added // as a function before. In that case, we just re-use it. - ExprCode(JavaCode.block(ctx.registerComment(this.toString)), subExprState.isNull, + ExprCode(code"${ctx.registerComment(this.toString)}", subExprState.isNull, subExprState.value) }.getOrElse { val isNull = ctx.freshName("isNull") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index 32682f00fef04..eac6b92837a06 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -112,10 +112,6 @@ object JavaCode { def isNullExpression(code: String): SimpleExprValue = { expression(code, BooleanType) } - - def block(code: String): Block = { - CodeBlock(codeParts = Seq(code), blockInputs = Seq.empty) - } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 791ba05d4c228..8d77b7d0ff1b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -64,7 +64,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { val (preprocess, assigns, postprocess, arrayData) = GenArrayData.genCodeToCreateArrayData(ctx, et, evals, false) ev.copy( - code = JavaCode.block(preprocess + assigns + postprocess), + code = code"$preprocess" + code"$assigns" + code"$postprocess", value = JavaCode.variable(arrayData, dataType), isNull = FalseLiteral) } From 5945c156b2df7aa2f61df57d86f07279b6959e2a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 5 May 2018 14:00:43 +0000 Subject: [PATCH 08/20] We should not implicitly convert code block to string. Otherwise we may not aware of necessary block inputs. --- .../spark/sql/catalyst/expressions/Expression.scala | 4 ++-- .../sql/catalyst/expressions/codegen/CodeGenerator.scala | 4 ++-- .../expressions/codegen/GenerateUnsafeProjection.scala | 2 +- .../sql/catalyst/expressions/codegen/javaCode.scala | 9 ++++----- .../spark/sql/execution/WholeStageCodegenExec.scala | 6 +++--- .../sql/execution/aggregate/HashAggregateExec.scala | 4 ++-- 6 files changed, 14 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index b348b0a2579b0..4836bf9c2052a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -121,7 +121,7 @@ abstract class Expression extends TreeNode[Expression] { private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = { // TODO: support whole stage codegen too - if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) { + if (eval.code.toString.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) { val setIsNull = if (!eval.isNull.isInstanceOf[LiteralValue]) { val globalIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "globalIsNull") val localIsNull = eval.isNull @@ -138,7 +138,7 @@ abstract class Expression extends TreeNode[Expression] { val funcFullName = ctx.addNewFunction(funcName, s""" |private $javaType $funcName(InternalRow ${ctx.INPUT_ROW}) { - | ${eval.code.trim} + | ${eval.code} | $setIsNull | return ${eval.value}; |} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index d12d464b84d08..116e237e86121 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -989,7 +989,7 @@ class CodegenContext { val eval = expr.genCode(this) val state = SubExprEliminationState(eval.isNull, eval.value) e.foreach(localSubExprEliminationExprs.put(_, state)) - eval.code.trim + eval.code.toString } SubExprCodes(codes, localSubExprEliminationExprs.toMap) } @@ -1017,7 +1017,7 @@ class CodegenContext { val fn = s""" |private void $fnName(InternalRow $INPUT_ROW) { - | ${eval.code.trim} + | ${eval.code} | $isNull = ${eval.isNull}; | $value = ${eval.value}; |} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 00f7b04f9a585..8f2a5a0dce943 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -344,7 +344,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro | } | | public UnsafeRow apply(InternalRow ${ctx.INPUT_ROW}) { - | ${eval.code.trim} + | ${eval.code} | return ${eval.value}; | } | diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index eac6b92837a06..b2c4e23ef2b2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -122,10 +122,10 @@ trait Block extends JavaCode { // The expressions to be evaluated inside this block. def exprValues: Seq[ExprValue] - // This will be called during string interpolation. + // Returns java code string for this code block. override def toString: String = _marginChar match { - case Some(c) => code.stripMargin(c) - case _ => code + case Some(c) => code.stripMargin(c).trim + case _ => code.trim } // The leading prefix that should be stripped from each line. @@ -142,6 +142,7 @@ trait Block extends JavaCode { this } + // Concatenates this block with other block. def + (other: Block): Block } @@ -149,8 +150,6 @@ object Block { val CODE_BLOCK_BUFFER_LENGTH: Int = 512 - implicit def blockToString(block: Block): String = block.toString - implicit def blocksToBlock(blocks: Seq[Block]): Block = Blocks(blocks) implicit class BlockHelper(val sc: StringContext) extends AnyVal { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index fadbeddff2c4e..9ff8162fd703c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -125,7 +125,7 @@ trait CodegenSupport extends SparkPlan { val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false) val code = code""" |$evaluateInputs - |${ev.code.trim} + |${ev.code} """.stripMargin ExprCode(code, FalseLiteral, ev.value) } else { @@ -260,7 +260,7 @@ trait CodegenSupport extends SparkPlan { * them to be evaluated twice. */ protected def evaluateVariables(variables: Seq[ExprCode]): String = { - val evaluate = variables.filter(_.code != "").map(_.code.trim).mkString("\n") + val evaluate = variables.filter(_.code.toString != "").map(_.code.toString).mkString("\n") variables.foreach(_.code = EmptyBlock) evaluate } @@ -276,7 +276,7 @@ trait CodegenSupport extends SparkPlan { val evaluateVars = new StringBuilder variables.zipWithIndex.foreach { case (ev, i) => if (ev.code != "" && required.contains(attributes(i))) { - evaluateVars.append(ev.code.trim + "\n") + evaluateVars.append(ev.code.toString + "\n") ev.code = EmptyBlock } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 90fba7b01466a..a164fab815879 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -771,8 +771,8 @@ case class HashAggregateExec( val findOrInsertRegularHashMap: String = s""" |// generate grouping key - |${unsafeRowKeyCode.code.trim} - |${hashEval.code.trim} + |${unsafeRowKeyCode.code} + |${hashEval.code} |if ($checkFallbackForBytesToBytesMap) { | // try to get the buffer from hash map | $unsafeRowBuffer = From 2b30654bc50a39a8af597df68ba288280299defe Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 5 May 2018 22:54:54 +0000 Subject: [PATCH 09/20] Address comment and trim expected code string. --- .../spark/sql/catalyst/expressions/codegen/javaCode.scala | 2 +- .../spark/sql/catalyst/expressions/complexTypeCreator.scala | 2 +- .../sql/catalyst/expressions/codegen/CodeBlockSuite.scala | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index b2c4e23ef2b2a..7932f0e04a72d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -205,7 +205,7 @@ case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[Any]) extends Bloc case class Blocks(blocks: Seq[Block]) extends Block { override def exprValues: Seq[ExprValue] = blocks.flatMap(_.exprValues) - override def code: String = blocks.map(_.toString).mkString + override def code: String = blocks.map(_.toString).mkString("\n") override def + (other: Block): Block = other match { case c: CodeBlock => Blocks(blocks :+ c) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 8d77b7d0ff1b8..a9867aaeb0cfe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -64,7 +64,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { val (preprocess, assigns, postprocess, arrayData) = GenArrayData.genCodeToCreateArrayData(ctx, et, evals, false) ev.copy( - code = code"$preprocess" + code"$assigns" + code"$postprocess", + code = code"${preprocess}${assigns}${postprocess}", value = JavaCode.variable(arrayData, dataType), isNull = FalseLiteral) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala index c266c029693cd..8cf7e47a4ddce 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala @@ -39,7 +39,7 @@ class CodeBlockSuite extends SparkFunSuite { val expected = s""" |boolean expr1_isNull = false; - |int expr1 = ${JavaCode.defaultLiteral(IntegerType)};""".stripMargin + |int expr1 = ${JavaCode.defaultLiteral(IntegerType)};""".stripMargin.trim assert(code1.toString == expected) val code2 = @@ -83,7 +83,7 @@ class CodeBlockSuite extends SparkFunSuite { |boolean expr1_isNull = false; |int expr1 = -1; |boolean expr2_isNull = true; - |int expr2 = 100;""".stripMargin + |int expr2 = 100;""".stripMargin.trim assert(code.toString == expected) From aff411bbf21f484c20588997d98682f2ec77191a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 8 May 2018 03:24:04 +0000 Subject: [PATCH 10/20] Remove unused import. --- .../spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala index 4045d24968002..378b8bc055e34 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValueSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import scala.collection.mutable - import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types.BooleanType From 53b329a3b4e72d423cf503af2982d545848bece8 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 8 May 2018 05:33:35 +0000 Subject: [PATCH 11/20] Address comments. --- .../sql/catalyst/expressions/codegen/javaCode.scala | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index 7932f0e04a72d..9c3205138dbb0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -176,12 +176,12 @@ object Block { * The actual java code is generated by embedding the inputs into the code parts. */ case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[Any]) extends Block { - override def exprValues: Seq[ExprValue] = { + override lazy val exprValues: Seq[ExprValue] = { blockInputs.flatMap { case b: Block => b.exprValues case e: ExprValue => Seq(e) case _ => Seq.empty - } + }.distinct } override def code: String = { @@ -204,7 +204,7 @@ case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[Any]) extends Bloc } case class Blocks(blocks: Seq[Block]) extends Block { - override def exprValues: Seq[ExprValue] = blocks.flatMap(_.exprValues) + override lazy val exprValues: Seq[ExprValue] = blocks.flatMap(_.exprValues).distinct override def code: String = blocks.map(_.toString).mkString("\n") override def + (other: Block): Block = other match { @@ -216,7 +216,7 @@ case class Blocks(blocks: Seq[Block]) extends Block { object EmptyBlock extends Block with Serializable { override def code: String = "" - override def exprValues: Seq[ExprValue] = Seq.empty + override val exprValues: Seq[ExprValue] = Seq.empty override def + (other: Block): Block = other } @@ -227,9 +227,6 @@ object EmptyBlock extends Block with Serializable { trait ExprValue extends JavaCode { def javaType: Class[_] def isPrimitive: Boolean = javaType.isPrimitive - - // This will be called during string interpolation. - override def toString: String = ExprValue.exprValueToString(this) } object ExprValue { From 72faac3209beb8bc38938f8788de6338e9b2ffae Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 9 May 2018 13:37:29 +0000 Subject: [PATCH 12/20] Address some comments. --- .../apache/spark/sql/catalyst/expressions/Cast.scala | 7 +++---- .../spark/sql/catalyst/expressions/Expression.scala | 6 +++--- .../catalyst/expressions/codegen/CodeGenerator.scala | 12 ++++++------ .../sql/catalyst/expressions/codegen/javaCode.scala | 2 ++ 4 files changed, 14 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index d3e2b96bc3786..699ea53b5df0f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -625,11 +625,10 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val eval = child.genCode(ctx) val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx) - // Below the code comment including `eval.value` and `eval.isNull` is a trick. It makes the two - // expr values are referred by this code block. - ev.copy(code = eval.code + + ev.copy(code = code""" - // Cast from ${eval.value}, ${eval.isNull} + ${eval.code} + // This comment is added for manually tracking reference of ${eval.value}, ${eval.isNull} ${castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast)} """) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 4836bf9c2052a..85dae0b85f487 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -101,7 +101,7 @@ abstract class Expression extends TreeNode[Expression] { ctx.subExprEliminationExprs.get(this).map { subExprState => // This expression is repeated which means that the code to evaluate it has already been added // as a function before. In that case, we just re-use it. - ExprCode(code"${ctx.registerComment(this.toString)}", subExprState.isNull, + ExprCode(ctx.registerComment(this.toString), subExprState.isNull, subExprState.value) }.getOrElse { val isNull = ctx.freshName("isNull") @@ -112,7 +112,7 @@ abstract class Expression extends TreeNode[Expression] { reduceCodeSize(ctx, eval) if (eval.code.toString.nonEmpty) { // Add `this` in the comment. - eval.copy(code = code"${ctx.registerComment(this.toString)}\n" + eval.code) + eval.copy(code = ctx.registerComment(this.toString) + eval.code) } else { eval } @@ -121,7 +121,7 @@ abstract class Expression extends TreeNode[Expression] { private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = { // TODO: support whole stage codegen too - if (eval.code.toString.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) { + if (eval.code.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) { val setIsNull = if (!eval.isNull.isInstanceOf[LiteralValue]) { val globalIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "globalIsNull") val localIsNull = eval.isNull diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 116e237e86121..ac36f13f4311a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -61,15 +61,15 @@ case class ExprCode(var code: Block, var isNull: ExprValue, var value: ExprValue object ExprCode { def apply(isNull: ExprValue, value: ExprValue): ExprCode = { - ExprCode(code = code"", isNull, value) + ExprCode(code = EmptyBlock, isNull, value) } def forNullValue(dataType: DataType): ExprCode = { - ExprCode(code = code"", isNull = TrueLiteral, JavaCode.defaultLiteral(dataType)) + ExprCode(code = EmptyBlock, isNull = TrueLiteral, JavaCode.defaultLiteral(dataType)) } def forNonNullValue(value: ExprValue): ExprCode = { - ExprCode(code = code"", isNull = FalseLiteral, value = value) + ExprCode(code = EmptyBlock, isNull = FalseLiteral, value = value) } } @@ -1074,7 +1074,7 @@ class CodegenContext { def registerComment( text: => String, placeholderId: String = "", - force: Boolean = false): String = { + force: Boolean = false): Block = { // By default, disable comments in generated code because computing the comments themselves can // be extremely expensive in certain cases, such as deeply-nested expressions which operate over // inputs with wide schemas. For more details on the performance issues that motivated this @@ -1093,9 +1093,9 @@ class CodegenContext { s"// $text" } placeHolderToComments += (name -> comment) - s"/*$name*/" + code"/*$name*/" } else { - "" + EmptyBlock } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index 9c3205138dbb0..318a4217ebeaa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -128,6 +128,8 @@ trait Block extends JavaCode { case _ => code.trim } + def length: Int = toString.length + // The leading prefix that should be stripped from each line. // By default we strip blanks or control characters followed by '|' from the line. var _marginChar: Option[Char] = Some('|') From d04067635ca5785e262afb65781faab5ca737d93 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 17 May 2018 09:38:59 +0000 Subject: [PATCH 13/20] Use code block for newly merged codegen. --- .../spark/sql/catalyst/expressions/collectionOperations.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 3b68c5a7da7e4..e433e603f2b3e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1523,7 +1523,7 @@ case class ArrayRepeat(left: Expression, right: Expression) val resultCode = nullElementsProtection(ev, rightGen.isNull, coreLogic) ev.copy(code = - s""" + code""" |boolean ${ev.isNull} = false; |${leftGen.code} |${rightGen.code} From c378ce2a1982b2c5453b15fbc1de06851033ac1d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 17 May 2018 12:55:31 +0000 Subject: [PATCH 14/20] Use Set as method exprValues method returning type. --- .../catalyst/expressions/codegen/javaCode.scala | 14 +++++++------- .../expressions/codegen/CodeBlockSuite.scala | 13 ++++--------- 2 files changed, 11 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index 318a4217ebeaa..046c06a3f65af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -120,7 +120,7 @@ object JavaCode { trait Block extends JavaCode { // The expressions to be evaluated inside this block. - def exprValues: Seq[ExprValue] + def exprValues: Set[ExprValue] // Returns java code string for this code block. override def toString: String = _marginChar match { @@ -178,12 +178,12 @@ object Block { * The actual java code is generated by embedding the inputs into the code parts. */ case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[Any]) extends Block { - override lazy val exprValues: Seq[ExprValue] = { + override lazy val exprValues: Set[ExprValue] = { blockInputs.flatMap { case b: Block => b.exprValues - case e: ExprValue => Seq(e) - case _ => Seq.empty - }.distinct + case e: ExprValue => Set(e) + case _ => Set.empty[ExprValue] + }.toSet } override def code: String = { @@ -206,7 +206,7 @@ case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[Any]) extends Bloc } case class Blocks(blocks: Seq[Block]) extends Block { - override lazy val exprValues: Seq[ExprValue] = blocks.flatMap(_.exprValues).distinct + override lazy val exprValues: Set[ExprValue] = blocks.flatMap(_.exprValues).toSet override def code: String = blocks.map(_.toString).mkString("\n") override def + (other: Block): Block = other match { @@ -218,7 +218,7 @@ case class Blocks(blocks: Seq[Block]) extends Block { object EmptyBlock extends Block with Serializable { override def code: String = "" - override val exprValues: Seq[ExprValue] = Seq.empty + override val exprValues: Set[ExprValue] = Set.empty override def + (other: Block): Block = other } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala index 8cf7e47a4ddce..9759b7a479164 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala @@ -58,9 +58,8 @@ class CodeBlockSuite extends SparkFunSuite { |int $value = -1; """.stripMargin val exprValues = code.exprValues - assert(exprValues.length == 2) - assert(exprValues(0).isInstanceOf[VariableValue] && exprValues(0).code == "expr1_isNull") - assert(exprValues(1).isInstanceOf[VariableValue] && exprValues(1).code == "expr1") + assert(exprValues.size == 2) + assert(exprValues === Set(value, isNull)) } test("concatenate blocks") { @@ -88,12 +87,8 @@ class CodeBlockSuite extends SparkFunSuite { assert(code.toString == expected) val exprValues = code.exprValues - assert(exprValues.length == 5) - assert(exprValues(0).isInstanceOf[VariableValue] && exprValues(0).code == "expr1_isNull") - assert(exprValues(1).isInstanceOf[VariableValue] && exprValues(1).code == "expr1") - assert(exprValues(2).isInstanceOf[VariableValue] && exprValues(2).code == "expr2_isNull") - assert(exprValues(3).isInstanceOf[VariableValue] && exprValues(3).code == "expr2") - assert(exprValues(4).isInstanceOf[LiteralValue] && exprValues(4).code == "100") + assert(exprValues.size == 5) + assert(exprValues === Set(isNull1, value1, isNull2, value2, literal)) } test("Throws exception when interpolating unexcepted object in code block") { From 2ca974190527f3ac1ade1feddf9bd2b953fabdd7 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 19 May 2018 09:31:51 +0000 Subject: [PATCH 15/20] Address comments. --- .../sql/catalyst/expressions/Expression.scala | 3 +- .../expressions/codegen/javaCode.scala | 43 ++++++++++++++++--- .../catalyst/expressions/inputFileBlock.scala | 13 +++--- .../expressions/codegen/CodeBlockSuite.scala | 24 +++++++---- .../sql/execution/ColumnarBatchScan.scala | 2 +- .../sql/execution/WholeStageCodegenExec.scala | 2 +- 6 files changed, 63 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 85dae0b85f487..9b9fa41a47d0f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -101,8 +101,7 @@ abstract class Expression extends TreeNode[Expression] { ctx.subExprEliminationExprs.get(this).map { subExprState => // This expression is repeated which means that the code to evaluate it has already been added // as a function before. In that case, we just re-use it. - ExprCode(ctx.registerComment(this.toString), subExprState.isNull, - subExprState.value) + ExprCode(ctx.registerComment(this.toString), subExprState.isNull, subExprState.value) }.getOrElse { val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index 046c06a3f65af..b89006ea2ba17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import java.lang.{Boolean => JBool} +import scala.collection.mutable.ArrayBuffer import scala.language.{existentials, implicitConversions} import org.apache.spark.sql.types.{BooleanType, DataType} @@ -130,6 +131,8 @@ trait Block extends JavaCode { def length: Int = toString.length + def nonEmpty: Boolean = toString.nonEmpty + // The leading prefix that should be stripped from each line. // By default we strip blanks or control characters followed by '|' from the line. var _marginChar: Option[Char] = Some('|') @@ -167,9 +170,40 @@ object Block { case other => throw new IllegalArgumentException( s"Can not interpolate ${other.getClass.getName} into code block.") } - CodeBlock(sc.parts, args) + + val (codeParts, blockInputs) = foldLiteralArgs(sc.parts, args) + CodeBlock(codeParts, blockInputs) + } + } + } + + // Folds eagerly the literal args into the code parts. + private def foldLiteralArgs(parts: Seq[String], args: Seq[Any]): (Seq[String], Seq[Any]) = { + val codeParts = ArrayBuffer.empty[String] + val blockInputs = ArrayBuffer.empty[Any] + + val strings = parts.iterator + val inputs = args.iterator + val buf = new StringBuilder(Block.CODE_BLOCK_BUFFER_LENGTH) + + buf append strings.next + while (strings.hasNext) { + val input = inputs.next + input match { + case _: ExprValue | _: Block => + codeParts += buf.toString + buf.clear + blockInputs += input + case _ => + buf append input } + buf append strings.next + } + if (buf.nonEmpty) { + codeParts += buf.toString } + + (codeParts.toSeq, blockInputs.toSeq) } } @@ -182,11 +216,10 @@ case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[Any]) extends Bloc blockInputs.flatMap { case b: Block => b.exprValues case e: ExprValue => Set(e) - case _ => Set.empty[ExprValue] }.toSet } - override def code: String = { + override lazy val code: String = { val strings = codeParts.iterator val inputs = blockInputs.iterator val buf = new StringBuilder(Block.CODE_BLOCK_BUFFER_LENGTH) @@ -207,7 +240,7 @@ case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[Any]) extends Bloc case class Blocks(blocks: Seq[Block]) extends Block { override lazy val exprValues: Set[ExprValue] = blocks.flatMap(_.exprValues).toSet - override def code: String = blocks.map(_.toString).mkString("\n") + override lazy val code: String = blocks.map(_.toString).mkString("\n") override def + (other: Block): Block = other match { case c: CodeBlock => Blocks(blocks :+ c) @@ -217,7 +250,7 @@ case class Blocks(blocks: Seq[Block]) extends Block { } object EmptyBlock extends Block with Serializable { - override def code: String = "" + override val code: String = "" override val exprValues: Set[ExprValue] = Set.empty override def + (other: Block): Block = other diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala index 6e57d9d1db6cb..3b0141ad52cc7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala @@ -43,8 +43,9 @@ case class InputFileName() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") - ev.copy(code = code"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + - code"$className.getInputFilePath();", isNull = FalseLiteral) + val typeDef = s"final ${CodeGenerator.javaType(dataType)}" + ev.copy(code = code"$typeDef ${ev.value} = $className.getInputFilePath();", + isNull = FalseLiteral) } } @@ -66,8 +67,8 @@ case class InputFileBlockStart() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") - ev.copy(code = code"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + - code"$className.getStartOffset();", isNull = FalseLiteral) + val typeDef = s"final ${CodeGenerator.javaType(dataType)}" + ev.copy(code = code"$typeDef ${ev.value} = $className.getStartOffset();", isNull = FalseLiteral) } } @@ -89,7 +90,7 @@ case class InputFileBlockLength() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") - ev.copy(code = code"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + - code"$className.getLength();", isNull = FalseLiteral) + val typeDef = s"final ${CodeGenerator.javaType(dataType)}" + ev.copy(code = code"$typeDef ${ev.value} = $className.getLength();", isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala index 9759b7a479164..d2c6420eadb20 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala @@ -23,12 +23,20 @@ import org.apache.spark.sql.types.{BooleanType, IntegerType} class CodeBlockSuite extends SparkFunSuite { - test("Block can interpolate string and ExprValue inputs") { + test("Block interpolates string and ExprValue inputs") { val isNull = JavaCode.isNullVariable("expr1_isNull") - val code = code"boolean ${isNull} = ${JavaCode.defaultLiteral(BooleanType)};" + val stringLiteral = "false" + val code = code"boolean $isNull = $stringLiteral;" assert(code.toString == "boolean expr1_isNull = false;") } + test("Literals are folded into string code parts instead of block inputs") { + val value = JavaCode.variable("expr1", IntegerType) + val intLiteral = 1 + val code = code"int $value = $intLiteral;" + assert(code.asInstanceOf[CodeBlock].blockInputs === Seq(value)) + } + test("Block.stripMargin") { val isNull = JavaCode.isNullVariable("expr1_isNull") val value = JavaCode.variable("expr1", IntegerType) @@ -92,7 +100,7 @@ class CodeBlockSuite extends SparkFunSuite { } test("Throws exception when interpolating unexcepted object in code block") { - val obj = TestClass(100) + val obj = Tuple2(1, 1) val e = intercept[IllegalArgumentException] { code"$obj" } @@ -100,18 +108,18 @@ class CodeBlockSuite extends SparkFunSuite { } test("replace expr values in code block") { - val statement = JavaCode.expression("1 + 1", IntegerType) + val expr = JavaCode.expression("1 + 1", IntegerType) val isNull = JavaCode.isNullVariable("expr1_isNull") val exprInFunc = JavaCode.variable("expr1", IntegerType) val code = code""" - |callFunc(int $statement) { + |callFunc(int $expr) { | boolean $isNull = false; - | int $exprInFunc = $statement + 1; + | int $exprInFunc = $expr + 1; |}""".stripMargin - val aliasedParam = JavaCode.variable("aliased", statement.javaType) + val aliasedParam = JavaCode.variable("aliased", expr.javaType) val aliasedInputs = code.asInstanceOf[CodeBlock].blockInputs.map { case _: SimpleExprValue => aliasedParam case other => other @@ -126,5 +134,3 @@ class CodeBlockSuite extends SparkFunSuite { assert(aliasedCode.toString == expected.toString) } } - -private case class TestClass(a: Int) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index d0edd77609024..48abad9078650 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -59,7 +59,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { } val valueVar = ctx.freshName("value") val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]" - val code = code"${ctx.registerComment(str)}\n" + (if (nullable) { + val code = code"${ctx.registerComment(str)}" + (if (nullable) { code""" boolean $isNullVar = $columnVar.isNullAt($ordinal); $javaType $valueVar = $isNullVar ? ${CodeGenerator.defaultValue(dataType)} : ($value); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 9ff8162fd703c..372dc3db36ce6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -260,7 +260,7 @@ trait CodegenSupport extends SparkPlan { * them to be evaluated twice. */ protected def evaluateVariables(variables: Seq[ExprCode]): String = { - val evaluate = variables.filter(_.code.toString != "").map(_.code.toString).mkString("\n") + val evaluate = variables.filter(_.code.nonEmpty).map(_.code.toString).mkString("\n") variables.foreach(_.code = EmptyBlock) evaluate } From 8d0b1b9d9dc6631ea787389bb9d3f561950c3d0b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 21 May 2018 09:19:09 +0000 Subject: [PATCH 16/20] Add API to manipulate blocks and expressions. --- .../expressions/codegen/javaCode.scala | 56 +++++++++++++++-- .../expressions/codegen/CodeBlockSuite.scala | 60 +++++++++++++++++-- 2 files changed, 105 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index b89006ea2ba17..b22cbadbf1cad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -22,14 +22,17 @@ import java.lang.{Boolean => JBool} import scala.collection.mutable.ArrayBuffer import scala.language.{existentials, implicitConversions} +import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types.{BooleanType, DataType} /** * Trait representing an opaque fragments of java code. */ -trait JavaCode { +trait JavaCode extends TreeNode[JavaCode] { def code: String override def toString: String = code + + override def verboseString: String = toString } /** @@ -42,7 +45,7 @@ object JavaCode { def literal(v: String, dataType: DataType): LiteralValue = dataType match { case BooleanType if v == "true" => TrueLiteral case BooleanType if v == "false" => FalseLiteral - case _ => new LiteralValue(v, CodeGenerator.javaClass(dataType)) + case _ => new LiteralExpr(v, CodeGenerator.javaClass(dataType)) } /** @@ -50,7 +53,7 @@ object JavaCode { * -1 for other primitive types. */ def defaultLiteral(dataType: DataType): LiteralValue = { - new LiteralValue( + new LiteralExpr( CodeGenerator.defaultValue(dataType, typedNull = true), CodeGenerator.javaClass(dataType)) } @@ -120,7 +123,7 @@ object JavaCode { */ trait Block extends JavaCode { - // The expressions to be evaluated inside this block. + // All expressions to be evaluated inside this block and underlying blocks. def exprValues: Set[ExprValue] // Returns java code string for this code block. @@ -149,6 +152,34 @@ trait Block extends JavaCode { // Concatenates this block with other block. def + (other: Block): Block + + /** + * Apply a map function to each java expression codes present in this java code, and return a new + * java code based on the mapped java expression codes. + */ + def transformExprValues(f: PartialFunction[ExprValue, ExprValue]): this.type = { + var changed = false + + @inline def transform(e: ExprValue): ExprValue = { + val newE = f lift e + if (!newE.isDefined || newE.get.fastEquals(e)) { + e + } else { + changed = true + newE.get + } + } + + def doTransform(arg: Any): AnyRef = arg match { + case e: ExprValue => transform(e) + case Some(value) => Some(doTransform(value)) + case seq: Traversable[_] => seq.map(doTransform) + case other: AnyRef => other + } + + val newArgs = mapProductIterator(doTransform) + if (changed) makeCopy(newArgs).asInstanceOf[this.type] else this + } } object Block { @@ -219,6 +250,9 @@ case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[Any]) extends Bloc }.toSet } + override def children: Seq[Block] = + blockInputs.filter(_.isInstanceOf[Block]).asInstanceOf[Seq[Block]] + override lazy val code: String = { val strings = codeParts.iterator val inputs = blockInputs.iterator @@ -242,6 +276,8 @@ case class Blocks(blocks: Seq[Block]) extends Block { override lazy val exprValues: Set[ExprValue] = blocks.flatMap(_.exprValues).toSet override lazy val code: String = blocks.map(_.toString).mkString("\n") + override def children: Seq[Block] = blocks + override def + (other: Block): Block = other match { case c: CodeBlock => Blocks(blocks :+ c) case b: Blocks => Blocks(blocks ++ b.blocks) @@ -249,11 +285,13 @@ case class Blocks(blocks: Seq[Block]) extends Block { } } -object EmptyBlock extends Block with Serializable { +case object EmptyBlock extends Block with Serializable { override val code: String = "" override val exprValues: Set[ExprValue] = Set.empty override def + (other: Block): Block = other + + override def children: Seq[Block] = Seq.empty } /** @@ -262,6 +300,8 @@ object EmptyBlock extends Block with Serializable { trait ExprValue extends JavaCode { def javaType: Class[_] def isPrimitive: Boolean = javaType.isPrimitive + + override def children: Seq[ExprValue] = Seq.empty } object ExprValue { @@ -292,7 +332,8 @@ case class GlobalValue(value: String, javaType: Class[_]) extends ExprValue { /** * A literal java expression. */ -class LiteralValue(val value: String, val javaType: Class[_]) extends ExprValue with Serializable { +abstract class LiteralValue(val value: String, val javaType: Class[_]) + extends ExprValue with Serializable { override def code: String = value override def equals(arg: Any): Boolean = arg match { @@ -303,5 +344,8 @@ class LiteralValue(val value: String, val javaType: Class[_]) extends ExprValue override def hashCode(): Int = value.hashCode() * 31 + javaType.hashCode() } +case class LiteralExpr(override val value: String, override val javaType: Class[_]) + extends LiteralValue(value, javaType) + case object TrueLiteral extends LiteralValue("true", JBool.TYPE) case object FalseLiteral extends LiteralValue("false", JBool.TYPE) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala index d2c6420eadb20..89b276a2864f7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala @@ -107,7 +107,7 @@ class CodeBlockSuite extends SparkFunSuite { assert(e.getMessage().contains(s"Can not interpolate ${obj.getClass.getName}")) } - test("replace expr values in code block") { + test("transform expr in code block") { val expr = JavaCode.expression("1 + 1", IntegerType) val isNull = JavaCode.isNullVariable("expr1_isNull") val exprInFunc = JavaCode.variable("expr1", IntegerType) @@ -120,11 +120,11 @@ class CodeBlockSuite extends SparkFunSuite { |}""".stripMargin val aliasedParam = JavaCode.variable("aliased", expr.javaType) - val aliasedInputs = code.asInstanceOf[CodeBlock].blockInputs.map { - case _: SimpleExprValue => aliasedParam - case other => other + + // We want to replace all occurrences of `expr` with the variable `aliasedParam`. + val aliasedCode = code.transformExprValues { + case SimpleExprValue("1 + 1", _) => aliasedParam } - val aliasedCode = CodeBlock(code.asInstanceOf[CodeBlock].codeParts, aliasedInputs).stripMargin val expected = code""" |callFunc(int $aliasedParam) { @@ -133,4 +133,54 @@ class CodeBlockSuite extends SparkFunSuite { |}""".stripMargin assert(aliasedCode.toString == expected.toString) } + + test ("transform expr in nested blocks") { + val expr = JavaCode.expression("1 + 1", IntegerType) + val isNull = JavaCode.isNullVariable("expr1_isNull") + val exprInFunc = JavaCode.variable("expr1", IntegerType) + + val funcs = Seq("callFunc1", "callFunc2", "callFunc3") + val subBlocks = funcs.map { funcName => + code""" + |$funcName(int $expr) { + | boolean $isNull = false; + | int $exprInFunc = $expr + 1; + |}""".stripMargin + } + + val aliasedParam = JavaCode.variable("aliased", expr.javaType) + + val block = subBlocks.fold(EmptyBlock)(_ + _) + val transformedBlock = block.transform { + case b: Block => b.transformExprValues { + case SimpleExprValue("1 + 1", _) => aliasedParam + } + }.asInstanceOf[Blocks] + + val expected1 = + code""" + |callFunc1(int aliased) { + | boolean expr1_isNull = false; + | int expr1 = aliased + 1; + |}""".stripMargin + + val expected2 = + code""" + |callFunc2(int aliased) { + | boolean expr1_isNull = false; + | int expr1 = aliased + 1; + |}""".stripMargin + + val expected3 = + code""" + |callFunc3(int aliased) { + | boolean expr1_isNull = false; + | int expr1 = aliased + 1; + |}""".stripMargin + + assert(transformedBlock.blocks(0).toString == expected1.toString) + assert(transformedBlock.blocks(1).toString == expected2.toString) + assert(transformedBlock.blocks(2).toString == expected3.toString) + assert(transformedBlock.toString == (expected1 + expected2 + expected3).toString) + } } From e30be7a3da9061a71e8ac222006e20b4e173cb74 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 29 May 2018 10:14:54 +0000 Subject: [PATCH 17/20] Address comments. --- .../sql/catalyst/expressions/codegen/CodeBlockSuite.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala index 89b276a2864f7..dbd809cacaf73 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala @@ -123,7 +123,7 @@ class CodeBlockSuite extends SparkFunSuite { // We want to replace all occurrences of `expr` with the variable `aliasedParam`. val aliasedCode = code.transformExprValues { - case SimpleExprValue("1 + 1", _) => aliasedParam + case SimpleExprValue("1 + 1", java.lang.Integer.TYPE) => aliasedParam } val expected = code""" @@ -153,7 +153,7 @@ class CodeBlockSuite extends SparkFunSuite { val block = subBlocks.fold(EmptyBlock)(_ + _) val transformedBlock = block.transform { case b: Block => b.transformExprValues { - case SimpleExprValue("1 + 1", _) => aliasedParam + case SimpleExprValue("1 + 1", java.lang.Integer.TYPE) => aliasedParam } }.asInstanceOf[Blocks] @@ -182,5 +182,6 @@ class CodeBlockSuite extends SparkFunSuite { assert(transformedBlock.blocks(1).toString == expected2.toString) assert(transformedBlock.blocks(2).toString == expected3.toString) assert(transformedBlock.toString == (expected1 + expected2 + expected3).toString) + assert(transformedBlock.exprValues === Set(isNull, exprInFunc, aliasedParam)) } } From ce8d9ed2165097e82ad2b3d03da6c841d623f572 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 5 Jul 2018 05:57:35 +0000 Subject: [PATCH 18/20] Only Block extends TreeNode. --- .../sql/catalyst/expressions/codegen/javaCode.scala | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index 152666ef8b7d9..cd40710305e85 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -28,11 +28,9 @@ import org.apache.spark.sql.types.{BooleanType, DataType} /** * Trait representing an opaque fragments of java code. */ -trait JavaCode extends TreeNode[JavaCode] { +trait JavaCode { def code: String override def toString: String = code - - override def verboseString: String = toString } /** @@ -121,7 +119,7 @@ object JavaCode { /** * A trait representing a block of java code. */ -trait Block extends JavaCode { +trait Block extends TreeNode[Block] with JavaCode { import Block._ // All expressions to be evaluated inside this block and underlying blocks. @@ -160,7 +158,7 @@ trait Block extends JavaCode { @inline def transform(e: ExprValue): ExprValue = { val newE = f lift e - if (!newE.isDefined || newE.get.fastEquals(e)) { + if (!newE.isDefined || newE.get.equals(e)) { e } else { changed = true @@ -184,6 +182,8 @@ trait Block extends JavaCode { case EmptyBlock => this case _ => code"$this\n$other" } + + override def verboseString: String = toString } object Block { @@ -286,8 +286,6 @@ case object EmptyBlock extends Block with Serializable { trait ExprValue extends JavaCode { def javaType: Class[_] def isPrimitive: Boolean = javaType.isPrimitive - - override def children: Seq[ExprValue] = Seq.empty } object ExprValue { From 707dd6f0f5385afaa4721c91272a11ca5d12b775 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 5 Jul 2018 07:36:44 +0000 Subject: [PATCH 19/20] Remove exprValues property from Block. --- .../catalyst/expressions/codegen/javaCode.scala | 12 ------------ .../expressions/codegen/CodeBlockSuite.scala | 16 +++++++++++++--- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index cd40710305e85..33516f507ade2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -122,9 +122,6 @@ object JavaCode { trait Block extends TreeNode[Block] with JavaCode { import Block._ - // All expressions to be evaluated inside this block and underlying blocks. - def exprValues: Set[ExprValue] - // Returns java code string for this code block. override def toString: String = _marginChar match { case Some(c) => code.stripMargin(c).trim @@ -250,13 +247,6 @@ object Block { * method splitting. */ case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[JavaCode]) extends Block { - override lazy val exprValues: Set[ExprValue] = { - blockInputs.flatMap { - case b: Block => b.exprValues - case e: ExprValue => Set(e) - }.toSet - } - override def children: Seq[Block] = blockInputs.filter(_.isInstanceOf[Block]).asInstanceOf[Seq[Block]] @@ -275,8 +265,6 @@ case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[JavaCode]) extends case object EmptyBlock extends Block with Serializable { override val code: String = "" - override val exprValues: Set[ExprValue] = Set.empty - override def children: Seq[Block] = Seq.empty } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala index 8f1c018110576..55569b6f2933e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala @@ -65,7 +65,9 @@ class CodeBlockSuite extends SparkFunSuite { |boolean $isNull = false; |int $value = -1; """.stripMargin - val exprValues = code.exprValues + val exprValues = code.asInstanceOf[CodeBlock].blockInputs.collect { + case e: ExprValue => e + }.toSet assert(exprValues.size == 2) assert(exprValues === Set(value, isNull)) } @@ -94,7 +96,9 @@ class CodeBlockSuite extends SparkFunSuite { assert(code.toString == expected) - val exprValues = code.exprValues + val exprValues = code.children.flatMap(_.asInstanceOf[CodeBlock].blockInputs.collect { + case e: ExprValue => e + }).toSet assert(exprValues.size == 5) assert(exprValues === Set(isNull1, value1, isNull2, value2, literal)) } @@ -178,10 +182,16 @@ class CodeBlockSuite extends SparkFunSuite { | int expr1 = aliased + 1; |}""".stripMargin + val exprValues = transformedBlock.children.flatMap { block => + block.asInstanceOf[CodeBlock].blockInputs.collect { + case e: ExprValue => e + } + }.toSet + assert(transformedBlock.children(0).toString == expected1.toString) assert(transformedBlock.children(1).toString == expected2.toString) assert(transformedBlock.children(2).toString == expected3.toString) assert(transformedBlock.toString == (expected1 + expected2 + expected3).toString) - assert(transformedBlock.exprValues === Set(isNull, exprInFunc, aliasedParam)) + assert(exprValues === Set(isNull, exprInFunc, aliasedParam)) } } From c9dee44c6b4e9fdd41b2460fbd331f8f0e9235f4 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 5 Jul 2018 07:40:00 +0000 Subject: [PATCH 20/20] Revert LiteralExpr. --- .../sql/catalyst/expressions/codegen/javaCode.scala | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index 33516f507ade2..2f8c853e836ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -43,7 +43,7 @@ object JavaCode { def literal(v: String, dataType: DataType): LiteralValue = dataType match { case BooleanType if v == "true" => TrueLiteral case BooleanType if v == "false" => FalseLiteral - case _ => new LiteralExpr(v, CodeGenerator.javaClass(dataType)) + case _ => new LiteralValue(v, CodeGenerator.javaClass(dataType)) } /** @@ -51,7 +51,7 @@ object JavaCode { * -1 for other primitive types. */ def defaultLiteral(dataType: DataType): LiteralValue = { - new LiteralExpr( + new LiteralValue( CodeGenerator.defaultValue(dataType, typedNull = true), CodeGenerator.javaClass(dataType)) } @@ -304,8 +304,7 @@ case class GlobalValue(value: String, javaType: Class[_]) extends ExprValue { /** * A literal java expression. */ -abstract class LiteralValue(val value: String, val javaType: Class[_]) - extends ExprValue with Serializable { +class LiteralValue(val value: String, val javaType: Class[_]) extends ExprValue with Serializable { override def code: String = value override def equals(arg: Any): Boolean = arg match { @@ -316,8 +315,5 @@ abstract class LiteralValue(val value: String, val javaType: Class[_]) override def hashCode(): Int = value.hashCode() * 31 + javaType.hashCode() } -case class LiteralExpr(override val value: String, override val javaType: Class[_]) - extends LiteralValue(value, javaType) - case object TrueLiteral extends LiteralValue("true", JBool.TYPE) case object FalseLiteral extends LiteralValue("false", JBool.TYPE)