diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 739bd13c5078d..1893eec22b65d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -602,23 +602,38 @@ case class Least(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) - ctx.addMutableState(ctx.javaType(dataType), ev.value) - def updateEval(eval: ExprCode): String = { + val tmpIsNull = ctx.freshName("leastTmpIsNull") + ctx.addMutableState(ctx.JAVA_BOOLEAN, tmpIsNull) + val evals = evalChildren.map(eval => s""" - ${eval.code} - if (!${eval.isNull} && (${ev.isNull} || - ${ctx.genGreater(dataType, ev.value, eval.value)})) { - ${ev.isNull} = false; - ${ev.value} = ${eval.value}; - } - """ - } - val codes = ctx.splitExpressionsWithCurrentInputs(evalChildren.map(updateEval)) - ev.copy(code = s""" - ${ev.isNull} = true; - ${ev.value} = ${ctx.defaultValue(dataType)}; - $codes""") + |${eval.code} + |if (!${eval.isNull} && ($tmpIsNull || + | ${ctx.genGreater(dataType, ev.value, eval.value)})) { + | $tmpIsNull = false; + | ${ev.value} = ${eval.value}; + |} + """.stripMargin + ) + + val resultType = ctx.javaType(dataType) + val codes = ctx.splitExpressionsWithCurrentInputs( + expressions = evals, + funcName = "least", + extraArguments = Seq(resultType -> ev.value), + returnType = resultType, + makeSplitFunction = body => + s""" + |$body + |return ${ev.value}; + """.stripMargin, + foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) + ev.copy(code = + s""" + |$tmpIsNull = true; + |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + |$codes + |final boolean ${ev.isNull} = $tmpIsNull; + """.stripMargin) } } @@ -668,22 +683,37 @@ case class Greatest(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) - ctx.addMutableState(ctx.javaType(dataType), ev.value) - def updateEval(eval: ExprCode): String = { + val tmpIsNull = ctx.freshName("greatestTmpIsNull") + ctx.addMutableState(ctx.JAVA_BOOLEAN, tmpIsNull) + val evals = evalChildren.map(eval => s""" - ${eval.code} - if (!${eval.isNull} && (${ev.isNull} || - ${ctx.genGreater(dataType, eval.value, ev.value)})) { - ${ev.isNull} = false; - ${ev.value} = ${eval.value}; - } - """ - } - val codes = ctx.splitExpressionsWithCurrentInputs(evalChildren.map(updateEval)) - ev.copy(code = s""" - ${ev.isNull} = true; - ${ev.value} = ${ctx.defaultValue(dataType)}; - $codes""") + |${eval.code} + |if (!${eval.isNull} && ($tmpIsNull || + | ${ctx.genGreater(dataType, eval.value, ev.value)})) { + | $tmpIsNull = false; + | ${ev.value} = ${eval.value}; + |} + """.stripMargin + ) + + val resultType = ctx.javaType(dataType) + val codes = ctx.splitExpressionsWithCurrentInputs( + expressions = evals, + funcName = "greatest", + extraArguments = Seq(resultType -> ev.value), + returnType = resultType, + makeSplitFunction = body => + s""" + |$body + |return ${ev.value}; + """.stripMargin, + foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) + ev.copy(code = + s""" + |$tmpIsNull = true; + |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + |$codes + |final boolean ${ev.isNull} = $tmpIsNull; + """.stripMargin) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index fb759eba6a9e2..be638d80e45d8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.types._ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -343,4 +344,14 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Least(inputsExpr), "s" * 1, EmptyRow) checkEvaluation(Greatest(inputsExpr), "s" * N, EmptyRow) } + + test("SPARK-22704: Least and greatest use less global variables") { + val ctx1 = new CodegenContext() + Least(Seq(Literal(1), Literal(1))).genCode(ctx1) + assert(ctx1.mutableStates.size == 1) + + val ctx2 = new CodegenContext() + Greatest(Seq(Literal(1), Literal(1))).genCode(ctx2) + assert(ctx2.mutableStates.size == 1) + } }