Skip to content

Commit 06b103e

Browse files
committed
initial commit
1 parent d6ee69e commit 06b103e

File tree

3 files changed

+70
-26
lines changed

3 files changed

+70
-26
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -790,23 +790,7 @@ class CodegenContext {
790790
returnType: String = "void",
791791
makeSplitFunction: String => String = identity,
792792
foldFunctions: Seq[String] => String = _.mkString("", ";\n", ";")): String = {
793-
val blocks = new ArrayBuffer[String]()
794-
val blockBuilder = new StringBuilder()
795-
var length = 0
796-
for (code <- expressions) {
797-
// We can't know how many bytecode will be generated, so use the length of source code
798-
// as metric. A method should not go beyond 8K, otherwise it will not be JITted, should
799-
// also not be too small, or it will have many function calls (for wide table), see the
800-
// results in BenchmarkWideTable.
801-
if (length > 1024) {
802-
blocks += blockBuilder.toString()
803-
blockBuilder.clear()
804-
length = 0
805-
}
806-
blockBuilder.append(code)
807-
length += CodeFormatter.stripExtraNewLinesAndComments(code).length
808-
}
809-
blocks += blockBuilder.toString()
793+
val blocks = splitCodes(expressions)
810794

811795
if (blocks.length == 1) {
812796
// inline execution if only one block
@@ -841,6 +825,26 @@ class CodegenContext {
841825
}
842826
}
843827

828+
def splitCodes(expressions: Seq[String]): Seq[String] = {
829+
val blocks = new ArrayBuffer[String]()
830+
val blockBuilder = new StringBuilder()
831+
var length = 0
832+
for (code <- expressions) {
833+
// We can't know how many bytecode will be generated, so use the length of source code
834+
// as metric. A method should not go beyond 8K, otherwise it will not be JITted, should
835+
// also not be too small, or it will have many function calls (for wide table), see the
836+
// results in BenchmarkWideTable.
837+
if (length > 1024) {
838+
blocks += blockBuilder.toString()
839+
blockBuilder.clear()
840+
length = 0
841+
}
842+
blockBuilder.append(code)
843+
length += CodeFormatter.stripExtraNewLinesAndComments(code).length
844+
}
845+
blocks += blockBuilder.toString()
846+
}
847+
844848
/**
845849
* Here we handle all the methods which have been added to the inner classes and
846850
* not to the outer class.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -224,22 +224,55 @@ case class Elt(children: Seq[Expression])
224224
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
225225
val index = indexExpr.genCode(ctx)
226226
val strings = stringExprs.map(_.genCode(ctx))
227+
val indexVal = ctx.freshName("index")
228+
val stringVal = ctx.freshName("stringVal")
227229
val assignStringValue = strings.zipWithIndex.map { case (eval, index) =>
228230
s"""
229231
case ${index + 1}:
230-
${ev.value} = ${eval.isNull} ? null : ${eval.value};
232+
${eval.code}
233+
$stringVal = ${eval.isNull} ? null : ${eval.value};
231234
break;
232235
"""
233-
}.mkString("\n")
234-
val indexVal = ctx.freshName("index")
235-
val stringArray = ctx.freshName("strings");
236+
}
236237

237-
ev.copy(index.code + "\n" + strings.map(_.code).mkString("\n") + s"""
238-
final int $indexVal = ${index.value};
239-
UTF8String ${ev.value} = null;
240-
switch ($indexVal) {
241-
$assignStringValue
238+
val cases = ctx.splitCodes(assignStringValue)
239+
val codes = if (cases.length == 1) {
240+
s"""
241+
UTF8String $stringVal = null;
242+
switch ($indexVal) {
243+
${cases.head}
244+
}
245+
"""
246+
} else {
247+
var fullFuncName = ""
248+
cases.reverse.zipWithIndex.map { case (s, index) =>
249+
val prevFunc = if (index == 0) {
250+
"null"
251+
} else {
252+
s"$fullFuncName(${ctx.INPUT_ROW}, $indexVal)"
253+
}
254+
val funcName = ctx.freshName("eltFunc")
255+
val funcBody = s"""
256+
private UTF8String $funcName(InternalRow ${ctx.INPUT_ROW}, int $indexVal) {
257+
UTF8String $stringVal = null;
258+
switch ($indexVal) {
259+
$s
260+
default:
261+
return $prevFunc;
262+
}
263+
return $stringVal;
264+
}
265+
"""
266+
fullFuncName = ctx.addNewFunction(funcName, funcBody)
242267
}
268+
s"UTF8String $stringVal = $fullFuncName(${ctx.INPUT_ROW}, ${indexVal});"
269+
}
270+
271+
ev.copy(index.code + "\n" +
272+
s"""
273+
final int $indexVal = ${index.value};
274+
$codes
275+
UTF8String ${ev.value} = $stringVal;
243276
final boolean ${ev.isNull} = ${ev.value} == null;
244277
""")
245278
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,13 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
9797
assert(Elt(Seq(Literal(1), Literal(2))).checkInputDataTypes().isFailure)
9898
}
9999

100+
test("SPARK-22498: Elt should not generate codes beyond 64KB") {
101+
val N = 10000
102+
val strings = (1 to N).map(x => s"s$x")
103+
val args = Literal.create(N, IntegerType) +: strings.map(Literal.create(_, StringType))
104+
checkEvaluation(Elt(args), s"s$N")
105+
}
106+
100107
test("StringComparison") {
101108
val row = create_row("abc", null)
102109
val c1 = 'a.string.at(0)

0 commit comments

Comments
 (0)