Skip to content

Commit 14c7ce3

Browse files
committed
Fixing scala style and test case
1 parent 75192a7 commit 14c7ce3

File tree

3 files changed

+9
-5
lines changed

3 files changed

+9
-5
lines changed

mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -377,12 +377,12 @@ private object BisectingKMeans extends Serializable {
377377
internalIndex -= 1
378378
val leftIndex = leftChildIndex(rawIndex)
379379
val rightIndex = rightChildIndex(rawIndex)
380-
val height = math.sqrt(Seq(leftIndex, rightIndex).map { childIndex =>
380+
val indexes = Seq(leftIndex, rightIndex).filter(clusters.contains(_))
381+
val height = math.sqrt(indexes.map { childIndex =>
381382
KMeans.fastSquaredDistance(center, clusters(childIndex).center)
382383
}.max)
383-
val left = buildSubTree(leftIndex)
384-
val right = buildSubTree(rightIndex)
385-
new ClusteringTreeNode(index, size, center, cost, height, Array(left, right))
384+
val children = indexes.map(buildSubTree(_)).toArray
385+
new ClusteringTreeNode(index, size, center, cost, height, children)
386386
} else {
387387
val index = leafIndex
388388
leafIndex += 1

mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ class BisectingKMeansSuite
6666
// Verify fit does not fail on very sparse data
6767
val model = bkm.fit(sparseDataset)
6868
assert(model.hasSummary)
69+
val result = model.transform(sparseDataset)
70+
val numClusters = result.select("prediction").distinct().collect().length
71+
assert(numClusters <= k && numClusters >= 1)
6972
}
7073

7174
test("setter/getter") {

mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,15 @@
1717

1818
package org.apache.spark.ml.clustering
1919

20+
import scala.util.Random
21+
2022
import org.apache.spark.SparkFunSuite
2123
import org.apache.spark.ml.linalg.{Vector, Vectors}
2224
import org.apache.spark.ml.param.ParamMap
2325
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
2426
import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans}
2527
import org.apache.spark.mllib.util.MLlibTestSparkContext
2628
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
27-
import scala.util.Random
2829

2930
private[clustering] case class TestRow(features: Vector)
3031

0 commit comments

Comments
 (0)