Skip to content

Commit 75192a7

Browse files
committed
Adding test case as requested in review. Test case generates synthetic sparse data which can generate the exception user encountered.
1 parent 55ab179 commit 75192a7

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,12 @@ class BisectingKMeansSuite
2929
final val k = 5
3030
@transient var dataset: Dataset[_] = _
3131

32+
@transient var sparseDataset: Dataset[_] = _
33+
3234
override def beforeAll(): Unit = {
3335
super.beforeAll()
3436
dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k)
37+
sparseDataset = KMeansSuite.generateSparseData(spark, 100, 1000, k, 42)
3538
}
3639

3740
test("default parameters") {
@@ -51,6 +54,20 @@ class BisectingKMeansSuite
5154
assert(copiedModel.hasSummary)
5255
}
5356

57+
test("SPARK-16473: Verify Bisecting K-Means does not fail in edge case where" +
58+
"one cluster is empty after split") {
59+
val bkm = new BisectingKMeans().setK(k).setMinDivisibleClusterSize(4).setMaxIter(4)
60+
61+
assert(bkm.getK === k)
62+
assert(bkm.getFeaturesCol === "features")
63+
assert(bkm.getPredictionCol === "prediction")
64+
assert(bkm.getMaxIter === 4)
65+
assert(bkm.getMinDivisibleClusterSize === 4)
66+
// Verify fit does not fail on very sparse data
67+
val model = bkm.fit(sparseDataset)
68+
assert(model.hasSummary)
69+
}
70+
5471
test("setter/getter") {
5572
val bkm = new BisectingKMeans()
5673
.setK(9)

mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
2424
import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans}
2525
import org.apache.spark.mllib.util.MLlibTestSparkContext
2626
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
27+
import scala.util.Random
2728

2829
private[clustering] case class TestRow(features: Vector)
2930

@@ -160,6 +161,17 @@ object KMeansSuite {
160161
spark.createDataFrame(rdd)
161162
}
162163

164+
def generateSparseData(spark: SparkSession, rows: Int, dim: Int, k: Int, seed: Int): DataFrame = {
165+
val sc = spark.sparkContext
166+
val random = new Random(seed)
167+
val nnz = random.nextInt(dim)
168+
val rdd = sc.parallelize(1 to rows)
169+
.map(i => Vectors.sparse(dim, random.shuffle(0 to dim - 1).slice(0, nnz).sorted.toArray,
170+
Array.fill(nnz)(random.nextDouble())))
171+
.map(v => new TestRow(v))
172+
spark.createDataFrame(rdd)
173+
}
174+
163175
/**
164176
* Mapping from all Params to valid settings which differ from the defaults.
165177
* This is useful for tests which need to exercise all Params, such as save/load.

0 commit comments

Comments
 (0)