Skip to content

Commit 73e04ec

Browse files
sharp-pixelsrowen
authored andcommitted
[MINOR] Correct validateAndTransformSchema in GaussianMixture and AFTSurvivalRegression
## What changes were proposed in this pull request? The line SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) did not modify the variable schema, hence only the last line had any effect. A temporary variable is used to correctly append the two columns predictionCol and probabilityCol. ## How was this patch tested? Manually. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Cédric Pelvet <[email protected]> Closes #18980 from sharp-pixel/master.
1 parent 72b738d commit 73e04ec

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ private[clustering] trait GaussianMixtureParams extends Params with HasMaxIter w
6464
*/
6565
protected def validateAndTransformSchema(schema: StructType): StructType = {
6666
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
67-
SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
68-
SchemaUtils.appendColumn(schema, $(probabilityCol), new VectorUDT)
67+
val schemaWithPredictionCol = SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
68+
SchemaUtils.appendColumn(schemaWithPredictionCol, $(probabilityCol), new VectorUDT)
6969
}
7070
}
7171

mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,12 @@ private[regression] trait AFTSurvivalRegressionParams extends Params
109109
SchemaUtils.checkNumericType(schema, $(censorCol))
110110
SchemaUtils.checkNumericType(schema, $(labelCol))
111111
}
112-
if (hasQuantilesCol) {
112+
113+
val schemaWithQuantilesCol = if (hasQuantilesCol) {
113114
SchemaUtils.appendColumn(schema, $(quantilesCol), new VectorUDT)
114-
}
115-
SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType)
115+
} else schema
116+
117+
SchemaUtils.appendColumn(schemaWithQuantilesCol, $(predictionCol), DoubleType)
116118
}
117119
}
118120

0 commit comments

Comments
 (0)