Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
1df9943
Add API for handling expression code generation.
viirya Apr 30, 2018
5fe425c
Add new abstraction for expression codegen.
viirya May 1, 2018
00bef6b
Add basic tests.
viirya May 3, 2018
5d9c454
Merge remote-tracking branch 'upstream/master' into SPARK-24121
viirya May 3, 2018
162deb2
Deal merging conflict.
viirya May 3, 2018
d138ee0
Address comments and add more tests.
viirya May 4, 2018
ee9a4c0
Address comments.
viirya May 5, 2018
e7cfa28
Remove JavaCode.block. We should always use code string interpolator …
viirya May 5, 2018
5945c15
We should not implicitly convert code block to string. Otherwise we m…
viirya May 5, 2018
2b30654
Address comment and trim expected code string.
viirya May 5, 2018
aff411b
Remove unused import.
viirya May 8, 2018
53b329a
Address comments.
viirya May 8, 2018
72faac3
Address some comments.
viirya May 9, 2018
ffbf4ab
Merge remote-tracking branch 'upstream/master' into SPARK-24121
viirya May 17, 2018
d040676
Use code block for newly merged codegen.
viirya May 17, 2018
c378ce2
Use Set as method exprValues method returning type.
viirya May 17, 2018
2ca9741
Address comments.
viirya May 19, 2018
d91f111
Merge remote-tracking branch 'upstream/master' into SPARK-24121
viirya May 19, 2018
4b49e8a
Merge remote-tracking branch 'upstream/master' into SPARK-24121
viirya May 22, 2018
96c594a
Use Java call style and use JavaCode instead of Any.
viirya May 22, 2018
00cc564
Merge remote-tracking branch 'upstream/master' into SPARK-24121
viirya May 22, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -56,13 +57,13 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
if (nullable) {
ev.copy(code =
s"""
code"""
|boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal);
|$javaType ${ev.value} = ${ev.isNull} ?
| ${CodeGenerator.defaultValue(dataType)} : ($value);
""".stripMargin)
} else {
ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = FalseLiteral)
ev.copy(code = code"$javaType ${ev.value} = $value;", isNull = FalseLiteral)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
Expand Down Expand Up @@ -623,8 +624,13 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val eval = child.genCode(ctx)
val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx)
ev.copy(code = eval.code +
castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast))

ev.copy(code =
code"""
${eval.code}
// This comment is added for manually tracking reference of ${eval.value}, ${eval.isNull}
${castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast)}
""")
}

// The function arguments are: `input`, `result` and `resultIsNull`. We don't need `inputIsNull`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.util.Locale
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -108,9 +109,9 @@ abstract class Expression extends TreeNode[Expression] {
JavaCode.isNullVariable(isNull),
JavaCode.variable(value, dataType)))
reduceCodeSize(ctx, eval)
if (eval.code.nonEmpty) {
if (eval.code.toString.nonEmpty) {
// Add `this` in the comment.
eval.copy(code = s"${ctx.registerComment(this.toString)}\n" + eval.code.trim)
eval.copy(code = ctx.registerComment(this.toString) + eval.code)
} else {
eval
}
Expand All @@ -119,7 +120,7 @@ abstract class Expression extends TreeNode[Expression] {

private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = {
// TODO: support whole stage codegen too
if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) {
if (eval.code.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) {
val setIsNull = if (!eval.isNull.isInstanceOf[LiteralValue]) {
val globalIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "globalIsNull")
val localIsNull = eval.isNull
Expand All @@ -136,14 +137,14 @@ abstract class Expression extends TreeNode[Expression] {
val funcFullName = ctx.addNewFunction(funcName,
s"""
|private $javaType $funcName(InternalRow ${ctx.INPUT_ROW}) {
| ${eval.code.trim}
| ${eval.code}
| $setIsNull
| return ${eval.value};
|}
""".stripMargin)

eval.value = JavaCode.variable(newValue, dataType)
eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});"
eval.code = code"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});"
}
}

Expand Down Expand Up @@ -437,15 +438,14 @@ abstract class UnaryExpression extends Expression {

if (nullable) {
val nullSafeEval = ctx.nullSafeExec(child.nullable, childGen.isNull)(resultCode)
ev.copy(code = s"""
ev.copy(code = code"""
${childGen.code}
boolean ${ev.isNull} = ${childGen.isNull};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$nullSafeEval
""")
} else {
ev.copy(code = s"""
boolean ${ev.isNull} = false;
ev.copy(code = code"""
${childGen.code}
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$resultCode""", isNull = FalseLiteral)
Expand Down Expand Up @@ -537,14 +537,13 @@ abstract class BinaryExpression extends Expression {
}
}

ev.copy(code = s"""
ev.copy(code = code"""
boolean ${ev.isNull} = true;
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$nullSafeEval
""")
} else {
ev.copy(code = s"""
boolean ${ev.isNull} = false;
ev.copy(code = code"""
${leftGen.code}
${rightGen.code}
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
Expand Down Expand Up @@ -681,13 +680,12 @@ abstract class TernaryExpression extends Expression {
}
}

ev.copy(code = s"""
ev.copy(code = code"""
boolean ${ev.isNull} = true;
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$nullSafeEval""")
} else {
ev.copy(code = s"""
boolean ${ev.isNull} = false;
ev.copy(code = code"""
${leftGen.code}
${midGen.code}
${rightGen.code}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.{DataType, LongType}

/**
Expand Down Expand Up @@ -72,7 +73,7 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Stateful {
ctx.addPartitionInitializationStatement(s"$countTerm = 0L;")
ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;")

ev.copy(code = s"""
ev.copy(code = code"""
final ${CodeGenerator.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm;
$countTerm++;""", isNull = FalseLiteral)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.DataType

/**
Expand Down Expand Up @@ -1030,7 +1031,7 @@ case class ScalaUDF(
""".stripMargin

ev.copy(code =
s"""
code"""
|$evalCode
|${initArgs.mkString("\n")}
|$callFunc
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.unsafe.sort.PrefixComparators._

Expand Down Expand Up @@ -181,7 +182,7 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression {
}

ev.copy(code = childCode.code +
s"""
code"""
|long ${ev.value} = 0L;
|boolean ${ev.isNull} = ${childCode.isNull};
|if (!${childCode.isNull}) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.{DataType, IntegerType}

/**
Expand Down Expand Up @@ -46,7 +47,7 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic {
val idTerm = "partitionId"
ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_INT, idTerm)
ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;")
ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = $idTerm;",
ev.copy(code = code"final ${CodeGenerator.javaType(dataType)} ${ev.value} = $idTerm;",
isNull = FalseLiteral)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval

Expand Down Expand Up @@ -164,7 +165,7 @@ case class PreciseTimestampConversion(
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val eval = child.genCode(ctx)
ev.copy(code = eval.code +
s"""boolean ${ev.isNull} = ${eval.isNull};
code"""boolean ${ev.isNull} = ${eval.isNull};
|${CodeGenerator.javaType(dataType)} ${ev.value} = ${eval.value};
""".stripMargin)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
Expand Down Expand Up @@ -259,7 +260,7 @@ trait DivModLike extends BinaryArithmetic {
s"($javaType)(${eval1.value} $symbol ${eval2.value})"
}
if (!left.nullable && !right.nullable) {
ev.copy(code = s"""
ev.copy(code = code"""
${eval2.code}
boolean ${ev.isNull} = false;
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
Expand All @@ -270,7 +271,7 @@ trait DivModLike extends BinaryArithmetic {
${ev.value} = $operation;
}""")
} else {
ev.copy(code = s"""
ev.copy(code = code"""
${eval2.code}
boolean ${ev.isNull} = false;
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
Expand Down Expand Up @@ -436,7 +437,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
}

if (!left.nullable && !right.nullable) {
ev.copy(code = s"""
ev.copy(code = code"""
${eval2.code}
boolean ${ev.isNull} = false;
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
Expand All @@ -447,7 +448,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
$result
}""")
} else {
ev.copy(code = s"""
ev.copy(code = code"""
${eval2.code}
boolean ${ev.isNull} = false;
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
Expand Down Expand Up @@ -569,7 +570,7 @@ case class Least(children: Seq[Expression]) extends Expression {
""".stripMargin,
foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n"))
ev.copy(code =
s"""
code"""
|${ev.isNull} = true;
|$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|$codes
Expand Down Expand Up @@ -644,7 +645,7 @@ case class Greatest(children: Seq[Expression]) extends Expression {
""".stripMargin,
foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n"))
ev.copy(code =
s"""
code"""
|${ev.isNull} = true;
|$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|$codes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.metrics.source.CodegenMetrics
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand All @@ -57,19 +58,19 @@ import org.apache.spark.util.{ParentClassLoader, Utils}
* @param value A term for a (possibly primitive) value of the result of the evaluation. Not
* valid if `isNull` is set to `true`.
*/
case class ExprCode(var code: String, var isNull: ExprValue, var value: ExprValue)
case class ExprCode(var code: Block, var isNull: ExprValue, var value: ExprValue)

object ExprCode {
def apply(isNull: ExprValue, value: ExprValue): ExprCode = {
ExprCode(code = "", isNull, value)
ExprCode(code = EmptyBlock, isNull, value)
}

def forNullValue(dataType: DataType): ExprCode = {
ExprCode(code = "", isNull = TrueLiteral, JavaCode.defaultLiteral(dataType))
ExprCode(code = EmptyBlock, isNull = TrueLiteral, JavaCode.defaultLiteral(dataType))
}

def forNonNullValue(value: ExprValue): ExprCode = {
ExprCode(code = "", isNull = FalseLiteral, value = value)
ExprCode(code = EmptyBlock, isNull = FalseLiteral, value = value)
}
}

Expand Down Expand Up @@ -330,9 +331,9 @@ class CodegenContext {
def addBufferedState(dataType: DataType, variableName: String, initCode: String): ExprCode = {
val value = addMutableState(javaType(dataType), variableName)
val code = dataType match {
case StringType => s"$value = $initCode.clone();"
case _: StructType | _: ArrayType | _: MapType => s"$value = $initCode.copy();"
case _ => s"$value = $initCode;"
case StringType => code"$value = $initCode.clone();"
case _: StructType | _: ArrayType | _: MapType => code"$value = $initCode.copy();"
case _ => code"$value = $initCode;"
}
ExprCode(code, FalseLiteral, JavaCode.global(value, dataType))
}
Expand Down Expand Up @@ -1056,7 +1057,7 @@ class CodegenContext {
val eval = expr.genCode(this)
val state = SubExprEliminationState(eval.isNull, eval.value)
e.foreach(localSubExprEliminationExprs.put(_, state))
eval.code.trim
eval.code.toString
}
SubExprCodes(codes, localSubExprEliminationExprs.toMap)
}
Expand Down Expand Up @@ -1084,7 +1085,7 @@ class CodegenContext {
val fn =
s"""
|private void $fnName(InternalRow $INPUT_ROW) {
| ${eval.code.trim}
| ${eval.code}
| $isNull = ${eval.isNull};
| $value = ${eval.value};
|}
Expand Down Expand Up @@ -1141,7 +1142,7 @@ class CodegenContext {
def registerComment(
text: => String,
placeholderId: String = "",
force: Boolean = false): String = {
force: Boolean = false): Block = {
// By default, disable comments in generated code because computing the comments themselves can
// be extremely expensive in certain cases, such as deeply-nested expressions which operate over
// inputs with wide schemas. For more details on the performance issues that motivated this
Expand All @@ -1160,9 +1161,9 @@ class CodegenContext {
s"// $text"
}
placeHolderToComments += (name -> comment)
s"/*$name*/"
code"/*$name*/"
} else {
""
EmptyBlock
}
}
}
Expand Down
Loading