Skip to content

Commit b5b5f8d

Browse files
committed
Add better documentation
1 parent a0fd790 commit b5b5f8d

File tree

2 files changed

+107
-2
lines changed

2 files changed

+107
-2
lines changed

mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,20 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
118
package org.apache.spark.mllib.clustering
219

320
import breeze.linalg.{Vector => BV}
@@ -13,12 +30,45 @@ import org.apache.spark.SparkContext._
1330
import org.apache.spark.streaming.dstream.DStream
1431
import org.apache.spark.streaming.StreamingContext._
1532

33+
/**
34+
* :: DeveloperApi ::
35+
*
36+
* StreamingKMeansModel extends MLlib's KMeansModel for streaming
37+
* algorithms, so it can keep track of the number of points assigned
38+
* to each cluster, and also update the model by doing a single iteration
39+
* of the standard KMeans algorithm.
40+
*
41+
* The update algorithm uses the "mini-batch" KMeans rule,
42+
* generalized to incorporate forgetfullness (i.e. decay).
43+
* The basic update rule (for each cluster) is:
44+
*
45+
* c_t+1 = [(c_t * n_t) + (x_t * m_t)] / [n_t + m_t]
46+
* n_t+t = n_t + m_t
47+
*
48+
* Where c_t is the previously estimated centroid for that cluster,
49+
* n_t is the number of points assigned to it thus far, x_t is the centroid
50+
* estimated on the current batch, and m_t is the number of points assigned
51+
* to that centroid in the current batch.
52+
*
53+
* This update rule is modified with a decay factor 'a' that scales
54+
* the contribution of the clusters as estimated thus far.
55+
* If a=1, all batches are weighted equally. If a=0, new centroids
56+
* are determined entirely by recent data. Lower values correspond to
57+
* more forgetting.
58+
*
59+
* Decay can optionally be specified as a decay fraction 'q',
60+
* which corresponds to the fraction of batches (or points)
61+
* after which the past will be reduced to a contribution of 0.5.
62+
* This decay fraction can be specified in units of 'points' or 'batches'.
63+
* if 'batches', behavior will be independent of the number of points per batch;
64+
* if 'points', the expected number of points per batch must be specified.
65+
*/
1666
@DeveloperApi
1767
class StreamingKMeansModel(
1868
override val clusterCenters: Array[Vector],
1969
val clusterCounts: Array[Long]) extends KMeansModel(clusterCenters) {
2070

21-
/** do a sequential KMeans update on a batch of data **/
71+
// do a sequential KMeans update on a batch of data
2272
def update(data: RDD[Vector], a: Double, units: String): StreamingKMeansModel = {
2373

2474
val centers = clusterCenters
@@ -70,67 +120,104 @@ class StreamingKMeans(
70120

71121
def this() = this(2, 1.0, "batches")
72122

123+
/** Set the number of clusters. */
73124
def setK(k: Int): this.type = {
74125
this.k = k
75126
this
76127
}
77128

129+
/** Set the decay factor directly (for forgetful algorithms). */
78130
def setDecayFactor(a: Double): this.type = {
79131
this.a = a
80132
this
81133
}
82134

135+
/** Set the decay units for forgetful algorithms ("batches" or "points"). */
83136
def setUnits(units: String): this.type = {
137+
if (units != "batches" && units != "points") {
138+
throw new IllegalArgumentException("Invalid units for decay: " + units)
139+
}
84140
this.units = units
85141
this
86142
}
87143

144+
/** Set decay fraction in units of batches. */
88145
def setDecayFractionBatches(q: Double): this.type = {
89146
this.a = math.log(1 - q) / math.log(0.5)
90147
this.units = "batches"
91148
this
92149
}
93150

151+
/** Set decay fraction in units of points. Must specify expected number of points per batch. */
94152
def setDecayFractionPoints(q: Double, m: Double): this.type = {
95153
this.a = math.pow(math.log(1 - q) / math.log(0.5), 1/m)
96154
this.units = "points"
97155
this
98156
}
99157

158+
/** Specify initial explicitly directly. */
100159
def setInitialCenters(initialCenters: Array[Vector]): this.type = {
101160
val clusterCounts = Array.fill(this.k)(0).map(_.toLong)
102161
this.model = new StreamingKMeansModel(initialCenters, clusterCounts)
103162
this
104163
}
105164

165+
/** Initialize random centers, requiring only the number of dimensions. */
106166
def setRandomCenters(d: Int): this.type = {
107167
val initialCenters = (0 until k).map(_ => Vectors.dense(Array.fill(d)(nextGaussian()))).toArray
108168
val clusterCounts = Array.fill(0)(d).map(_.toLong)
109169
this.model = new StreamingKMeansModel(initialCenters, clusterCounts)
110170
this
111171
}
112172

173+
/** Return the latest model. */
113174
def latestModel(): StreamingKMeansModel = {
114175
model
115176
}
116177

178+
/**
179+
* Update the clustering model by training on batches of data from a DStream.
180+
* This operation registers a DStream for training the model,
181+
* checks whether the cluster centers have been initialized,
182+
* and updates the model using each batch of data from the stream.
183+
*
184+
* @param data DStream containing vector data
185+
*/
117186
def trainOn(data: DStream[Vector]) {
118187
this.isInitialized
119188
data.foreachRDD { (rdd, time) =>
120189
model = model.update(rdd, this.a, this.units)
121190
}
122191
}
123192

193+
/**
194+
* Use the clustering model to make predictions on batches of data from a DStream.
195+
*
196+
* @param data DStream containing vector data
197+
* @return DStream containing predictions
198+
*/
124199
def predictOn(data: DStream[Vector]): DStream[Int] = {
125200
this.isInitialized
126201
data.map(model.predict)
127202
}
128203

204+
/**
205+
* Use the model to make predictions on the values of a DStream and carry over its keys.
206+
*
207+
* @param data DStream containing (key, feature vector) pairs
208+
* @tparam K key type
209+
* @return DStream containing the input keys and the predictions as values
210+
*/
129211
def predictOnValues[K: ClassTag](data: DStream[(K, Vector)]): DStream[(K, Int)] = {
130212
this.isInitialized
131213
data.mapValues(model.predict)
132214
}
133215

216+
/**
217+
* Check whether cluster centers have been initialized.
218+
*
219+
* @return Boolean, True if cluster centrs have been initialized
220+
*/
134221
def isInitialized: Boolean = {
135222
if (Option(model.clusterCenters) == None) {
136223
logError("Initial cluster centers must be set before starting predictions")

mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,20 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
118
package org.apache.spark.mllib.clustering
219

320
import scala.collection.mutable.ArrayBuffer
@@ -43,6 +60,7 @@ class StreamingKMeansSuite extends FunSuite with TestSuiteBase {
4360
assert(centers(0) ~== model.latestModel().clusterCenters(0) absTol 1E-1)
4461

4562
// estimated center from streaming should exactly match the arithmetic mean of all data points
63+
// because the decay factor is set to 1.0
4664
val grandMean = input.flatten.map(x => x.toBreeze).reduce(_+_) / (numBatches * numPoints).toDouble
4765
assert(model.latestModel().clusterCenters(0) ~== Vectors.dense(grandMean.toArray) absTol 1E-5)
4866

@@ -74,7 +92,7 @@ class StreamingKMeansSuite extends FunSuite with TestSuiteBase {
7492
runStreams(ssc, numBatches, numBatches)
7593

7694
// check that estimated centers are close to true centers
77-
// NOTE this depends on the initialization! allow for binary flip
95+
// NOTE exact assignment depends on the initialization!
7896
assert(centers(0) ~== model.latestModel().clusterCenters(0) absTol 1E-1)
7997
assert(centers(1) ~== model.latestModel().clusterCenters(1) absTol 1E-1)
8098

0 commit comments

Comments
 (0)