1+ /*
2+ * Licensed to the Apache Software Foundation (ASF) under one or more
3+ * contributor license agreements. See the NOTICE file distributed with
4+ * this work for additional information regarding copyright ownership.
5+ * The ASF licenses this file to You under the Apache License, Version 2.0
6+ * (the "License"); you may not use this file except in compliance with
7+ * the License. You may obtain a copy of the License at
8+ *
9+ * http://www.apache.org/licenses/LICENSE-2.0
10+ *
11+ * Unless required by applicable law or agreed to in writing, software
12+ * distributed under the License is distributed on an "AS IS" BASIS,
13+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+ * See the License for the specific language governing permissions and
15+ * limitations under the License.
16+ */
17+
118package org .apache .spark .mllib .clustering
219
320import breeze .linalg .{Vector => BV }
@@ -13,12 +30,45 @@ import org.apache.spark.SparkContext._
1330import org .apache .spark .streaming .dstream .DStream
1431import org .apache .spark .streaming .StreamingContext ._
1532
33+ /**
34+ * :: DeveloperApi ::
35+ *
36+ * StreamingKMeansModel extends MLlib's KMeansModel for streaming
37+ * algorithms, so it can keep track of the number of points assigned
38+ * to each cluster, and also update the model by doing a single iteration
39+ * of the standard KMeans algorithm.
40+ *
41+ * The update algorithm uses the "mini-batch" KMeans rule,
42+ * generalized to incorporate forgetfullness (i.e. decay).
43+ * The basic update rule (for each cluster) is:
44+ *
45+ * c_t+1 = [(c_t * n_t) + (x_t * m_t)] / [n_t + m_t]
46+ * n_t+t = n_t + m_t
47+ *
48+ * Where c_t is the previously estimated centroid for that cluster,
49+ * n_t is the number of points assigned to it thus far, x_t is the centroid
50+ * estimated on the current batch, and m_t is the number of points assigned
51+ * to that centroid in the current batch.
52+ *
53+ * This update rule is modified with a decay factor 'a' that scales
54+ * the contribution of the clusters as estimated thus far.
55+ * If a=1, all batches are weighted equally. If a=0, new centroids
56+ * are determined entirely by recent data. Lower values correspond to
57+ * more forgetting.
58+ *
59+ * Decay can optionally be specified as a decay fraction 'q',
60+ * which corresponds to the fraction of batches (or points)
61+ * after which the past will be reduced to a contribution of 0.5.
62+ * This decay fraction can be specified in units of 'points' or 'batches'.
63+ * if 'batches', behavior will be independent of the number of points per batch;
64+ * if 'points', the expected number of points per batch must be specified.
65+ */
1666@ DeveloperApi
1767class StreamingKMeansModel (
1868 override val clusterCenters : Array [Vector ],
1969 val clusterCounts : Array [Long ]) extends KMeansModel (clusterCenters) {
2070
21- /** do a sequential KMeans update on a batch of data * */
71+ // do a sequential KMeans update on a batch of data
2272 def update (data : RDD [Vector ], a : Double , units : String ): StreamingKMeansModel = {
2373
2474 val centers = clusterCenters
@@ -70,67 +120,104 @@ class StreamingKMeans(
70120
71121 def this () = this (2 , 1.0 , " batches" )
72122
123+ /** Set the number of clusters. */
73124 def setK (k : Int ): this .type = {
74125 this .k = k
75126 this
76127 }
77128
129+ /** Set the decay factor directly (for forgetful algorithms). */
78130 def setDecayFactor (a : Double ): this .type = {
79131 this .a = a
80132 this
81133 }
82134
135+ /** Set the decay units for forgetful algorithms ("batches" or "points"). */
83136 def setUnits (units : String ): this .type = {
137+ if (units != " batches" && units != " points" ) {
138+ throw new IllegalArgumentException (" Invalid units for decay: " + units)
139+ }
84140 this .units = units
85141 this
86142 }
87143
144+ /** Set decay fraction in units of batches. */
88145 def setDecayFractionBatches (q : Double ): this .type = {
89146 this .a = math.log(1 - q) / math.log(0.5 )
90147 this .units = " batches"
91148 this
92149 }
93150
151+ /** Set decay fraction in units of points. Must specify expected number of points per batch. */
94152 def setDecayFractionPoints (q : Double , m : Double ): this .type = {
95153 this .a = math.pow(math.log(1 - q) / math.log(0.5 ), 1 / m)
96154 this .units = " points"
97155 this
98156 }
99157
158+ /** Specify initial explicitly directly. */
100159 def setInitialCenters (initialCenters : Array [Vector ]): this .type = {
101160 val clusterCounts = Array .fill(this .k)(0 ).map(_.toLong)
102161 this .model = new StreamingKMeansModel (initialCenters, clusterCounts)
103162 this
104163 }
105164
165+ /** Initialize random centers, requiring only the number of dimensions. */
106166 def setRandomCenters (d : Int ): this .type = {
107167 val initialCenters = (0 until k).map(_ => Vectors .dense(Array .fill(d)(nextGaussian()))).toArray
108168 val clusterCounts = Array .fill(0 )(d).map(_.toLong)
109169 this .model = new StreamingKMeansModel (initialCenters, clusterCounts)
110170 this
111171 }
112172
173+ /** Return the latest model. */
113174 def latestModel (): StreamingKMeansModel = {
114175 model
115176 }
116177
178+ /**
179+ * Update the clustering model by training on batches of data from a DStream.
180+ * This operation registers a DStream for training the model,
181+ * checks whether the cluster centers have been initialized,
182+ * and updates the model using each batch of data from the stream.
183+ *
184+ * @param data DStream containing vector data
185+ */
117186 def trainOn (data : DStream [Vector ]) {
118187 this .isInitialized
119188 data.foreachRDD { (rdd, time) =>
120189 model = model.update(rdd, this .a, this .units)
121190 }
122191 }
123192
193+ /**
194+ * Use the clustering model to make predictions on batches of data from a DStream.
195+ *
196+ * @param data DStream containing vector data
197+ * @return DStream containing predictions
198+ */
124199 def predictOn (data : DStream [Vector ]): DStream [Int ] = {
125200 this .isInitialized
126201 data.map(model.predict)
127202 }
128203
204+ /**
205+ * Use the model to make predictions on the values of a DStream and carry over its keys.
206+ *
207+ * @param data DStream containing (key, feature vector) pairs
208+ * @tparam K key type
209+ * @return DStream containing the input keys and the predictions as values
210+ */
129211 def predictOnValues [K : ClassTag ](data : DStream [(K , Vector )]): DStream [(K , Int )] = {
130212 this .isInitialized
131213 data.mapValues(model.predict)
132214 }
133215
216+ /**
217+ * Check whether cluster centers have been initialized.
218+ *
219+ * @return Boolean, True if cluster centrs have been initialized
220+ */
134221 def isInitialized : Boolean = {
135222 if (Option (model.clusterCenters) == None ) {
136223 logError(" Initial cluster centers must be set before starting predictions" )
0 commit comments