Skip to content

Commit 2061a76

Browse files
committed
Add tests for simultaneous training and prediction
Minor style fixes
1 parent 81482fd commit 2061a76

File tree

4 files changed

+93
-68
lines changed

4 files changed

+93
-68
lines changed

docs/mllib-clustering.md

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -599,48 +599,40 @@ ssc.awaitTermination()
599599
First we import the neccessary classes.
600600

601601
{% highlight python %}
602-
603602
from pyspark.mllib.linalg import Vectors
604603
from pyspark.mllib.regression import LabeledPoint
605604
from pyspark.mllib.clustering import StreamingKMeans
606-
607605
{% endhighlight %}
608606

609607
Then we make an input stream of vectors for training, as well as a stream of labeled data
610608
points for testing. We assume a StreamingContext `ssc` has been created, see
611609
[Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) for more info.
612610

613611
{% highlight python %}
612+
def parse(lp):
613+
label = float(lp[lp.find('(') + 1: lp.find(',')])
614+
vec = Vectors.dense(lp[lp.find('[') + 1: lp.find(']')].split(','))
615+
return LabeledPoint(label, vec)
614616

615617
trainingData = ssc.textFileStream("/training/data/dir").map(Vectors.parse)
616-
testData = ssc.textFileStream("/testing/data/dir").map(LabeledPoint.parse)
617-
618+
testData = ssc.textFileStream("/testing/data/dir").map(parse)
618619
{% endhighlight %}
619620

620621
We create a model with random clusters and specify the number of clusters to find
621622

622623
{% highlight python %}
623-
624-
numDimensions = 3
625-
numClusters = 2
626-
model = StreamingKMeans()
627-
model.setK(numClusters)
628-
model.setDecayFactor(1.0)
629-
model.setRandomCenters(numDimensions, 0.0)
630-
624+
model = StreamingKMeans(k=2, decayFactor=1.0).setRandomCenters(3, 1.0, 0)
631625
{% endhighlight %}
632626

633627
Now register the streams for training and testing and start the job, printing
634628
the predicted cluster assignments on new data points as they arrive.
635629

636630
{% highlight python %}
637-
638631
model.trainOn(trainingData)
639632
model.predictOnValues(testData.map(lambda lp: (lp.label, lp.features)))
640633

