Skip to content

Commit e69f126

Browse files
committed
fix failures for ImputerSuite
1 parent b6030d9 commit e69f126

File tree

2 files changed

+24
-10
lines changed

2 files changed

+24
-10
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -279,21 +279,21 @@ case class CaseWhenCodegen(
279279
"""
280280
}
281281

282-
var numIfthen = 0
283282
var isGlobalVariable = false
284-
var generatedCode = if (cases.map(s => s.length).sum <= 1024) {
285-
cases.mkString("\nelse {\n")
283+
val (generatedIfThenElse, numBrankets) = if (cases.map(s => s.length).sum <= 1024) {
284+
(cases.mkString("\nelse {\n"), cases.length - 1)
286285
} else {
286+
var numIfThen = 0
287287
var code = ""
288-
cases.foreach { ifthen =>
289-
code += ifthen + "\nelse {\n"
290-
numIfthen += 1
288+
cases.foreach { ifThen =>
289+
code += ifThen + "\nelse {\n"
290+
numIfThen += 1
291291

292292
if (code.length > 1024 &&
293293
// Split these expressions only if they are created from a row object
294294
(ctx.INPUT_ROW != null && ctx.currentVars == null)) {
295295
val flag = "flag"
296-
code += s" $flag = false;\n" + "}\n" * numIfthen
296+
code += s" $flag = false;\n" + "}\n" * numIfThen
297297
val funcName = ctx.freshName("caseWhenNestedIf")
298298
val funcBody =
299299
s"""
@@ -307,12 +307,13 @@ case class CaseWhenCodegen(
307307
isGlobalVariable = true
308308

309309
code = s"if ($fullFuncName(${ctx.INPUT_ROW})) {\n// do nothing\n} else {\n"
310-
numIfthen = 1
310+
numIfThen = 1
311311
}
312312
}
313-
code
313+
(code, numIfThen)
314314
}
315315

316+
var generatedCode = generatedIfThenElse
316317
elseValue.foreach { elseExpr =>
317318
val (resFunc, resIsNull, resValue) = genCodeForExpression(ctx, elseExpr)
318319
generatedCode += s"""
@@ -322,7 +323,7 @@ case class CaseWhenCodegen(
322323
"""
323324
}
324325

325-
generatedCode += "}\n" * numIfthen
326+
generatedCode += "}\n" * numBrankets
326327

327328
if (!isGlobalVariable) {
328329
ev.copy(s"""

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,5 +430,18 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
430430
if (!checkResult(result3, expectedStr, expr3.dataType)) {
431431
fail(s"Incorrect Evaluation: expressions: $expr3, actual: $result3, expected: $expectedStr")
432432
}
433+
434+
// total code size is small
435+
val cases4 = Seq((EqualTo(exprStr, Literal("def")), Literal("xyz")))
436+
val expr4 = CaseWhen(cases4, exprStr).toCodegen()
437+
val plan4 = GenerateMutableProjection.generate(Seq(expr4))
438+
val row4 = new GenericInternalRow(Array[Any](1))
439+
row4.update(0, expectedStr)
440+
val actual4 = plan4(row4).toSeq(Seq(expr4.dataType))
441+
assert(actual4.length == 1)
442+
val result4 = actual4(0)
443+
if (!checkResult(result4, expectedStr, expr4.dataType)) {
444+
fail(s"Incorrect Evaluation: expressions: $expr4, actual: $result4, expected: $expectedStr")
445+
}
433446
}
434447
}

0 commit comments

Comments
 (0)