Skip to content

Commit 33ab236

Browse files
committed
Revert "[SPARK-12006][ML][PYTHON] Fix GMM failure if initialModel is not None"
This reverts commit fcd013c. Author: Yin Huai <[email protected]> Closes #10632 from yhuai/pythonStyle. (cherry picked from commit e5cde7a) Signed-off-by: Yin Huai <[email protected]>
1 parent f2bc02e commit 33ab236

File tree

2 files changed

+1
-13
lines changed

2 files changed

+1
-13
lines changed

python/pyspark/mllib/clustering.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None, initia
300300
if initialModel.k != k:
301301
raise Exception("Mismatched cluster count, initialModel.k = %s, however k = %s"
302302
% (initialModel.k, k))
303-
initialModelWeights = list(initialModel.weights)
303+
initialModelWeights = initialModel.weights
304304
initialModelMu = [initialModel.gaussians[i].mu for i in range(initialModel.k)]
305305
initialModelSigma = [initialModel.gaussians[i].sigma for i in range(initialModel.k)]
306306
java_model = callMLlibFunc("trainGaussianMixtureModel", rdd.map(_convert_to_vector),

python/pyspark/mllib/tests.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -428,18 +428,6 @@ def test_gmm_deterministic(self):
428428
for c1, c2 in zip(clusters1.weights, clusters2.weights):
429429
self.assertEquals(round(c1, 7), round(c2, 7))
430430

431-
def test_gmm_with_initial_model(self):
432-
from pyspark.mllib.clustering import GaussianMixture
433-
data = self.sc.parallelize([
434-
(-10, -5), (-9, -4), (10, 5), (9, 4)
435-
])
436-
437-
gmm1 = GaussianMixture.train(data, 2, convergenceTol=0.001,
438-
maxIterations=10, seed=63)
439-
gmm2 = GaussianMixture.train(data, 2, convergenceTol=0.001,
440-
maxIterations=10, seed=63, initialModel=gmm1)
441-
self.assertAlmostEqual((gmm1.weights - gmm2.weights).sum(), 0.0)
442-
443431
def test_classification(self):
444432
from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes
445433
from pyspark.mllib.tree import DecisionTree, DecisionTreeModel, RandomForest,\

0 commit comments

Comments
 (0)