641634
ssc.start()
642635
ssc.awaitTermination()
643-
644636
{% endhighlight %}
645637
</div>
646638

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -968,12 +968,13 @@ private[python] class PythonMLLibAPI extends Serializable {
968968
* Java stub for the update method of StreamingKMeansModel.
969969
*/
970970
def updateStreamingKMeansModel(
971-
clusterCenters: java.util.ArrayList[Vector],
972-
clusterWeights: java.util.ArrayList[Double],
973-
data: JavaRDD[Vector], decayFactor: Double,
974-
timeUnit: String) : JList[Object] = {
975-
val model = new StreamingKMeansModel(
976-
clusterCenters.asScala.toArray, clusterWeights.asScala.toArray)
971+
clusterCenters: JList[Vector],
972+
clusterWeights: JList[Double],
973+
data: JavaRDD[Vector],
974+
decayFactor: Double,
975+
timeUnit: String): JList[Object] = {
976+
val model = new StreamingKMeansModel(
977+
clusterCenters.asScala.toArray, clusterWeights.asScala.toArray)
977978
.update(data, decayFactor, timeUnit)
978979
List[AnyRef](model.clusterCenters, Vectors.dense(model.clusterWeights)).asJava
979980
}

python/pyspark/mllib/clustering.py

Lines changed: 32 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -275,18 +275,19 @@ class StreamingKMeansModel(KMeansModel):
275275
.. note:: Experimental
276276
Clustering model which can perform an online update of the centroids.
277277
278-
The update formula is given by
278+
The update formula for each centroid is given by
279279
c_t+1 = [(c_t * n_t * a) + (x_t * m_t)] / [n_t + m_t]
280280
n_t+1 = n_t * a + m_t
281281
282282
where
283-
c_t: Centroid at the n_th iteration.
284-
n_t: Number of weights at the n_th iteration.
285-
x_t: Centroid of the new data closest to c_t
286-
m_t: Number of weights of the new data closest to c_t
287-
c_t+1: New centroid
283+
c_t: Centroid at the n_th iteration.
284+
n_t: Number of samples (or) weights associated with the centroid
285+
at the n_th iteration.
286+
x_t: Centroid of the new data closest to c_t.
287+
m_t: Number of samples (or) weights of the new data closest to c_t
288+
c_t+1: New centroid.
288289
n_t+1: New number of weights.
289-
a: Decay Factor, which gives the forgetfulnes
290+
a: Decay Factor, which gives the forgetfulness.
290291
291292
Note that if a is set to 1, it is the weighted mean of the previous
292293
and new data. If it set to zero, the old centroids are completely
@@ -304,41 +305,41 @@ class StreamingKMeansModel(KMeansModel):
304305
True
305306
>>> stkm.predict([0.9, 0.9]) == stkm.predict([1.1, 1.1]) == 1
306307
True
307-
>>> stkm.getClusterWeights
308+
>>> stkm.clusterWeights
308309
[3.0, 3.0]
309310
>>> decayFactor = 0.0
310311
>>> data = sc.parallelize([DenseVector([1.5, 1.5]), DenseVector([0.2, 0.2])])
311312
>>> stkm = stkm.update(data, 0.0, u"batches")
312313
>>> stkm.centers
313314
array([[ 0.2, 0.2],
314315
[ 1.5, 1.5]])
315-
>>> stkm.getClusterWeights
316+
>>> stkm.clusterWeights
316317
[1.0, 1.0]
317318
>>> stkm.predict([0.2, 0.2])
318319
0
319320
>>> stkm.predict([1.5, 1.5])
320321
1
322+
323+
:param clusterCenters: Initial cluster centers.
324+
:param clusterWeights: List of weights assigned to each cluster.
321325
"""
322326
def __init__(self, clusterCenters, clusterWeights):
323327
super(StreamingKMeansModel, self).__init__(centers=clusterCenters)
324328
self._clusterWeights = list(clusterWeights)
325329

326330
@property
327-
def getClusterWeights(self):
331+
def clusterWeights(self):
328332
"""Convenience method to return the cluster weights."""
329333
return self._clusterWeights
330334

331335
@ignore_unicode_prefix
332336
def update(self, data, decayFactor, timeUnit):
333337
"""Update the centroids, according to data
334338
335-
Parameters
336-
----------
337-
data: Should be a RDD that represents the new data.
338-
339-
decayFactor: forgetfulness of the previous centroids.
339+
:param data: Should be a RDD that represents the new data.
340+
:param decayFactor: forgetfulness of the previous centroids.
341+
:param timeUnit: Can be "batches" or "points"
340342
341-
timeUnit: Can be "batches" or "points"
342343
If points, then the decay factor is raised to the power of
343344
number of new points and if batches, it is used as it is.
344345
"""
@@ -365,17 +366,10 @@ class StreamingKMeans(object):
365366
Provides methods to set k, decayFactor, timeUnit to train and
366367
predict the incoming data
367368
368-
Parameters
369-
----------
370-
k: int
371-
Number of clusters
372-
373-
decayFactor: float
374-
Forgetfulness of the previous centroid.
375-
376-
timeUnit: str, "batches" or "points"
377-
If points, then the decayfactor is raised to the power of new
378-
points.
369+
:param k: int, number of clusters
370+
:param decayFactor: float, forgetfulness of the previous centroids.
371+
:param timeUnit: can be "batches" or "points". If points, then the
372+
decayfactor is raised to the power of no. of new points.
379373
"""
380374
def __init__(self, k=2, decayFactor=1.0, timeUnit="batches"):
381375
self._k = k
@@ -384,10 +378,14 @@ def __init__(self, k=2, decayFactor=1.0, timeUnit="batches"):
384378
raise ValueError(
385379
"timeUnit should be 'batches' or 'points', got %s." % timeUnit)
386380
self._timeUnit = timeUnit
387-
self.latestModel = None
381+
self._model = None
382+
383+
def latestModel(self):
384+
"""Return the latest model"""
385+
return self._model
388386

389387
def _validate(self, dstream):
390-
if self.latestModel is None:
388+
if self._model is None:
391389
raise ValueError(
392390
"Initial centers should be set either by setInitialCenters "
393391
"or setRandomCenters.")
@@ -416,7 +414,7 @@ def setHalfLife(self, halfLife, timeUnit):
416414
return self
417415

418416
def setInitialCenters(self, centers, weights):
419-
self.latestModel = StreamingKMeansModel(centers, weights)
417+
self._model = StreamingKMeansModel(centers, weights)
420418
return self
421419

422420
def setRandomCenters(self, dim, weight, seed):
@@ -427,15 +425,15 @@ def setRandomCenters(self, dim, weight, seed):
427425
rng = random.RandomState(seed)
428426
clusterCenters = rng.randn(self._k, dim)
429427
clusterWeights = tile(weight, self._k)
430-
self.latestModel = StreamingKMeansModel(clusterCenters, clusterWeights)
428+
self._model = StreamingKMeansModel(clusterCenters, clusterWeights)
431429
return self
432430

433431
def trainOn(self, dstream):
434432
"""Train the model on the incoming dstream."""
435433
self._validate(dstream)
436434

437435
def update(rdd):
438-
self.latestModel.update(rdd, self._decayFactor, self._timeUnit)
436+
self._model.update(rdd, self._decayFactor, self._timeUnit)
439437

440438
dstream.foreachRDD(update)
441439

@@ -445,15 +443,15 @@ def predictOn(self, dstream):
445443
Returns a transformed dstream object
446444
"""
447445
self._validate(dstream)
448-
return dstream.map(lambda x: self.latestModel.predict(x))
446+
return dstream.map(lambda x: self._model.predict(x))
449447

450448
def predictOnValues(self, dstream):
451449
"""
452450
Make predictions on a keyed dstream.
453451
Returns a transformed dstream object.
454452
"""
455453
self._validate(dstream)
456-
return dstream.mapValues(lambda x: self.latestModel.predict(x))
454+
return dstream.mapValues(lambda x: self._model.predict(x))
457455

458456

459457
def _test():

python/pyspark/mllib/tests.py

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,11 @@ def setUp(self):
7979
def tearDown(self):
8080
self.ssc.stop(False)
8181

82+
@staticmethod
83+
def _ssc_wait(start_time, end_time, sleep_time):
84+
while time() - start_time < end_time:
85+
sleep(0.01)
86+
8287

8388
def _squared_distance(a, b):
8489
if isinstance(a, Vector):
@@ -878,25 +883,23 @@ def test_model_transform(self):
878883

879884
class StreamingKMeansTest(MLLibStreamingTestCase):
880885
def test_model_params(self):
886+
"""Test that the model params are set correctly"""
881887
stkm = StreamingKMeans()
882888
stkm.setK(5).setDecayFactor(0.0)
883889
self.assertEquals(stkm._k, 5)
884890
self.assertEquals(stkm._decayFactor, 0.0)
885891

886892
# Model not set yet.
887-
self.assertIsNone(stkm.latestModel)
893+
self.assertIsNone(stkm.latestModel())
888894
self.assertRaises(ValueError, stkm.trainOn, [0.0, 1.0])
889895

890896
stkm.setInitialCenters([[0.0, 0.0], [1.0, 1.0]], [1.0, 1.0])
891-
self.assertEquals(stkm.latestModel.centers, [[0.0, 0.0], [1.0, 1.0]])
892-
self.assertEquals(stkm.latestModel.getClusterWeights, [1.0, 1.0])
893-
894-
@staticmethod
895-
def _ssc_wait(start_time, end_time, sleep_time):
896-
while time() - start_time < end_time:
897-
sleep(0.01)
897+
self.assertEquals(
898+
stkm.latestModel().centers, [[0.0, 0.0], [1.0, 1.0]])
899+
self.assertEquals(stkm.latestModel().clusterWeights, [1.0, 1.0])
898900

899901
def test_accuracy_for_single_center(self):
902+
"""Test that the parameters obtained are correct for a single center."""
900903
numBatches, numPoints, k, d, r, seed = 5, 5, 1, 5, 0.1, 0
901904
centers, batches = self.streamingKMeansDataGenerator(
902905
numBatches, numPoints, k, d, r, seed)
@@ -905,13 +908,14 @@ def test_accuracy_for_single_center(self):
905908
input_stream = self.ssc.queueStream(
906909
[self.sc.parallelize(batch, 1) for batch in batches])
907910
stkm.trainOn(input_stream)
911+
908912
t = time()
909913
self.ssc.start()
910914
self._ssc_wait(t, 10.0, 0.01)
911-
self.assertEquals(stkm.latestModel.getClusterWeights, [25.0])
915+
self.assertEquals(stkm.latestModel().clusterWeights, [25.0])
912916
realCenters = array_sum(array(centers), axis=0)
913917
for i in range(d):
914-
modelCenters = stkm.latestModel.centers[0][i]
918+
modelCenters = stkm.latestModel().centers[0][i]
915919
self.assertAlmostEqual(centers[0][i], modelCenters, 1)
916920
self.assertAlmostEqual(realCenters[i], modelCenters, 1)
917921

@@ -927,7 +931,7 @@ def streamingKMeansDataGenerator(self, batches, numPoints,
927931
for i in range(batches)]
928932

929933
def test_trainOn_model(self):
930-
# Test the model on toy data with four clusters.
934+
"""Test the model on toy data with four clusters."""
931935
stkm = StreamingKMeans()
932936
initCenters = [[1.0, 1.0], [-1.0, 1.0], [-1.0, -1.0], [1.0, -1.0]]
933937
weights = [1.0, 1.0, 1.0, 1.0]
@@ -948,15 +952,16 @@ def test_trainOn_model(self):
948952

949953
# Give enough time to train the model.
950954
self._ssc_wait(t, 6.0, 0.01)
951-
finalModel = stkm.latestModel
955+
finalModel = stkm.latestModel()
952956
self.assertTrue(all(finalModel.centers == array(initCenters)))
953-
self.assertEquals(finalModel.getClusterWeights, [5.0, 5.0, 5.0, 5.0])
957+
self.assertEquals(finalModel.clusterWeights, [5.0, 5.0, 5.0, 5.0])
954958

955959
def test_predictOn_model(self):
960+
"""Test that the model predicts correctly on toy data."""
956961
initCenters = [[1.0, 1.0], [-1.0, 1.0], [-1.0, -1.0], [1.0, -1.0]]
957962
weights = [1.0, 1.0, 1.0, 1.0]
958963
stkm = StreamingKMeans()
959-
stkm.latestModel = StreamingKMeansModel(initCenters, weights)
964+
stkm._model = StreamingKMeansModel(initCenters, weights)
960965

961966
predict_data = [[[1.5, 1.5]], [[-1.5, 1.5]], [[-1.5, -1.5]], [[1.5, -1.5]]]
962967
predict_data = [sc.parallelize(batch, 1) for batch in predict_data]
@@ -976,6 +981,35 @@ def update(rdd):
976981
self._ssc_wait(t, 6.0, 0.01)
977982
self.assertEquals(result, [[0], [1], [2], [3]])
978983

984+
def test_trainOn_predictOn(self):
985+
"""Test that prediction happens on the updated model."""
986+
stkm = StreamingKMeans(decayFactor=0.0, k=2)
987+
stkm.setInitialCenters([[0.0], [1.0]], [1.0, 1.0])
988+
989+
# Since decay factor is set to zero, once the first batch
990+
# is passed the clusterCenters are updated to [-0.5, 0.7]
991+
# which causes 0.2 & 0.3 to be classified as 1, even though the
992+
# classification based in the initial model would have been 0
993+
# proving that the model is updated.
994+
batches = [[[-0.5], [0.6], [0.8]], [[0.2], [-0.1], [0.3]]]
995+
batches = [sc.parallelize(batch) for batch in batches]
996+
input_stream = self.ssc.queueStream(batches)
997+
predict_results = []
998+
999+
def collect(rdd):
1000+
rdd_collect = rdd.collect()
1001+
if rdd_collect:
1002+
predict_results.append(rdd_collect)
1003+
1004+
stkm.trainOn(input_stream)
1005+
predict_stream = stkm.predictOn(input_stream)
1006+
predict_stream.foreachRDD(collect)
1007+
1008+
t = time()
1009+
self.ssc.start()
1010+
self._ssc_wait(t, 6.0, 0.01)
1011+
self.assertEqual(predict_results, [[0, 1, 1], [1, 0, 1]])
1012+
9791013

9801014
if __name__ == "__main__":
9811015
if not _have_scipy:

0 commit comments

Comments
 (0)