Skip to content

Commit abfd06f

Browse files
author
ALeksander Eskilson
committed
fixing mutable state for several classes
1 parent e7bdc53 commit abfd06f

File tree

4 files changed

+24
-35
lines changed

4 files changed

+24
-35
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,12 @@ class CodegenContext {
178178
* variable is inlined to the class, or an array access if the variable is to be stored
179179
* in an array of variables of the same type and initialization.
180180
*/
181-
def addMutableState(javaType: String, variableName: String, initCode: String): String = {
182-
if (mutableStateCount > 10000 && variableName.matches(".*\\d+.*") &&
181+
def addMutableState(
182+
javaType: String,
183+
variableName: String,
184+
initCode: String,
185+
inLine: Boolean = false): String = {
186+
if (!inLine && variableName.matches(".*\\d+.*") &&
183187
(initCode.matches("(^.*\\s*=\\s*null;$|^$)") || isPrimitiveType(javaType))) {
184188
val initCodeKey = initCode.replaceAll(variableName, "*VALUE*")
185189
if (mutableStateArrayIdx.contains((javaType, initCodeKey))) {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -268,17 +268,17 @@ abstract class HashExpression[E] extends Expression {
268268

269269
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
270270
ev.isNull = "false"
271+
val valueAccessor = ctx.addMutableState(ctx.javaType(dataType), ev.value, "")
271272
val childrenHash = ctx.splitExpressions(ctx.INPUT_ROW, children.map { child =>
272273
val childGen = child.genCode(ctx)
273274
childGen.code + ctx.nullSafeExec(child.nullable, childGen.isNull) {
274-
computeHash(childGen.value, child.dataType, ev.value, ctx)
275+
computeHash(childGen.value, child.dataType, valueAccessor, ctx)
275276
}
276277
})
277278

278-
val valueAccessor = ctx.addMutableState(ctx.javaType(dataType), ev.value, "")
279279
ev.copy(code = s"""
280280
$valueAccessor = $seed;
281-
$childrenHash""")
281+
$childrenHash""", value = valueAccessor)
282282
}
283283

284284
protected def nullSafeElementHash(
@@ -612,7 +612,7 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
612612

613613
ev.copy(code = s"""
614614
$valueAccessor = $seed;
615-
$childrenHash""")
615+
$childrenHash""", value = valueAccessor)
616616
}
617617

618618
override def eval(input: InternalRow = null): Int = {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala

Lines changed: 9 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ case class NewInstance(
339339
${outer.map(_.code).getOrElse("")}
340340
$valueAccessor = ${ev.isNull} ? ${ctx.defaultValue(javaType)} : $constructorCall;
341341
"""
342-
ev.copy(code = code)
342+
ev.copy(code = code, value = valueAccessor)
343343
}
344344

345345
override def toString: String = s"newInstance($cls)"
@@ -418,27 +418,17 @@ case class WrapOption(child: Expression, optType: DataType)
418418
case class LambdaVariable(
419419
value: String,
420420
isNull: String,
421-
loopValuesMap: mutable.Map[String, String],
422421
dataType: DataType,
423422
nullable: Boolean = true) extends LeafExpression
424423
with Unevaluable with NonSQLExpression {
425424

426425
override def genCode(ctx: CodegenContext): ExprCode = {
427-
val valueAccessor = loopValuesMap.getOrElseUpdate(value,
428-
ctx.addMutableState(ctx.javaType(dataType), value, ""))
429-
val isNullAccessor = loopValuesMap.getOrElseUpdate(isNull,
430-
ctx.addMutableState("boolean", isNull, ""))
431-
ExprCode(code = "", value = valueAccessor, isNull = if (nullable) isNullAccessor else "false")
426+
ExprCode(code = "", value = value, isNull = if (nullable) isNull else "false")
432427
}
433428
}
434429

435430
object MapObjects {
436431
private val curId = new java.util.concurrent.atomic.AtomicInteger()
437-
// Since the loopValue and loopIsNull mutable state may be compacted into an array of their
438-
// corresponding types, we keep a map between the variable name and its accessor, which is
439-
// either the same name, or an array-access, such that state may be properly assigned between
440-
// the lambdaFunction and the body of `MapObjects`
441-
private val loopValuesMap: mutable.Map[String, String] = mutable.Map.empty[String, String]
442432

443433
/**
444434
* Construct an instance of MapObjects case class.
@@ -453,8 +443,8 @@ object MapObjects {
453443
elementType: DataType): MapObjects = {
454444
val loopValue = "MapObjects_loopValue" + curId.getAndIncrement()
455445
val loopIsNull = "MapObjects_loopIsNull" + curId.getAndIncrement()
456-
val loopVar = LambdaVariable(loopValue, loopIsNull, loopValuesMap, elementType)
457-
MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData)(loopValuesMap)
446+
val loopVar = LambdaVariable(loopValue, loopIsNull, elementType)
447+
MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData)
458448
}
459449
}
460450

@@ -475,16 +465,13 @@ object MapObjects {
475465
* @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function
476466
* to handle collection elements.
477467
* @param inputData An expression that when evaluated returns a collection object.
478-
* @param loopValuesMap a map holding the name or array-accessor for the mutable state of loopValue
479-
* and loopIsNull variables.
480468
*/
481469
case class MapObjects private(
482470
loopValue: String,
483471
loopIsNull: String,
484472
loopVarDataType: DataType,
485473
lambdaFunction: Expression,
486474
inputData: Expression)
487-
(loopValuesMap: mutable.Map[String, String] = mutable.Map.empty[String, String])
488475
extends Expression with NonSQLExpression {
489476

490477
override def nullable: Boolean = inputData.nullable
@@ -497,14 +484,10 @@ case class MapObjects private(
497484
override def dataType: DataType =
498485
ArrayType(lambdaFunction.dataType, containsNull = lambdaFunction.nullable)
499486

500-
override protected def otherCopyArgs: Seq[AnyRef] = loopValuesMap :: Nil
501-
502487
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
503488
val elementJavaType = ctx.javaType(loopVarDataType)
504-
val loopIsNullAccessor = loopValuesMap.getOrElseUpdate(loopIsNull,
505-
ctx.addMutableState("boolean", loopIsNull, ""))
506-
val loopValueAccessor = loopValuesMap.getOrElseUpdate(loopValue,
507-
ctx.addMutableState(elementJavaType, loopValue, ""))
489+
val loopIsNullAccessor = ctx.addMutableState("boolean", loopIsNull, "", inLine = true)
490+
val loopValueAccessor = ctx.addMutableState(elementJavaType, loopValue, "", inLine = true)
508491
val genInputData = inputData.genCode(ctx)
509492
val genFunction = lambdaFunction.genCode(ctx)
510493
val dataLength = ctx.freshName("dataLength")
@@ -634,12 +617,12 @@ object ExternalMapToCatalyst {
634617
keyName,
635618
keyType,
636619
keyConverter(
637-
LambdaVariable(keyName, "false", mapValuesMap, keyType, false)),
620+
LambdaVariable(keyName, "false", keyType, false)),
638621
valueName,
639622
valueIsNull,
640623
valueType,
641624
valueConverter(
642-
LambdaVariable(valueName, valueIsNull, mapValuesMap, valueType, valueNullable)),
625+
LambdaVariable(valueName, valueIsNull, valueType, valueNullable)),
643626
inputMap
644627
)
645628
}
@@ -953,7 +936,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp
953936
$initializeCode
954937
}
955938
"""
956-
ev.copy(code = code, isNull = instanceGen.isNull, value = instanceGen.value)
939+
ev.copy(code = code, isNull = instanceGen.isNull, value = javaBeanInstanceAccessor)
957940
}
958941
}
959942

sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,11 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport {
8888
val columnVectorClz = "org.apache.spark.sql.execution.vectorized.ColumnVector"
8989
val idx = ctx.freshName("batchIdx")
9090
val idxAccessor = ctx.addMutableState("int", idx, s"$idx = 0;")
91-
val colVars = output.indices.map(i => ctx.freshName("colInstance" + i))
92-
val columnAssigns = colVars.zipWithIndex.map { case (name, i) =>
93-
val nameAccessor = ctx.addMutableState(columnVectorClz, name, s"$name = null;")
91+
val colVars = output.indices.map(i => {
92+
val name = ctx.freshName("colInstance" + i)
93+
ctx.addMutableState(columnVectorClz, name, s"$name = null;")
94+
})
95+
val columnAssigns = colVars.zipWithIndex.map { case (nameAccessor, i) =>
9496
s"$nameAccessor = $batchAccessor.column($i);"
9597
}
9698

0 commit comments

Comments
 (0)