Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -226,17 +226,10 @@ class PowerIterationClustering private[clustering] (
val predictionsSchema = StructType(Seq(
StructField($(idCol), LongType, nullable = false),
StructField($(predictionCol), IntegerType, nullable = false)))
val predictions = {
val uncastPredictions = sparkSession.createDataFrame(predictionsRDD, predictionsSchema)
dataset.schema($(idCol)).dataType match {
case _: LongType =>
uncastPredictions
case otherType =>
uncastPredictions.select(col($(idCol)).cast(otherType).alias($(idCol)))
}
}

dataset.join(predictions, $(idCol))
val predictions = sparkSession.createDataFrame(predictionsRDD, predictionsSchema)

predictions
}

@Since("2.4.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class PowerIterationClusteringSuite extends SparkFunSuite
result.select("id", "prediction").collect().foreach {
case Row(id: Long, cluster: Integer) => predictions(cluster) += id
}
assert(predictions.toSet == Set((1 until n1).toSet, (n1 until n).toSet))
assert(predictions.toSet == Set((0 until n1).toSet, (n1 until n).toSet))

val result2 = new PowerIterationClustering()
.setK(2)
Expand All @@ -95,7 +95,7 @@ class PowerIterationClusteringSuite extends SparkFunSuite
result2.select("id", "prediction").collect().foreach {
case Row(id: Long, cluster: Integer) => predictions2(cluster) += id
}
assert(predictions2.toSet == Set((1 until n1).toSet, (n1 until n).toSet))
assert(predictions2.toSet == Set((0 until n1).toSet, (n1 until n).toSet))
}

test("supported input types") {
Expand Down Expand Up @@ -183,6 +183,30 @@ class PowerIterationClusteringSuite extends SparkFunSuite
assert(msg.contains(s"Row for ID ${model.getIdCol}=1"))
}

test("valid input : When ID is IntType") {

val data = spark.createDataFrame(Seq(
(0, Array(1), Array(0.9)),
(1, Array(2), Array(0.9)),
(2, Array(3), Array(0.9)),
(3, Array(4), Array(0.1)),
(4, Array(5), Array(0.9))
)).toDF("id", "neighbors", "similarities")

val result = new PowerIterationClustering()
.setK(2)
.setMaxIter(10)
.setInitMode("random")
.transform(data)

val predictions = Array.fill(2)(mutable.Set.empty[Long])
result.select("id", "prediction").collect().foreach {
case Row(id: Long, cluster: Integer) => predictions(cluster) += id
}
assert(predictions.toSet == Set((0 until 4).toSet, Set(4, 5)))
assert(result.columns(1).equals("prediction"))
}

test("read/write") {
val t = new PowerIterationClustering()
.setK(4)
Expand Down