From a26aafee913ac1974a538f8b1d5bc372386e7b00 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sat, 15 Jul 2017 17:25:20 +0900 Subject: [PATCH 01/12] initial commit --- .../expressions/conditionalExpressions.scala | 30 +++++++++++++++---- .../expressions/CodeGenerationSuite.scala | 19 ++++++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 18 +++++++++++ 3 files changed, 61 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index c41a10c7b0f87..6f6cbf542e6af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -111,7 +111,6 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi ev.copy(code = generatedCode) } - override def toString: String = s"if ($predicate) $trueValue else $falseValue" override def sql: String = s"(IF(${predicate.sql}, ${trueValue.sql}, ${falseValue.sql}))" @@ -265,12 +264,31 @@ case class CaseWhenCodegen( val cases = branches.map { case (condExpr, valueExpr) => val cond = condExpr.genCode(ctx) val res = valueExpr.genCode(ctx) + val (condFunc, condIsNull, condValue) = if ((cond.code.length >= 512) && + // Split these expressions only if they are created from a row object + (ctx.INPUT_ROW != null && ctx.currentVars == null)) { + val (funcName, globalIsNull, globalValue) = + CondExpression.createAndAddFunction(ctx, cond, condExpr.dataType, "caseWhenCondExpr") + (s"$funcName(${ctx.INPUT_ROW});", globalIsNull, globalValue) + } else { + (cond.code, cond.isNull, cond.value) + } + val (resFunc, resIsNull, resValue) = if ((res.code.length >= 512) && + // Split these expressions only if they are created from a row object + (ctx.INPUT_ROW != null && ctx.currentVars == null)) { + val (funcName, globalIsNull, globalValue) = + CondExpression.createAndAddFunction(ctx, res, valueExpr.dataType, "caseWhenResExpr") + (s"$funcName(${ctx.INPUT_ROW});", globalIsNull, globalValue) + } else { + (res.code, res.isNull, res.value) + } + s""" - ${cond.code} - if (!${cond.isNull} && ${cond.value}) { - ${res.code} - ${ev.isNull} = ${res.isNull}; - ${ev.value} = ${res.value}; + ${condFunc} + if (!${condIsNull} && ${condValue}) { + ${resFunc} + ${ev.isNull} = ${resIsNull}; + ${ev.value} = ${resValue}; } """ } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 8f6289f00571c..76e02a41a271b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -342,6 +342,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { projection(row) } +<<<<<<< HEAD test("SPARK-21720: split large predications into blocks due to JVM code size limit") { val length = 600 @@ -380,4 +381,22 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { s"Incorrect Evaluation: expressions: $exprAnd, actual: $actualAnd, expected: $expectedAnd") } } + + test("SPARK-21413: split large case when into blocks due to JVM code size limit") { + val expected = 2 + var expr: Expression = BoundReference(0, IntegerType, true) + for (_ <- 1 to 10) { + expr = CaseWhen(Seq((EqualTo(expr, Literal(0)), Literal(-1))), expr) + .toCodegen() + } + val plan = GenerateMutableProjection.generate(Seq(expr)) + val row = new GenericInternalRow(Array[Any](1)) + row.setInt(0, expected) + val actual = plan(row).toSeq(Seq(expr.dataType)) + assert(actual.length == 1) + + if (!checkResult(actual(0), expected, expr.dataType)) { + fail(s"Incorrect Evaluation: expressions: $expr, actual: ${actual(0)}, expected: $expected") + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 644e72c893ceb..9bf6ffb1ee059 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2132,6 +2132,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } } +<<<<<<< HEAD test("order-by ordinal.") { checkAnswer( testData2.select(lit(7), 'a, 'b).orderBy(lit(1), lit(2), lit(3)), @@ -2158,4 +2159,21 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val mean = result.select("DecimalCol").where($"summary" === "mean") assert(mean.collect().toSet === Set(Row("0.0345678900000000000000000000000000000"))) } + + testQuietly("SPARK-21413: Multiple projections with CASE WHEN fails") { + val schema = StructType(StructField("a", IntegerType) :: Nil) + val df = spark.createDataFrame(sparkContext.parallelize(Seq(Row(1))), schema) + val df1 = + df.withColumn("a", when($"a" === 0, null).otherwise($"a")) + .withColumn("a", when($"a" === 0, null).otherwise($"a")) + .withColumn("a", when($"a" === 0, null).otherwise($"a")) + .withColumn("a", when($"a" === 0, null).otherwise($"a")) + .withColumn("a", when($"a" === 0, null).otherwise($"a")) + .withColumn("a", when($"a" === 0, null).otherwise($"a")) + .withColumn("a", when($"a" === 0, null).otherwise($"a")) + .withColumn("a", when($"a" === 0, null).otherwise($"a")) + .withColumn("a", when($"a" === 0, null).otherwise($"a")) + .withColumn("a", when($"a" === 0, null).otherwise($"a")) + checkAnswer(df1, Row(1)) + } } From 6f0de18dff021f856ae79363c3d0947ba14140df Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 16 Jul 2017 01:15:07 +0900 Subject: [PATCH 02/12] remove end-to-end test since sbt test does not go to fallback path in whole-stage codegen --- .../src/test/scala/org/apache/spark/sql/DataFrameSuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 9bf6ffb1ee059..52cee5f797446 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2132,7 +2132,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } } -<<<<<<< HEAD test("order-by ordinal.") { checkAnswer( testData2.select(lit(7), 'a, 'b).orderBy(lit(1), lit(2), lit(3)), From 7c78cf2b0199998b0c399832dd14d25f58850c6f Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 17 Jul 2017 15:25:23 +0900 Subject: [PATCH 03/12] Changed a decision logic to split it into methods --- .../expressions/conditionalExpressions.scala | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 6f6cbf542e6af..dfd6b657c216c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -264,23 +264,18 @@ case class CaseWhenCodegen( val cases = branches.map { case (condExpr, valueExpr) => val cond = condExpr.genCode(ctx) val res = valueExpr.genCode(ctx) - val (condFunc, condIsNull, condValue) = if ((cond.code.length >= 512) && + val (condFunc, condIsNull, condValue, resFunc, resIsNull, resValue ) = + if ((cond.code.length + res.code.length) > 1024 && // Split these expressions only if they are created from a row object (ctx.INPUT_ROW != null && ctx.currentVars == null)) { - val (funcName, globalIsNull, globalValue) = + val (condFuncName, condGlobalIsNull, condGlobalValue) = CondExpression.createAndAddFunction(ctx, cond, condExpr.dataType, "caseWhenCondExpr") - (s"$funcName(${ctx.INPUT_ROW});", globalIsNull, globalValue) - } else { - (cond.code, cond.isNull, cond.value) - } - val (resFunc, resIsNull, resValue) = if ((res.code.length >= 512) && - // Split these expressions only if they are created from a row object - (ctx.INPUT_ROW != null && ctx.currentVars == null)) { - val (funcName, globalIsNull, globalValue) = + val (resFuncName, resGlobalIsNull, resGlobalValue) = CondExpression.createAndAddFunction(ctx, res, valueExpr.dataType, "caseWhenResExpr") - (s"$funcName(${ctx.INPUT_ROW});", globalIsNull, globalValue) + (s"$condFuncName(${ctx.INPUT_ROW});", condGlobalIsNull, condGlobalValue, + s"$resFuncName(${ctx.INPUT_ROW});", resGlobalIsNull, resGlobalValue) } else { - (res.code, res.isNull, res.value) + (cond.code, cond.isNull, cond.value, res.code, res.isNull, res.value) } s""" From 2e01427537ffdae5a3163fb88b16e18bf3df621a Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 19 Jul 2017 18:12:59 +0900 Subject: [PATCH 04/12] address review comment --- .../expressions/conditionalExpressions.scala | 102 ++++++++++++------ .../expressions/CodeGenerationSuite.scala | 55 ++++++++-- 2 files changed, 115 insertions(+), 42 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index dfd6b657c216c..1111ec937ca88 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -261,51 +261,91 @@ case class CaseWhenCodegen( // } // } // } - val cases = branches.map { case (condExpr, valueExpr) => - val cond = condExpr.genCode(ctx) - val res = valueExpr.genCode(ctx) - val (condFunc, condIsNull, condValue, resFunc, resIsNull, resValue ) = - if ((cond.code.length + res.code.length) > 1024 && - // Split these expressions only if they are created from a row object - (ctx.INPUT_ROW != null && ctx.currentVars == null)) { - val (condFuncName, condGlobalIsNull, condGlobalValue) = - CondExpression.createAndAddFunction(ctx, cond, condExpr.dataType, "caseWhenCondExpr") - val (resFuncName, resGlobalIsNull, resGlobalValue) = - CondExpression.createAndAddFunction(ctx, res, valueExpr.dataType, "caseWhenResExpr") - (s"$condFuncName(${ctx.INPUT_ROW});", condGlobalIsNull, condGlobalValue, - s"$resFuncName(${ctx.INPUT_ROW});", resGlobalIsNull, resGlobalValue) - } else { - (cond.code, cond.isNull, cond.value, res.code, res.isNull, res.value) - } + val isNull = ctx.freshName("caseWhenIsNull") + val value = ctx.freshName("caseWhenValue") + // Split these expressions only if they are created from a row object + + val cases = branches.map { case (condExpr, valueExpr) => + val (condFunc, condIsNull, condValue) = genCodeForExpression(ctx, condExpr) + val (resFunc, resIsNull, resValue) = genCodeForExpression(ctx, valueExpr) s""" ${condFunc} if (!${condIsNull} && ${condValue}) { ${resFunc} - ${ev.isNull} = ${resIsNull}; - ${ev.value} = ${resValue}; + $isNull = ${resIsNull}; + $value = ${resValue}; } """ } - var generatedCode = cases.mkString("", "\nelse {\n", "\nelse {\n") + var isGlobalVariable = false + var generatedCode = "" + var numIfthen = 0 + cases.foreach { ifthen => + generatedCode += ifthen + "\nelse {\n" + numIfthen += 1 + + if (generatedCode.length > 1024) { + val flag = "flag" + generatedCode += s" $flag = false;\n" + "}\n" * numIfthen + val funcName = ctx.freshName("caseWhenNestedIf") + val funcBody = + s""" + |private boolean $funcName(InternalRow ${ctx.INPUT_ROW}) { + | boolean $flag = true; + | $generatedCode + | return $flag; + |} + """.stripMargin + val fullFuncName = ctx.addNewFunction(funcName, funcBody) + isGlobalVariable = true + + generatedCode = s"if ($funcName(${ctx.INPUT_ROW})) {\n// do nothing\n} else {\n" + numIfthen = 1 + } + } elseValue.foreach { elseExpr => - val res = elseExpr.genCode(ctx) - generatedCode += - s""" - ${res.code} - ${ev.isNull} = ${res.isNull}; - ${ev.value} = ${res.value}; - """ + val (resFunc, resIsNull, resValue) = genCodeForExpression(ctx, elseExpr) + generatedCode += s""" + ${resFunc} + $isNull = ${resIsNull}; + $value = ${resValue}; + """ } - generatedCode += "}\n" * cases.size + generatedCode += "}\n" * numIfthen - ev.copy(code = s""" - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - $generatedCode""") + if (!isGlobalVariable) { + ev.copy(s""" + boolean $isNull = true; + ${ctx.javaType(dataType)} $value = ${ctx.defaultValue(dataType)}; + $generatedCode + """, isNull, value) + } else { + ctx.addMutableState("boolean", isNull, s"$isNull = false;") + ctx.addMutableState(ctx.javaType(dataType), value, + s"$value = ${ctx.defaultValue(dataType)};") + ev.copy(code = s""" + $generatedCode + boolean ${ev.isNull} = $isNull; + ${ctx.javaType(dataType)} ${ev.value} = $value; + """) + } + } + + def genCodeForExpression(ctx: CodegenContext, expression: Expression): + (String, String, String) = { + val ev = expression.genCode(ctx) + if (ev.code.length > 1024 && (ctx.INPUT_ROW != null && ctx.currentVars == null)) { + val (funcName, globalIsNull, globalValue) = + CondExpression.createAndAddFunction(ctx, ev, expression.dataType, + "caseWhenElseExpr") + (s"$funcName(${ctx.INPUT_ROW});", globalIsNull, globalValue) + } else { + (ev.code, ev.isNull, ev.value) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 76e02a41a271b..1ef17d71e3ebb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -383,20 +383,53 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { } test("SPARK-21413: split large case when into blocks due to JVM code size limit") { - val expected = 2 - var expr: Expression = BoundReference(0, IntegerType, true) - for (_ <- 1 to 10) { - expr = CaseWhen(Seq((EqualTo(expr, Literal(0)), Literal(-1))), expr) + val expectedInt = -2 + var exprInt: Expression = BoundReference(0, IntegerType, true) + val expectedStr = UTF8String.fromString("abc") + val exprStr: Expression = BoundReference(0, StringType, true) + + // Code size of condition or then expression is large + var expr1 = exprInt + for (i <- 1 to 10) { + expr1 = CaseWhen(Seq((EqualTo(expr1, Literal(i)), Literal(-1))), expr1).toCodegen() + } + val plan1 = GenerateMutableProjection.generate(Seq(expr1)) + val row1 = new GenericInternalRow(Array[Any](1)) + row1.setInt(0, expectedInt) + val actual1 = plan1(row1).toSeq(Seq(expr1.dataType)) + assert(actual1.length == 1) + val result1 = actual1(0) + if (!checkResult(result1, expectedInt, expr1.dataType)) { + fail(s"Incorrect Evaluation: expressions: $expr1, actual: $result1, expected: $expectedInt") + } + + // Code size of condition or then expression is large + var expr2 = exprStr + for (i <- 1 to 512) { + expr2 = CaseWhen(Seq((EqualTo(exprStr, Literal(s"def$i")), Literal(s"xyz$i"))), expr2) .toCodegen() } - val plan = GenerateMutableProjection.generate(Seq(expr)) - val row = new GenericInternalRow(Array[Any](1)) - row.setInt(0, expected) - val actual = plan(row).toSeq(Seq(expr.dataType)) - assert(actual.length == 1) + val plan2 = GenerateMutableProjection.generate(Seq(expr2)) + val row2 = new GenericInternalRow(Array[Any](1)) + row2.update(0, expectedStr) + val actual2 = plan2(row2).toSeq(Seq(expr2.dataType)) + assert(actual2.length == 1) + val result2 = actual2(0) + if (!checkResult(result2, expectedStr, expr2.dataType)) { + fail(s"Incorrect Evaluation: expressions: $expr2, actual: $result2, expected: $expectedStr") + } - if (!checkResult(actual(0), expected, expr.dataType)) { - fail(s"Incorrect Evaluation: expressions: $expr, actual: ${actual(0)}, expected: $expected") + // Code size of total conditional branches is large + val cases = (1 to 512).map(i => (EqualTo(exprStr, Literal(s"def$i")), Literal(s"xyz$i"))) + val expr3 = CaseWhen(cases, exprStr).toCodegen() + val plan3 = GenerateMutableProjection.generate(Seq(expr3)) + val row3 = new GenericInternalRow(Array[Any](1)) + row3.update(0, expectedStr) + val actual3 = plan3(row3).toSeq(Seq(expr3.dataType)) + assert(actual3.length == 1) + val result3 = actual3(0) + if (!checkResult(result3, expectedStr, expr3.dataType)) { + fail(s"Incorrect Evaluation: expressions: $expr3, actual: $result3, expected: $expectedStr") } } } From fb62f3c57f9868a693b3149d33351dd44cf5409a Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 19 Jul 2017 19:38:19 +0900 Subject: [PATCH 05/12] fix comments --- .../spark/sql/catalyst/expressions/CodeGenerationSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 1ef17d71e3ebb..6e85c6e8ca141 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -403,7 +403,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { fail(s"Incorrect Evaluation: expressions: $expr1, actual: $result1, expected: $expectedInt") } - // Code size of condition or then expression is large + // Code size of else expression is large var expr2 = exprStr for (i <- 1 to 512) { expr2 = CaseWhen(Seq((EqualTo(exprStr, Literal(s"def$i")), Literal(s"xyz$i"))), expr2) @@ -419,7 +419,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { fail(s"Incorrect Evaluation: expressions: $expr2, actual: $result2, expected: $expectedStr") } - // Code size of total conditional branches is large + // total code size of conditional branches is large val cases = (1 to 512).map(i => (EqualTo(exprStr, Literal(s"def$i")), Literal(s"xyz$i"))) val expr3 = CaseWhen(cases, exprStr).toCodegen() val plan3 = GenerateMutableProjection.generate(Seq(expr3)) From 85bee25d1adcaf527ca7695937864e7f7ac110d8 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 19 Jul 2017 23:30:56 +0900 Subject: [PATCH 06/12] fix failure of HiveCompatibilitySuite.ppr_allchildsarenull --- .../spark/sql/catalyst/expressions/conditionalExpressions.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 1111ec937ca88..e7332a95b87f6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -286,7 +286,7 @@ case class CaseWhenCodegen( generatedCode += ifthen + "\nelse {\n" numIfthen += 1 - if (generatedCode.length > 1024) { + if (generatedCode.length > 1024 && (ctx.INPUT_ROW != null && ctx.currentVars == null)) { val flag = "flag" generatedCode += s" $flag = false;\n" + "}\n" * numIfthen val funcName = ctx.freshName("caseWhenNestedIf") From 3cb95aa2a3ac8968f1085d11cdd880d0a5179558 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sun, 30 Jul 2017 00:45:09 +0900 Subject: [PATCH 07/12] address review comment --- .../sql/catalyst/expressions/conditionalExpressions.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index e7332a95b87f6..0e96e5f2eb8fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -264,7 +264,6 @@ case class CaseWhenCodegen( val isNull = ctx.freshName("caseWhenIsNull") val value = ctx.freshName("caseWhenValue") - // Split these expressions only if they are created from a row object val cases = branches.map { case (condExpr, valueExpr) => val (condFunc, condIsNull, condValue) = genCodeForExpression(ctx, condExpr) @@ -286,7 +285,9 @@ case class CaseWhenCodegen( generatedCode += ifthen + "\nelse {\n" numIfthen += 1 - if (generatedCode.length > 1024 && (ctx.INPUT_ROW != null && ctx.currentVars == null)) { + if (generatedCode.length > 1024 && + // Split these expressions only if they are created from a row object + (ctx.INPUT_ROW != null && ctx.currentVars == null)) { val flag = "flag" generatedCode += s" $flag = false;\n" + "}\n" * numIfthen val funcName = ctx.freshName("caseWhenNestedIf") From ca0ddd5d6d12bf8ff3d2682300e51ce74e7411df Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 13 Nov 2017 19:07:01 +0000 Subject: [PATCH 08/12] use return value of addNewFunction for caller function name --- .../spark/sql/catalyst/expressions/conditionalExpressions.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 0e96e5f2eb8fb..b33fd8fb92387 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -302,7 +302,7 @@ case class CaseWhenCodegen( val fullFuncName = ctx.addNewFunction(funcName, funcBody) isGlobalVariable = true - generatedCode = s"if ($funcName(${ctx.INPUT_ROW})) {\n// do nothing\n} else {\n" + generatedCode = s"if ($fullFuncName(${ctx.INPUT_ROW})) {\n// do nothing\n} else {\n" numIfthen = 1 } } From 908050007cb5a256aa71f6bbe014c1b0a595ce7d Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 14 Nov 2017 09:15:17 +0000 Subject: [PATCH 09/12] rebase with master --- .../sql/catalyst/expressions/conditionalExpressions.scala | 4 ++-- .../spark/sql/catalyst/expressions/CodeGenerationSuite.scala | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index b33fd8fb92387..2d5a864679083 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -111,6 +111,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi ev.copy(code = generatedCode) } + override def toString: String = s"if ($predicate) $trueValue else $falseValue" override def sql: String = s"(IF(${predicate.sql}, ${trueValue.sql}, ${falseValue.sql}))" @@ -341,8 +342,7 @@ case class CaseWhenCodegen( val ev = expression.genCode(ctx) if (ev.code.length > 1024 && (ctx.INPUT_ROW != null && ctx.currentVars == null)) { val (funcName, globalIsNull, globalValue) = - CondExpression.createAndAddFunction(ctx, ev, expression.dataType, - "caseWhenElseExpr") + ctx.createAndAddFunction(ev, expression.dataType, "caseWhenElseExpr") (s"$funcName(${ctx.INPUT_ROW});", globalIsNull, globalValue) } else { (ev.code, ev.isNull, ev.value) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 6e85c6e8ca141..9aab9b8fb3c1a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -342,7 +342,6 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { projection(row) } -<<<<<<< HEAD test("SPARK-21720: split large predications into blocks due to JVM code size limit") { val length = 600 From b6030d95f3244a4e9040d68a54d4681f16f02760 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 14 Nov 2017 17:28:57 +0000 Subject: [PATCH 10/12] address review comment --- .../expressions/conditionalExpressions.scala | 55 ++++++++++--------- .../org/apache/spark/sql/DataFrameSuite.scala | 21 +++---- 2 files changed, 37 insertions(+), 39 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 2d5a864679083..8866562841698 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -279,33 +279,38 @@ case class CaseWhenCodegen( """ } - var isGlobalVariable = false - var generatedCode = "" var numIfthen = 0 - cases.foreach { ifthen => - generatedCode += ifthen + "\nelse {\n" - numIfthen += 1 - - if (generatedCode.length > 1024 && - // Split these expressions only if they are created from a row object - (ctx.INPUT_ROW != null && ctx.currentVars == null)) { - val flag = "flag" - generatedCode += s" $flag = false;\n" + "}\n" * numIfthen - val funcName = ctx.freshName("caseWhenNestedIf") - val funcBody = - s""" - |private boolean $funcName(InternalRow ${ctx.INPUT_ROW}) { - | boolean $flag = true; - | $generatedCode - | return $flag; - |} - """.stripMargin - val fullFuncName = ctx.addNewFunction(funcName, funcBody) - isGlobalVariable = true - - generatedCode = s"if ($fullFuncName(${ctx.INPUT_ROW})) {\n// do nothing\n} else {\n" - numIfthen = 1 + var isGlobalVariable = false + var generatedCode = if (cases.map(s => s.length).sum <= 1024) { + cases.mkString("\nelse {\n") + } else { + var code = "" + cases.foreach { ifthen => + code += ifthen + "\nelse {\n" + numIfthen += 1 + + if (code.length > 1024 && + // Split these expressions only if they are created from a row object + (ctx.INPUT_ROW != null && ctx.currentVars == null)) { + val flag = "flag" + code += s" $flag = false;\n" + "}\n" * numIfthen + val funcName = ctx.freshName("caseWhenNestedIf") + val funcBody = + s""" + |private boolean $funcName(InternalRow ${ctx.INPUT_ROW}) { + | boolean $flag = true; + | $code + | return $flag; + |} + """.stripMargin + val fullFuncName = ctx.addNewFunction(funcName, funcBody) + isGlobalVariable = true + + code = s"if ($fullFuncName(${ctx.INPUT_ROW})) {\n// do nothing\n} else {\n" + numIfthen = 1 + } } + code } elseValue.foreach { elseExpr => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 52cee5f797446..f5d5fa1fbc25e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2159,20 +2159,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(mean.collect().toSet === Set(Row("0.0345678900000000000000000000000000000"))) } - testQuietly("SPARK-21413: Multiple projections with CASE WHEN fails") { + // ignore end-to-end test since sbt test does not go to fallback path in whole-stage codegen + ignore("SPARK-21413: Multiple projections with CASE WHEN fails") { val schema = StructType(StructField("a", IntegerType) :: Nil) - val df = spark.createDataFrame(sparkContext.parallelize(Seq(Row(1))), schema) - val df1 = - df.withColumn("a", when($"a" === 0, null).otherwise($"a")) - .withColumn("a", when($"a" === 0, null).otherwise($"a")) - .withColumn("a", when($"a" === 0, null).otherwise($"a")) - .withColumn("a", when($"a" === 0, null).otherwise($"a")) - .withColumn("a", when($"a" === 0, null).otherwise($"a")) - .withColumn("a", when($"a" === 0, null).otherwise($"a")) - .withColumn("a", when($"a" === 0, null).otherwise($"a")) - .withColumn("a", when($"a" === 0, null).otherwise($"a")) - .withColumn("a", when($"a" === 0, null).otherwise($"a")) - .withColumn("a", when($"a" === 0, null).otherwise($"a")) - checkAnswer(df1, Row(1)) + var df = spark.createDataFrame(sparkContext.parallelize(Seq(Row(1))), schema) + for (i <- 1 to 10) { + df = df.withColumn("a", when($"a" === 0, null).otherwise($"a")) + } + checkAnswer(df, Row(1)) } } From e69f12636bee5f3496421d70f764976f4cb687b7 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 15 Nov 2017 05:39:07 +0000 Subject: [PATCH 11/12] fix failures for ImputerSuite --- .../expressions/conditionalExpressions.scala | 21 ++++++++++--------- .../expressions/CodeGenerationSuite.scala | 13 ++++++++++++ 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 8866562841698..b3392802c51c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -279,21 +279,21 @@ case class CaseWhenCodegen( """ } - var numIfthen = 0 var isGlobalVariable = false - var generatedCode = if (cases.map(s => s.length).sum <= 1024) { - cases.mkString("\nelse {\n") + val (generatedIfThenElse, numBrankets) = if (cases.map(s => s.length).sum <= 1024) { + (cases.mkString("\nelse {\n"), cases.length - 1) } else { + var numIfThen = 0 var code = "" - cases.foreach { ifthen => - code += ifthen + "\nelse {\n" - numIfthen += 1 + cases.foreach { ifThen => + code += ifThen + "\nelse {\n" + numIfThen += 1 if (code.length > 1024 && // Split these expressions only if they are created from a row object (ctx.INPUT_ROW != null && ctx.currentVars == null)) { val flag = "flag" - code += s" $flag = false;\n" + "}\n" * numIfthen + code += s" $flag = false;\n" + "}\n" * numIfThen val funcName = ctx.freshName("caseWhenNestedIf") val funcBody = s""" @@ -307,12 +307,13 @@ case class CaseWhenCodegen( isGlobalVariable = true code = s"if ($fullFuncName(${ctx.INPUT_ROW})) {\n// do nothing\n} else {\n" - numIfthen = 1 + numIfThen = 1 } } - code + (code, numIfThen) } + var generatedCode = generatedIfThenElse elseValue.foreach { elseExpr => val (resFunc, resIsNull, resValue) = genCodeForExpression(ctx, elseExpr) generatedCode += s""" @@ -322,7 +323,7 @@ case class CaseWhenCodegen( """ } - generatedCode += "}\n" * numIfthen + generatedCode += "}\n" * numBrankets if (!isGlobalVariable) { ev.copy(s""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 9aab9b8fb3c1a..d4cc9a6943666 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -430,5 +430,18 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { if (!checkResult(result3, expectedStr, expr3.dataType)) { fail(s"Incorrect Evaluation: expressions: $expr3, actual: $result3, expected: $expectedStr") } + + // total code size is small + val cases4 = Seq((EqualTo(exprStr, Literal("def")), Literal("xyz"))) + val expr4 = CaseWhen(cases4, exprStr).toCodegen() + val plan4 = GenerateMutableProjection.generate(Seq(expr4)) + val row4 = new GenericInternalRow(Array[Any](1)) + row4.update(0, expectedStr) + val actual4 = plan4(row4).toSeq(Seq(expr4.dataType)) + assert(actual4.length == 1) + val result4 = actual4(0) + if (!checkResult(result4, expectedStr, expr4.dataType)) { + fail(s"Incorrect Evaluation: expressions: $expr4, actual: $result4, expected: $expectedStr") + } } } From 5466ef0914e8e702d43019995067ef49ecb90696 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 15 Nov 2017 07:21:56 +0000 Subject: [PATCH 12/12] fix failures for ImputerSuite --- .../spark/sql/catalyst/expressions/conditionalExpressions.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index b3392802c51c9..1bd70acbe1b7b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -281,7 +281,7 @@ case class CaseWhenCodegen( var isGlobalVariable = false val (generatedIfThenElse, numBrankets) = if (cases.map(s => s.length).sum <= 1024) { - (cases.mkString("\nelse {\n"), cases.length - 1) + (cases.mkString("", "\nelse {\n", "\nelse {\n"), cases.length) } else { var numIfThen = 0 var code = ""