Skip to content

Commit b9ca4ff

Browse files
committed
drop changes for concat_ws and elt
1 parent 22019b1 commit b9ca4ff

File tree

3 files changed

+39
-125
lines changed

3 files changed

+39
-125
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -790,7 +790,23 @@ class CodegenContext {
790790
returnType: String = "void",
791791
makeSplitFunction: String => String = identity,
792792
foldFunctions: Seq[String] => String = _.mkString("", ";\n", ";")): String = {
793-
val blocks = splitCodes(expressions)
793+
val blocks = new ArrayBuffer[String]()
794+
val blockBuilder = new StringBuilder()
795+
var length = 0
796+
for (code <- expressions) {
797+
// We can't know how many bytecode will be generated, so use the length of source code
798+
// as metric. A method should not go beyond 8K, otherwise it will not be JITted, should
799+
// also not be too small, or it will have many function calls (for wide table), see the
800+
// results in BenchmarkWideTable.
801+
if (length > 1024) {
802+
blocks += blockBuilder.toString()
803+
blockBuilder.clear()
804+
length = 0
805+
}
806+
blockBuilder.append(code)
807+
length += CodeFormatter.stripExtraNewLinesAndComments(code).length
808+
}
809+
blocks += blockBuilder.toString()
794810

795811
if (blocks.length == 1) {
796812
// inline execution if only one block
@@ -825,27 +841,6 @@ class CodegenContext {
825841
}
826842
}
827843

828-
def splitCodes(expressions: Seq[String]): Seq[String] = {
829-
val blocks = new ArrayBuffer[String]()
830-
val blockBuilder = new StringBuilder()
831-
var length = 0
832-
for (code <- expressions) {
833-
// We can't know how many bytecode will be generated, so use the length of source code
834-
// as metric. A method should not go beyond 8K, otherwise it will not be JITted, should
835-
// also not be too small, or it will have many function calls (for wide table), see the
836-
// results in BenchmarkWideTable.
837-
if (length > 1024) {
838-
blocks += blockBuilder.toString()
839-
blockBuilder.clear()
840-
length = 0
841-
}
842-
blockBuilder.append(code)
843-
length += CodeFormatter.stripExtraNewLinesAndComments(code).length
844-
}
845-
blocks += blockBuilder.toString()
846-
blocks
847-
}
848-
849844
/**
850845
* Here we handle all the methods which have been added to the inner classes and
851846
* not to the outer class.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala

Lines changed: 22 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -137,43 +137,19 @@ case class ConcatWs(children: Seq[Expression])
137137
if (children.forall(_.dataType == StringType)) {
138138
// All children are strings. In that case we can construct a fixed size array.
139139
val evals = children.map(_.genCode(ctx))
140-
val separator = evals.head
141-
val strings = evals.tail
142-
val numArgs = strings.length
143-
val args = ctx.freshName("args")
144-
145-
val inputs = strings.zipWithIndex.map { case (eval, index) =>
146-
if (eval.isNull != "true") {
147-
s"""
148-
${eval.code}
149-
if (!${eval.isNull}) {
150-
$args[$index] = ${eval.value};
151-
}
152-
"""
153-
} else {
154-
""
155-
}
156-
}
157-
val codes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) {
158-
ctx.splitExpressions(inputs, "valueConcatWs",
159-
("InternalRow", ctx.INPUT_ROW) :: ("UTF8String[]", args) :: Nil)
160-
} else {
161-
inputs.mkString("\n")
162-
}
163-
ev.copy(s"""
164-
UTF8String[] $args = new UTF8String[$numArgs];
165-
${separator.code}
166-
$codes
167-
UTF8String ${ev.value} = UTF8String.concatWs(${separator.value}, $args);
140+
141+
val inputs = evals.map { eval =>
142+
s"${eval.isNull} ? (UTF8String) null : ${eval.value}"
143+
}.mkString(", ")
144+
145+
ev.copy(evals.map(_.code).mkString("\n") + s"""
146+
UTF8String ${ev.value} = UTF8String.concatWs($inputs);
168147
boolean ${ev.isNull} = ${ev.value} == null;
169148
""")
170149
} else {
171150
val array = ctx.freshName("array")
172-
ctx.addMutableState("UTF8String[]", array, "")
173151
val varargNum = ctx.freshName("varargNum")
174-
ctx.addMutableState("int", varargNum, "")
175152
val idxInVararg = ctx.freshName("idxInVararg")
176-
ctx.addMutableState("int", idxInVararg, "")
177153

178154
val evals = children.map(_.genCode(ctx))
179155
val (varargCount, varargBuild) = children.tail.zip(evals.tail).map { case (child, eval) =>
@@ -199,17 +175,13 @@ case class ConcatWs(children: Seq[Expression])
199175
}
200176
}.unzip
201177

202-
val codes = ctx.splitExpressions(ctx.INPUT_ROW, evals.map(_.code))
203-
val varargCounts = ctx.splitExpressions(ctx.INPUT_ROW, varargCount)
204-
val varargBuilds = ctx.splitExpressions(ctx.INPUT_ROW, varargBuild)
205-
ev.copy(
178+
ev.copy(evals.map(_.code).mkString("\n") +
206179
s"""
207-
$codes
208-
$varargNum = ${children.count(_.dataType == StringType) - 1};
209-
$idxInVararg = 0;
210-
$varargCounts
211-
$array = new UTF8String[$varargNum];
212-
$varargBuilds
180+
int $varargNum = ${children.count(_.dataType == StringType) - 1};
181+
int $idxInVararg = 0;
182+
${varargCount.mkString("\n")}
183+
UTF8String[] $array = new UTF8String[$varargNum];
184+
${varargBuild.mkString("\n")}
213185
UTF8String ${ev.value} = UTF8String.concatWs(${evals.head.value}, $array);
214186
boolean ${ev.isNull} = ${ev.value} == null;
215187
""")
@@ -264,55 +236,22 @@ case class Elt(children: Seq[Expression])
264236
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
265237
val index = indexExpr.genCode(ctx)
266238
val strings = stringExprs.map(_.genCode(ctx))
267-
val indexVal = ctx.freshName("index")
268-
val stringVal = ctx.freshName("stringVal")
269239
val assignStringValue = strings.zipWithIndex.map { case (eval, index) =>
270240
s"""
271241
case ${index + 1}:
272-
${eval.code}
273-
$stringVal = ${eval.isNull} ? null : ${eval.value};
242+
${ev.value} = ${eval.isNull} ? null : ${eval.value};
274243
break;
275244
"""
276-
}
277-
278-
val cases = ctx.splitCodes(assignStringValue)
279-
val codes = if (cases.length == 1) {
280-
s"""
281-
UTF8String $stringVal = null;
282-
switch ($indexVal) {
283-
${cases.head}
284-
}
285-
"""
286-
} else {
287-
var fullFuncName = ""
288-
cases.reverse.zipWithIndex.map { case (s, index) =>
289-
val prevFunc = if (index == 0) {
290-
"null"
291-
} else {
292-
s"$fullFuncName(${ctx.INPUT_ROW}, $indexVal)"
293-
}
294-
val funcName = ctx.freshName("eltFunc")
295-
val funcBody = s"""
296-
private UTF8String $funcName(InternalRow ${ctx.INPUT_ROW}, int $indexVal) {
297-
UTF8String $stringVal = null;
298-
switch ($indexVal) {
299-
$s
300-
default:
301-
return $prevFunc;
302-
}
303-
return $stringVal;
304-
}
305-
"""
306-
fullFuncName = ctx.addNewFunction(funcName, funcBody)
307-
}
308-
s"UTF8String $stringVal = $fullFuncName(${ctx.INPUT_ROW}, ${indexVal});"
309-
}
245+
}.mkString("\n")
246+
val indexVal = ctx.freshName("index")
247+
val stringArray = ctx.freshName("strings");
310248

311-
ev.copy(index.code + "\n" +
312-
s"""
249+
ev.copy(index.code + "\n" + strings.map(_.code).mkString("\n") + s"""
313250
final int $indexVal = ${index.value};
314-
$codes
315-
UTF8String ${ev.value} = $stringVal;
251+
UTF8String ${ev.value} = null;
252+
switch ($indexVal) {
253+
$assignStringValue
254+
}
316255
final boolean ${ev.isNull} = ${ev.value} == null;
317256
""")
318257
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -80,19 +80,6 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
8080
// scalastyle:on
8181
}
8282

83-
test("SPARK-22498: ConcatWs should not generate codes beyond 64KB") {
84-
val N = 5000
85-
val sepExpr = Literal.create("#", StringType)
86-
val strings1 = (1 to N).map(x => s"s$x")
87-
val inputsExpr1 = strings1.map(Literal.create(_, StringType))
88-
checkEvaluation(ConcatWs(sepExpr +: inputsExpr1), strings1.mkString("#"), EmptyRow)
89-
90-
val strings2 = (1 to N).map(x => Seq(s"s$x"))
91-
val inputsExpr2 = strings2.map(Literal.create(_, ArrayType(StringType)))
92-
checkEvaluation(
93-
ConcatWs(sepExpr +: inputsExpr2), strings2.map(s => s(0)).mkString("#"), EmptyRow)
94-
}
95-
9683
test("elt") {
9784
def testElt(result: String, n: java.lang.Integer, args: String*): Unit = {
9885
checkEvaluation(
@@ -116,13 +103,6 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
116103
assert(Elt(Seq(Literal(1), Literal(2))).checkInputDataTypes().isFailure)
117104
}
118105

119-
test("SPARK-22498: Elt should not generate codes beyond 64KB") {
120-
val N = 10000
121-
val strings = (1 to N).map(x => s"s$x")
122-
val args = Literal.create(N, IntegerType) +: strings.map(Literal.create(_, StringType))
123-
checkEvaluation(Elt(args), s"s$N")
124-
}
125-
126106
test("StringComparison") {
127107
val row = create_row("abc", null)
128108
val c1 = 'a.string.at(0)

0 commit comments

Comments
 (0)