Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -262,37 +262,97 @@ case class CaseWhenCodegen(
// }
// }
// }

val isNull = ctx.freshName("caseWhenIsNull")
val value = ctx.freshName("caseWhenValue")

val cases = branches.map { case (condExpr, valueExpr) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we follow what we did for And and Or, and just check the code length at the beginning? TBH I don't understand your change after several minutes reading.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For CaseWhen, the code bloat occurs in one case class. The CaseWhenCodegen.doGenCode can generate deeply-nested if-then-else statements as above in the comment. Each element in cases has only a if-then. Thus, it is not possible to insert code check here. Since And and Or generates deeply nested if-then-else by calling doGenCode many times, to check code size here works well.

This line generates the nested if-then-else. Thus, after this line, code size check is performed.

What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only codegen CASE WHEN if the case branches are less than 20, I think check code size here is good enough.

Copy link
Member Author

@kiszk kiszk Nov 14, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Do you want to add such a code?

val cases = ...
val genCode = if (cases.map(s => s.length).sum <= 1024) {
  cases.mkString("\nelse {\n")
} else {
  // current code
  var isGlobalVariable = false
  ...
  generatedCode
}

val cond = condExpr.genCode(ctx)
val res = valueExpr.genCode(ctx)
val (condFunc, condIsNull, condValue) = genCodeForExpression(ctx, condExpr)
val (resFunc, resIsNull, resValue) = genCodeForExpression(ctx, valueExpr)
s"""
${cond.code}
if (!${cond.isNull} && ${cond.value}) {
${res.code}
${ev.isNull} = ${res.isNull};
${ev.value} = ${res.value};
${condFunc}
if (!${condIsNull} && ${condValue}) {
${resFunc}
$isNull = ${resIsNull};
$value = ${resValue};
}
"""
}

var generatedCode = cases.mkString("", "\nelse {\n", "\nelse {\n")
var isGlobalVariable = false
val (generatedIfThenElse, numBrankets) = if (cases.map(s => s.length).sum <= 1024) {
(cases.mkString("", "\nelse {\n", "\nelse {\n"), cases.length)
} else {
var numIfThen = 0
var code = ""
cases.foreach { ifThen =>
code += ifThen + "\nelse {\n"
numIfThen += 1

if (code.length > 1024 &&
// Split these expressions only if they are created from a row object
(ctx.INPUT_ROW != null && ctx.currentVars == null)) {
val flag = "flag"
code += s" $flag = false;\n" + "}\n" * numIfThen
val funcName = ctx.freshName("caseWhenNestedIf")
val funcBody =
s"""
|private boolean $funcName(InternalRow ${ctx.INPUT_ROW}) {
| boolean $flag = true;
| $code
| return $flag;
|}
""".stripMargin
val fullFuncName = ctx.addNewFunction(funcName, funcBody)
isGlobalVariable = true

code = s"if ($fullFuncName(${ctx.INPUT_ROW})) {\n// do nothing\n} else {\n"
numIfThen = 1
}
}
(code, numIfThen)
}

var generatedCode = generatedIfThenElse
elseValue.foreach { elseExpr =>
val res = elseExpr.genCode(ctx)
generatedCode +=
s"""
${res.code}
${ev.isNull} = ${res.isNull};
${ev.value} = ${res.value};
"""
val (resFunc, resIsNull, resValue) = genCodeForExpression(ctx, elseExpr)
generatedCode += s"""
${resFunc}
$isNull = ${resIsNull};
$value = ${resValue};
"""
}

generatedCode += "}\n" * cases.size
generatedCode += "}\n" * numBrankets

ev.copy(code = s"""
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
$generatedCode""")
if (!isGlobalVariable) {
ev.copy(s"""
boolean $isNull = true;
${ctx.javaType(dataType)} $value = ${ctx.defaultValue(dataType)};
$generatedCode
""", isNull, value)
} else {
ctx.addMutableState("boolean", isNull, s"$isNull = false;")
ctx.addMutableState(ctx.javaType(dataType), value,
s"$value = ${ctx.defaultValue(dataType)};")
ev.copy(code = s"""
$generatedCode
boolean ${ev.isNull} = $isNull;
${ctx.javaType(dataType)} ${ev.value} = $value;
""")
}
}

