@@ -270,17 +270,36 @@ abstract class HashExpression[E] extends Expression {
270270
271271 override def doGenCode (ctx : CodegenContext , ev : ExprCode ): ExprCode = {
272272 ev.isNull = " false"
273- val childrenHash = ctx.splitExpressions(children.map { child =>
273+
274+ val childrenHash = children.map { child =>
274275 val childGen = child.genCode(ctx)
275276 childGen.code + ctx.nullSafeExec(child.nullable, childGen.isNull) {
276277 computeHash(childGen.value, child.dataType, ev.value, ctx)
277278 }
278- })
279+ }
280+
281+ val hashResultType = ctx.javaType(dataType)
282+ val codes = if (ctx.INPUT_ROW == null || ctx.currentVars != null ) {
283+ childrenHash.mkString(" \n " )
284+ } else {
285+ ctx.splitExpressions(
286+ expressions = childrenHash,
287+ funcName = " computeHash" ,
288+ arguments = Seq (" InternalRow" -> ctx.INPUT_ROW , hashResultType -> ev.value),
289+ returnType = hashResultType,
290+ makeSplitFunction = body =>
291+ s """
292+ | $body
293+ |return ${ev.value};
294+ """ .stripMargin,
295+ foldFunctions = _.map(funcCall => s " ${ev.value} = $funcCall; " ).mkString(" \n " ))
296+ }
279297
280- ctx.addMutableState(ctx.javaType(dataType), ev.value)
281- ev.copy(code = s """
282- ${ev.value} = $seed;
283- $childrenHash""" )
298+ ev.copy(code =
299+ s """
300+ | $hashResultType ${ev.value} = $seed;
301+ | $codes
302+ """ .stripMargin)
284303 }
285304
286305 protected def nullSafeElementHash (
@@ -389,13 +408,21 @@ abstract class HashExpression[E] extends Expression {
389408 input : String ,
390409 result : String ,
391410 fields : Array [StructField ]): String = {
392- val hashes = fields.zipWithIndex.map { case (field, index) =>
411+ val fieldsHash = fields.zipWithIndex.map { case (field, index) =>
393412 nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx)
394413 }
414+ val hashResultType = ctx.javaType(dataType)
395415 ctx.splitExpressions(
396- expressions = hashes,
397- funcName = " getHash" ,
398- arguments = Seq (" InternalRow" -> input))
416+ expressions = fieldsHash,
417+ funcName = " computeHashForStruct" ,
418+ arguments = Seq (" InternalRow" -> input, hashResultType -> result),
419+ returnType = hashResultType,
420+ makeSplitFunction = body =>
421+ s """
422+ | $body
423+ |return $result;
424+ """ .stripMargin,
425+ foldFunctions = _.map(funcCall => s " $result = $funcCall; " ).mkString(" \n " ))
399426 }
400427
401428 @ tailrec
@@ -610,25 +637,44 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
610637
611638 override def doGenCode (ctx : CodegenContext , ev : ExprCode ): ExprCode = {
612639 ev.isNull = " false"
640+
613641 val childHash = ctx.freshName(" childHash" )
614- val childrenHash = ctx.splitExpressions( children.map { child =>
642+ val childrenHash = children.map { child =>
615643 val childGen = child.genCode(ctx)
616644 val codeToComputeHash = ctx.nullSafeExec(child.nullable, childGen.isNull) {
617645 computeHash(childGen.value, child.dataType, childHash, ctx)
618646 }
619647 s """
620648 | ${childGen.code}
649+ | $childHash = 0;
621650 | $codeToComputeHash
622651 | ${ev.value} = (31 * ${ev.value}) + $childHash;
623- | $childHash = 0;
624652 """ .stripMargin
625- })
653+ }
626654
627- ctx.addMutableState(ctx.javaType(dataType), ev.value)
628- ctx.addMutableState(ctx.JAVA_INT , childHash, s " $childHash = 0; " )
629- ev.copy(code = s """
630- ${ev.value} = $seed;
631- $childrenHash""" )
655+ val codes = if (ctx.INPUT_ROW == null || ctx.currentVars != null ) {
656+ childrenHash.mkString(" \n " )
657+ } else {
658+ ctx.splitExpressions(
659+ expressions = childrenHash,
660+ funcName = " computeHash" ,
661+ arguments = Seq (" InternalRow" -> ctx.INPUT_ROW , ctx.JAVA_INT -> ev.value),
662+ returnType = ctx.JAVA_INT ,
663+ makeSplitFunction = body =>
664+ s """
665+ | ${ctx.JAVA_INT } $childHash = 0;
666+ | $body
667+ |return ${ev.value};
668+ """ .stripMargin,
669+ foldFunctions = _.map(funcCall => s " ${ev.value} = $funcCall; " ).mkString(" \n " ))
670+ }
671+
672+ ev.copy(code =
673+ s """
674+ | ${ctx.JAVA_INT } ${ev.value} = $seed;
675+ | ${ctx.JAVA_INT } $childHash = 0;
676+ | $codes
677+ """ .stripMargin)
632678 }
633679
634680 override def eval (input : InternalRow = null ): Int = {
@@ -730,23 +776,29 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
730776 input : String ,
731777 result : String ,
732778 fields : Array [StructField ]): String = {
733- val localResult = ctx.freshName(" localResult" )
734779 val childResult = ctx.freshName(" childResult" )
735- fields.zipWithIndex.map { case (field, index) =>
780+ val fieldsHash = fields.zipWithIndex.map { case (field, index) =>
781+ val computeFieldHash = nullSafeElementHash(
782+ input, index.toString, field.nullable, field.dataType, childResult, ctx)
736783 s """
737- $childResult = 0;
738- ${nullSafeElementHash(input, index.toString, field.nullable, field.dataType,
739- childResult, ctx)}
740- $localResult = (31 * $localResult) + $childResult;
741- """
742- }.mkString(
743- s """
744- int $localResult = 0;
745- int $childResult = 0;
746- """ ,
747- " " ,
748- s " $result = (31 * $result) + $localResult; "
749- )
784+ | $childResult = 0;
785+ | $computeFieldHash
786+ | $result = (31 * $result) + $childResult;
787+ """ .stripMargin
788+ }
789+
790+ s " ${ctx.JAVA_INT } $childResult = 0; \n " + ctx.splitExpressions(
791+ expressions = fieldsHash,
792+ funcName = " computeHashForStruct" ,
793+ arguments = Seq (" InternalRow" -> input, ctx.JAVA_INT -> result),
794+ returnType = ctx.JAVA_INT ,
795+ makeSplitFunction = body =>
796+ s """
797+ | ${ctx.JAVA_INT } $childResult = 0;
798+ | $body
799+ |return $result;
800+ """ .stripMargin,
801+ foldFunctions = _.map(funcCall => s " $result = $funcCall; " ).mkString(" \n " ))
750802 }
751803}
752804
0 commit comments