diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 98eda2a1ba92c..90d2b56675348 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -790,23 +790,7 @@ class CodegenContext { returnType: String = "void", makeSplitFunction: String => String = identity, foldFunctions: Seq[String] => String = _.mkString("", ";\n", ";")): String = { - val blocks = new ArrayBuffer[String]() - val blockBuilder = new StringBuilder() - var length = 0 - for (code <- expressions) { - // We can't know how many bytecode will be generated, so use the length of source code - // as metric. A method should not go beyond 8K, otherwise it will not be JITted, should - // also not be too small, or it will have many function calls (for wide table), see the - // results in BenchmarkWideTable. - if (length > 1024) { - blocks += blockBuilder.toString() - blockBuilder.clear() - length = 0 - } - blockBuilder.append(code) - length += CodeFormatter.stripExtraNewLinesAndComments(code).length - } - blocks += blockBuilder.toString() + val blocks = buildCodeBlocks(expressions) if (blocks.length == 1) { // inline execution if only one block @@ -841,6 +825,32 @@ class CodegenContext { } } + /** + * Splits the generated code of expressions into multiple sequences of String + * based on a threshold of length of a String + * + * @param expressions the codes to evaluate expressions. + */ + def buildCodeBlocks(expressions: Seq[String]): Seq[String] = { + val blocks = new ArrayBuffer[String]() + val blockBuilder = new StringBuilder() + var length = 0 + for (code <- expressions) { + // We can't know how many bytecode will be generated, so use the length of source code + // as metric. A method should not go beyond 8K, otherwise it will not be JITted, should + // also not be too small, or it will have many function calls (for wide table), see the + // results in BenchmarkWideTable. + if (length > 1024) { + blocks += blockBuilder.toString() + blockBuilder.clear() + length = 0 + } + blockBuilder.append(code) + length += CodeFormatter.stripExtraNewLinesAndComments(code).length + } + blocks += blockBuilder.toString() + } + /** * Here we handle all the methods which have been added to the inner classes and * not to the outer class. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index c341943187820..120b9b9ce4a30 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -224,22 +224,52 @@ case class Elt(children: Seq[Expression]) override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val index = indexExpr.genCode(ctx) val strings = stringExprs.map(_.genCode(ctx)) + val indexVal = ctx.freshName("index") + val stringVal = ctx.freshName("stringVal") val assignStringValue = strings.zipWithIndex.map { case (eval, index) => s""" case ${index + 1}: - ${ev.value} = ${eval.isNull} ? null : ${eval.value}; + ${eval.code} + $stringVal = ${eval.isNull} ? null : ${eval.value}; break; """ - }.mkString("\n") - val indexVal = ctx.freshName("index") - val stringArray = ctx.freshName("strings"); + } - ev.copy(index.code + "\n" + strings.map(_.code).mkString("\n") + s""" - final int $indexVal = ${index.value}; - UTF8String ${ev.value} = null; - switch ($indexVal) { - $assignStringValue + val cases = ctx.buildCodeBlocks(assignStringValue) + val codes = if (cases.length == 1) { + s""" + UTF8String $stringVal = null; + switch ($indexVal) { + ${cases.head} + } + """ + } else { + var prevFunc = "null" + for (c <- cases.reverse) { + val funcName = ctx.freshName("eltFunc") + val funcBody = s""" + private UTF8String $funcName(InternalRow ${ctx.INPUT_ROW}, int $indexVal) { + UTF8String $stringVal = null; + switch ($indexVal) { + $c + default: + return $prevFunc; + } + return $stringVal; + } + """ + val fullFuncName = ctx.addNewFunction(funcName, funcBody) + prevFunc = s"$fullFuncName(${ctx.INPUT_ROW}, $indexVal)" } + s"UTF8String $stringVal = $prevFunc;" + } + + ev.copy( + s""" + ${index.code} + final int $indexVal = ${index.value}; + $codes + UTF8String ${ev.value} = $stringVal; final boolean ${ev.isNull} = ${ev.value} == null; """) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 18ef4bc37c2b5..e2f4585b4beb0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -97,6 +97,13 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { assert(Elt(Seq(Literal(1), Literal(2))).checkInputDataTypes().isFailure) } + test("SPARK-22550: Elt should not generate codes beyond 64KB") { + val N = 10000 + val strings = (1 to N).map(x => s"s$x") + val args = Literal.create(N, IntegerType) +: strings.map(Literal.create(_, StringType)) + checkEvaluation(Elt(args), s"s$N") + } + test("StringComparison") { val row = create_row("abc", null) val c1 = 'a.string.at(0)