diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala index 2c30a1d9aa947..9077a5e82d1c8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala @@ -226,17 +226,9 @@ class PowerIterationClustering private[clustering] ( val predictionsSchema = StructType(Seq( StructField($(idCol), LongType, nullable = false), StructField($(predictionCol), IntegerType, nullable = false))) - val predictions = { - val uncastPredictions = sparkSession.createDataFrame(predictionsRDD, predictionsSchema) - dataset.schema($(idCol)).dataType match { - case _: LongType => - uncastPredictions - case otherType => - uncastPredictions.select(col($(idCol)).cast(otherType).alias($(idCol))) - } - } + val predictions = sparkSession.createDataFrame(predictionsRDD, predictionsSchema) - dataset.join(predictions, $(idCol)) + predictions } @Since("2.4.0") diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala index 65328df17baff..ad07cccaab46f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala @@ -84,7 +84,7 @@ class PowerIterationClusteringSuite extends SparkFunSuite result.select("id", "prediction").collect().foreach { case Row(id: Long, cluster: Integer) => predictions(cluster) += id } - assert(predictions.toSet == Set((1 until n1).toSet, (n1 until n).toSet)) + assert(predictions.toSet == Set((0 until n1).toSet, (n1 until n).toSet)) val result2 = new PowerIterationClustering() .setK(2) @@ -95,7 +95,30 @@ class PowerIterationClusteringSuite extends SparkFunSuite result2.select("id", "prediction").collect().foreach { case Row(id: Long, cluster: Integer) => predictions2(cluster) += id } - assert(predictions2.toSet == Set((1 until n1).toSet, (n1 until n).toSet)) + assert(predictions2.toSet == Set((0 until n1).toSet, (n1 until n).toSet)) + } + + test("power iteration clustering: random init mode") { + + val data = spark.createDataFrame(Seq( + (0, Array(1), Array(0.9)), + (1, Array(2), Array(0.9)), + (2, Array(3), Array(0.9)), + (3, Array(4), Array(0.1)), + (4, Array(5), Array(0.9)) + )).toDF("id", "neighbors", "similarities") + + val result = new PowerIterationClustering() + .setK(2) + .setMaxIter(10) + .setInitMode("random") + .transform(data) + + val predictions = Array.fill(2)(mutable.Set.empty[Long]) + result.select("id", "prediction").collect().foreach { + case Row(id: Long, cluster: Integer) => predictions(cluster) += id + } + assert(predictions.toSet == Set((0 until 4).toSet, Set(4, 5))) } test("supported input types") {