Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ import org.apache.spark.streaming.{Seconds, StreamingContext}
object StreamingKMeans {

def main(args: Array[String]) {

if (args.length != 5) {
System.err.println(
"Usage: StreamingKMeans " +
Expand All @@ -67,14 +66,12 @@ object StreamingKMeans {
val model = new StreamingKMeans()
.setK(args(3).toInt)
.setDecayFactor(1.0)
.setRandomCenters(args(4).toInt)
.setRandomCenters(args(4).toInt, 0.0)

model.trainOn(trainingData)
model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print()

ssc.start()
ssc.awaitTermination()

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,15 @@ package org.apache.spark.mllib.clustering

import scala.reflect.ClassTag

import breeze.linalg.{Vector => BV}

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.Logging
import org.apache.spark.mllib.linalg.{Vectors, Vector}
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.StreamingContext._
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom

/**
* :: DeveloperApi ::
Expand Down Expand Up @@ -66,55 +65,81 @@ import org.apache.spark.util.Utils
@DeveloperApi
class StreamingKMeansModel(
override val clusterCenters: Array[Vector],
val clusterCounts: Array[Long]) extends KMeansModel(clusterCenters) with Logging {
val clusterWeights: Array[Double]) extends KMeansModel(clusterCenters) with Logging {

/** Perform a k-means update on a batch of data. */
def update(data: RDD[Vector], decayFactor: Double, timeUnit: String): StreamingKMeansModel = {

val centers = clusterCenters
val counts = clusterCounts

// find nearest cluster to each point
val closest = data.map(point => (this.predict(point), (point.toBreeze, 1.toLong)))
val closest = data.map(point => (this.predict(point), (point, 1L)))

// get sums and counts for updating each cluster
type WeightedPoint = (BV[Double], Long)
def mergeContribs(p1: WeightedPoint, p2: WeightedPoint): WeightedPoint = {
(p1._1 += p2._1, p1._2 + p2._2)
val mergeContribs: ((Vector, Long), (Vector, Long)) => (Vector, Long) = (p1, p2) => {
BLAS.axpy(1.0, p2._1, p1._1)
(p1._1, p1._2 + p2._2)
}
val pointStats: Array[(Int, (BV[Double], Long))] =
closest.reduceByKey(mergeContribs).collect()
val dim = clusterCenters(0).size
val pointStats: Array[(Int, (Vector, Long))] = closest
.aggregateByKey((Vectors.zeros(dim), 0L))(mergeContribs, mergeContribs)
.collect()

val discount = timeUnit match {
case StreamingKMeans.BATCHES => decayFactor
case StreamingKMeans.POINTS =>
val numNewPoints = pointStats.view.map { case (_, (_, n)) =>
n
}.sum
math.pow(decayFactor, numNewPoints)
}

// apply discount to weights
BLAS.scal(discount, Vectors.dense(clusterWeights))

// implement update rule
pointStats.foreach { case (label, (mean, count)) =>
// store old count and centroid
val oldCount = counts(label)
val oldCentroid = centers(label).toBreeze
// get new count and centroid
val newCount = count
val newCentroid = mean / newCount.toDouble
// compute the normalized scale factor that controls forgetting
val lambda = timeUnit match {
case "batches" => newCount / (decayFactor * oldCount + newCount)
case "points" => newCount / (math.pow(decayFactor, newCount) * oldCount + newCount)
}
// perform the update
val updatedCentroid = oldCentroid + (newCentroid - oldCentroid) * lambda
// store the new counts and centers
counts(label) = oldCount + newCount
centers(label) = Vectors.fromBreeze(updatedCentroid)
pointStats.foreach { case (label, (sum, count)) =>
val centroid = clusterCenters(label)

val updatedWeight = clusterWeights(label) + count
val lambda = count / math.max(updatedWeight, 1e-16)

clusterWeights(label) = updatedWeight
BLAS.scal(1.0 - lambda, centroid)
BLAS.axpy(lambda / count, sum, centroid)

// display the updated cluster centers
val display = centers(label).size match {
case x if x > 100 => centers(label).toArray.take(100).mkString("[", ",", "...")
case _ => centers(label).toArray.mkString("[", ",", "]")
val display = clusterCenters(label).size match {
case x if x > 100 => centroid.toArray.take(100).mkString("[", ",", "...")
case _ => centroid.toArray.mkString("[", ",", "]")
}

logInfo(s"Cluster $label updated with weight $updatedWeight and centroid: $display")
}

// Check whether the smallest cluster is dying. If so, split the largest cluster.
val weightsWithIndex = clusterWeights.view.zipWithIndex
val (maxWeight, largest) = weightsWithIndex.maxBy(_._1)
val (minWeight, smallest) = weightsWithIndex.minBy(_._1)
if (minWeight < 1e-8 * maxWeight) {
logInfo(s"Cluster $smallest is dying. Split the largest cluster $largest into two.")
val weight = (maxWeight + minWeight) / 2.0
clusterWeights(largest) = weight
clusterWeights(smallest) = weight
val largestClusterCenter = clusterCenters(largest)
val smallestClusterCenter = clusterCenters(smallest)
var j = 0
while (j < dim) {
val x = largestClusterCenter(j)
val p = 1e-14 * math.max(math.abs(x), 1.0)
largestClusterCenter.toBreeze(j) = x + p
smallestClusterCenter.toBreeze(j) = x - p
j += 1
}
logInfo("Cluster %d updated: %s ".format (label, display))
}
new StreamingKMeansModel(centers, counts)
}

this
}
}

/**
* :: DeveloperApi ::
* StreamingKMeans provides methods for configuring a
Expand All @@ -128,7 +153,7 @@ class StreamingKMeansModel(
* val model = new StreamingKMeans()
* .setDecayFactor(0.5)
* .setK(3)
* .setRandomCenters(5)
* .setRandomCenters(5, 100.0)
* .trainOn(DStream)
*/
@DeveloperApi
Expand All @@ -137,9 +162,9 @@ class StreamingKMeans(
var decayFactor: Double,
var timeUnit: String) extends Logging {

protected var model: StreamingKMeansModel = new StreamingKMeansModel(null, null)
def this() = this(2, 1.0, StreamingKMeans.BATCHES)

def this() = this(2, 1.0, "batches")
protected var model: StreamingKMeansModel = new StreamingKMeansModel(null, null)

/** Set the number of clusters. */
def setK(k: Int): this.type = {
Expand All @@ -155,7 +180,7 @@ class StreamingKMeans(

/** Set the half life and time unit ("batches" or "points") for forgetful algorithms. */
def setHalfLife(halfLife: Double, timeUnit: String): this.type = {
if (timeUnit != "batches" && timeUnit != "points") {
if (timeUnit != StreamingKMeans.BATCHES && timeUnit != StreamingKMeans.POINTS) {
throw new IllegalArgumentException("Invalid time unit for decay: " + timeUnit)
}
this.decayFactor = math.exp(math.log(0.5) / halfLife)
Expand All @@ -165,26 +190,23 @@ class StreamingKMeans(
}

/** Specify initial centers directly. */
def setInitialCenters(initialCenters: Array[Vector]): this.type = {
val clusterCounts = new Array[Long](this.k)
this.model = new StreamingKMeansModel(initialCenters, clusterCounts)
def setInitialCenters(centers: Array[Vector], weights: Array[Double]): this.type = {
model = new StreamingKMeansModel(centers, weights)
this
}

/** Initialize random centers, requiring only the number of dimensions.
*
* @param dim Number of dimensions
* @param seed Random seed
* */
def setRandomCenters(dim: Int, seed: Long = Utils.random.nextLong): this.type = {

val random = Utils.random
random.setSeed(seed)

val initialCenters = (0 until k)
.map(_ => Vectors.dense(Array.fill(dim)(random.nextGaussian()))).toArray
val clusterCounts = new Array[Long](this.k)
this.model = new StreamingKMeansModel(initialCenters, clusterCounts)
/**
* Initialize random centers, requiring only the number of dimensions.
*
* @param dim Number of dimensions
* @param weight Weight for each center
* @param seed Random seed
*/
def setRandomCenters(dim: Int, weight: Double, seed: Long = Utils.random.nextLong): this.type = {
val random = new XORShiftRandom(seed)
val centers = Array.fill(k)(Vectors.dense(Array.fill(dim)(random.nextGaussian())))
val weights = Array.fill(k)(weight)
model = new StreamingKMeansModel(centers, weights)
this
}

Expand All @@ -202,9 +224,9 @@ class StreamingKMeans(
* @param data DStream containing vector data
*/
def trainOn(data: DStream[Vector]) {
this.assertInitialized()
assertInitialized()
data.foreachRDD { (rdd, time) =>
model = model.update(rdd, this.decayFactor, this.timeUnit)
model = model.update(rdd, decayFactor, timeUnit)
}
}

Expand All @@ -215,7 +237,7 @@ class StreamingKMeans(
* @return DStream containing predictions
*/
def predictOn(data: DStream[Vector]): DStream[Int] = {
this.assertInitialized()
assertInitialized()
data.map(model.predict)
}

Expand All @@ -227,16 +249,20 @@ class StreamingKMeans(
* @return DStream containing the input keys and the predictions as values
*/
def predictOnValues[K: ClassTag](data: DStream[(K, Vector)]): DStream[(K, Int)] = {
this.assertInitialized()
assertInitialized()
data.mapValues(model.predict)
}

/** Check whether cluster centers have been initialized. */
def assertInitialized(): Unit = {
if (Option(model.clusterCenters) == None) {
private[this] def assertInitialized(): Unit = {
if (model.clusterCenters == null) {
throw new IllegalStateException(
"Initial cluster centers must be set before starting predictions")
}
}
}

private[clustering] object StreamingKMeans {
final val BATCHES = "batches"
final val POINTS = "points"
}
Loading