def genCodeForExpression(ctx: CodegenContext, expression: Expression):
(String, String, String) = {
val ev = expression.genCode(ctx)
if (ev.code.length > 1024 && (ctx.INPUT_ROW != null && ctx.currentVars == null)) {
val (funcName, globalIsNull, globalValue) =
ctx.createAndAddFunction(ev, expression.dataType, "caseWhenElseExpr")
(s"$funcName(${ctx.INPUT_ROW});", globalIsNull, globalValue)
} else {
(ev.code, ev.isNull, ev.value)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -380,4 +380,68 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
s"Incorrect Evaluation: expressions: $exprAnd, actual: $actualAnd, expected: $expectedAnd")
}
}

test("SPARK-21413: split large case when into blocks due to JVM code size limit") {
val expectedInt = -2
var exprInt: Expression = BoundReference(0, IntegerType, true)
val expectedStr = UTF8String.fromString("abc")
val exprStr: Expression = BoundReference(0, StringType, true)

// Code size of condition or then expression is large
var expr1 = exprInt
for (i <- 1 to 10) {
expr1 = CaseWhen(Seq((EqualTo(expr1, Literal(i)), Literal(-1))), expr1).toCodegen()
}
val plan1 = GenerateMutableProjection.generate(Seq(expr1))
val row1 = new GenericInternalRow(Array[Any](1))
row1.setInt(0, expectedInt)
val actual1 = plan1(row1).toSeq(Seq(expr1.dataType))
assert(actual1.length == 1)
val result1 = actual1(0)
if (!checkResult(result1, expectedInt, expr1.dataType)) {
fail(s"Incorrect Evaluation: expressions: $expr1, actual: $result1, expected: $expectedInt")
}

// Code size of else expression is large
var expr2 = exprStr
for (i <- 1 to 512) {
expr2 = CaseWhen(Seq((EqualTo(exprStr, Literal(s"def$i")), Literal(s"xyz$i"))), expr2)
.toCodegen()
}
val plan2 = GenerateMutableProjection.generate(Seq(expr2))
val row2 = new GenericInternalRow(Array[Any](1))
row2.update(0, expectedStr)
val actual2 = plan2(row2).toSeq(Seq(expr2.dataType))
assert(actual2.length == 1)
val result2 = actual2(0)
if (!checkResult(result2, expectedStr, expr2.dataType)) {
fail(s"Incorrect Evaluation: expressions: $expr2, actual: $result2, expected: $expectedStr")
}

// total code size of conditional branches is large
val cases = (1 to 512).map(i => (EqualTo(exprStr, Literal(s"def$i")), Literal(s"xyz$i")))
val expr3 = CaseWhen(cases, exprStr).toCodegen()
val plan3 = GenerateMutableProjection.generate(Seq(expr3))
val row3 = new GenericInternalRow(Array[Any](1))
row3.update(0, expectedStr)
val actual3 = plan3(row3).toSeq(Seq(expr3.dataType))
assert(actual3.length == 1)
val result3 = actual3(0)
if (!checkResult(result3, expectedStr, expr3.dataType)) {
fail(s"Incorrect Evaluation: expressions: $expr3, actual: $result3, expected: $expectedStr")
}

// total code size is small
val cases4 = Seq((EqualTo(exprStr, Literal("def")), Literal("xyz")))
val expr4 = CaseWhen(cases4, exprStr).toCodegen()
val plan4 = GenerateMutableProjection.generate(Seq(expr4))
val row4 = new GenericInternalRow(Array[Any](1))
row4.update(0, expectedStr)
val actual4 = plan4(row4).toSeq(Seq(expr4.dataType))
assert(actual4.length == 1)
val result4 = actual4(0)
if (!checkResult(result4, expectedStr, expr4.dataType)) {
fail(s"Incorrect Evaluation: expressions: $expr4, actual: $result4, expected: $expectedStr")
}
}
}
10 changes: 10 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2158,4 +2158,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val mean = result.select("DecimalCol").where($"summary" === "mean")
assert(mean.collect().toSet === Set(Row("0.0345678900000000000000000000000000000")))
}

// ignore end-to-end test since sbt test does not go to fallback path in whole-stage codegen
ignore("SPARK-21413: Multiple projections with CASE WHEN fails") {
val schema = StructType(StructField("a", IntegerType) :: Nil)
var df = spark.createDataFrame(sparkContext.parallelize(Seq(Row(1))), schema)
for (i <- 1 to 10) {
df = df.withColumn("a", when($"a" === 0, null).otherwise($"a"))
}
checkAnswer(df, Row(1))
}
}