@@ -81,6 +81,9 @@ class GMMExpectationMaximization private (
8181 private type DenseDoubleVector = BreezeVector [Double ]
8282 private type DenseDoubleMatrix = BreezeMatrix [Double ]
8383
84+ // number of samples per cluster to use when initializing Gaussians
85+ private val nSamples = 5 ;
86+
8487 // A default instance, 2 Gaussians, 100 iterations, 0.01 log-likelihood threshold
8588 def this () = this (2 , 0.01 , 100 )
8689
@@ -118,15 +121,15 @@ class GMMExpectationMaximization private (
118121 // Get length of the input vectors
119122 val d = breezeData.first.length
120123
121- // For each Gaussian, we will initialize the mean as some random
122- // point from the data. (This could be improved)
123- val samples = breezeData.takeSample(true , k, scala.util.Random .nextInt)
124+ // For each Gaussian, we will initialize the mean as the average
125+ // of some random samples from the data
126+ val samples = breezeData.takeSample(true , k * nSamples , scala.util.Random .nextInt)
124127
125128 // C will be array of (weight, mean, covariance) tuples
126129 // we start with uniform weights, a random mean from the data, and
127130 // identity matrices for covariance
128131 var C = (0 until k).map(i => (1.0 / k,
129- samples(i ),
132+ vec_mean( samples.slice(i * nSamples, (i + 1 ) * nSamples) ),
130133 BreezeMatrix .eye[Double ](d))).toArray
131134
132135 val acc_w = new Array [Accumulator [Double ]](k)
@@ -148,7 +151,7 @@ class GMMExpectationMaximization private (
148151 }
149152
150153 val log_likelihood = ctx.accumulator(0.0 )
151-
154+
152155 // broadcast the current weights and distributions to all nodes
153156 val dists = ctx.broadcast((0 until k).map(i =>
154157 new MultivariateGaussian (C (i)._2, C (i)._3)).toArray)
@@ -164,11 +167,12 @@ class GMMExpectationMaximization private (
164167 log_likelihood += math.log(norm)
165168
166169 // accumulate weighted sums
170+ val xxt = x * new Transpose (x)
167171 for (i <- 0 until k){
168172 p(i) /= norm
169173 acc_w(i) += p(i)
170174 acc_mu(i) += x * p(i)
171- acc_sigma(i) += x * new Transpose (x) * p(i)
175+ acc_sigma(i) += xxt * p(i)
172176 }
173177 })
174178
@@ -205,6 +209,13 @@ class GMMExpectationMaximization private (
205209 s
206210 }
207211
212+ /** Average of dense breeze vectors */
213+ private def vec_mean (x : Array [DenseDoubleVector ]) : DenseDoubleVector = {
214+ val v = BreezeVector .zeros[Double ](x(0 ).length)
215+ (0 until x.length).foreach(j => v += x(j))
216+ v / x.length.asInstanceOf [Double ]
217+ }
218+
208219 /** AccumulatorParam for Dense Breeze Vectors */
209220 private object DenseDoubleVectorAccumulatorParam extends AccumulatorParam [DenseDoubleVector ] {
210221 def zero (initialVector : DenseDoubleVector ) : DenseDoubleVector = {
0 commit comments