Skip to content

Commit a8af4da

Browse files
committed
[SPARK-22682][SQL] HashExpression does not need to create global variables
## What changes were proposed in this pull request? It turns out that `HashExpression` can pass around some values via parameter when splitting codes into methods, to save some global variable slots. This can also prevent a weird case that global variable appears in parameter list, which is discovered by #19865 ## How was this patch tested? existing tests Author: Wenchen Fan <[email protected]> Closes #19878 from cloud-fan/minor.
1 parent 295df74 commit a8af4da

File tree

2 files changed

+106
-46
lines changed

2 files changed

+106
-46
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala

Lines changed: 85 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -270,17 +270,36 @@ abstract class HashExpression[E] extends Expression {
270270

271271
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
272272
ev.isNull = "false"
273-
val childrenHash = ctx.splitExpressions(children.map { child =>
273+
274+
val childrenHash = children.map { child =>
274275
val childGen = child.genCode(ctx)
275276
childGen.code + ctx.nullSafeExec(child.nullable, childGen.isNull) {
276277
computeHash(childGen.value, child.dataType, ev.value, ctx)
277278
}
278-
})
279+
}
280+
281+
val hashResultType = ctx.javaType(dataType)
282+
val codes = if (ctx.INPUT_ROW == null || ctx.currentVars != null) {
283+
childrenHash.mkString("\n")
284+
} else {
285+
ctx.splitExpressions(
286+
expressions = childrenHash,
287+
funcName = "computeHash",
288+
arguments = Seq("InternalRow" -> ctx.INPUT_ROW, hashResultType -> ev.value),
289+
returnType = hashResultType,
290+
makeSplitFunction = body =>
291+
s"""
292+
|$body
293+
|return ${ev.value};
294+
""".stripMargin,
295+
foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n"))
296+
}
279297

280-
ctx.addMutableState(ctx.javaType(dataType), ev.value)
281-
ev.copy(code = s"""
282-
${ev.value} = $seed;
283-
$childrenHash""")
298+
ev.copy(code =
299+
s"""
300+
|$hashResultType ${ev.value} = $seed;
301+
|$codes
302+
""".stripMargin)
284303
}
285304

286305
protected def nullSafeElementHash(
@@ -389,13 +408,21 @@ abstract class HashExpression[E] extends Expression {
389408
input: String,
390409
result: String,
391410
fields: Array[StructField]): String = {
392-
val hashes = fields.zipWithIndex.map { case (field, index) =>
411+
val fieldsHash = fields.zipWithIndex.map { case (field, index) =>
393412
nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx)
394413
}
414+
val hashResultType = ctx.javaType(dataType)
395415
ctx.splitExpressions(
396-
expressions = hashes,
397-
funcName = "getHash",
398-
arguments = Seq("InternalRow" -> input))
416+
expressions = fieldsHash,
417+
funcName = "computeHashForStruct",
418+
arguments = Seq("InternalRow" -> input, hashResultType -> result),
419+
returnType = hashResultType,
420+
makeSplitFunction = body =>
421+
s"""
422+
|$body
423+
|return $result;
424+
""".stripMargin,
425+
foldFunctions = _.map(funcCall => s"$result = $funcCall;").mkString("\n"))
399426
}
400427

401428
@tailrec
@@ -610,25 +637,44 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
610637

611638
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
612639
ev.isNull = "false"
640+
613641
val childHash = ctx.freshName("childHash")
614-
val childrenHash = ctx.splitExpressions(children.map { child =>
642+
val childrenHash = children.map { child =>
615643
val childGen = child.genCode(ctx)
616644
val codeToComputeHash = ctx.nullSafeExec(child.nullable, childGen.isNull) {
617645
computeHash(childGen.value, child.dataType, childHash, ctx)
618646
}
619647
s"""
620648
|${childGen.code}
649+
|$childHash = 0;
621650
|$codeToComputeHash
622651
|${ev.value} = (31 * ${ev.value}) + $childHash;
623-
|$childHash = 0;
624652
""".stripMargin
625-
})
653+
}
626654

627-
ctx.addMutableState(ctx.javaType(dataType), ev.value)
628-
ctx.addMutableState(ctx.JAVA_INT, childHash, s"$childHash = 0;")
629-
ev.copy(code = s"""
630-
${ev.value} = $seed;
631-
$childrenHash""")
655+
val codes = if (ctx.INPUT_ROW == null || ctx.currentVars != null) {
656+
childrenHash.mkString("\n")
657+
} else {
658+
ctx.splitExpressions(
659+
expressions = childrenHash,
660+
funcName = "computeHash",
661+
arguments = Seq("InternalRow" -> ctx.INPUT_ROW, ctx.JAVA_INT -> ev.value),
662+
returnType = ctx.JAVA_INT,
663+
makeSplitFunction = body =>
664+
s"""
665+
|${ctx.JAVA_INT} $childHash = 0;
666+
|$body
667+
|return ${ev.value};
668+
""".stripMargin,
669+
foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n"))
670+
}
671+
672+
ev.copy(code =
673+
s"""
674+
|${ctx.JAVA_INT} ${ev.value} = $seed;
675+
|${ctx.JAVA_INT} $childHash = 0;
676+
|$codes
677+
""".stripMargin)
632678
}
633679

634680
override def eval(input: InternalRow = null): Int = {
@@ -730,23 +776,29 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
730776
input: String,
731777
result: String,
732778
fields: Array[StructField]): String = {
733-
val localResult = ctx.freshName("localResult")
734779
val childResult = ctx.freshName("childResult")
735-
fields.zipWithIndex.map { case (field, index) =>
780+
val fieldsHash = fields.zipWithIndex.map { case (field, index) =>
781+
val computeFieldHash = nullSafeElementHash(
782+
input, index.toString, field.nullable, field.dataType, childResult, ctx)
736783
s"""
737-
$childResult = 0;
738-
${nullSafeElementHash(input, index.toString, field.nullable, field.dataType,
739-
childResult, ctx)}
740-
$localResult = (31 * $localResult) + $childResult;
741-
"""
742-
}.mkString(
743-
s"""
744-
int $localResult = 0;
745-
int $childResult = 0;
746-
""",
747-
"",
748-
s"$result = (31 * $result) + $localResult;"
749-
)
784+
|$childResult = 0;
785+
|$computeFieldHash
786+
|$result = (31 * $result) + $childResult;
787+
""".stripMargin
788+
}
789+
790+
s"${ctx.JAVA_INT} $childResult = 0;\n" + ctx.splitExpressions(
791+
expressions = fieldsHash,
792+
funcName = "computeHashForStruct",
793+
arguments = Seq("InternalRow" -> input, ctx.JAVA_INT -> result),
794+
returnType = ctx.JAVA_INT,
795+
makeSplitFunction = body =>
796+
s"""
797+
|${ctx.JAVA_INT} $childResult = 0;
798+
|$body
799+
|return $result;
800+
""".stripMargin,
801+
foldFunctions = _.map(funcCall => s"$result = $funcCall;").mkString("\n"))
750802
}
751803
}
752804

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import org.scalatest.exceptions.TestFailedException
2727

2828
import org.apache.spark.SparkFunSuite
2929
import org.apache.spark.sql.{RandomDataGenerator, Row}
30+
import org.apache.spark.sql.catalyst.InternalRow
3031
import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, RowEncoder}
3132
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection
3233
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
@@ -620,23 +621,30 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
620621
}
621622

622623
test("SPARK-18207: Compute hash for a lot of expressions") {
624+
def checkResult(schema: StructType, input: InternalRow): Unit = {
625+
val exprs = schema.fields.zipWithIndex.map { case (f, i) =>
626+
BoundReference(i, f.dataType, true)
627+
}
628+
val murmur3HashExpr = Murmur3Hash(exprs, 42)
629+
val murmur3HashPlan = GenerateMutableProjection.generate(Seq(murmur3HashExpr))
630+
val murmursHashEval = Murmur3Hash(exprs, 42).eval(input)
631+
assert(murmur3HashPlan(input).getInt(0) == murmursHashEval)
632+
633+
val hiveHashExpr = HiveHash(exprs)
634+
val hiveHashPlan = GenerateMutableProjection.generate(Seq(hiveHashExpr))
635+
val hiveHashEval = HiveHash(exprs).eval(input)
636+
assert(hiveHashPlan(input).getInt(0) == hiveHashEval)
637+
}
638+
623639
val N = 1000
624640
val wideRow = new GenericInternalRow(
625641
Seq.tabulate(N)(i => UTF8String.fromString(i.toString)).toArray[Any])
626-
val schema = StructType((1 to N).map(i => StructField("", StringType)))
627-
628-
val exprs = schema.fields.zipWithIndex.map { case (f, i) =>
629-
BoundReference(i, f.dataType, true)
630-
}
631-
val murmur3HashExpr = Murmur3Hash(exprs, 42)
632-
val murmur3HashPlan = GenerateMutableProjection.generate(Seq(murmur3HashExpr))
633-
val murmursHashEval = Murmur3Hash(exprs, 42).eval(wideRow)
634-
assert(murmur3HashPlan(wideRow).getInt(0) == murmursHashEval)
642+
val schema = StructType((1 to N).map(i => StructField(i.toString, StringType)))
643+
checkResult(schema, wideRow)
635644

636-
val hiveHashExpr = HiveHash(exprs)
637-
val hiveHashPlan = GenerateMutableProjection.generate(Seq(hiveHashExpr))
638-
val hiveHashEval = HiveHash(exprs).eval(wideRow)
639-
assert(hiveHashPlan(wideRow).getInt(0) == hiveHashEval)
645+
val nestedRow = InternalRow(wideRow)
646+
val nestedSchema = new StructType().add("nested", schema)
647+
checkResult(nestedSchema, nestedRow)
640648
}
641649

642650
test("SPARK-22284: Compute hash for nested structs") {

0 commit comments

Comments
 (0)